import torch import torch.nn as nn import os import shutil import numpy as np from draw_tools.Attention_Rollout import AttentionRollout # Dummy classes to mock the structure of ModifiedViTBasic class DummyViT(nn.Module): def __init__(self): super().__init__() self.encoder = nn.Module() # Emulate the LayerNorm layer targeted by VITAttentionRollout self.encoder.ln = nn.LayerNorm(768) def forward(self, x): # x shape: [B, 3, 224, 224] batch_size = x.shape[0] # Simulate ViT output: [Batch, Seq_Len, Hidden_Dim] # 197 tokens = 1 class token + 14x14 patches seq_output = torch.randn(batch_size, 197, 768, device=x.device) # Pass through the targeted layer norm_output = self.encoder.ln(seq_output) return norm_output class DummyModifiedViTBasic(nn.Module): def __init__(self): super().__init__() self.base_model = DummyViT() self.head = nn.Linear(768, 3) def forward(self, x): features = self.base_model(x) # Take class token (index 0) for classification cls_token = features[:, 0] output = self.head(cls_token) return output def test_vit_cam(): print("Setting up test environment...") # 1. Setup device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # 2. Initialize Dummy Model model = DummyModifiedViTBasic().to(device) model.eval() print("Model initialized.") # 3. Create dummy data loader # Dataloader yields: (images, labels, File_Name, File_Classes) batch_size = 2 dummy_images = torch.randn(batch_size, 3, 224, 224) dummy_labels = torch.tensor([0, 1]) dummy_filenames = ["test_img_1.jpg", "test_img_2.jpg"] dummy_classes = ["class_A", "class_B"] # Create a simple generator to mock dataloader dummy_loader = [(dummy_images, dummy_labels, dummy_filenames, dummy_classes)] # 4. Setup output directory output_dir = "test_gradcam_output" if os.path.exists(output_dir): shutil.rmtree(output_dir) os.makedirs(output_dir) # 5. Initialize AttentionRollout print("Initializing AttentionRollout...") try: # Check if we should wrap in DataParallel to reproduce potential issues # Force DataParallel even if 1 GPU to test logic if torch.cuda.is_available(): print(f"Wrapping model in DataParallel ({torch.cuda.device_count()} GPUs)") model = nn.DataParallel(model) vit_cam = AttentionRollout(model) print(f"Target layers found: {vit_cam.target_layers}") # 6. Run Processing_Main print("Running Processing_Main...") vit_cam.Processing_Main(dummy_loader, output_dir) # 7. Verify outputs expected_files = [ os.path.join(output_dir, "class_A", "test_img_1.jpg"), os.path.join(output_dir, "class_B", "test_img_2.jpg") ] success = True for f in expected_files: if os.path.exists(f): print(f"Success: Generated {f}") else: print(f"Error: Missing {f}") success = False if success: print("\nTest PASSED! GradCAM images generated successfully.") else: print("\nTest FAILED! Some images were not generated.") except Exception as e: print(f"\nTest FAILED with exception: {e}") import traceback traceback.print_exc() if __name__ == "__main__": test_vit_cam()