299 lines
13 KiB
Python
299 lines
13 KiB
Python
import torch
|
|
import numpy as np
|
|
import cv2
|
|
import os
|
|
import torchvision
|
|
from Load_process.file_processing import Process_File
|
|
from pytorch_grad_cam import GradCAM
|
|
from pytorch_grad_cam.utils.image import show_cam_on_image
|
|
|
|
# Global storage for attention weights (for Attention Rollout)
|
|
_attention_weights_store = []
|
|
# Global flag to control when to record attention weights
|
|
_record_attention = False
|
|
|
|
def _hook_encoder_block_forward(self, input: torch.Tensor):
|
|
"""
|
|
Monkey patch for torchvision.models.vision_transformer.EncoderBlock.forward
|
|
to capture attention weights.
|
|
"""
|
|
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
|
|
x = self.ln_1(input)
|
|
# Force need_weights=True to get the attention matrix
|
|
x, weights = self.self_attention(x, x, x, need_weights=True)
|
|
|
|
# Store weights: (Batch, Heads, Seq_Len, Seq_Len)
|
|
# Only store if recording is enabled to prevent memory leaks during training
|
|
if _record_attention and weights is not None:
|
|
# Detach to avoid memory leaks, move to CPU
|
|
_attention_weights_store.append(weights.detach().cpu())
|
|
|
|
x = self.dropout(x)
|
|
x = x + input
|
|
|
|
y = self.ln_2(x)
|
|
y = self.mlp(y)
|
|
return x + y
|
|
|
|
class AttentionRollout:
|
|
def __init__(self, model, attention_layer_name='encoder.ln', use_cuda=True):
|
|
self.model = model
|
|
self.use_cuda = use_cuda
|
|
try:
|
|
self.device = next(model.parameters()).device
|
|
except StopIteration:
|
|
self.device = torch.device("cpu")
|
|
|
|
# Apply Monkey Patch to capture weights
|
|
self._apply_monkey_patch()
|
|
|
|
# Handle DataParallel/DistributedDataParallel
|
|
if isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
|
|
base_model_wrapper = model.module
|
|
else:
|
|
base_model_wrapper = model
|
|
|
|
# Determine target layer for ViT
|
|
target_layer = None
|
|
|
|
# Helper function to find the target layer
|
|
def get_target_layer(model_obj):
|
|
# Handle Fusion Models (e.g. EfficientNetViTFusion) that have 'vit' attribute
|
|
if hasattr(model_obj, 'vit'):
|
|
model_obj = model_obj.vit
|
|
|
|
# Option 1: Target the final LayerNorm (User request: "Norm layer")
|
|
# We prioritize this as requested to use Norm layer for heatmap calculation
|
|
if hasattr(model_obj, 'encoder') and hasattr(model_obj.encoder, 'ln'):
|
|
return model_obj.encoder.ln
|
|
|
|
# Option 2: Target the last Self-Attention layer
|
|
# torchvision ViT structure: encoder.layers[-1].self_attention
|
|
# Note: We target the 'dropout' layer immediately after self_attention because
|
|
# MultiheadAttention returns a tuple (output, weights) which causes issues with GradCAM gradients.
|
|
# The dropout layer (even with p=0) passes the attention output and returns a single tensor.
|
|
if hasattr(model_obj, 'encoder') and hasattr(model_obj.encoder, 'layers'):
|
|
last_layer = model_obj.encoder.layers[-1]
|
|
if hasattr(last_layer, 'dropout'):
|
|
return last_layer.dropout
|
|
elif hasattr(last_layer, 'self_attention'):
|
|
return last_layer.self_attention
|
|
|
|
return None
|
|
|
|
# Check in base_model_wrapper
|
|
target_layer = get_target_layer(base_model_wrapper)
|
|
|
|
# Check in base_model_wrapper.base_model (common wrapper pattern)
|
|
if target_layer is None and hasattr(base_model_wrapper, 'base_model'):
|
|
target_layer = get_target_layer(base_model_wrapper.base_model)
|
|
|
|
# Fallback to the last layer if structure is unknown
|
|
|
|
if target_layer is None:
|
|
# Fallback to the last layer if structure is unknown
|
|
print("Warning: Could not automatically find 'encoder.ln'. Using the last module.")
|
|
# Use base_model_wrapper.modules() to avoid DataParallel wrapper issues
|
|
# We filter for modules that have parameters to avoid empty containers
|
|
modules = [m for m in base_model_wrapper.modules() if len(list(m.parameters())) > 0]
|
|
if modules:
|
|
target_layer = modules[-1]
|
|
else:
|
|
target_layer = list(base_model_wrapper.modules())[-1]
|
|
|
|
self.target_layers = [target_layer]
|
|
|
|
# Initialize GradCAM
|
|
# reshape_transform is essential for ViT to convert 1D sequence to 2D image
|
|
self.cam = GradCAM(model=self.model, target_layers=self.target_layers, reshape_transform=self.reshape_transform)
|
|
|
|
def _apply_monkey_patch(self):
|
|
# Patch EncoderBlock to capture weights
|
|
torchvision.models.vision_transformer.EncoderBlock.forward = _hook_encoder_block_forward
|
|
|
|
def compute_rollout(self, attention_weights_list):
|
|
# attention_weights_list: List of tensors [(B, H, N, N), ...]
|
|
# We process batch items individually or vectorized.
|
|
# Assuming all items in batch have same attention flow structure.
|
|
|
|
if not attention_weights_list:
|
|
return None
|
|
|
|
# Stack layers: (L, B, H, N, N)
|
|
# We only need the last rollout, but we need to compute recursively.
|
|
|
|
# Take the first batch item for visualization (since we process batch in loop usually)
|
|
# Or handle full batch. Let's handle full batch.
|
|
|
|
# (L, B, H, N, N)
|
|
# Check dimensions first
|
|
if len(attention_weights_list[0].shape) == 3:
|
|
# Missing Batch dimension? Or missing Head dimension?
|
|
# Usually (B, N, N) if weights averaged?
|
|
# Or (H, N, N) if batch=1 squeezed?
|
|
# MultiheadAttention output weights are (B, N, N) if average_attn_weights=True (default in some versions)
|
|
# OR (B, H, N, N) if average_attn_weights=False
|
|
pass
|
|
|
|
all_layers = torch.stack(attention_weights_list, dim=0)
|
|
|
|
if len(all_layers.shape) == 5:
|
|
L, B, H, N, _ = all_layers.shape
|
|
# Mean across heads: (L, B, N, N)
|
|
all_layers = all_layers.mean(dim=2)
|
|
elif len(all_layers.shape) == 4:
|
|
# Assumed (L, B, N, N) -> Already averaged or single head
|
|
L, B, N, _ = all_layers.shape
|
|
else:
|
|
print(f"Unexpected attention weight shape: {all_layers.shape}")
|
|
return None
|
|
|
|
# Initialize identity matrix for Rollout accumulation: (B, N, N)
|
|
# identity = torch.eye(N).to(all_layers.device).unsqueeze(0).repeat(B, 1, 1)
|
|
# result = identity
|
|
|
|
# Better: Start with Identity
|
|
result = torch.eye(N).to(all_layers.device).unsqueeze(0).repeat(B, 1, 1)
|
|
|
|
with torch.no_grad():
|
|
for i in range(L):
|
|
# Attention matrix at layer i
|
|
attn = all_layers[i] # (B, N, N)
|
|
|
|
# Add residual connection (identity) and re-normalize
|
|
# 0.5 * A + 0.5 * I
|
|
I = torch.eye(N).to(attn.device).unsqueeze(0)
|
|
a = (attn + I) / 2
|
|
a = a / a.sum(dim=-1, keepdim=True)
|
|
|
|
# Recursive multiplication
|
|
result = torch.matmul(a, result)
|
|
|
|
return result
|
|
|
|
def reshape_transform(self, tensor):
|
|
# Handle tuple output (e.g. from MultiheadAttention: (output, weights))
|
|
if isinstance(tensor, tuple):
|
|
tensor = tensor[0]
|
|
|
|
# ViT output is (B, N, C) where N = 1 (CLS) + H*W
|
|
# tensor is the output of the target layer.
|
|
|
|
# Handle 2D tensor case (fallback)
|
|
if len(tensor.shape) == 2:
|
|
return tensor.unsqueeze(-1).unsqueeze(-1)
|
|
|
|
# Discard CLS token (index 0)
|
|
# Assumes CLS token is at index 0.
|
|
# For some models it might be different, but standard ViT is 0.
|
|
result = tensor[:, 1:, :]
|
|
|
|
# (B, 196, 768) -> (B, 768, 196)
|
|
result = result.transpose(1, 2)
|
|
|
|
# Reshape to (B, C, H, W)
|
|
n_patches = result.size(2)
|
|
h = int(np.sqrt(n_patches))
|
|
w = h
|
|
result = result.reshape(tensor.size(0), result.size(1), h, w)
|
|
return result
|
|
|
|
def Processing_Main(self, Test_Dataloader, File_Path):
|
|
import gc
|
|
File = Process_File()
|
|
self.model.eval()
|
|
|
|
global _attention_weights_store
|
|
global _record_attention
|
|
|
|
# Enable recording of attention weights
|
|
_record_attention = True
|
|
|
|
try:
|
|
for batch_idx, (images, labels, File_Name, File_Classes) in enumerate(Test_Dataloader):
|
|
images = images.to(self.device)
|
|
|
|
# Clear stored weights before forward pass
|
|
_attention_weights_store.clear()
|
|
|
|
try:
|
|
# 1. Run Grad-CAM (also runs forward pass and triggers hooks)
|
|
# targets=None automatically targets the highest scoring category
|
|
grayscale_cam = self.cam(input_tensor=images, targets=None)
|
|
|
|
# 2. Check if we captured weights for Rollout
|
|
rollout_masks = None
|
|
# if _attention_weights_store:
|
|
# # (B, N, N)
|
|
# full_rollout = self.compute_rollout(_attention_weights_store)
|
|
# if full_rollout is not None:
|
|
# # Extract CLS attention to other tokens: (B, N) -> take [:, 0, 1:]
|
|
# # (B, N, N) -> row 0 is CLS attention to all.
|
|
# cls_attn = full_rollout[:, 0, 1:] # (B, N-1)
|
|
#
|
|
# # Normalize
|
|
# cls_attn = cls_attn / (cls_attn.max(dim=1, keepdim=True)[0] + 1e-8)
|
|
# rollout_masks = cls_attn.cpu().numpy()
|
|
|
|
for i in range(images.size(0)):
|
|
# Option A: Use Grad-CAM (Feature-based)
|
|
mask_cam = grayscale_cam[i, :]
|
|
|
|
# Option B: Use Attention Rollout (Weight-based)
|
|
# User request: Switch from MLP (Rollout/Attn) to Norm layer (Grad-CAM)
|
|
# We disable Rollout preference to use the Norm-based Grad-CAM
|
|
mask = mask_cam
|
|
|
|
# if rollout_masks is not None:
|
|
# mask_rollout_flat = rollout_masks[i]
|
|
# # Reshape to square
|
|
# side = int(np.sqrt(mask_rollout_flat.shape[0]))
|
|
# mask_rollout = mask_rollout_flat.reshape(side, side)
|
|
# # Resize to image size
|
|
# mask_rollout = cv2.resize(mask_rollout, (images.size(3), images.size(2)))
|
|
# mask = mask_rollout # Prefer Rollout
|
|
# else:
|
|
# mask = mask_cam # Fallback
|
|
|
|
# Original image for visualization
|
|
img_tensor = images[i].detach().cpu()
|
|
img_np = img_tensor.permute(1, 2, 0).numpy()
|
|
|
|
# Denormalize/Clip to [0, 1]
|
|
if img_np.max() > 1.0 or img_np.min() < 0.0:
|
|
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
|
|
|
|
visualization = show_cam_on_image(img_np, mask, use_rgb=True)
|
|
|
|
# Save
|
|
Save_File_Root = File.Make_Save_Root(FileName=File_Classes[i], File_root=File_Path)
|
|
File.JudgeRoot_MakeDir(Save_File_Root)
|
|
|
|
file_name = os.path.basename(File_Name[i])
|
|
if not file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
|
|
file_name += ".png"
|
|
|
|
# Convert RGB to BGR for OpenCV saving
|
|
visualization_bgr = cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR)
|
|
File.Save_CV2_File(file_name, Save_File_Root, visualization_bgr)
|
|
|
|
except RuntimeError as e:
|
|
if "out of memory" in str(e):
|
|
print(f"WARNING: Out of Memory in Grad-CAM batch {batch_idx}. Skipping this batch.")
|
|
torch.cuda.empty_cache()
|
|
else:
|
|
raise e
|
|
finally:
|
|
# Cleanup
|
|
if 'grayscale_cam' in locals():
|
|
del grayscale_cam
|
|
if 'images' in locals():
|
|
del images
|
|
_attention_weights_store.clear() # Clear for next batch
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
finally:
|
|
# Disable recording to prevent memory leaks in subsequent training steps
|
|
_record_attention = False
|
|
_attention_weights_store.clear()
|