import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import cv2 from PIL import Image import matplotlib.pyplot as plt import datetime from Load_process.file_processing import Process_File class GradCAM: def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.activations = None self.gradients = None self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model.to(self.device) # Ensure model is on the correct device # Register hooks self.target_layer.register_forward_hook(self.save_activations) self.target_layer.register_backward_hook(self.save_gradients) def Processing_Main(self, Test_Dataloader, File_Path): File = Process_File() for batch_idx, (images, labels, File_Name, File_Classes) in enumerate(Test_Dataloader): # Move data to device images = images.to(self.device, dtype=torch.float32) # [64, C, H, W] labels = labels.to(self.device, dtype=torch.float32) # [64, num_classes] # Get ground-truth class indices label_classes = torch.argmax(labels, dim=1).cpu().numpy() # [64] # Generate Grad-CAM heatmaps for the entire batch heatmaps = self.generate(images, label_classes) # Process each image in the batch for i in range(images.size(0)): # Loop over batch size (64) heatmap = heatmaps[i] # Extract heatmap for this image overlaid_image = self.overlay_heatmap(heatmap, images[i], alpha=0.5) # Create file path based on class path = f"{File_Path}/{File_Classes[i]}" File.JudgeRoot_MakeDir(path) # Save overlaid image File.Save_CV2_File(f"batch_{batch_idx}_{File_Name[i]}", path, overlaid_image) # # Save raw heatmap separately # heatmap_resized = cv2.resize(heatmap, (images[i].shape[2], images[i].shape[1]), interpolation=cv2.INTER_CUBIC) # heatmap_colored = (plt.cm.viridis(heatmap_resized)[:, :, :3] * 255).astype(np.uint8) # File.Save_CV2_File(f"batch_{batch_idx}_img_{i}_heatmap.png", path, heatmap_colored) def save_activations(self, module, input, output): self.activations = output.detach() # [64, C, H', W'] def save_gradients(self, module, grad_input, grad_output): self.gradients = grad_output[0].detach() # [64, C, H', W'] def generate(self, input_images, class_indices=None): self.model.eval() input_images.requires_grad = True # [64, C, H, W] # Forward pass outputs = self.model(input_images) # [64, num_classes] if class_indices is None: class_indices = torch.argmax(outputs, dim=1).cpu().numpy() # [64] # Zero gradients self.model.zero_grad() # Backward pass for each image in the batch heatmaps = [] for i in range(input_images.size(0)): self.model.zero_grad() outputs[i, class_indices[i]].backward(retain_graph=True) # Backward for specific image/class heatmap = self._compute_heatmap() heatmaps.append(heatmap) return np.stack(heatmaps) # [64, H', W'] def _compute_heatmap(self): # Get gradients and activations gradients = self.gradients # [64, C, H', W'] activations = self.activations # [64, C, H', W'] # Compute weights (global average pooling of gradients) weights = torch.mean(gradients, dim=[2, 3], keepdim=True) # [64, C, 1, 1] # Compute Grad-CAM heatmap for one image (after single backward) grad_cam = torch.sum(weights * activations, dim=1)[0] # [64, H', W'] -> [H', W'] grad_cam = F.relu(grad_cam) # Apply ReLU grad_cam = grad_cam / (grad_cam.max() + 1e-8) # Normalize to [0, 1] # Apply Gaussian smoothing to reduce artifacts grad_cam_np = grad_cam.cpu().numpy() grad_cam_np = cv2.GaussianBlur(grad_cam_np, (5, 5), 0) # Re-normalize after blur grad_cam_np = grad_cam_np / (grad_cam_np.max() + 1e-8) return grad_cam_np def overlay_heatmap(self, heatmap, image, alpha=0.5): # Resize heatmap to match input image spatial dimensions using INTER_CUBIC for smoother results heatmap = np.uint8(255 * heatmap) # Scale to [0, 255] heatmap = cv2.resize(heatmap, (image.shape[2], image.shape[1]), interpolation=cv2.INTER_CUBIC) # Use viridis colormap for better interpretability heatmap = plt.cm.viridis(heatmap)[:, :, :3] # Apply viridis colormap # Convert image tensor to numpy and denormalize (assuming ImageNet stats) image_np = image.detach().cpu().permute(1, 2, 0).numpy() # [H, W, C] # Ensure image is in [0, 1] range (if not already) if image_np.max() > 1.0: image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min()) # Overlay heatmap on the image overlay = alpha * heatmap + (1 - alpha) * image_np overlay = np.clip(overlay, 0, 1) * 255 return overlay.astype(np.uint8) # Return uint8 for cv2