68 lines
2.4 KiB
Python
Executable File
68 lines
2.4 KiB
Python
Executable File
from Load_process.file_processing import Process_File
|
|
from torchcam.methods import GradCAM
|
|
from torchvision.transforms.functional import to_pil_image
|
|
from matplotlib import pyplot as plt
|
|
import torch
|
|
import cv2
|
|
import numpy as np
|
|
import datetime
|
|
|
|
class Grad_CAM:
|
|
def __init__(self, Label, One_Hot, Experiment_Name, Layer_Name) -> None:
|
|
self.experiment_name = Experiment_Name
|
|
self.Layer_Name = Layer_Name
|
|
self.Label = Label
|
|
self.One_Hot_Label = One_Hot
|
|
self.Save_File_Name = self.Convert_One_Hot_To_int()
|
|
|
|
pass
|
|
|
|
def process_main(self, model, index, images):
|
|
cam_extractor = GradCAM(model, target_layer=self.Layer_Name)
|
|
|
|
for i, image in enumerate(images):
|
|
input_tensor = image.unsqueeze(0) # 在 PyTorch 中增加批次維度
|
|
heatmap = self.gradcam(input_tensor, model, cam_extractor)
|
|
self.plot_heatmap(heatmap, image, self.Save_File_Name[i], index, i)
|
|
pass
|
|
|
|
def Convert_One_Hot_To_int(self):
|
|
return [np.argmax(Label)for Label in self.One_Hot_Label]
|
|
|
|
|
|
def gradcam(self, Image, model, cam_extractor):
|
|
# 將模型設為評估模式
|
|
model.eval()
|
|
# 前向傳播並生成熱力圖
|
|
with torch.no_grad():
|
|
out = model(Image)
|
|
|
|
# 取得預測類別
|
|
pred_index = out.argmax(dim=1).item()
|
|
|
|
# 生成對應的 Grad-CAM 熱力圖
|
|
heatmap = cam_extractor(pred_index, out)
|
|
return heatmap[0].cpu().numpy()
|
|
|
|
def plot_heatmap(self, heatmap, img, Label, index, Title):
|
|
File = Process_File()
|
|
|
|
# 調整影像大小
|
|
img_path = cv2.resize(img.numpy().transpose(1, 2, 0), (512, 512))
|
|
heatmap = cv2.resize(heatmap, (512, 512))
|
|
heatmap = np.uint8(255 * heatmap)
|
|
|
|
img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
|
|
|
|
# 顯示影像和熱力圖
|
|
fig, ax = plt.subplots()
|
|
ax.imshow(img_path, alpha=1)
|
|
ax.imshow(heatmap, cmap='jet', alpha=0.3)
|
|
|
|
save_root = '../Result/CNN_result_of_reading('+ str(datetime.date.today()) + ")/" + str(Label)
|
|
File.JudgeRoot_MakeDir(save_root)
|
|
save_root = File.Make_Save_Root(self.experiment_name + "-" + str(index) + "-" + str(Title) + ".png", save_root)
|
|
|
|
plt.savefig(save_root)
|
|
plt.close("all")
|
|
pass |