Stomach_Cancer_Pytorch/draw_tools/Grad_cam.py

152 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import datetime
from Load_process.file_processing import Process_File
class GradCAM:
def __init__(self, model, target_layer):
"""
初始化 Grad-CAM
Args:
model: 訓練好的 ModifiedXception 模型
target_layer: 要計算 Grad-CAM 的目標層名稱 (例如 'base_model')
"""
self.model = model
self.target_layer = target_layer
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.eval()
self.model.to(self.device)
# 用於儲存特徵圖和梯度
self.features = None
self.gradients = None
# 註冊 hook
self._register_hooks()
def _register_hooks(self):
"""註冊前向和反向傳播的 hook"""
def forward_hook(module, input, output):
self.features = output
def backward_hook(module, grad_in, grad_out):
self.gradients = grad_out[0]
# 獲取目標層
target_module = dict(self.model.named_modules())[self.target_layer]
target_module.register_forward_hook(forward_hook)
target_module.register_backward_hook(backward_hook)
def generate_cam(self, input_image, target_class=None):
"""
生成 Grad-CAM 熱力圖
Args:
input_image: 輸入影像 (torch.Tensor, shape: [1, C, H, W])
target_class: 目標類別索引 (若為 None使用預測最高分數的類別)
Returns:
cam: Grad-CAM 熱力圖 (numpy array)
"""
input_image = input_image.to(self.device)
# 前向傳播
output = self.model(input_image)
if target_class is None:
target_class = torch.argmax(output, dim=1).item()
# 清除梯度
self.model.zero_grad()
# 反向傳播計算梯度
one_hot = torch.zeros_like(output)
one_hot[0][target_class] = 1
output.backward(gradient=one_hot, retain_graph=True)
# 計算 Grad-CAM
gradients = self.gradients.data.cpu().numpy()[0]
features = self.features.data.cpu().numpy()[0]
# 全局平均池化梯度
weights = np.mean(gradients, axis=(1, 2))
# 計算加權和
cam = np.zeros(features.shape[1:], dtype=np.float32)
for i, w in enumerate(weights):
cam += w * features[i]
# ReLU 激活
cam = np.maximum(cam, 0)
# 歸一化到 0-1
cam = cam - np.min(cam)
cam = cam / np.max(cam)
# 調整大小到輸入影像尺寸
h, w = input_image.shape[2:]
cam = cv2.resize(cam, (w, h))
return cam
def overlay_cam(self, original_image, cam, alpha=0.5):
"""
將 Grad-CAM 熱力圖疊加到原始影像上
Args:
original_image: 原始影像 (numpy array, shape: [H, W, C])
cam: Grad-CAM 熱力圖
alpha: 透明度
Returns:
overlay_img: 疊加後的影像
"""
# 將熱力圖轉為 RGB
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
# 確保原始影像格式正確
if original_image.max() > 1:
original_image = original_image / 255.0
# 疊加熱力圖
overlay_img = heatmap * alpha + original_image * (1 - alpha)
overlay_img = np.clip(overlay_img, 0, 1)
return overlay_img
def visualize(self, input_image, original_image, target_class=None, File_Name=None, model_name = None):
"""
可視化 Grad-CAM 結果
Args:
input_image: 輸入影像 (torch.Tensor)
original_image: 原始影像 (numpy array)
target_class: 目標類別索引
save_path: 保存路徑 (可選)
"""
File = Process_File()
# 生成 CAM
cam = self.generate_cam(input_image, target_class)
# 疊加到原始影像
overlay = self.overlay_cam(original_image, cam)
# 顯示結果
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(original_image)
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(overlay)
plt.title(f'Grad-CAM (Class {target_class})')
plt.axis('off')
model_dir = '../Result/Grad-CAM( ' + str(datetime.date.today()) + " )"
File.JudgeRoot_MakeDir(model_dir)
modelfiles = File.Make_Save_Root(str(model_name) + " " + File_Name + ".png", model_dir)
plt.savefig(modelfiles)
plt.close("all") # 關閉圖表