Stomach_Cancer_Pytorch/run_vit_vis.py

310 lines
13 KiB
Python

import sys
import os
import torch
import glob
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
# Add project root to path
sys.path.append(os.getcwd())
try:
from utils.Stomach_Config import Training_Config, Loading_Config, Save_Result_File_Config, Model_Config
from experiments.Models.ViT_Model import ViTBasic
from draw_tools.Attention_Rollout import AttentionRollout
from draw_tools.Transformer_specific_Grad_CAM_plus_plus import ViTGradCAMPlusPlus
from draw_tools.ViT_LRP import ViTLRP
except ImportError as e:
print(f"Import Error: {e}")
sys.exit(1)
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# 1. Locate Model Weights
# Config path: "../Result/save_the_best_model/{Experiment_Name}"
# We look for '2025-12-19' inside or related to it.
experiment_name = Training_Config.get("Three_Classes_Experiment_Name", "")
# Construct base path from Config
# Save_Result_File_Config["Three_Classes_Identification_Best_Model"] usually has "../"
base_model_path_cfg = Save_Result_File_Config["Three_Classes_Identification_Best_Model"]
# Resolve relative path
if base_model_path_cfg.startswith(".."):
# Assume cwd is Pytorch, so .. is parent of Pytorch
project_root = os.path.dirname(os.getcwd())
base_model_path = os.path.normpath(os.path.join(project_root, base_model_path_cfg[3:]))
else:
base_model_path = os.path.abspath(base_model_path_cfg)
print(f"Base Model Path from Config: {base_model_path}")
# Target directory for weights
# The files are directly in base_model_path with date in filename
target_date = "2025-12-19"
weight_dir = base_model_path # Files are here
print(f"Searching for weights in {weight_dir} with date {target_date}")
if not os.path.exists(weight_dir):
print(f"Directory {weight_dir} does not exist.")
# Fallback search logic could go here, but from previous step we know where they are
return
# Look for files matching *2025-12-19*.pt
weight_files = glob.glob(os.path.join(weight_dir, f"*{target_date}*.pt"))
if not weight_files:
print(f"No .pt files found in {weight_dir} matching *{target_date}*")
return
print(f"Found {len(weight_files)} models: {[os.path.basename(f) for f in weight_files]}")
# 2. Locate Test Data
test_data_cfg = Save_Result_File_Config["Testing_Image_Crop_And_Enhance"]
if test_data_cfg.startswith(".."):
project_root = os.path.dirname(os.getcwd())
test_data_root = os.path.normpath(os.path.join(project_root, test_data_cfg[3:]))
else:
test_data_root = os.path.abspath(test_data_cfg)
print(f"Test Data Root: {test_data_root}")
if not os.path.exists(test_data_root):
print("Test data directory does not exist.")
return
# 3. Initialize Model
# Assuming standard ViT-B/16 structure if weights contain 12 blocks
# Model_Config defaults might be for a smaller model, but saved weights indicate deeper model
img_size = 224
num_classes = len(Loading_Config["Training_Labels"])
print(f"Initializing ViTBasic with img_size={img_size}, num_classes={num_classes}, depth=12, heads=12")
model = ViTBasic(
img_size=img_size,
patch_size=16,
in_chans=3,
num_classes=num_classes,
embed_dim=768,
depth=12, # Force depth=12 to match saved weights
num_heads=12, # Force heads=12 to match standard ViT-B/16
mlp_ratio=4.0,
drop_rate=0.,
attn_drop_rate=0.,
pretrained=False
).to(device)
# Transform (Resize + ToTensor)
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor()
])
# 4. Process
classes = Loading_Config["Training_Labels"]
# Create Output Directory
output_base = os.path.join(os.path.dirname(os.getcwd()), "Result", "Vis_Result", os.path.basename(weight_dir), target_date)
for weight_path in weight_files:
weight_name = os.path.basename(weight_path).replace('.pt', '')
print(f"\nProcessing Model: {weight_name}")
# Load Weights with Key Mapping
try:
state_dict = torch.load(weight_path, map_location=device)
# 1. Clean keys (remove module. and base_model.)
clean_state_dict = {}
for k, v in state_dict.items():
name = k
if name.startswith('module.'):
name = name[7:]
if name.startswith('base_model.'):
name = name[11:]
clean_state_dict[name] = v
# 2. Map Torchvision-style keys to ViTBasic keys
# ViTBasic has a helper for this, but we can do it explicitly here to be safe
# Mapping logic derived from ViTBasic.load_pretrained_weights
mapping = {
'conv_proj.weight': 'patch_embed.proj.weight',
'conv_proj.bias': 'patch_embed.proj.bias',
'class_token': 'cls_token',
'encoder.pos_embedding': 'pos_embed',
'encoder.ln.weight': 'norm.weight',
'encoder.ln.bias': 'norm.bias',
'heads.0.weight': 'head.weight', # Map head
'heads.0.bias': 'head.bias'
}
# Dynamic mapping for blocks
# clean_state_dict has 'encoder.layers.encoder_layer_X...'
# ViTBasic expects 'blocks.X...'
final_state_dict = {}
# Process all keys in clean_state_dict
for k, v in clean_state_dict.items():
new_key = None
# Check static mapping
if k in mapping:
new_key = mapping[k]
# Check block mapping
elif k.startswith('encoder.layers.encoder_layer_'):
# format: encoder.layers.encoder_layer_{i}.{suffix}
parts = k.split('.')
layer_idx = parts[2].split('_')[-1] # get i from encoder_layer_i
suffix = '.'.join(parts[3:])
# Map suffix
suffix_map = {
'ln_1.weight': 'norm1.weight',
'ln_1.bias': 'norm1.bias',
'self_attention.in_proj_weight': 'attn.in_proj_weight',
'self_attention.in_proj_bias': 'attn.in_proj_bias',
'self_attention.out_proj.weight': 'attn.out_proj.weight',
'self_attention.out_proj.bias': 'attn.out_proj.bias',
'ln_2.weight': 'norm2.weight',
'ln_2.bias': 'norm2.bias',
'mlp.0.weight': 'mlp.0.weight',
'mlp.0.bias': 'mlp.0.bias',
'mlp.3.weight': 'mlp.3.weight',
'mlp.3.bias': 'mlp.3.bias'
}
if suffix in suffix_map:
new_key = f"blocks.{layer_idx}.{suffix_map[suffix]}"
if new_key:
final_state_dict[new_key] = v
else:
# If key not mapped, check if it already matches ViTBasic (unlikely given the logs, but possible)
# or just ignore if it's not relevant
pass
msg = model.load_state_dict(final_state_dict, strict=False)
print(f" Load Status: {msg}")
if len(msg.missing_keys) > 0:
print(f" Missing keys: {len(msg.missing_keys)}")
# print(msg.missing_keys) # Uncomment to debug
model.eval()
model.requires_grad_(True) # Ensure model parameters require grad for LRP/GradCAM
# Set average_attn_weights = False for ViT-LRP to work (enable per-head gradients)
for module in model.modules():
if isinstance(module, torch.nn.MultiheadAttention):
module.average_attn_weights = False
except Exception as e:
print(f"Failed to load {weight_name}: {e}")
import traceback
traceback.print_exc()
continue
# Initialize Visualization Tools
rollout_tool = VITAttentionRollout(model)
# Explicitly pass target layer (final norm) for GradCAM++
# In ViTBasic, final norm is model.norm
gradcam_tool = ViTGradCAMPlusPlus(model, target_layer=model.norm)
lrp_tool = ViTLRP(model)
for cls_idx, cls_name in enumerate(classes):
cls_dir = os.path.join(test_data_root, cls_name)
if not os.path.exists(cls_dir):
print(f"Class folder {cls_name} not found in {test_data_root}, skipping.")
continue
# Output path for this model and class
save_dir = os.path.join(output_base, weight_name, cls_name)
os.makedirs(save_dir, exist_ok=True)
image_files = [f for f in os.listdir(cls_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
print(f" Class {cls_name}: {len(image_files)} images")
for img_name in tqdm(image_files, desc=f" Processing {cls_name}"):
img_path = os.path.join(cls_dir, img_name)
try:
# Load Image
raw_img = Image.open(img_path).convert("RGB")
input_tensor = transform(raw_img).unsqueeze(0).to(device)
input_tensor.requires_grad = True # Ensure gradients are tracked
# 1. Inference (Test)
with torch.no_grad():
output = model(input_tensor)
pred_idx = output.argmax(dim=1).item()
# 2. Visualizations
# Attention Rollout
mask_rollout = rollout_tool(input_tensor)
# Grad-CAM++
mask_cam = gradcam_tool(input_tensor, class_idx=pred_idx)
# LRP
mask_lrp = lrp_tool(input_tensor, class_idx=pred_idx)
# 3. Save Results
# Resize masks to original image size for better visualization overlay
raw_np = np.array(raw_img)
def save_vis(mask, suffix):
if mask is None:
return
# Resize mask to image size
mask_resized = cv2.resize(mask[0], (raw_np.shape[1], raw_np.shape[0]))
# Apply colormap
heatmap = cv2.applyColorMap(np.uint8(255 * mask_resized), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
# Overlay
cam = heatmap + np.float32(raw_np) / 255
cam = cam / np.max(cam)
cam = np.uint8(255 * cam)
# Save
save_path = os.path.join(save_dir, f"{os.path.splitext(img_name)[0]}_{suffix}.jpg")
# Save side-by-side: Original | Heatmap | Overlay
# Or just Overlay. User said "產生Grad-CAM的影像後並存起來"
# I'll save the overlay.
cv2.imwrite(save_path, cam)
save_vis(mask_rollout, "Rollout")
save_vis(mask_cam, "GradCAMPlusPlus")
save_vis(mask_lrp, "LRP")
except Exception as e:
print(f"Error processing {img_name}: {e}")
continue
# Cleanup hooks
# Tools might need cleanup if they register hooks that persist (LRP/GradCAM do)
if hasattr(rollout_tool, 'remove_hooks'):
rollout_tool.remove_hooks()
if hasattr(gradcam_tool, 'remove_hooks'):
gradcam_tool.remove_hooks()
if hasattr(lrp_tool, 'remove_hooks'):
lrp_tool.remove_hooks()
del rollout_tool
del gradcam_tool
del lrp_tool
print("Processing Complete.")
if __name__ == "__main__":
main()