111 lines
3.6 KiB
Python
111 lines
3.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import os
|
|
import shutil
|
|
import numpy as np
|
|
from draw_tools.Attention_Rollout import AttentionRollout
|
|
|
|
# Dummy classes to mock the structure of ModifiedViTBasic
|
|
class DummyViT(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.encoder = nn.Module()
|
|
# Emulate the LayerNorm layer targeted by VITAttentionRollout
|
|
self.encoder.ln = nn.LayerNorm(768)
|
|
|
|
def forward(self, x):
|
|
# x shape: [B, 3, 224, 224]
|
|
batch_size = x.shape[0]
|
|
# Simulate ViT output: [Batch, Seq_Len, Hidden_Dim]
|
|
# 197 tokens = 1 class token + 14x14 patches
|
|
seq_output = torch.randn(batch_size, 197, 768, device=x.device)
|
|
|
|
# Pass through the targeted layer
|
|
norm_output = self.encoder.ln(seq_output)
|
|
return norm_output
|
|
|
|
class DummyModifiedViTBasic(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.base_model = DummyViT()
|
|
self.head = nn.Linear(768, 3)
|
|
|
|
def forward(self, x):
|
|
features = self.base_model(x)
|
|
# Take class token (index 0) for classification
|
|
cls_token = features[:, 0]
|
|
output = self.head(cls_token)
|
|
return output
|
|
|
|
def test_vit_cam():
|
|
print("Setting up test environment...")
|
|
|
|
# 1. Setup device
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f"Using device: {device}")
|
|
|
|
# 2. Initialize Dummy Model
|
|
model = DummyModifiedViTBasic().to(device)
|
|
model.eval()
|
|
print("Model initialized.")
|
|
|
|
# 3. Create dummy data loader
|
|
# Dataloader yields: (images, labels, File_Name, File_Classes)
|
|
batch_size = 2
|
|
dummy_images = torch.randn(batch_size, 3, 224, 224)
|
|
dummy_labels = torch.tensor([0, 1])
|
|
dummy_filenames = ["test_img_1.jpg", "test_img_2.jpg"]
|
|
dummy_classes = ["class_A", "class_B"]
|
|
|
|
# Create a simple generator to mock dataloader
|
|
dummy_loader = [(dummy_images, dummy_labels, dummy_filenames, dummy_classes)]
|
|
|
|
# 4. Setup output directory
|
|
output_dir = "test_gradcam_output"
|
|
if os.path.exists(output_dir):
|
|
shutil.rmtree(output_dir)
|
|
os.makedirs(output_dir)
|
|
|
|
# 5. Initialize AttentionRollout
|
|
print("Initializing AttentionRollout...")
|
|
try:
|
|
# Check if we should wrap in DataParallel to reproduce potential issues
|
|
# Force DataParallel even if 1 GPU to test logic
|
|
if torch.cuda.is_available():
|
|
print(f"Wrapping model in DataParallel ({torch.cuda.device_count()} GPUs)")
|
|
model = nn.DataParallel(model)
|
|
|
|
vit_cam = AttentionRollout(model)
|
|
print(f"Target layers found: {vit_cam.target_layers}")
|
|
|
|
# 6. Run Processing_Main
|
|
print("Running Processing_Main...")
|
|
vit_cam.Processing_Main(dummy_loader, output_dir)
|
|
|
|
# 7. Verify outputs
|
|
expected_files = [
|
|
os.path.join(output_dir, "class_A", "test_img_1.jpg"),
|
|
os.path.join(output_dir, "class_B", "test_img_2.jpg")
|
|
]
|
|
|
|
success = True
|
|
for f in expected_files:
|
|
if os.path.exists(f):
|
|
print(f"Success: Generated {f}")
|
|
else:
|
|
print(f"Error: Missing {f}")
|
|
success = False
|
|
|
|
if success:
|
|
print("\nTest PASSED! GradCAM images generated successfully.")
|
|
else:
|
|
print("\nTest FAILED! Some images were not generated.")
|
|
|
|
except Exception as e:
|
|
print(f"\nTest FAILED with exception: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
if __name__ == "__main__":
|
|
test_vit_cam()
|