Files
Stomach_Cancer_Pytorch/draw_tools/Saliency_Map.py

195 lines
7.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import numpy as np
import cv2
import matplotlib.pyplot as plt
from Load_process.file_processing import Process_File
class SaliencyMap:
def __init__(self, model):
self.model = model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval() # 設置為評估模式
def Processing_Main(self, Test_Dataloader, File_Path):
"""處理測試數據集並生成顯著性圖"""
File = Process_File()
for batch_idx, (images, labels, File_Name, File_Classes) in enumerate(Test_Dataloader):
# 將數據移至設備
images = images.to(self.device, dtype=torch.float32)
labels = labels.to(self.device, dtype=torch.float32)
# 獲取真實類別索引
label_classes = torch.argmax(labels, dim=1).cpu().numpy()
# 為批次中的每個圖像生成顯著性圖
for i in range(images.size(0)):
# 獲取單個圖像和類別
image = images[i:i+1] # 保持批次維度
target_class = label_classes[i]
# 生成顯著性圖
saliency_map = self.generate_saliency(image, target_class)
# 將顯著性圖疊加到原始圖像上
overlaid_image = self.overlay_saliency(saliency_map, image[0])
# 創建保存路徑
path = f"{File_Path}/{File_Classes[i]}"
File.JudgeRoot_MakeDir(path)
# 保存結果
File.Save_CV2_File(f"saliency_{batch_idx}_{File_Name[i]}", path, overlaid_image)
def generate_saliency(self, image, target_class):
"""生成單個圖像的顯著性圖"""
# 確保需要梯度
image.requires_grad_(True)
# 前向傳播
output = self.model(image)
# 清除之前的梯度
self.model.zero_grad()
# 創建one-hot編碼的目標
one_hot = torch.zeros_like(output)
one_hot[0, target_class] = 1
# 反向傳播
output.backward(gradient=one_hot)
# 獲取梯度
gradients = image.grad.data
# 計算顯著性圖 (取絕對值並在通道維度上取最大值)
saliency, _ = torch.max(gradients.abs(), dim=1)
# 轉換為numpy並歸一化
saliency_np = saliency.cpu().numpy()[0]
saliency_np = self._normalize(saliency_np)
# 應用平滑處理以減少噪聲
saliency_np = cv2.GaussianBlur(saliency_np, (5, 5), 0)
saliency_np = self._normalize(saliency_np) # 再次歸一化
return saliency_np
def _normalize(self, x):
"""將數組歸一化到[0,1]範圍"""
# 添加小的epsilon以避免除以零
return (x - x.min()) / (x.max() - x.min() + 1e-8)
def overlay_saliency(self, saliency, image, alpha=0.5):
"""將顯著性圖疊加到原始圖像上"""
# 將顯著性圖縮放到[0,255]範圍
saliency_uint8 = np.uint8(255 * saliency)
# 應用顏色映射
heatmap = cv2.applyColorMap(saliency_uint8, cv2.COLORMAP_JET)
# 將圖像張量轉換為numpy數組
image_np = image.detach().cpu().permute(1, 2, 0).numpy()
# 確保圖像在[0,1]範圍內
if image_np.max() > 1.0:
image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
# 將圖像轉換為uint8
image_uint8 = np.uint8(255 * image_np)
# 如果圖像是單通道的轉換為3通道
if len(image_uint8.shape) == 2 or image_uint8.shape[2] == 1:
image_uint8 = cv2.cvtColor(image_uint8, cv2.COLOR_GRAY2BGR)
# 疊加顯著性圖和原始圖像
overlaid = cv2.addWeighted(heatmap, alpha, image_uint8, 1-alpha, 0)
return overlaid
def generate_smooth_saliency(self, image, target_class, n_samples=20, noise_level=0.1):
"""使用SmoothGrad技術生成更平滑的顯著性圖"""
# 獲取輸入圖像的標準差
stdev = noise_level * (torch.max(image) - torch.min(image)).item()
# 累積梯度
accumulated_gradients = None
# 生成多個帶噪聲的樣本並計算梯度
for _ in range(n_samples):
# 添加高斯噪聲
noisy_image = image + torch.randn_like(image) * stdev
noisy_image.requires_grad_(True)
# 前向傳播
output = self.model(noisy_image)
# 反向傳播
self.model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0, target_class] = 1
output.backward(gradient=one_hot)
# 獲取梯度
gradients = noisy_image.grad.data
# 累積梯度
if accumulated_gradients is None:
accumulated_gradients = gradients
else:
accumulated_gradients += gradients
# 計算平均梯度
avg_gradients = accumulated_gradients / n_samples
# 計算顯著性圖
saliency, _ = torch.max(avg_gradients.abs(), dim=1)
# 轉換為numpy並歸一化
saliency_np = saliency.cpu().numpy()[0]
saliency_np = self._normalize(saliency_np)
return saliency_np
def generate_guided_saliency(self, image, target_class):
"""使用Guided Backpropagation生成顯著性圖"""
# 保存原始ReLU反向傳播函數
relu_backward_functions = {}
for module in self.model.modules():
if isinstance(module, nn.ReLU):
relu_backward_functions[module] = module.backward
module.backward = self._guided_relu_backward
# 生成顯著性圖
image.requires_grad_(True)
output = self.model(image)
self.model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0, target_class] = 1
output.backward(gradient=one_hot)
# 獲取梯度
gradients = image.grad.data
# 恢復原始ReLU反向傳播函數
for module in relu_backward_functions:
module.backward = relu_backward_functions[module]
# 計算顯著性圖 (只保留正梯度)
saliency = torch.clamp(gradients, min=0)
saliency, _ = torch.max(saliency, dim=1)
# 轉換為numpy並歸一化
saliency_np = saliency.cpu().numpy()[0]
saliency_np = self._normalize(saliency_np)
return saliency_np
def _guided_relu_backward(self, grad_output):
"""Guided ReLU的反向傳播函數"""
# 只允許正梯度流過
positive_grad_output = torch.clamp(grad_output, min=0)
return positive_grad_output