Stomach_Cancer_Pytorch/draw_tools/Grad_cam.py

67 lines
2.4 KiB
Python

from Load_process.file_processing import Process_File
from torchcam.methods import GradCAM
from torchvision.transforms.functional import to_pil_image
from matplotlib import pyplot as plt
import torch
import cv2
import numpy as np
import datetime
class Grad_CAM:
def __init__(self, Experiment_Name, Layer, Image_Size) -> None:
self.experiment_name = Experiment_Name
self.Layer = Layer
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.Image_Size = Image_Size
pass
def process_main(self, model, index, images):
cam_extractor = GradCAM(model, target_layer=self.Layer)
i = 0
for image, label in images:
heatmap = self.gradcam(image, model, cam_extractor)
self.plot_heatmap(heatmap, image, label, index, i)
i += 1
pass
def gradcam(self, Image, model, cam_extractor):
Image = torch.tensor(Image).to(self.device)
# 將模型設為評估模式
model.eval()
# 前向傳播並生成熱力圖
with torch.no_grad():
out = model(Image)
# 收集訓練預測和標籤
Output_Values, Output_Indexs = torch.max(out, 1)
# 生成對應的 Grad-CAM 熱力圖
heatmap = cam_extractor(class_idx=Output_Indexs, scores=out)
return heatmap[0].cpu().numpy()
def plot_heatmap(self, heatmap, img, Label, index, Title):
File = Process_File()
Label = np.argmax(Label.cpu().numpy(), 1)
# 調整影像大小
img_path = cv2.resize(img.numpy().transpose(1, 2, 0), (self.Image_Size, self.Image_Size))
heatmap = cv2.resize(heatmap, (self.Image_Size, self.Image_Size))
heatmap = np.uint8(255 * heatmap)
img_path = cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)
# 顯示影像和熱力圖
fig, ax = plt.subplots()
ax.imshow(img_path, alpha=1)
ax.imshow(heatmap, cmap='jet', alpha=0.3)
save_root = '../Result/CNN_result_of_reading('+ str(datetime.date.today()) + ")/" + str(Label)
File.JudgeRoot_MakeDir(save_root)
save_root = File.Make_Save_Root(self.experiment_name + "-" + str(index) + "-" + str(Title) + ".png", save_root)
plt.savefig(save_root)
plt.close("all")
pass