Files
Stomach_Cancer_Pytorch/Model_Loss/binary_cross_entropy.py

146 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn as nn
import torch.nn.functional as F
class BinaryCrossEntropy(nn.Module):
"""
基本的二元交叉熵損失函數
"""
def __init__(self, reduction='mean'):
"""
初始化
Args:
reduction (str): 'mean', 'sum''none',指定如何減少損失
"""
super(BinaryCrossEntropy, self).__init__()
def forward(self, predictions, targets):
"""
計算二元交叉熵損失
Args:
predictions (torch.Tensor): 模型的預測輸出,形狀為 [batch_size, ...]
targets (torch.Tensor): 目標標籤,形狀與 predictions 相同
Returns:
torch.Tensor: 計算得到的損失值
"""
# 確保輸入是張量
predictions = torch.as_tensor(predictions, dtype=torch.float32)
targets = torch.as_tensor(targets, dtype=torch.float32)
Loss = nn.BCELoss()
return Loss(predictions, targets)
# # 檢查輸出和標籤的維度是否匹配
# if predictions.shape[1] != targets.shape[1]:
# # 如果維度不匹配,使用交叉熵損失函數
# # 對於交叉熵損失標籤需要是類別索引而不是one-hot編碼
# # 將one-hot編碼轉換為類別索引
# _, targets_indices = torch.max(targets, dim=1)
# return F.cross_entropy(predictions, targets_indices, reduction=self.reduction)
# else:
# # 如果維度匹配,使用二元交叉熵損失函數
# # 使用 PyTorch 內建的 binary_cross_entropy_with_logits 函數
# # 它會自動應用 sigmoid 函數,避免輸入值超出 [0,1] 範圍
# return F.binary_cross_entropy_with_logits(predictions, targets, reduction=self.reduction)
class WeightedBinaryCrossEntropy(nn.Module):
"""
帶權重的二元交叉熵損失函數
"""
def __init__(self, pos_weight=1.0, neg_weight=1.0, reduction='mean'):
"""
初始化
Args:
pos_weight (float): 正樣本的權重
neg_weight (float): 負樣本的權重
reduction (str): 'mean', 'sum''none',指定如何減少損失
"""
super(WeightedBinaryCrossEntropy, self).__init__()
self.pos_weight = pos_weight
self.neg_weight = neg_weight
self.reduction = reduction
def forward(self, predictions, targets):
"""
計算帶權重的二元交叉熵損失
Args:
predictions (torch.Tensor): 模型的預測輸出,形狀為 [batch_size, ...]
targets (torch.Tensor): 目標標籤,形狀與 predictions 相同
Returns:
torch.Tensor: 計算得到的損失值
"""
# 確保輸入是張量
predictions = torch.as_tensor(predictions, dtype=torch.float32)
targets = torch.as_tensor(targets, dtype=torch.float32)
# 使用 sigmoid 確保預測值在 [0,1] 範圍內
predictions = torch.sigmoid(predictions)
# 計算帶權重的二元交叉熵損失
loss = -self.pos_weight * targets * torch.log(predictions + 1e-7) - \
self.neg_weight * (1 - targets) * torch.log(1 - predictions + 1e-7)
# 根據 reduction 方式返回損失
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else: # 'none'
return loss
class LabelSmoothingBCE(nn.Module):
"""
帶標籤平滑的二元交叉熵損失函數
"""
def __init__(self, smoothing=0.1, reduction='mean'):
"""
初始化
Args:
smoothing (float): 標籤平滑係數,範圍 [0, 1]
reduction (str): 'mean', 'sum''none',指定如何減少損失
"""
super(LabelSmoothingBCE, self).__init__()
self.smoothing = smoothing
self.reduction = reduction
def forward(self, predictions, targets):
"""
計算帶標籤平滑的二元交叉熵損失
Args:
predictions (torch.Tensor): 模型的預測輸出,形狀為 [batch_size, ...]
targets (torch.Tensor): 目標標籤,形狀與 predictions 相同
Returns:
torch.Tensor: 計算得到的損失值
"""
# 確保輸入是張量
predictions = torch.as_tensor(predictions, dtype=torch.float32)
targets = torch.as_tensor(targets, dtype=torch.float32)
# 應用標籤平滑
targets = targets * (1 - self.smoothing) + 0.5 * self.smoothing
# 使用 sigmoid 確保預測值在 [0,1] 範圍內
predictions = torch.sigmoid(predictions)
# 計算二元交叉熵損失
loss = -targets * torch.log(predictions + 1e-7) - (1 - targets) * torch.log(1 - predictions + 1e-7)
# 根據 reduction 方式返回損失
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else: # 'none'
return loss