import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models, transforms class VGGPerceptualLoss(nn.Module): """ 基於VGG19的感知損失函數 使用預訓練的VGG19網絡提取特徵,計算特徵空間中的損失 """ def __init__(self, feature_layers=[2, 7, 12, 21, 30], use_normalization=True): super(VGGPerceptualLoss, self).__init__() # 載入預訓練的VGG19模型 vgg = models.vgg19(pretrained=True).features # 凍結VGG參數 for param in vgg.parameters(): param.requires_grad = False # 將模型移到與輸入相同的設備上(在forward中處理) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 選擇要使用的特徵層 self.feature_layers = feature_layers self.vgg_layers = nn.ModuleList() # 分割VGG網絡到指定層 layer_idx = 0 current_layer = 0 for i, layer in enumerate(vgg): if layer_idx < len(feature_layers) and i <= feature_layers[layer_idx]: self.vgg_layers.append(layer) if i == feature_layers[layer_idx]: layer_idx += 1 else: break # 是否使用ImageNet標準化 self.use_normalization = use_normalization if use_normalization: self.normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) # 損失權重 self.weights = [1.0, 1.0, 1.0, 1.0, 1.0] # 可以調整不同層的權重 def extract_features(self, x): """ 提取VGG特徵 """ # 確保輸入在[0,1]範圍內 if x.min() < 0 or x.max() > 1: x = torch.clamp(x, 0, 1) # 標準化 if self.use_normalization: # 確保normalize在與輸入相同的設備上 if hasattr(self, 'normalize') and not isinstance(self.normalize, torch.nn.Module): self.normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ).to(x.device) x = self.normalize(x) features = [] layer_idx = 0 # 確保所有VGG層都在與輸入相同的設備上 device = x.device for i, layer in enumerate(self.vgg_layers): layer = layer.to(device) # 確保層在正確的設備上 x = layer(x) # 檢查是否到達目標特徵層 if layer_idx < len(self.feature_layers) and i == self.feature_layers[layer_idx]: features.append(x) layer_idx += 1 return features def forward(self, pred, target): """ 計算感知損失 pred: 預測圖像 [B, C, H, W] target: 目標圖像 [B, C, H, W] """ # 確保模型在與輸入相同的設備上 device = pred.device self.vgg_layers = nn.ModuleList([layer.to(device) for layer in self.vgg_layers]) # 確保輸入尺寸匹配 if pred.shape != target.shape: pred = F.interpolate(pred, size=target.shape[2:], mode='bilinear', align_corners=False) # 如果是單通道,轉換為三通道 if pred.shape[1] == 1: pred = pred.repeat(1, 3, 1, 1) if target.shape[1] == 1: target = target.repeat(1, 3, 1, 1) # 提取特徵 pred_features = self.extract_features(pred) target_features = self.extract_features(target) # 計算特徵損失 perceptual_loss = 0 for i, (pred_feat, target_feat) in enumerate(zip(pred_features, target_features)): # 使用MSE計算特徵差異 feat_loss = F.mse_loss(pred_feat, target_feat) perceptual_loss += self.weights[i] * feat_loss return perceptual_loss