Stomach_Cancer_Pytorch/draw_tools/Attention_Rollout.py
2026-03-05 16:37:39 +08:00

299 lines
13 KiB
Python

import torch
import numpy as np
import cv2
import os
import torchvision
from Load_process.file_processing import Process_File
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
# Global storage for attention weights (for Attention Rollout)
_attention_weights_store = []
# Global flag to control when to record attention weights
_record_attention = False
def _hook_encoder_block_forward(self, input: torch.Tensor):
"""
Monkey patch for torchvision.models.vision_transformer.EncoderBlock.forward
to capture attention weights.
"""
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input)
# Force need_weights=True to get the attention matrix
x, weights = self.self_attention(x, x, x, need_weights=True)
# Store weights: (Batch, Heads, Seq_Len, Seq_Len)
# Only store if recording is enabled to prevent memory leaks during training
if _record_attention and weights is not None:
# Detach to avoid memory leaks, move to CPU
_attention_weights_store.append(weights.detach().cpu())
x = self.dropout(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
return x + y
class AttentionRollout:
def __init__(self, model, attention_layer_name='encoder.ln', use_cuda=True):
self.model = model
self.use_cuda = use_cuda
try:
self.device = next(model.parameters()).device
except StopIteration:
self.device = torch.device("cpu")
# Apply Monkey Patch to capture weights
self._apply_monkey_patch()
# Handle DataParallel/DistributedDataParallel
if isinstance(model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
base_model_wrapper = model.module
else:
base_model_wrapper = model
# Determine target layer for ViT
target_layer = None
# Helper function to find the target layer
def get_target_layer(model_obj):
# Handle Fusion Models (e.g. EfficientNetViTFusion) that have 'vit' attribute
if hasattr(model_obj, 'vit'):
model_obj = model_obj.vit
# Option 1: Target the final LayerNorm (User request: "Norm layer")
# We prioritize this as requested to use Norm layer for heatmap calculation
if hasattr(model_obj, 'encoder') and hasattr(model_obj.encoder, 'ln'):
return model_obj.encoder.ln
# Option 2: Target the last Self-Attention layer
# torchvision ViT structure: encoder.layers[-1].self_attention
# Note: We target the 'dropout' layer immediately after self_attention because
# MultiheadAttention returns a tuple (output, weights) which causes issues with GradCAM gradients.
# The dropout layer (even with p=0) passes the attention output and returns a single tensor.
if hasattr(model_obj, 'encoder') and hasattr(model_obj.encoder, 'layers'):
last_layer = model_obj.encoder.layers[-1]
if hasattr(last_layer, 'dropout'):
return last_layer.dropout
elif hasattr(last_layer, 'self_attention'):
return last_layer.self_attention
return None
# Check in base_model_wrapper
target_layer = get_target_layer(base_model_wrapper)
# Check in base_model_wrapper.base_model (common wrapper pattern)
if target_layer is None and hasattr(base_model_wrapper, 'base_model'):
target_layer = get_target_layer(base_model_wrapper.base_model)
# Fallback to the last layer if structure is unknown
if target_layer is None:
# Fallback to the last layer if structure is unknown
print("Warning: Could not automatically find 'encoder.ln'. Using the last module.")
# Use base_model_wrapper.modules() to avoid DataParallel wrapper issues
# We filter for modules that have parameters to avoid empty containers
modules = [m for m in base_model_wrapper.modules() if len(list(m.parameters())) > 0]
if modules:
target_layer = modules[-1]
else:
target_layer = list(base_model_wrapper.modules())[-1]
self.target_layers = [target_layer]
# Initialize GradCAM
# reshape_transform is essential for ViT to convert 1D sequence to 2D image
self.cam = GradCAM(model=self.model, target_layers=self.target_layers, reshape_transform=self.reshape_transform)
def _apply_monkey_patch(self):
# Patch EncoderBlock to capture weights
torchvision.models.vision_transformer.EncoderBlock.forward = _hook_encoder_block_forward
def compute_rollout(self, attention_weights_list):
# attention_weights_list: List of tensors [(B, H, N, N), ...]
# We process batch items individually or vectorized.
# Assuming all items in batch have same attention flow structure.
if not attention_weights_list:
return None
# Stack layers: (L, B, H, N, N)
# We only need the last rollout, but we need to compute recursively.
# Take the first batch item for visualization (since we process batch in loop usually)
# Or handle full batch. Let's handle full batch.
# (L, B, H, N, N)
# Check dimensions first
if len(attention_weights_list[0].shape) == 3:
# Missing Batch dimension? Or missing Head dimension?
# Usually (B, N, N) if weights averaged?
# Or (H, N, N) if batch=1 squeezed?
# MultiheadAttention output weights are (B, N, N) if average_attn_weights=True (default in some versions)
# OR (B, H, N, N) if average_attn_weights=False
pass
all_layers = torch.stack(attention_weights_list, dim=0)
if len(all_layers.shape) == 5:
L, B, H, N, _ = all_layers.shape
# Mean across heads: (L, B, N, N)
all_layers = all_layers.mean(dim=2)
elif len(all_layers.shape) == 4:
# Assumed (L, B, N, N) -> Already averaged or single head
L, B, N, _ = all_layers.shape
else:
print(f"Unexpected attention weight shape: {all_layers.shape}")
return None
# Initialize identity matrix for Rollout accumulation: (B, N, N)
# identity = torch.eye(N).to(all_layers.device).unsqueeze(0).repeat(B, 1, 1)
# result = identity
# Better: Start with Identity
result = torch.eye(N).to(all_layers.device).unsqueeze(0).repeat(B, 1, 1)
with torch.no_grad():
for i in range(L):
# Attention matrix at layer i
attn = all_layers[i] # (B, N, N)
# Add residual connection (identity) and re-normalize
# 0.5 * A + 0.5 * I
I = torch.eye(N).to(attn.device).unsqueeze(0)
a = (attn + I) / 2
a = a / a.sum(dim=-1, keepdim=True)
# Recursive multiplication
result = torch.matmul(a, result)
return result
def reshape_transform(self, tensor):
# Handle tuple output (e.g. from MultiheadAttention: (output, weights))
if isinstance(tensor, tuple):
tensor = tensor[0]
# ViT output is (B, N, C) where N = 1 (CLS) + H*W
# tensor is the output of the target layer.
# Handle 2D tensor case (fallback)
if len(tensor.shape) == 2:
return tensor.unsqueeze(-1).unsqueeze(-1)
# Discard CLS token (index 0)
# Assumes CLS token is at index 0.
# For some models it might be different, but standard ViT is 0.
result = tensor[:, 1:, :]
# (B, 196, 768) -> (B, 768, 196)
result = result.transpose(1, 2)
# Reshape to (B, C, H, W)
n_patches = result.size(2)
h = int(np.sqrt(n_patches))
w = h
result = result.reshape(tensor.size(0), result.size(1), h, w)
return result
def Processing_Main(self, Test_Dataloader, File_Path):
import gc
File = Process_File()
self.model.eval()
global _attention_weights_store
global _record_attention
# Enable recording of attention weights
_record_attention = True
try:
for batch_idx, (images, labels, File_Name, File_Classes) in enumerate(Test_Dataloader):
images = images.to(self.device)
# Clear stored weights before forward pass
_attention_weights_store.clear()
try:
# 1. Run Grad-CAM (also runs forward pass and triggers hooks)
# targets=None automatically targets the highest scoring category
grayscale_cam = self.cam(input_tensor=images, targets=None)
# 2. Check if we captured weights for Rollout
rollout_masks = None
# if _attention_weights_store:
# # (B, N, N)
# full_rollout = self.compute_rollout(_attention_weights_store)
# if full_rollout is not None:
# # Extract CLS attention to other tokens: (B, N) -> take [:, 0, 1:]
# # (B, N, N) -> row 0 is CLS attention to all.
# cls_attn = full_rollout[:, 0, 1:] # (B, N-1)
#
# # Normalize
# cls_attn = cls_attn / (cls_attn.max(dim=1, keepdim=True)[0] + 1e-8)
# rollout_masks = cls_attn.cpu().numpy()
for i in range(images.size(0)):
# Option A: Use Grad-CAM (Feature-based)
mask_cam = grayscale_cam[i, :]
# Option B: Use Attention Rollout (Weight-based)
# User request: Switch from MLP (Rollout/Attn) to Norm layer (Grad-CAM)
# We disable Rollout preference to use the Norm-based Grad-CAM
mask = mask_cam
# if rollout_masks is not None:
# mask_rollout_flat = rollout_masks[i]
# # Reshape to square
# side = int(np.sqrt(mask_rollout_flat.shape[0]))
# mask_rollout = mask_rollout_flat.reshape(side, side)
# # Resize to image size
# mask_rollout = cv2.resize(mask_rollout, (images.size(3), images.size(2)))
# mask = mask_rollout # Prefer Rollout
# else:
# mask = mask_cam # Fallback
# Original image for visualization
img_tensor = images[i].detach().cpu()
img_np = img_tensor.permute(1, 2, 0).numpy()
# Denormalize/Clip to [0, 1]
if img_np.max() > 1.0 or img_np.min() < 0.0:
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
visualization = show_cam_on_image(img_np, mask, use_rgb=True)
# Save
Save_File_Root = File.Make_Save_Root(FileName=File_Classes[i], File_root=File_Path)
File.JudgeRoot_MakeDir(Save_File_Root)
file_name = os.path.basename(File_Name[i])
if not file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
file_name += ".png"
# Convert RGB to BGR for OpenCV saving
visualization_bgr = cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR)
File.Save_CV2_File(file_name, Save_File_Root, visualization_bgr)
except RuntimeError as e:
if "out of memory" in str(e):
print(f"WARNING: Out of Memory in Grad-CAM batch {batch_idx}. Skipping this batch.")
torch.cuda.empty_cache()
else:
raise e
finally:
# Cleanup
if 'grayscale_cam' in locals():
del grayscale_cam
if 'images' in locals():
del images
_attention_weights_store.clear() # Clear for next batch
torch.cuda.empty_cache()
gc.collect()
finally:
# Disable recording to prevent memory leaks in subsequent training steps
_record_attention = False
_attention_weights_store.clear()