195 lines
7.2 KiB
Python
195 lines
7.2 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import numpy as np
|
||
import cv2
|
||
import matplotlib.pyplot as plt
|
||
from Load_process.file_processing import Process_File
|
||
|
||
class SaliencyMap:
|
||
def __init__(self, model):
|
||
self.model = model
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
self.model.to(self.device)
|
||
self.model.eval() # 設置為評估模式
|
||
|
||
def Processing_Main(self, Test_Dataloader, File_Path):
|
||
"""處理測試數據集並生成顯著性圖"""
|
||
File = Process_File()
|
||
|
||
for batch_idx, (images, labels, File_Name, File_Classes) in enumerate(Test_Dataloader):
|
||
# 將數據移至設備
|
||
images = images.to(self.device, dtype=torch.float32)
|
||
labels = labels.to(self.device, dtype=torch.float32)
|
||
|
||
# 獲取真實類別索引
|
||
label_classes = torch.argmax(labels, dim=1).cpu().numpy()
|
||
|
||
# 為批次中的每個圖像生成顯著性圖
|
||
for i in range(images.size(0)):
|
||
# 獲取單個圖像和類別
|
||
image = images[i:i+1] # 保持批次維度
|
||
target_class = label_classes[i]
|
||
|
||
# 生成顯著性圖
|
||
saliency_map = self.generate_saliency(image, target_class)
|
||
|
||
# 將顯著性圖疊加到原始圖像上
|
||
overlaid_image = self.overlay_saliency(saliency_map, image[0])
|
||
|
||
# 創建保存路徑
|
||
path = f"{File_Path}/{File_Classes[i]}"
|
||
File.JudgeRoot_MakeDir(path)
|
||
|
||
# 保存結果
|
||
File.Save_CV2_File(f"saliency_{batch_idx}_{File_Name[i]}", path, overlaid_image)
|
||
|
||
def generate_saliency(self, image, target_class):
|
||
"""生成單個圖像的顯著性圖"""
|
||
# 確保需要梯度
|
||
image.requires_grad_(True)
|
||
|
||
# 前向傳播
|
||
output = self.model(image)
|
||
|
||
# 清除之前的梯度
|
||
self.model.zero_grad()
|
||
|
||
# 創建one-hot編碼的目標
|
||
one_hot = torch.zeros_like(output)
|
||
one_hot[0, target_class] = 1
|
||
|
||
# 反向傳播
|
||
output.backward(gradient=one_hot)
|
||
|
||
# 獲取梯度
|
||
gradients = image.grad.data
|
||
|
||
# 計算顯著性圖 (取絕對值並在通道維度上取最大值)
|
||
saliency, _ = torch.max(gradients.abs(), dim=1)
|
||
|
||
# 轉換為numpy並歸一化
|
||
saliency_np = saliency.cpu().numpy()[0]
|
||
saliency_np = self._normalize(saliency_np)
|
||
|
||
# 應用平滑處理以減少噪聲
|
||
saliency_np = cv2.GaussianBlur(saliency_np, (5, 5), 0)
|
||
saliency_np = self._normalize(saliency_np) # 再次歸一化
|
||
|
||
return saliency_np
|
||
|
||
def _normalize(self, x):
|
||
"""將數組歸一化到[0,1]範圍"""
|
||
# 添加小的epsilon以避免除以零
|
||
return (x - x.min()) / (x.max() - x.min() + 1e-8)
|
||
|
||
def overlay_saliency(self, saliency, image, alpha=0.5):
|
||
"""將顯著性圖疊加到原始圖像上"""
|
||
# 將顯著性圖縮放到[0,255]範圍
|
||
saliency_uint8 = np.uint8(255 * saliency)
|
||
|
||
# 應用顏色映射
|
||
heatmap = cv2.applyColorMap(saliency_uint8, cv2.COLORMAP_JET)
|
||
|
||
# 將圖像張量轉換為numpy數組
|
||
image_np = image.detach().cpu().permute(1, 2, 0).numpy()
|
||
|
||
# 確保圖像在[0,1]範圍內
|
||
if image_np.max() > 1.0:
|
||
image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())
|
||
|
||
# 將圖像轉換為uint8
|
||
image_uint8 = np.uint8(255 * image_np)
|
||
|
||
# 如果圖像是單通道的,轉換為3通道
|
||
if len(image_uint8.shape) == 2 or image_uint8.shape[2] == 1:
|
||
image_uint8 = cv2.cvtColor(image_uint8, cv2.COLOR_GRAY2BGR)
|
||
|
||
# 疊加顯著性圖和原始圖像
|
||
overlaid = cv2.addWeighted(heatmap, alpha, image_uint8, 1-alpha, 0)
|
||
|
||
return overlaid
|
||
|
||
def generate_smooth_saliency(self, image, target_class, n_samples=20, noise_level=0.1):
|
||
"""使用SmoothGrad技術生成更平滑的顯著性圖"""
|
||
# 獲取輸入圖像的標準差
|
||
stdev = noise_level * (torch.max(image) - torch.min(image)).item()
|
||
|
||
# 累積梯度
|
||
accumulated_gradients = None
|
||
|
||
# 生成多個帶噪聲的樣本並計算梯度
|
||
for _ in range(n_samples):
|
||
# 添加高斯噪聲
|
||
noisy_image = image + torch.randn_like(image) * stdev
|
||
noisy_image.requires_grad_(True)
|
||
|
||
# 前向傳播
|
||
output = self.model(noisy_image)
|
||
|
||
# 反向傳播
|
||
self.model.zero_grad()
|
||
one_hot = torch.zeros_like(output)
|
||
one_hot[0, target_class] = 1
|
||
output.backward(gradient=one_hot)
|
||
|
||
# 獲取梯度
|
||
gradients = noisy_image.grad.data
|
||
|
||
# 累積梯度
|
||
if accumulated_gradients is None:
|
||
accumulated_gradients = gradients
|
||
else:
|
||
accumulated_gradients += gradients
|
||
|
||
# 計算平均梯度
|
||
avg_gradients = accumulated_gradients / n_samples
|
||
|
||
# 計算顯著性圖
|
||
saliency, _ = torch.max(avg_gradients.abs(), dim=1)
|
||
|
||
# 轉換為numpy並歸一化
|
||
saliency_np = saliency.cpu().numpy()[0]
|
||
saliency_np = self._normalize(saliency_np)
|
||
|
||
return saliency_np
|
||
|
||
def generate_guided_saliency(self, image, target_class):
|
||
"""使用Guided Backpropagation生成顯著性圖"""
|
||
# 保存原始ReLU反向傳播函數
|
||
relu_backward_functions = {}
|
||
for module in self.model.modules():
|
||
if isinstance(module, nn.ReLU):
|
||
relu_backward_functions[module] = module.backward
|
||
module.backward = self._guided_relu_backward
|
||
|
||
# 生成顯著性圖
|
||
image.requires_grad_(True)
|
||
output = self.model(image)
|
||
|
||
self.model.zero_grad()
|
||
one_hot = torch.zeros_like(output)
|
||
one_hot[0, target_class] = 1
|
||
output.backward(gradient=one_hot)
|
||
|
||
# 獲取梯度
|
||
gradients = image.grad.data
|
||
|
||
# 恢復原始ReLU反向傳播函數
|
||
for module in relu_backward_functions:
|
||
module.backward = relu_backward_functions[module]
|
||
|
||
# 計算顯著性圖 (只保留正梯度)
|
||
saliency = torch.clamp(gradients, min=0)
|
||
saliency, _ = torch.max(saliency, dim=1)
|
||
|
||
# 轉換為numpy並歸一化
|
||
saliency_np = saliency.cpu().numpy()[0]
|
||
saliency_np = self._normalize(saliency_np)
|
||
|
||
return saliency_np
|
||
|
||
def _guided_relu_backward(self, grad_output):
|
||
"""Guided ReLU的反向傳播函數"""
|
||
# 只允許正梯度流過
|
||
positive_grad_output = torch.clamp(grad_output, min=0)
|
||
return positive_grad_output |