import torch import os from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import sys # Add project root to path sys.path.append(os.getcwd()) from experiments.Models.pytorch_Model import ModifiedViTBasic from draw_tools.Attention_Rollout import AttentionRollout class CustomImageDataset(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform self.samples = [] self.classes = [] # Walk through directories for entry in os.scandir(root_dir): if entry.is_dir(): class_name = entry.name self.classes.append(class_name) class_dir = entry.path for file in os.scandir(class_dir): if file.is_file() and file.name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')): self.samples.append((file.path, class_name)) self.classes.sort() self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)} def __len__(self): return len(self.samples) def __getitem__(self, idx): img_path, class_name = self.samples[idx] try: image = Image.open(img_path).convert('RGB') except Exception as e: print(f"Error loading image {img_path}: {e}") # Return a dummy image or handle error image = Image.new('RGB', (224, 224)) label = self.class_to_idx[class_name] if self.transform: image = self.transform(image) # Return format expected by AttentionRollout.Processing_Main: # (images, labels, File_Name, File_Classes) # Note: Processing_Main expects batches, so these will be batched. # File_Name should be the full path or filename. Processing_Main uses os.path.basename(File_Name[i]) so full path is fine. return image, label, img_path, class_name def run_batch_rollout(): # Configuration dataset_root = r"D:\Programing\stomach_cancer\Dataset\Testing_Crop" weight_path = r"D:\Programing\stomach_cancer\Result\save_the_best_model\Using pre-train ViTBasic to adds hsv_adaptive_histogram_equalization\best_model( 2025-12-23 )-fold2.pt" output_dir = r"D:\Programing\stomach_cancer\Result\Attention_Rollout_Result" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Transforms transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Dataset & DataLoader dataset = CustomImageDataset(dataset_root, transform=transform) print(f"Found {len(dataset)} images across {len(dataset.classes)} classes: {dataset.classes}") dataloader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) # num_workers=0 for Windows compatibility/safety # Load Model print("Loading model...") model = ModifiedViTBasic() try: # Load weights # Note: The weight file might be a full checkpoint or just state_dict checkpoint = torch.load(weight_path, map_location=device) # Check if it's a dict with 'model_state_dict' or just the state_dict if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] else: state_dict = checkpoint # Handle DataParallel keys if present new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[7:]] = v else: new_state_dict[k] = v model.load_state_dict(new_state_dict) print("Model weights loaded successfully.") except Exception as e: print(f"Error loading weights: {e}") return model.to(device) model.eval() # Run Attention Rollout print("Initializing AttentionRollout...") visualizer = AttentionRollout(model) print(f"Processing images... Saving to {output_dir}") visualizer.Processing_Main(dataloader, output_dir) print("Processing complete.") if __name__ == "__main__": run_batch_rollout()