import torch import torch.nn as nn import os import glob import datetime from torch.utils.data import DataLoader from torchvision import transforms # Import project modules from experiments.Models.pytorch_Model import ModifiedViTBasic from draw_tools.Attention_Rollout import AttentionRollout from utils.Stomach_Config import Loading_Config from Load_process.Loading_Tools import Load_Data_Tools from Training_Tools.PreProcess import ListDataset def get_test_data(): """ Load testing data paths and labels similar to Load_process logic """ test_root = Loading_Config["Test_Data_Root"] labels = Loading_Config["Training_Labels"] print(f"Loading test data from: {test_root}") print(f"Labels: {labels}") data_dict = {} loading_tool = Load_Data_Tools() # Get all file paths for each class # Loading_Tool.get_data_root returns a dict {label: [paths...]} data_dict = loading_tool.get_data_root(test_root, data_dict, labels) # Flatten into lists for ListDataset all_data_paths = [] all_labels = [] for idx, label in enumerate(labels): paths = data_dict.get(label, []) print(f"Found {len(paths)} images for class '{label}'") all_data_paths.extend(paths) # Assign integer label based on index in Training_Labels all_labels.extend([idx] * len(paths)) return all_data_paths, all_labels def run_inference_gradcam(): # Configuration model_weight_path = r"D:\Programing\stomach_cancer\Result\save_the_best_model\Using pre-train ViTBasic to adds hsv_adaptive_histogram_equalization\best_model( 2025-12-22 )-fold1.pt" # Output directory today_str = str(datetime.date.today()) output_dir = os.path.join("..", "Result", "GradCAM_Inference", f"Test_Run_{today_str}") if not os.path.exists(output_dir): os.makedirs(output_dir) print(f"Results will be saved to: {os.path.abspath(output_dir)}") # 1. Prepare Data data_paths, data_labels = get_test_data() if len(data_paths) == 0: print("Error: No testing data found!") return # Transform: Resize to 224x224 and ToTensor # Assuming standard normalization is handled inside model or not used in this specific pipeline based on PreProcess.py # PreProcess.py: transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]) image_size = 224 transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor() ]) # Create Dataset # Mask_List is None for testing usually test_dataset = ListDataset(image_size, data_paths, data_labels, None, transform) # Create DataLoader batch_size = 16 # Adjust based on GPU memory test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True ) # 2. Load Model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") model = ModifiedViTBasic() # Load weights if os.path.exists(model_weight_path): print(f"Loading weights from: {model_weight_path}") state_dict = torch.load(model_weight_path, map_location=device) # Handle DataParallel state_dict (keys start with 'module.') new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[7:]] = v else: new_state_dict[k] = v # Load state dict try: model.load_state_dict(new_state_dict) print("Weights loaded successfully.") except Exception as e: print(f"Error loading weights: {e}") print("Attempting to load with strict=False...") model.load_state_dict(new_state_dict, strict=False) else: print(f"Error: Model weight file not found at {model_weight_path}") return # Move model to device model = model.to(device) model.eval() # Wrap in DataParallel if multiple GPUs are available (matches training environment) # This also tests the fix in VITAttentionRollout for DataParallel if torch.cuda.device_count() > 1: print(f"Wrapping model in DataParallel ({torch.cuda.device_count()} GPUs)") model = nn.DataParallel(model) # 3. Run GradCAM print("Initializing AttentionRollout...") vit_cam = AttentionRollout(model) print("Starting GradCAM processing...") vit_cam.Processing_Main(test_loader, output_dir) print("Inference completed.") if __name__ == "__main__": run_inference_gradcam()