Stomach_Cancer_Pytorch/test_vit_cam.py

111 lines
3.6 KiB
Python

import torch
import torch.nn as nn
import os
import shutil
import numpy as np
from draw_tools.Attention_Rollout import AttentionRollout
# Dummy classes to mock the structure of ModifiedViTBasic
class DummyViT(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Module()
# Emulate the LayerNorm layer targeted by VITAttentionRollout
self.encoder.ln = nn.LayerNorm(768)
def forward(self, x):
# x shape: [B, 3, 224, 224]
batch_size = x.shape[0]
# Simulate ViT output: [Batch, Seq_Len, Hidden_Dim]
# 197 tokens = 1 class token + 14x14 patches
seq_output = torch.randn(batch_size, 197, 768, device=x.device)
# Pass through the targeted layer
norm_output = self.encoder.ln(seq_output)
return norm_output
class DummyModifiedViTBasic(nn.Module):
def __init__(self):
super().__init__()
self.base_model = DummyViT()
self.head = nn.Linear(768, 3)
def forward(self, x):
features = self.base_model(x)
# Take class token (index 0) for classification
cls_token = features[:, 0]
output = self.head(cls_token)
return output
def test_vit_cam():
print("Setting up test environment...")
# 1. Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# 2. Initialize Dummy Model
model = DummyModifiedViTBasic().to(device)
model.eval()
print("Model initialized.")
# 3. Create dummy data loader
# Dataloader yields: (images, labels, File_Name, File_Classes)
batch_size = 2
dummy_images = torch.randn(batch_size, 3, 224, 224)
dummy_labels = torch.tensor([0, 1])
dummy_filenames = ["test_img_1.jpg", "test_img_2.jpg"]
dummy_classes = ["class_A", "class_B"]
# Create a simple generator to mock dataloader
dummy_loader = [(dummy_images, dummy_labels, dummy_filenames, dummy_classes)]
# 4. Setup output directory
output_dir = "test_gradcam_output"
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
# 5. Initialize AttentionRollout
print("Initializing AttentionRollout...")
try:
# Check if we should wrap in DataParallel to reproduce potential issues
# Force DataParallel even if 1 GPU to test logic
if torch.cuda.is_available():
print(f"Wrapping model in DataParallel ({torch.cuda.device_count()} GPUs)")
model = nn.DataParallel(model)
vit_cam = AttentionRollout(model)
print(f"Target layers found: {vit_cam.target_layers}")
# 6. Run Processing_Main
print("Running Processing_Main...")
vit_cam.Processing_Main(dummy_loader, output_dir)
# 7. Verify outputs
expected_files = [
os.path.join(output_dir, "class_A", "test_img_1.jpg"),
os.path.join(output_dir, "class_B", "test_img_2.jpg")
]
success = True
for f in expected_files:
if os.path.exists(f):
print(f"Success: Generated {f}")
else:
print(f"Error: Missing {f}")
success = False
if success:
print("\nTest PASSED! GradCAM images generated successfully.")
else:
print("\nTest FAILED! Some images were not generated.")
except Exception as e:
print(f"\nTest FAILED with exception: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_vit_cam()