import torch import os import sys import cv2 import numpy as np from torch.utils.data import DataLoader # Add project root to path sys.path.append(os.getcwd()) from experiments.Models.pytorch_Model import ModifiedViTBasic from draw_tools.Attention_Rollout import AttentionRollout def test_attention_rollout(): # 1. Configuration model_path = r"D:\Programing\stomach_cancer\Result\save_the_best_model\Using pre-train ViTBasic to adds hsv_adaptive_histogram_equalization\best_model( 2025-12-23 )-fold2.pt" image_path = "a140.jpg" # Use an existing image in root output_dir = "test_output_attention_rollout" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # 2. Load Model print("Loading model...") # ModifiedViTBasic does not take arguments in __init__ model = ModifiedViTBasic() if os.path.exists(model_path): state_dict = torch.load(model_path, map_location=device) # Handle DataParallel/DistributedDataParallel keys if present new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[7:]] = v else: new_state_dict[k] = v # Load weights try: model.load_state_dict(new_state_dict) print("Model weights loaded successfully.") except RuntimeError as e: print(f"Error loading weights: {e}") # Try stricter loading if needed, or inspect keys return else: print(f"Error: Model file not found at {model_path}") return model.to(device) model.eval() # 3. Prepare Dummy Data if not os.path.exists(image_path): print(f"Error: Test image not found at {image_path}") # Create a dummy image cv2.imwrite("dummy_test.jpg", np.zeros((224, 224, 3), dtype=np.uint8)) image_path = "dummy_test.jpg" # Read and preprocess image img = cv2.imread(image_path) img = cv2.resize(img, (224, 224)) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Normalize (assuming standard ImageNet normalization or similar used in project) # If not sure, we use simple 0-1 or standard img_tensor = torch.from_numpy(img).float() / 255.0 img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # (1, 3, 224, 224) # Mock DataLoader structure # yield (images, labels, File_Name, File_Classes) # labels can be dummy # File_Name should be list of paths # File_Classes should be list of class names class MockDataLoader: def __iter__(self): yield img_tensor, torch.tensor([0]), [os.path.abspath(image_path)], ["TestClass"] test_loader = MockDataLoader() # 4. Run AttentionRollout print("Running AttentionRollout...") visualizer = AttentionRollout(model, use_cuda=torch.cuda.is_available()) # Check what target layer was selected print(f"Target layers: {visualizer.target_layers}") visualizer.Processing_Main(test_loader, output_dir) print(f"Done. Output saved to {output_dir}") if __name__ == "__main__": test_attention_rollout()