Files
Stomach_Cancer_Pytorch/draw_tools/Grad_cam.py
2025-11-07 21:03:13 +08:00

179 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import datetime
from Load_process.file_processing import Process_File
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
# 若為 DataParallel取出真正的 backbone
self.backbone = model.module if isinstance(model, nn.DataParallel) else model
self.target_layer = self._resolve_target_layer(target_layer)
self.activations = None
self.gradients = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device) # Ensure model is on the correct device
# Register hooks on resolved module
self.target_layer.register_forward_hook(self.save_activations)
# Use full backward hook if available to avoid deprecation issues
if hasattr(self.target_layer, "register_full_backward_hook"):
self.target_layer.register_full_backward_hook(self.save_gradients)
else:
self.target_layer.register_backward_hook(self.save_gradients)
def _resolve_target_layer(self, target):
# 支援 nn.Module / nn.Parameter / 字串路徑
if isinstance(target, nn.Module):
return target
if isinstance(target, torch.nn.Parameter):
# 先在 backbone 參數中找到該 Parameter 的名稱
for name, param in self.backbone.named_parameters():
if param is target:
# 去掉 .weight / .bias取得父模組名稱
module_name = name.rsplit('.', 1)[0]
# 先嘗試用 named_modules 快速匹配
for mod_name, mod in self.backbone.named_modules():
if mod_name == module_name:
return mod
# 回退為屬性遍歷
obj = self.backbone
for attr in module_name.split('.'):
obj = getattr(obj, attr)
return obj
raise AttributeError("Target parameter not found in model parameters.")
if isinstance(target, str):
# 允許使用字串路徑指定層,例如 'conv4.pointwise'
obj = self.backbone
for attr in target.split('.'):
obj = getattr(obj, attr)
return obj
raise TypeError("target_layer must be nn.Module, nn.Parameter, or str")
def Processing_Main(self, Test_Dataloader, File_Path):
File = Process_File()
for batch_idx, (images, labels, File_Name, File_Classes) in enumerate(Test_Dataloader):
# Move data to device
images = images.to(self.device, dtype=torch.float32) # [64, C, H, W]
labels = labels.to(self.device, dtype=torch.float32) # [64, num_classes]
# Get ground-truth class indices
label_classes = torch.argmax(labels, dim=1).cpu().numpy() # [64]
# Generate Grad-CAM heatmaps for the entire batch
heatmaps = self.generate(images, label_classes)
# Process each image in the batch
for i in range(images.size(0)): # Loop over batch size (64)
heatmap = heatmaps[i] # Extract heatmap for this image
overlaid_image = self.overlay_heatmap(heatmap, images[i], alpha=0.5)
# Create file path based on class
path = f"{File_Path}/{File_Classes[i]}"
File.JudgeRoot_MakeDir(path)
# Save overlaid image
File.Save_CV2_File(f"batch_{batch_idx}_{File_Name[i]}", path, overlaid_image)
# # Save raw heatmap separately
# heatmap_resized = cv2.resize(heatmap, (images[i].shape[2], images[i].shape[1]), interpolation=cv2.INTER_CUBIC)
# heatmap_colored = (plt.cm.viridis(heatmap_resized)[:, :, :3] * 255).astype(np.uint8)
# File.Save_CV2_File(f"batch_{batch_idx}_img_{i}_heatmap.png", path, heatmap_colored)
def save_activations(self, module, input, output):
self.activations = output.detach() # [64, C, H', W']
def save_gradients(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach() # [64, C, H', W']
def generate(self, input_images, class_indices=None):
self.model.eval()
input_images.requires_grad = True # [B, C, H, W]
outputs = self.model(input_images) # [B, num_classes]
if class_indices is None:
class_indices = torch.argmax(outputs, dim=1).cpu().numpy()
self.model.zero_grad()
heatmaps = []
for i in range(input_images.size(0)):
self.model.zero_grad()
# Backward for the specific image and class
outputs[i, class_indices[i]].backward(retain_graph=True)
# Compute heatmap for sample i
heatmap = self._compute_heatmap(sample_index=i)
heatmaps.append(heatmap)
return np.stack(heatmaps) # [B, H', W']
def _compute_heatmap(self, sample_index):
# Get gradients and activations for the specific sample
gradients = self.gradients[sample_index] # [C, H', W']
activations = self.activations[sample_index] # [C, H', W']
# Compute weights (global average pooling of gradients)
weights = torch.mean(gradients, dim=(1, 2), keepdim=True) # [C, 1, 1]
# Grad-CAM: weighted sum of activations
grad_cam = torch.sum(weights * activations, dim=0) # [H', W']
grad_cam = F.relu(grad_cam)
grad_cam = grad_cam / (grad_cam.max() + 1e-8)
# Apply Gaussian smoothing to reduce artifacts
grad_cam_np = grad_cam.detach().cpu().numpy()
grad_cam_np = cv2.GaussianBlur(grad_cam_np, (5, 5), 0)
grad_cam_np = grad_cam_np / (grad_cam_np.max() + 1e-8)
return grad_cam_np
def overlay_heatmap(self, heatmap, image, alpha=0.5):
# Resize heatmap to match input image spatial dimensions using INTER_CUBIC for smoother results
heatmap = np.uint8(255 * heatmap) # Scale to [0, 255]
heatmap = cv2.resize(heatmap, (image.shape[2], image.shape[1]), interpolation=cv2.INTER_CUBIC)
# Use viridis colormap for better interpretability
heatmap = plt.cm.viridis(heatmap)[:, :, :3] # Apply viridis colormap
# Convert image tensor to numpy and denormalize (assuming ImageNet stats)
image_np = image.detach().cpu().permute(1, 2, 0).numpy() # [H, W, C]
# Ensure image is in [0, 1] range (if not already)
if image_np.max() > 1.0:
image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
# Overlay heatmap on the image
overlay = alpha * heatmap + (1 - alpha) * image_np
overlay = np.clip(overlay, 0, 1) * 255
return overlay.astype(np.uint8) # Return uint8 for cv2
def find_last_conv_layer(model):
# Traverse modules in reverse order to find the last Conv2d
last_conv = None
for m in model.modules():
if isinstance(m, nn.Conv2d):
last_conv = m
if last_conv is None:
raise RuntimeError("No nn.Conv2d layer found in the model for Grad-CAM.")
return last_conv
def run_grad_cam(model, dataloader, output_root):
# Convenience wrapper to run Grad-CAM end-to-end with your loaders
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
target_layer = find_last_conv_layer(model)
grad = GradCAM(model, target_layer)
file = Process_File()
for batch_idx, (images, labels, file_names, file_classes) in enumerate(dataloader):
images = images.to(device, dtype=torch.float32)
labels = labels.to(device, dtype=torch.float32)
label_classes = torch.argmax(labels, dim=1).cpu().numpy()
heatmaps = grad.generate(images, label_classes)
for i in range(images.size(0)):
overlaid = grad.overlay_heatmap(heatmaps[i], images[i], alpha=0.5)
out_dir = f"{output_root}/{file_classes[i]}"
file.JudgeRoot_MakeDir(out_dir)
file.Save_CV2_File(f"batch_{batch_idx}_{file_names[i]}", out_dir, overlaid)