139 lines
4.6 KiB
Python
139 lines
4.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import os
|
|
import glob
|
|
import datetime
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import transforms
|
|
|
|
# Import project modules
|
|
from experiments.Models.pytorch_Model import ModifiedViTBasic
|
|
from draw_tools.Attention_Rollout import AttentionRollout
|
|
from utils.Stomach_Config import Loading_Config
|
|
from Load_process.Loading_Tools import Load_Data_Tools
|
|
from Training_Tools.PreProcess import ListDataset
|
|
|
|
def get_test_data():
|
|
"""
|
|
Load testing data paths and labels similar to Load_process logic
|
|
"""
|
|
test_root = Loading_Config["Test_Data_Root"]
|
|
labels = Loading_Config["Training_Labels"]
|
|
|
|
print(f"Loading test data from: {test_root}")
|
|
print(f"Labels: {labels}")
|
|
|
|
data_dict = {}
|
|
loading_tool = Load_Data_Tools()
|
|
|
|
# Get all file paths for each class
|
|
# Loading_Tool.get_data_root returns a dict {label: [paths...]}
|
|
data_dict = loading_tool.get_data_root(test_root, data_dict, labels)
|
|
|
|
# Flatten into lists for ListDataset
|
|
all_data_paths = []
|
|
all_labels = []
|
|
|
|
for idx, label in enumerate(labels):
|
|
paths = data_dict.get(label, [])
|
|
print(f"Found {len(paths)} images for class '{label}'")
|
|
all_data_paths.extend(paths)
|
|
# Assign integer label based on index in Training_Labels
|
|
all_labels.extend([idx] * len(paths))
|
|
|
|
return all_data_paths, all_labels
|
|
|
|
def run_inference_gradcam():
|
|
# Configuration
|
|
model_weight_path = r"D:\Programing\stomach_cancer\Result\save_the_best_model\Using pre-train ViTBasic to adds hsv_adaptive_histogram_equalization\best_model( 2025-12-22 )-fold1.pt"
|
|
|
|
# Output directory
|
|
today_str = str(datetime.date.today())
|
|
output_dir = os.path.join("..", "Result", "GradCAM_Inference", f"Test_Run_{today_str}")
|
|
if not os.path.exists(output_dir):
|
|
os.makedirs(output_dir)
|
|
print(f"Results will be saved to: {os.path.abspath(output_dir)}")
|
|
|
|
# 1. Prepare Data
|
|
data_paths, data_labels = get_test_data()
|
|
|
|
if len(data_paths) == 0:
|
|
print("Error: No testing data found!")
|
|
return
|
|
|
|
# Transform: Resize to 224x224 and ToTensor
|
|
# Assuming standard normalization is handled inside model or not used in this specific pipeline based on PreProcess.py
|
|
# PreProcess.py: transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
|
|
image_size = 224
|
|
transform = transforms.Compose([
|
|
transforms.Resize((image_size, image_size)),
|
|
transforms.ToTensor()
|
|
])
|
|
|
|
# Create Dataset
|
|
# Mask_List is None for testing usually
|
|
test_dataset = ListDataset(image_size, data_paths, data_labels, None, transform)
|
|
|
|
# Create DataLoader
|
|
batch_size = 16 # Adjust based on GPU memory
|
|
test_loader = DataLoader(
|
|
test_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
num_workers=0,
|
|
pin_memory=True
|
|
)
|
|
|
|
# 2. Load Model
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f"Using device: {device}")
|
|
|
|
model = ModifiedViTBasic()
|
|
|
|
# Load weights
|
|
if os.path.exists(model_weight_path):
|
|
print(f"Loading weights from: {model_weight_path}")
|
|
state_dict = torch.load(model_weight_path, map_location=device)
|
|
|
|
# Handle DataParallel state_dict (keys start with 'module.')
|
|
new_state_dict = {}
|
|
for k, v in state_dict.items():
|
|
if k.startswith('module.'):
|
|
new_state_dict[k[7:]] = v
|
|
else:
|
|
new_state_dict[k] = v
|
|
|
|
# Load state dict
|
|
try:
|
|
model.load_state_dict(new_state_dict)
|
|
print("Weights loaded successfully.")
|
|
except Exception as e:
|
|
print(f"Error loading weights: {e}")
|
|
print("Attempting to load with strict=False...")
|
|
model.load_state_dict(new_state_dict, strict=False)
|
|
else:
|
|
print(f"Error: Model weight file not found at {model_weight_path}")
|
|
return
|
|
|
|
# Move model to device
|
|
model = model.to(device)
|
|
model.eval()
|
|
|
|
# Wrap in DataParallel if multiple GPUs are available (matches training environment)
|
|
# This also tests the fix in VITAttentionRollout for DataParallel
|
|
if torch.cuda.device_count() > 1:
|
|
print(f"Wrapping model in DataParallel ({torch.cuda.device_count()} GPUs)")
|
|
model = nn.DataParallel(model)
|
|
|
|
# 3. Run GradCAM
|
|
print("Initializing AttentionRollout...")
|
|
vit_cam = AttentionRollout(model)
|
|
|
|
print("Starting GradCAM processing...")
|
|
vit_cam.Processing_Main(test_loader, output_dir)
|
|
|
|
print("Inference completed.")
|
|
|
|
if __name__ == "__main__":
|
|
run_inference_gradcam()
|