Stomach_Cancer_Pytorch/run_inference_gradcam.py

139 lines
4.6 KiB
Python

import torch
import torch.nn as nn
import os
import glob
import datetime
from torch.utils.data import DataLoader
from torchvision import transforms
# Import project modules
from experiments.Models.pytorch_Model import ModifiedViTBasic
from draw_tools.Attention_Rollout import AttentionRollout
from utils.Stomach_Config import Loading_Config
from Load_process.Loading_Tools import Load_Data_Tools
from Training_Tools.PreProcess import ListDataset
def get_test_data():
"""
Load testing data paths and labels similar to Load_process logic
"""
test_root = Loading_Config["Test_Data_Root"]
labels = Loading_Config["Training_Labels"]
print(f"Loading test data from: {test_root}")
print(f"Labels: {labels}")
data_dict = {}
loading_tool = Load_Data_Tools()
# Get all file paths for each class
# Loading_Tool.get_data_root returns a dict {label: [paths...]}
data_dict = loading_tool.get_data_root(test_root, data_dict, labels)
# Flatten into lists for ListDataset
all_data_paths = []
all_labels = []
for idx, label in enumerate(labels):
paths = data_dict.get(label, [])
print(f"Found {len(paths)} images for class '{label}'")
all_data_paths.extend(paths)
# Assign integer label based on index in Training_Labels
all_labels.extend([idx] * len(paths))
return all_data_paths, all_labels
def run_inference_gradcam():
# Configuration
model_weight_path = r"D:\Programing\stomach_cancer\Result\save_the_best_model\Using pre-train ViTBasic to adds hsv_adaptive_histogram_equalization\best_model( 2025-12-22 )-fold1.pt"
# Output directory
today_str = str(datetime.date.today())
output_dir = os.path.join("..", "Result", "GradCAM_Inference", f"Test_Run_{today_str}")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(f"Results will be saved to: {os.path.abspath(output_dir)}")
# 1. Prepare Data
data_paths, data_labels = get_test_data()
if len(data_paths) == 0:
print("Error: No testing data found!")
return
# Transform: Resize to 224x224 and ToTensor
# Assuming standard normalization is handled inside model or not used in this specific pipeline based on PreProcess.py
# PreProcess.py: transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
image_size = 224
transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor()
])
# Create Dataset
# Mask_List is None for testing usually
test_dataset = ListDataset(image_size, data_paths, data_labels, None, transform)
# Create DataLoader
batch_size = 16 # Adjust based on GPU memory
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True
)
# 2. Load Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = ModifiedViTBasic()
# Load weights
if os.path.exists(model_weight_path):
print(f"Loading weights from: {model_weight_path}")
state_dict = torch.load(model_weight_path, map_location=device)
# Handle DataParallel state_dict (keys start with 'module.')
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
# Load state dict
try:
model.load_state_dict(new_state_dict)
print("Weights loaded successfully.")
except Exception as e:
print(f"Error loading weights: {e}")
print("Attempting to load with strict=False...")
model.load_state_dict(new_state_dict, strict=False)
else:
print(f"Error: Model weight file not found at {model_weight_path}")
return
# Move model to device
model = model.to(device)
model.eval()
# Wrap in DataParallel if multiple GPUs are available (matches training environment)
# This also tests the fix in VITAttentionRollout for DataParallel
if torch.cuda.device_count() > 1:
print(f"Wrapping model in DataParallel ({torch.cuda.device_count()} GPUs)")
model = nn.DataParallel(model)
# 3. Run GradCAM
print("Initializing AttentionRollout...")
vit_cam = AttentionRollout(model)
print("Starting GradCAM processing...")
vit_cam.Processing_Main(test_loader, output_dir)
print("Inference completed.")
if __name__ == "__main__":
run_inference_gradcam()