Stomach_Cancer_Pytorch/draw_tools/Grad_cam.py
2024-11-19 12:16:22 +00:00

68 lines
2.4 KiB
Python
Executable File

from Load_process.file_processing import Process_File
from torchcam.methods import GradCAM
from torchvision.transforms.functional import to_pil_image
from matplotlib import pyplot as plt
import torch
import cv2
import numpy as np
import datetime
class Grad_CAM:
def __init__(self, Label, One_Hot, Experiment_Name, Layer_Name) -> None:
self.experiment_name = Experiment_Name
self.Layer_Name = Layer_Name
self.Label = Label
self.One_Hot_Label = One_Hot
self.Save_File_Name = self.Convert_One_Hot_To_int()
pass
def process_main(self, model, index, images):
cam_extractor = GradCAM(model, target_layer=self.Layer_Name)
for i, image in enumerate(images):
input_tensor = image.unsqueeze(0) # 在 PyTorch 中增加批次維度
heatmap = self.gradcam(input_tensor, model, cam_extractor)
self.plot_heatmap(heatmap, image, self.Save_File_Name[i], index, i)
pass
def Convert_One_Hot_To_int(self):
return [np.argmax(Label)for Label in self.One_Hot_Label]
def gradcam(self, Image, model, cam_extractor):
# 將模型設為評估模式
model.eval()
# 前向傳播並生成熱力圖
with torch.no_grad():
out = model(Image)
# 取得預測類別
pred_index = out.argmax(dim=1).item()
# 生成對應的 Grad-CAM 熱力圖
heatmap = cam_extractor(pred_index, out)
return heatmap[0].cpu().numpy()
def plot_heatmap(self, heatmap, img, Label, index, Title):
File = Process_File()
# 調整影像大小
img_path = cv2.resize(img.numpy().transpose(1, 2, 0), (512, 512))
heatmap = cv2.resize(heatmap, (512, 512))
heatmap = np.uint8(255 * heatmap)
img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
# 顯示影像和熱力圖
fig, ax = plt.subplots()
ax.imshow(img_path, alpha=1)
ax.imshow(heatmap, cmap='jet', alpha=0.3)
save_root = '../Result/CNN_result_of_reading('+ str(datetime.date.today()) + ")/" + str(Label)
File.JudgeRoot_MakeDir(save_root)
save_root = File.Make_Save_Root(self.experiment_name + "-" + str(index) + "-" + str(Title) + ".png", save_root)
plt.savefig(save_root)
plt.close("all")
pass