152 lines
5.1 KiB
Python
152 lines
5.1 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import numpy as np
|
||
import cv2
|
||
from PIL import Image
|
||
import matplotlib.pyplot as plt
|
||
import datetime
|
||
from Load_process.file_processing import Process_File
|
||
|
||
class GradCAM:
|
||
def __init__(self, model, target_layer):
|
||
"""
|
||
初始化 Grad-CAM
|
||
Args:
|
||
model: 訓練好的 ModifiedXception 模型
|
||
target_layer: 要計算 Grad-CAM 的目標層名稱 (例如 'base_model')
|
||
"""
|
||
self.model = model
|
||
self.target_layer = target_layer
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
self.model.eval()
|
||
self.model.to(self.device)
|
||
|
||
# 用於儲存特徵圖和梯度
|
||
self.features = None
|
||
self.gradients = None
|
||
|
||
# 註冊 hook
|
||
self._register_hooks()
|
||
|
||
def _register_hooks(self):
|
||
"""註冊前向和反向傳播的 hook"""
|
||
def forward_hook(module, input, output):
|
||
self.features = output
|
||
|
||
def backward_hook(module, grad_in, grad_out):
|
||
self.gradients = grad_out[0]
|
||
|
||
# 獲取目標層
|
||
target_module = dict(self.model.named_modules())[self.target_layer]
|
||
target_module.register_forward_hook(forward_hook)
|
||
target_module.register_backward_hook(backward_hook)
|
||
|
||
def generate_cam(self, input_image, target_class=None):
|
||
"""
|
||
生成 Grad-CAM 熱力圖
|
||
Args:
|
||
input_image: 輸入影像 (torch.Tensor, shape: [1, C, H, W])
|
||
target_class: 目標類別索引 (若為 None,使用預測最高分數的類別)
|
||
Returns:
|
||
cam: Grad-CAM 熱力圖 (numpy array)
|
||
"""
|
||
input_image = input_image.to(self.device)
|
||
|
||
# 前向傳播
|
||
output = self.model(input_image)
|
||
|
||
if target_class is None:
|
||
target_class = torch.argmax(output, dim=1).item()
|
||
|
||
# 清除梯度
|
||
self.model.zero_grad()
|
||
|
||
# 反向傳播計算梯度
|
||
one_hot = torch.zeros_like(output)
|
||
one_hot[0][target_class] = 1
|
||
output.backward(gradient=one_hot, retain_graph=True)
|
||
|
||
# 計算 Grad-CAM
|
||
gradients = self.gradients.data.cpu().numpy()[0]
|
||
features = self.features.data.cpu().numpy()[0]
|
||
|
||
# 全局平均池化梯度
|
||
weights = np.mean(gradients, axis=(1, 2))
|
||
|
||
# 計算加權和
|
||
cam = np.zeros(features.shape[1:], dtype=np.float32)
|
||
for i, w in enumerate(weights):
|
||
cam += w * features[i]
|
||
|
||
# ReLU 激活
|
||
cam = np.maximum(cam, 0)
|
||
|
||
# 歸一化到 0-1
|
||
cam = cam - np.min(cam)
|
||
cam = cam / np.max(cam)
|
||
|
||
# 調整大小到輸入影像尺寸
|
||
h, w = input_image.shape[2:]
|
||
cam = cv2.resize(cam, (w, h))
|
||
|
||
return cam
|
||
|
||
def overlay_cam(self, original_image, cam, alpha=0.5):
|
||
"""
|
||
將 Grad-CAM 熱力圖疊加到原始影像上
|
||
Args:
|
||
original_image: 原始影像 (numpy array, shape: [H, W, C])
|
||
cam: Grad-CAM 熱力圖
|
||
alpha: 透明度
|
||
Returns:
|
||
overlay_img: 疊加後的影像
|
||
"""
|
||
# 將熱力圖轉為 RGB
|
||
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
|
||
heatmap = np.float32(heatmap) / 255
|
||
|
||
# 確保原始影像格式正確
|
||
if original_image.max() > 1:
|
||
original_image = original_image / 255.0
|
||
|
||
# 疊加熱力圖
|
||
overlay_img = heatmap * alpha + original_image * (1 - alpha)
|
||
overlay_img = np.clip(overlay_img, 0, 1)
|
||
|
||
return overlay_img
|
||
|
||
def visualize(self, input_image, original_image, target_class=None, File_Name=None, model_name = None):
|
||
"""
|
||
可視化 Grad-CAM 結果
|
||
Args:
|
||
input_image: 輸入影像 (torch.Tensor)
|
||
original_image: 原始影像 (numpy array)
|
||
target_class: 目標類別索引
|
||
save_path: 保存路徑 (可選)
|
||
"""
|
||
File = Process_File()
|
||
# 生成 CAM
|
||
cam = self.generate_cam(input_image, target_class)
|
||
|
||
# 疊加到原始影像
|
||
overlay = self.overlay_cam(original_image, cam)
|
||
|
||
# 顯示結果
|
||
plt.figure(figsize=(10, 5))
|
||
plt.subplot(1, 2, 1)
|
||
plt.imshow(original_image)
|
||
plt.title('Original Image')
|
||
plt.axis('off')
|
||
|
||
plt.subplot(1, 2, 2)
|
||
plt.imshow(overlay)
|
||
plt.title(f'Grad-CAM (Class {target_class})')
|
||
plt.axis('off')
|
||
|
||
model_dir = '../Result/Grad-CAM( ' + str(datetime.date.today()) + " )"
|
||
File.JudgeRoot_MakeDir(model_dir)
|
||
modelfiles = File.Make_Save_Root(str(model_name) + " " + File_Name + ".png", model_dir)
|
||
plt.savefig(modelfiles)
|
||
plt.close("all") # 關閉圖表
|