310 lines
13 KiB
Python
310 lines
13 KiB
Python
import sys
|
|
import os
|
|
import torch
|
|
import glob
|
|
import cv2
|
|
import numpy as np
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
from tqdm import tqdm
|
|
import matplotlib.pyplot as plt
|
|
|
|
# Add project root to path
|
|
sys.path.append(os.getcwd())
|
|
|
|
try:
|
|
from utils.Stomach_Config import Training_Config, Loading_Config, Save_Result_File_Config, Model_Config
|
|
from experiments.Models.ViT_Model import ViTBasic
|
|
from draw_tools.Attention_Rollout import AttentionRollout
|
|
from draw_tools.Transformer_specific_Grad_CAM_plus_plus import ViTGradCAMPlusPlus
|
|
from draw_tools.ViT_LRP import ViTLRP
|
|
except ImportError as e:
|
|
print(f"Import Error: {e}")
|
|
sys.exit(1)
|
|
|
|
def main():
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
print(f"Using device: {device}")
|
|
|
|
# 1. Locate Model Weights
|
|
# Config path: "../Result/save_the_best_model/{Experiment_Name}"
|
|
# We look for '2025-12-19' inside or related to it.
|
|
|
|
experiment_name = Training_Config.get("Three_Classes_Experiment_Name", "")
|
|
# Construct base path from Config
|
|
# Save_Result_File_Config["Three_Classes_Identification_Best_Model"] usually has "../"
|
|
base_model_path_cfg = Save_Result_File_Config["Three_Classes_Identification_Best_Model"]
|
|
|
|
# Resolve relative path
|
|
if base_model_path_cfg.startswith(".."):
|
|
# Assume cwd is Pytorch, so .. is parent of Pytorch
|
|
project_root = os.path.dirname(os.getcwd())
|
|
base_model_path = os.path.normpath(os.path.join(project_root, base_model_path_cfg[3:]))
|
|
else:
|
|
base_model_path = os.path.abspath(base_model_path_cfg)
|
|
|
|
print(f"Base Model Path from Config: {base_model_path}")
|
|
|
|
# Target directory for weights
|
|
# The files are directly in base_model_path with date in filename
|
|
target_date = "2025-12-19"
|
|
weight_dir = base_model_path # Files are here
|
|
|
|
print(f"Searching for weights in {weight_dir} with date {target_date}")
|
|
|
|
if not os.path.exists(weight_dir):
|
|
print(f"Directory {weight_dir} does not exist.")
|
|
# Fallback search logic could go here, but from previous step we know where they are
|
|
return
|
|
|
|
# Look for files matching *2025-12-19*.pt
|
|
weight_files = glob.glob(os.path.join(weight_dir, f"*{target_date}*.pt"))
|
|
if not weight_files:
|
|
print(f"No .pt files found in {weight_dir} matching *{target_date}*")
|
|
return
|
|
|
|
print(f"Found {len(weight_files)} models: {[os.path.basename(f) for f in weight_files]}")
|
|
|
|
# 2. Locate Test Data
|
|
test_data_cfg = Save_Result_File_Config["Testing_Image_Crop_And_Enhance"]
|
|
if test_data_cfg.startswith(".."):
|
|
project_root = os.path.dirname(os.getcwd())
|
|
test_data_root = os.path.normpath(os.path.join(project_root, test_data_cfg[3:]))
|
|
else:
|
|
test_data_root = os.path.abspath(test_data_cfg)
|
|
|
|
print(f"Test Data Root: {test_data_root}")
|
|
if not os.path.exists(test_data_root):
|
|
print("Test data directory does not exist.")
|
|
return
|
|
|
|
# 3. Initialize Model
|
|
# Assuming standard ViT-B/16 structure if weights contain 12 blocks
|
|
# Model_Config defaults might be for a smaller model, but saved weights indicate deeper model
|
|
img_size = 224
|
|
num_classes = len(Loading_Config["Training_Labels"])
|
|
|
|
print(f"Initializing ViTBasic with img_size={img_size}, num_classes={num_classes}, depth=12, heads=12")
|
|
|
|
model = ViTBasic(
|
|
img_size=img_size,
|
|
patch_size=16,
|
|
in_chans=3,
|
|
num_classes=num_classes,
|
|
embed_dim=768,
|
|
depth=12, # Force depth=12 to match saved weights
|
|
num_heads=12, # Force heads=12 to match standard ViT-B/16
|
|
mlp_ratio=4.0,
|
|
drop_rate=0.,
|
|
attn_drop_rate=0.,
|
|
pretrained=False
|
|
).to(device)
|
|
|
|
# Transform (Resize + ToTensor)
|
|
transform = transforms.Compose([
|
|
transforms.Resize((img_size, img_size)),
|
|
transforms.ToTensor()
|
|
])
|
|
|
|
# 4. Process
|
|
classes = Loading_Config["Training_Labels"]
|
|
|
|
# Create Output Directory
|
|
output_base = os.path.join(os.path.dirname(os.getcwd()), "Result", "Vis_Result", os.path.basename(weight_dir), target_date)
|
|
|
|
for weight_path in weight_files:
|
|
weight_name = os.path.basename(weight_path).replace('.pt', '')
|
|
print(f"\nProcessing Model: {weight_name}")
|
|
|
|
# Load Weights with Key Mapping
|
|
try:
|
|
state_dict = torch.load(weight_path, map_location=device)
|
|
|
|
# 1. Clean keys (remove module. and base_model.)
|
|
clean_state_dict = {}
|
|
for k, v in state_dict.items():
|
|
name = k
|
|
if name.startswith('module.'):
|
|
name = name[7:]
|
|
if name.startswith('base_model.'):
|
|
name = name[11:]
|
|
clean_state_dict[name] = v
|
|
|
|
# 2. Map Torchvision-style keys to ViTBasic keys
|
|
# ViTBasic has a helper for this, but we can do it explicitly here to be safe
|
|
# Mapping logic derived from ViTBasic.load_pretrained_weights
|
|
mapping = {
|
|
'conv_proj.weight': 'patch_embed.proj.weight',
|
|
'conv_proj.bias': 'patch_embed.proj.bias',
|
|
'class_token': 'cls_token',
|
|
'encoder.pos_embedding': 'pos_embed',
|
|
'encoder.ln.weight': 'norm.weight',
|
|
'encoder.ln.bias': 'norm.bias',
|
|
'heads.0.weight': 'head.weight', # Map head
|
|
'heads.0.bias': 'head.bias'
|
|
}
|
|
|
|
# Dynamic mapping for blocks
|
|
# clean_state_dict has 'encoder.layers.encoder_layer_X...'
|
|
# ViTBasic expects 'blocks.X...'
|
|
|
|
final_state_dict = {}
|
|
|
|
# Process all keys in clean_state_dict
|
|
for k, v in clean_state_dict.items():
|
|
new_key = None
|
|
|
|
# Check static mapping
|
|
if k in mapping:
|
|
new_key = mapping[k]
|
|
|
|
# Check block mapping
|
|
elif k.startswith('encoder.layers.encoder_layer_'):
|
|
# format: encoder.layers.encoder_layer_{i}.{suffix}
|
|
parts = k.split('.')
|
|
layer_idx = parts[2].split('_')[-1] # get i from encoder_layer_i
|
|
suffix = '.'.join(parts[3:])
|
|
|
|
# Map suffix
|
|
suffix_map = {
|
|
'ln_1.weight': 'norm1.weight',
|
|
'ln_1.bias': 'norm1.bias',
|
|
'self_attention.in_proj_weight': 'attn.in_proj_weight',
|
|
'self_attention.in_proj_bias': 'attn.in_proj_bias',
|
|
'self_attention.out_proj.weight': 'attn.out_proj.weight',
|
|
'self_attention.out_proj.bias': 'attn.out_proj.bias',
|
|
'ln_2.weight': 'norm2.weight',
|
|
'ln_2.bias': 'norm2.bias',
|
|
'mlp.0.weight': 'mlp.0.weight',
|
|
'mlp.0.bias': 'mlp.0.bias',
|
|
'mlp.3.weight': 'mlp.3.weight',
|
|
'mlp.3.bias': 'mlp.3.bias'
|
|
}
|
|
|
|
if suffix in suffix_map:
|
|
new_key = f"blocks.{layer_idx}.{suffix_map[suffix]}"
|
|
|
|
if new_key:
|
|
final_state_dict[new_key] = v
|
|
else:
|
|
# If key not mapped, check if it already matches ViTBasic (unlikely given the logs, but possible)
|
|
# or just ignore if it's not relevant
|
|
pass
|
|
|
|
msg = model.load_state_dict(final_state_dict, strict=False)
|
|
print(f" Load Status: {msg}")
|
|
if len(msg.missing_keys) > 0:
|
|
print(f" Missing keys: {len(msg.missing_keys)}")
|
|
# print(msg.missing_keys) # Uncomment to debug
|
|
|
|
model.eval()
|
|
model.requires_grad_(True) # Ensure model parameters require grad for LRP/GradCAM
|
|
|
|
# Set average_attn_weights = False for ViT-LRP to work (enable per-head gradients)
|
|
for module in model.modules():
|
|
if isinstance(module, torch.nn.MultiheadAttention):
|
|
module.average_attn_weights = False
|
|
|
|
except Exception as e:
|
|
print(f"Failed to load {weight_name}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
continue
|
|
|
|
# Initialize Visualization Tools
|
|
rollout_tool = VITAttentionRollout(model)
|
|
# Explicitly pass target layer (final norm) for GradCAM++
|
|
# In ViTBasic, final norm is model.norm
|
|
gradcam_tool = ViTGradCAMPlusPlus(model, target_layer=model.norm)
|
|
lrp_tool = ViTLRP(model)
|
|
|
|
for cls_idx, cls_name in enumerate(classes):
|
|
cls_dir = os.path.join(test_data_root, cls_name)
|
|
if not os.path.exists(cls_dir):
|
|
print(f"Class folder {cls_name} not found in {test_data_root}, skipping.")
|
|
continue
|
|
|
|
# Output path for this model and class
|
|
save_dir = os.path.join(output_base, weight_name, cls_name)
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
image_files = [f for f in os.listdir(cls_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
|
|
print(f" Class {cls_name}: {len(image_files)} images")
|
|
|
|
for img_name in tqdm(image_files, desc=f" Processing {cls_name}"):
|
|
img_path = os.path.join(cls_dir, img_name)
|
|
|
|
try:
|
|
# Load Image
|
|
raw_img = Image.open(img_path).convert("RGB")
|
|
input_tensor = transform(raw_img).unsqueeze(0).to(device)
|
|
input_tensor.requires_grad = True # Ensure gradients are tracked
|
|
|
|
# 1. Inference (Test)
|
|
with torch.no_grad():
|
|
output = model(input_tensor)
|
|
pred_idx = output.argmax(dim=1).item()
|
|
|
|
# 2. Visualizations
|
|
|
|
# Attention Rollout
|
|
mask_rollout = rollout_tool(input_tensor)
|
|
|
|
# Grad-CAM++
|
|
mask_cam = gradcam_tool(input_tensor, class_idx=pred_idx)
|
|
|
|
# LRP
|
|
mask_lrp = lrp_tool(input_tensor, class_idx=pred_idx)
|
|
|
|
# 3. Save Results
|
|
# Resize masks to original image size for better visualization overlay
|
|
raw_np = np.array(raw_img)
|
|
|
|
def save_vis(mask, suffix):
|
|
if mask is None:
|
|
return
|
|
# Resize mask to image size
|
|
mask_resized = cv2.resize(mask[0], (raw_np.shape[1], raw_np.shape[0]))
|
|
|
|
# Apply colormap
|
|
heatmap = cv2.applyColorMap(np.uint8(255 * mask_resized), cv2.COLORMAP_JET)
|
|
heatmap = np.float32(heatmap) / 255
|
|
|
|
# Overlay
|
|
cam = heatmap + np.float32(raw_np) / 255
|
|
cam = cam / np.max(cam)
|
|
cam = np.uint8(255 * cam)
|
|
|
|
# Save
|
|
save_path = os.path.join(save_dir, f"{os.path.splitext(img_name)[0]}_{suffix}.jpg")
|
|
# Save side-by-side: Original | Heatmap | Overlay
|
|
# Or just Overlay. User said "產生Grad-CAM的影像後並存起來"
|
|
# I'll save the overlay.
|
|
cv2.imwrite(save_path, cam)
|
|
|
|
save_vis(mask_rollout, "Rollout")
|
|
save_vis(mask_cam, "GradCAMPlusPlus")
|
|
save_vis(mask_lrp, "LRP")
|
|
|
|
except Exception as e:
|
|
print(f"Error processing {img_name}: {e}")
|
|
continue
|
|
|
|
# Cleanup hooks
|
|
# Tools might need cleanup if they register hooks that persist (LRP/GradCAM do)
|
|
if hasattr(rollout_tool, 'remove_hooks'):
|
|
rollout_tool.remove_hooks()
|
|
if hasattr(gradcam_tool, 'remove_hooks'):
|
|
gradcam_tool.remove_hooks()
|
|
if hasattr(lrp_tool, 'remove_hooks'):
|
|
lrp_tool.remove_hooks()
|
|
|
|
del rollout_tool
|
|
del gradcam_tool
|
|
del lrp_tool
|
|
|
|
print("Processing Complete.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|