117 lines
4.2 KiB
Python
117 lines
4.2 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torchvision import models, transforms
|
||
|
||
class VGGPerceptualLoss(nn.Module):
|
||
"""
|
||
基於VGG19的感知損失函數
|
||
使用預訓練的VGG19網絡提取特徵,計算特徵空間中的損失
|
||
"""
|
||
def __init__(self, feature_layers=[2, 7, 12, 21, 30], use_normalization=True):
|
||
super(VGGPerceptualLoss, self).__init__()
|
||
|
||
# 載入預訓練的VGG19模型
|
||
vgg = models.vgg19(pretrained=True).features
|
||
|
||
# 凍結VGG參數
|
||
for param in vgg.parameters():
|
||
param.requires_grad = False
|
||
|
||
# 將模型移到與輸入相同的設備上(在forward中處理)
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
# 選擇要使用的特徵層
|
||
self.feature_layers = feature_layers
|
||
self.vgg_layers = nn.ModuleList()
|
||
|
||
# 分割VGG網絡到指定層
|
||
layer_idx = 0
|
||
current_layer = 0
|
||
|
||
for i, layer in enumerate(vgg):
|
||
if layer_idx < len(feature_layers) and i <= feature_layers[layer_idx]:
|
||
self.vgg_layers.append(layer)
|
||
if i == feature_layers[layer_idx]:
|
||
layer_idx += 1
|
||
else:
|
||
break
|
||
|
||
# 是否使用ImageNet標準化
|
||
self.use_normalization = use_normalization
|
||
if use_normalization:
|
||
self.normalize = transforms.Normalize(
|
||
mean=[0.485, 0.456, 0.406],
|
||
std=[0.229, 0.224, 0.225]
|
||
)
|
||
|
||
# 損失權重
|
||
self.weights = [1.0, 1.0, 1.0, 1.0, 1.0] # 可以調整不同層的權重
|
||
|
||
def extract_features(self, x):
|
||
"""
|
||
提取VGG特徵
|
||
"""
|
||
# 確保輸入在[0,1]範圍內
|
||
if x.min() < 0 or x.max() > 1:
|
||
x = torch.clamp(x, 0, 1)
|
||
|
||
# 標準化
|
||
if self.use_normalization:
|
||
# 確保normalize在與輸入相同的設備上
|
||
if hasattr(self, 'normalize') and not isinstance(self.normalize, torch.nn.Module):
|
||
self.normalize = transforms.Normalize(
|
||
mean=[0.485, 0.456, 0.406],
|
||
std=[0.229, 0.224, 0.225]
|
||
).to(x.device)
|
||
x = self.normalize(x)
|
||
|
||
features = []
|
||
layer_idx = 0
|
||
|
||
# 確保所有VGG層都在與輸入相同的設備上
|
||
device = x.device
|
||
for i, layer in enumerate(self.vgg_layers):
|
||
layer = layer.to(device) # 確保層在正確的設備上
|
||
x = layer(x)
|
||
|
||
# 檢查是否到達目標特徵層
|
||
if layer_idx < len(self.feature_layers) and i == self.feature_layers[layer_idx]:
|
||
features.append(x)
|
||
layer_idx += 1
|
||
|
||
return features
|
||
|
||
def forward(self, pred, target):
|
||
"""
|
||
計算感知損失
|
||
pred: 預測圖像 [B, C, H, W]
|
||
target: 目標圖像 [B, C, H, W]
|
||
"""
|
||
# 確保模型在與輸入相同的設備上
|
||
device = pred.device
|
||
self.vgg_layers = nn.ModuleList([layer.to(device) for layer in self.vgg_layers])
|
||
|
||
# 確保輸入尺寸匹配
|
||
if pred.shape != target.shape:
|
||
pred = F.interpolate(pred, size=target.shape[2:], mode='bilinear', align_corners=False)
|
||
|
||
# 如果是單通道,轉換為三通道
|
||
if pred.shape[1] == 1:
|
||
pred = pred.repeat(1, 3, 1, 1)
|
||
if target.shape[1] == 1:
|
||
target = target.repeat(1, 3, 1, 1)
|
||
|
||
# 提取特徵
|
||
pred_features = self.extract_features(pred)
|
||
target_features = self.extract_features(target)
|
||
|
||
# 計算特徵損失
|
||
perceptual_loss = 0
|
||
for i, (pred_feat, target_feat) in enumerate(zip(pred_features, target_features)):
|
||
# 使用MSE計算特徵差異
|
||
feat_loss = F.mse_loss(pred_feat, target_feat)
|
||
perceptual_loss += self.weights[i] * feat_loss
|
||
|
||
return perceptual_loss
|