Files
Stomach_Cancer_Pytorch/Model_Loss/Perceptual_Loss.py

117 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
class VGGPerceptualLoss(nn.Module):
"""
基於VGG19的感知損失函數
使用預訓練的VGG19網絡提取特徵計算特徵空間中的損失
"""
def __init__(self, feature_layers=[2, 7, 12, 21, 30], use_normalization=True):
super(VGGPerceptualLoss, self).__init__()
# 載入預訓練的VGG19模型
vgg = models.vgg19(pretrained=True).features
# 凍結VGG參數
for param in vgg.parameters():
param.requires_grad = False
# 將模型移到與輸入相同的設備上在forward中處理
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 選擇要使用的特徵層
self.feature_layers = feature_layers
self.vgg_layers = nn.ModuleList()
# 分割VGG網絡到指定層
layer_idx = 0
current_layer = 0
for i, layer in enumerate(vgg):
if layer_idx < len(feature_layers) and i <= feature_layers[layer_idx]:
self.vgg_layers.append(layer)
if i == feature_layers[layer_idx]:
layer_idx += 1
else:
break
# 是否使用ImageNet標準化
self.use_normalization = use_normalization
if use_normalization:
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
# 損失權重
self.weights = [1.0, 1.0, 1.0, 1.0, 1.0] # 可以調整不同層的權重
def extract_features(self, x):
"""
提取VGG特徵
"""
# 確保輸入在[0,1]範圍內
if x.min() < 0 or x.max() > 1:
x = torch.clamp(x, 0, 1)
# 標準化
if self.use_normalization:
# 確保normalize在與輸入相同的設備上
if hasattr(self, 'normalize') and not isinstance(self.normalize, torch.nn.Module):
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
).to(x.device)
x = self.normalize(x)
features = []
layer_idx = 0
# 確保所有VGG層都在與輸入相同的設備上
device = x.device
for i, layer in enumerate(self.vgg_layers):
layer = layer.to(device) # 確保層在正確的設備上
x = layer(x)
# 檢查是否到達目標特徵層
if layer_idx < len(self.feature_layers) and i == self.feature_layers[layer_idx]:
features.append(x)
layer_idx += 1
return features
def forward(self, pred, target):
"""
計算感知損失
pred: 預測圖像 [B, C, H, W]
target: 目標圖像 [B, C, H, W]
"""
# 確保模型在與輸入相同的設備上
device = pred.device
self.vgg_layers = nn.ModuleList([layer.to(device) for layer in self.vgg_layers])
# 確保輸入尺寸匹配
if pred.shape != target.shape:
pred = F.interpolate(pred, size=target.shape[2:], mode='bilinear', align_corners=False)
# 如果是單通道,轉換為三通道
if pred.shape[1] == 1:
pred = pred.repeat(1, 3, 1, 1)
if target.shape[1] == 1:
target = target.repeat(1, 3, 1, 1)
# 提取特徵
pred_features = self.extract_features(pred)
target_features = self.extract_features(target)
# 計算特徵損失
perceptual_loss = 0
for i, (pred_feat, target_feat) in enumerate(zip(pred_features, target_features)):
# 使用MSE計算特徵差異
feat_loss = F.mse_loss(pred_feat, target_feat)
perceptual_loss += self.weights[i] * feat_loss
return perceptual_loss