120 lines
4.4 KiB
Python
120 lines
4.4 KiB
Python
import torch
|
|
import os
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
import sys
|
|
|
|
# Add project root to path
|
|
sys.path.append(os.getcwd())
|
|
|
|
from experiments.Models.pytorch_Model import ModifiedViTBasic
|
|
from draw_tools.Attention_Rollout import AttentionRollout
|
|
|
|
class CustomImageDataset(Dataset):
|
|
def __init__(self, root_dir, transform=None):
|
|
self.root_dir = root_dir
|
|
self.transform = transform
|
|
self.samples = []
|
|
self.classes = []
|
|
|
|
# Walk through directories
|
|
for entry in os.scandir(root_dir):
|
|
if entry.is_dir():
|
|
class_name = entry.name
|
|
self.classes.append(class_name)
|
|
class_dir = entry.path
|
|
for file in os.scandir(class_dir):
|
|
if file.is_file() and file.name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
|
|
self.samples.append((file.path, class_name))
|
|
|
|
self.classes.sort()
|
|
self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def __getitem__(self, idx):
|
|
img_path, class_name = self.samples[idx]
|
|
try:
|
|
image = Image.open(img_path).convert('RGB')
|
|
except Exception as e:
|
|
print(f"Error loading image {img_path}: {e}")
|
|
# Return a dummy image or handle error
|
|
image = Image.new('RGB', (224, 224))
|
|
|
|
label = self.class_to_idx[class_name]
|
|
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
|
|
# Return format expected by AttentionRollout.Processing_Main:
|
|
# (images, labels, File_Name, File_Classes)
|
|
# Note: Processing_Main expects batches, so these will be batched.
|
|
# File_Name should be the full path or filename. Processing_Main uses os.path.basename(File_Name[i]) so full path is fine.
|
|
return image, label, img_path, class_name
|
|
|
|
def run_batch_rollout():
|
|
# Configuration
|
|
dataset_root = r"D:\Programing\stomach_cancer\Dataset\Testing_Crop"
|
|
weight_path = r"D:\Programing\stomach_cancer\Result\save_the_best_model\Using pre-train ViTBasic to adds hsv_adaptive_histogram_equalization\best_model( 2025-12-23 )-fold2.pt"
|
|
output_dir = r"D:\Programing\stomach_cancer\Result\Attention_Rollout_Result"
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
# Transforms
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
# Dataset & DataLoader
|
|
dataset = CustomImageDataset(dataset_root, transform=transform)
|
|
print(f"Found {len(dataset)} images across {len(dataset.classes)} classes: {dataset.classes}")
|
|
|
|
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) # num_workers=0 for Windows compatibility/safety
|
|
|
|
# Load Model
|
|
print("Loading model...")
|
|
model = ModifiedViTBasic()
|
|
try:
|
|
# Load weights
|
|
# Note: The weight file might be a full checkpoint or just state_dict
|
|
checkpoint = torch.load(weight_path, map_location=device)
|
|
|
|
# Check if it's a dict with 'model_state_dict' or just the state_dict
|
|
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
|
state_dict = checkpoint['model_state_dict']
|
|
else:
|
|
state_dict = checkpoint
|
|
|
|
# Handle DataParallel keys if present
|
|
new_state_dict = {}
|
|
for k, v in state_dict.items():
|
|
if k.startswith('module.'):
|
|
new_state_dict[k[7:]] = v
|
|
else:
|
|
new_state_dict[k] = v
|
|
|
|
model.load_state_dict(new_state_dict)
|
|
print("Model weights loaded successfully.")
|
|
except Exception as e:
|
|
print(f"Error loading weights: {e}")
|
|
return
|
|
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
# Run Attention Rollout
|
|
print("Initializing AttentionRollout...")
|
|
visualizer = AttentionRollout(model)
|
|
|
|
print(f"Processing images... Saving to {output_dir}")
|
|
visualizer.Processing_Main(dataloader, output_dir)
|
|
print("Processing complete.")
|
|
|
|
if __name__ == "__main__":
|
|
run_batch_rollout()
|