104 lines
4.4 KiB
Python
104 lines
4.4 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import cv2
|
|
from PIL import Image
|
|
import matplotlib.pyplot as plt
|
|
import datetime
|
|
from Load_process.file_processing import Process_File
|
|
|
|
class GradCAM:
|
|
def __init__(self, model, target_layer):
|
|
self.model = model
|
|
self.target_layer = target_layer
|
|
self.activations = None
|
|
self.gradients = None
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.model.to(self.device) # Ensure model is on the correct device
|
|
|
|
# Register hooks
|
|
self.target_layer.register_forward_hook(self.save_activations)
|
|
self.target_layer.register_backward_hook(self.save_gradients)
|
|
|
|
def Processing_Main(self, Test_Dataloader, File_Path):
|
|
File = Process_File()
|
|
for batch_idx, (images, labels) in enumerate(Test_Dataloader):
|
|
# Move data to device
|
|
images = images.to(self.device, dtype=torch.float32) # [64, C, H, W]
|
|
labels = labels.to(self.device, dtype=torch.float32) # [64, num_classes]
|
|
|
|
# Get ground-truth class indices
|
|
label_classes = torch.argmax(labels, dim=1).cpu().numpy() # [64]
|
|
|
|
# Generate Grad-CAM heatmaps for the entire batch
|
|
heatmaps = self.generate(images, label_classes)
|
|
|
|
# Process each image in the batch
|
|
for i in range(images.size(0)): # Loop over batch size (64)
|
|
class_idx = label_classes[i]
|
|
heatmap = heatmaps[i] # Extract heatmap for this image
|
|
overlaid_image = self.overlay_heatmap(heatmap, images[i])
|
|
|
|
# Create file path based on class
|
|
path = f"{File_Path}/class_{class_idx}"
|
|
File.JudgeRoot_MakeDir(path)
|
|
File.Save_CV2_File(f"batch_{batch_idx}_img_{i}.png", path, overlaid_image)
|
|
|
|
def save_activations(self, module, input, output):
|
|
self.activations = output.detach() # [64, C, H', W']
|
|
|
|
def save_gradients(self, module, grad_input, grad_output):
|
|
self.gradients = grad_output[0].detach() # [64, C, H', W']
|
|
|
|
def generate(self, input_images, class_indices=None):
|
|
self.model.eval()
|
|
input_images.requires_grad = True # [64, C, H, W]
|
|
|
|
# Forward pass
|
|
outputs = self.model(input_images) # [64, num_classes]
|
|
|
|
if class_indices is None:
|
|
class_indices = torch.argmax(outputs, dim=1).cpu().numpy() # [64]
|
|
|
|
# Zero gradients
|
|
self.model.zero_grad()
|
|
|
|
# Backward pass for each image in the batch
|
|
heatmaps = []
|
|
for i in range(input_images.size(0)):
|
|
self.model.zero_grad()
|
|
outputs[i, class_indices[i]].backward(retain_graph=True) # Backward for specific image/class
|
|
heatmap = self._compute_heatmap()
|
|
heatmaps.append(heatmap)
|
|
|
|
return np.stack(heatmaps) # [64, H', W']
|
|
|
|
def _compute_heatmap(self):
|
|
# Get gradients and activations
|
|
gradients = self.gradients # [64, C, H', W']
|
|
activations = self.activations # [64, C, H', W']
|
|
|
|
# Compute weights (global average pooling of gradients)
|
|
weights = torch.mean(gradients, dim=[2, 3], keepdim=True) # [64, C, 1, 1]
|
|
|
|
# Compute Grad-CAM heatmap for one image (after single backward)
|
|
grad_cam = torch.sum(weights * activations, dim=1)[0] # [64, H', W'] -> [H', W']
|
|
grad_cam = F.relu(grad_cam) # Apply ReLU
|
|
grad_cam = grad_cam / (grad_cam.max() + 1e-8) # Normalize to [0, 1]
|
|
return grad_cam.cpu().numpy()
|
|
|
|
def overlay_heatmap(self, heatmap, image, alpha=0.5):
|
|
# Resize heatmap to match input image spatial dimensions
|
|
heatmap = np.uint8(255 * heatmap) # Scale to [0, 255]
|
|
heatmap = Image.fromarray(heatmap).resize((image.shape[1], image.shape[2]), Image.BILINEAR)
|
|
heatmap = np.array(heatmap)
|
|
heatmap = plt.cm.jet(heatmap)[:, :, :3] # Apply colormap (jet)
|
|
|
|
# Convert image tensor to numpy and denormalize (assuming ImageNet stats)
|
|
image_np = image.detach().cpu().permute(1, 2, 0).numpy() # [H, W, C]
|
|
|
|
# Overlay
|
|
overlay = alpha * heatmap + (1 - alpha) * image_np / 255.0
|
|
overlay = np.clip(overlay, 0, 1) * 255
|
|
return overlay.astype(np.uint8) # Return uint8 for cv2 |