Stomach_Cancer_Pytorch/run_batch_attention_rollout.py

120 lines
4.4 KiB
Python

import torch
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import sys
# Add project root to path
sys.path.append(os.getcwd())
from experiments.Models.pytorch_Model import ModifiedViTBasic
from draw_tools.Attention_Rollout import AttentionRollout
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.samples = []
self.classes = []
# Walk through directories
for entry in os.scandir(root_dir):
if entry.is_dir():
class_name = entry.name
self.classes.append(class_name)
class_dir = entry.path
for file in os.scandir(class_dir):
if file.is_file() and file.name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
self.samples.append((file.path, class_name))
self.classes.sort()
self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, class_name = self.samples[idx]
try:
image = Image.open(img_path).convert('RGB')
except Exception as e:
print(f"Error loading image {img_path}: {e}")
# Return a dummy image or handle error
image = Image.new('RGB', (224, 224))
label = self.class_to_idx[class_name]
if self.transform:
image = self.transform(image)
# Return format expected by AttentionRollout.Processing_Main:
# (images, labels, File_Name, File_Classes)
# Note: Processing_Main expects batches, so these will be batched.
# File_Name should be the full path or filename. Processing_Main uses os.path.basename(File_Name[i]) so full path is fine.
return image, label, img_path, class_name
def run_batch_rollout():
# Configuration
dataset_root = r"D:\Programing\stomach_cancer\Dataset\Testing_Crop"
weight_path = r"D:\Programing\stomach_cancer\Result\save_the_best_model\Using pre-train ViTBasic to adds hsv_adaptive_histogram_equalization\best_model( 2025-12-23 )-fold2.pt"
output_dir = r"D:\Programing\stomach_cancer\Result\Attention_Rollout_Result"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Dataset & DataLoader
dataset = CustomImageDataset(dataset_root, transform=transform)
print(f"Found {len(dataset)} images across {len(dataset.classes)} classes: {dataset.classes}")
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) # num_workers=0 for Windows compatibility/safety
# Load Model
print("Loading model...")
model = ModifiedViTBasic()
try:
# Load weights
# Note: The weight file might be a full checkpoint or just state_dict
checkpoint = torch.load(weight_path, map_location=device)
# Check if it's a dict with 'model_state_dict' or just the state_dict
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
# Handle DataParallel keys if present
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
print("Model weights loaded successfully.")
except Exception as e:
print(f"Error loading weights: {e}")
return
model.to(device)
model.eval()
# Run Attention Rollout
print("Initializing AttentionRollout...")
visualizer = AttentionRollout(model)
print(f"Processing images... Saving to {output_dir}")
visualizer.Processing_Main(dataloader, output_dir)
print("Processing complete.")
if __name__ == "__main__":
run_batch_rollout()