179 lines
8.3 KiB
Python
179 lines
8.3 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import numpy as np
|
||
import cv2
|
||
from PIL import Image
|
||
import matplotlib.pyplot as plt
|
||
import datetime
|
||
from Load_process.file_processing import Process_File
|
||
|
||
class GradCAM:
|
||
def __init__(self, model, target_layer):
|
||
self.model = model
|
||
# 若為 DataParallel,取出真正的 backbone
|
||
self.backbone = model.module if isinstance(model, nn.DataParallel) else model
|
||
self.target_layer = self._resolve_target_layer(target_layer)
|
||
self.activations = None
|
||
self.gradients = None
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
self.model.to(self.device) # Ensure model is on the correct device
|
||
|
||
# Register hooks on resolved module
|
||
self.target_layer.register_forward_hook(self.save_activations)
|
||
# Use full backward hook if available to avoid deprecation issues
|
||
if hasattr(self.target_layer, "register_full_backward_hook"):
|
||
self.target_layer.register_full_backward_hook(self.save_gradients)
|
||
else:
|
||
self.target_layer.register_backward_hook(self.save_gradients)
|
||
|
||
def _resolve_target_layer(self, target):
|
||
# 支援 nn.Module / nn.Parameter / 字串路徑
|
||
if isinstance(target, nn.Module):
|
||
return target
|
||
if isinstance(target, torch.nn.Parameter):
|
||
# 先在 backbone 參數中找到該 Parameter 的名稱
|
||
for name, param in self.backbone.named_parameters():
|
||
if param is target:
|
||
# 去掉 .weight / .bias,取得父模組名稱
|
||
module_name = name.rsplit('.', 1)[0]
|
||
# 先嘗試用 named_modules 快速匹配
|
||
for mod_name, mod in self.backbone.named_modules():
|
||
if mod_name == module_name:
|
||
return mod
|
||
# 回退為屬性遍歷
|
||
obj = self.backbone
|
||
for attr in module_name.split('.'):
|
||
obj = getattr(obj, attr)
|
||
return obj
|
||
raise AttributeError("Target parameter not found in model parameters.")
|
||
if isinstance(target, str):
|
||
# 允許使用字串路徑指定層,例如 'conv4.pointwise'
|
||
obj = self.backbone
|
||
for attr in target.split('.'):
|
||
obj = getattr(obj, attr)
|
||
return obj
|
||
raise TypeError("target_layer must be nn.Module, nn.Parameter, or str")
|
||
|
||
def Processing_Main(self, Test_Dataloader, File_Path):
|
||
File = Process_File()
|
||
for batch_idx, (images, labels, File_Name, File_Classes) in enumerate(Test_Dataloader):
|
||
# Move data to device
|
||
images = images.to(self.device, dtype=torch.float32) # [64, C, H, W]
|
||
labels = labels.to(self.device, dtype=torch.float32) # [64, num_classes]
|
||
|
||
# Get ground-truth class indices
|
||
label_classes = torch.argmax(labels, dim=1).cpu().numpy() # [64]
|
||
|
||
# Generate Grad-CAM heatmaps for the entire batch
|
||
heatmaps = self.generate(images, label_classes)
|
||
|
||
# Process each image in the batch
|
||
for i in range(images.size(0)): # Loop over batch size (64)
|
||
heatmap = heatmaps[i] # Extract heatmap for this image
|
||
overlaid_image = self.overlay_heatmap(heatmap, images[i], alpha=0.5)
|
||
|
||
# Create file path based on class
|
||
path = f"{File_Path}/{File_Classes[i]}"
|
||
File.JudgeRoot_MakeDir(path)
|
||
# Save overlaid image
|
||
File.Save_CV2_File(f"batch_{batch_idx}_{File_Name[i]}", path, overlaid_image)
|
||
# # Save raw heatmap separately
|
||
# heatmap_resized = cv2.resize(heatmap, (images[i].shape[2], images[i].shape[1]), interpolation=cv2.INTER_CUBIC)
|
||
# heatmap_colored = (plt.cm.viridis(heatmap_resized)[:, :, :3] * 255).astype(np.uint8)
|
||
# File.Save_CV2_File(f"batch_{batch_idx}_img_{i}_heatmap.png", path, heatmap_colored)
|
||
|
||
def save_activations(self, module, input, output):
|
||
self.activations = output.detach() # [64, C, H', W']
|
||
|
||
def save_gradients(self, module, grad_input, grad_output):
|
||
self.gradients = grad_output[0].detach() # [64, C, H', W']
|
||
|
||
def generate(self, input_images, class_indices=None):
|
||
self.model.eval()
|
||
input_images.requires_grad = True # [B, C, H, W]
|
||
|
||
outputs = self.model(input_images) # [B, num_classes]
|
||
if class_indices is None:
|
||
class_indices = torch.argmax(outputs, dim=1).cpu().numpy()
|
||
|
||
self.model.zero_grad()
|
||
|
||
heatmaps = []
|
||
for i in range(input_images.size(0)):
|
||
self.model.zero_grad()
|
||
# Backward for the specific image and class
|
||
outputs[i, class_indices[i]].backward(retain_graph=True)
|
||
# Compute heatmap for sample i
|
||
heatmap = self._compute_heatmap(sample_index=i)
|
||
heatmaps.append(heatmap)
|
||
|
||
return np.stack(heatmaps) # [B, H', W']
|
||
|
||
def _compute_heatmap(self, sample_index):
|
||
# Get gradients and activations for the specific sample
|
||
gradients = self.gradients[sample_index] # [C, H', W']
|
||
activations = self.activations[sample_index] # [C, H', W']
|
||
|
||
# Compute weights (global average pooling of gradients)
|
||
weights = torch.mean(gradients, dim=(1, 2), keepdim=True) # [C, 1, 1]
|
||
|
||
# Grad-CAM: weighted sum of activations
|
||
grad_cam = torch.sum(weights * activations, dim=0) # [H', W']
|
||
grad_cam = F.relu(grad_cam)
|
||
grad_cam = grad_cam / (grad_cam.max() + 1e-8)
|
||
|
||
# Apply Gaussian smoothing to reduce artifacts
|
||
grad_cam_np = grad_cam.detach().cpu().numpy()
|
||
grad_cam_np = cv2.GaussianBlur(grad_cam_np, (5, 5), 0)
|
||
grad_cam_np = grad_cam_np / (grad_cam_np.max() + 1e-8)
|
||
|
||
return grad_cam_np
|
||
|
||
def overlay_heatmap(self, heatmap, image, alpha=0.5):
|
||
# Resize heatmap to match input image spatial dimensions using INTER_CUBIC for smoother results
|
||
heatmap = np.uint8(255 * heatmap) # Scale to [0, 255]
|
||
heatmap = cv2.resize(heatmap, (image.shape[2], image.shape[1]), interpolation=cv2.INTER_CUBIC)
|
||
# Use viridis colormap for better interpretability
|
||
heatmap = plt.cm.viridis(heatmap)[:, :, :3] # Apply viridis colormap
|
||
|
||
# Convert image tensor to numpy and denormalize (assuming ImageNet stats)
|
||
image_np = image.detach().cpu().permute(1, 2, 0).numpy() # [H, W, C]
|
||
# Ensure image is in [0, 1] range (if not already)
|
||
if image_np.max() > 1.0:
|
||
image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
|
||
|
||
# Overlay heatmap on the image
|
||
overlay = alpha * heatmap + (1 - alpha) * image_np
|
||
overlay = np.clip(overlay, 0, 1) * 255
|
||
return overlay.astype(np.uint8) # Return uint8 for cv2
|
||
|
||
def find_last_conv_layer(model):
|
||
# Traverse modules in reverse order to find the last Conv2d
|
||
last_conv = None
|
||
for m in model.modules():
|
||
if isinstance(m, nn.Conv2d):
|
||
last_conv = m
|
||
if last_conv is None:
|
||
raise RuntimeError("No nn.Conv2d layer found in the model for Grad-CAM.")
|
||
return last_conv
|
||
|
||
def run_grad_cam(model, dataloader, output_root):
|
||
# Convenience wrapper to run Grad-CAM end-to-end with your loaders
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
model.to(device)
|
||
target_layer = find_last_conv_layer(model)
|
||
grad = GradCAM(model, target_layer)
|
||
|
||
file = Process_File()
|
||
for batch_idx, (images, labels, file_names, file_classes) in enumerate(dataloader):
|
||
images = images.to(device, dtype=torch.float32)
|
||
labels = labels.to(device, dtype=torch.float32)
|
||
label_classes = torch.argmax(labels, dim=1).cpu().numpy()
|
||
|
||
heatmaps = grad.generate(images, label_classes)
|
||
for i in range(images.size(0)):
|
||
overlaid = grad.overlay_heatmap(heatmaps[i], images[i], alpha=0.5)
|
||
out_dir = f"{output_root}/{file_classes[i]}"
|
||
file.JudgeRoot_MakeDir(out_dir)
|
||
file.Save_CV2_File(f"batch_{batch_idx}_{file_names[i]}", out_dir, overlaid) |