Stomach_Cancer_Pytorch/Model_Loss/Loss.py

27 lines
1.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from torch import nn
from torch.nn import functional
import torch
class Entropy_Loss(nn.Module):
def __init__(self):
super(Entropy_Loss, self).__init__()
def forward(self, outputs, labels):
# 转换为张量
outputs_New = torch.as_tensor(outputs, dtype=torch.float32)
labels_New = torch.as_tensor(labels, dtype=torch.float32)
# 检查输出和标签的维度是否匹配
if outputs_New.shape[1] != labels_New.shape[1]:
# 如果维度不匹配,使用交叉熵损失函数
# 对于交叉熵损失标签需要是类别索引而不是one-hot编码
# 将one-hot编码转换为类别索引
_, labels_indices = torch.max(labels_New, dim=1)
loss = functional.cross_entropy(outputs_New, labels_indices)
else:
# 如果维度匹配始终使用binary_cross_entropy_with_logits
# 它会自动应用sigmoid函数避免输入值超出[0,1]范围
loss = functional.binary_cross_entropy_with_logits(outputs_New, labels_New)
return torch.as_tensor(loss, dtype = torch.float32)