Stomach_Cancer_Pytorch/test_attention_rollout.py

96 lines
3.2 KiB
Python

import torch
import os
import sys
import cv2
import numpy as np
from torch.utils.data import DataLoader
# Add project root to path
sys.path.append(os.getcwd())
from experiments.Models.pytorch_Model import ModifiedViTBasic
from draw_tools.Attention_Rollout import AttentionRollout
def test_attention_rollout():
# 1. Configuration
model_path = r"D:\Programing\stomach_cancer\Result\save_the_best_model\Using pre-train ViTBasic to adds hsv_adaptive_histogram_equalization\best_model( 2025-12-23 )-fold2.pt"
image_path = "a140.jpg" # Use an existing image in root
output_dir = "test_output_attention_rollout"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 2. Load Model
print("Loading model...")
# ModifiedViTBasic does not take arguments in __init__
model = ModifiedViTBasic()
if os.path.exists(model_path):
state_dict = torch.load(model_path, map_location=device)
# Handle DataParallel/DistributedDataParallel keys if present
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
# Load weights
try:
model.load_state_dict(new_state_dict)
print("Model weights loaded successfully.")
except RuntimeError as e:
print(f"Error loading weights: {e}")
# Try stricter loading if needed, or inspect keys
return
else:
print(f"Error: Model file not found at {model_path}")
return
model.to(device)
model.eval()
# 3. Prepare Dummy Data
if not os.path.exists(image_path):
print(f"Error: Test image not found at {image_path}")
# Create a dummy image
cv2.imwrite("dummy_test.jpg", np.zeros((224, 224, 3), dtype=np.uint8))
image_path = "dummy_test.jpg"
# Read and preprocess image
img = cv2.imread(image_path)
img = cv2.resize(img, (224, 224))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Normalize (assuming standard ImageNet normalization or similar used in project)
# If not sure, we use simple 0-1 or standard
img_tensor = torch.from_numpy(img).float() / 255.0
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # (1, 3, 224, 224)
# Mock DataLoader structure
# yield (images, labels, File_Name, File_Classes)
# labels can be dummy
# File_Name should be list of paths
# File_Classes should be list of class names
class MockDataLoader:
def __iter__(self):
yield img_tensor, torch.tensor([0]), [os.path.abspath(image_path)], ["TestClass"]
test_loader = MockDataLoader()
# 4. Run AttentionRollout
print("Running AttentionRollout...")
visualizer = AttentionRollout(model, use_cuda=torch.cuda.is_available())
# Check what target layer was selected
print(f"Target layers: {visualizer.target_layers}")
visualizer.Processing_Main(test_loader, output_dir)
print(f"Done. Output saved to {output_dir}")
if __name__ == "__main__":
test_attention_rollout()