from torch import nn from torch.nn import functional import torch class Entropy_Loss(nn.Module): def __init__(self): super(Entropy_Loss, self).__init__() def forward(self, outputs, labels): # 转换为张量 outputs_New = torch.as_tensor(outputs, dtype=torch.float32) labels_New = torch.as_tensor(labels, dtype=torch.float32) # 检查输出和标签的维度是否匹配 if outputs_New.shape[1] != labels_New.shape[1]: # 如果维度不匹配,使用交叉熵损失函数 # 对于交叉熵损失,标签需要是类别索引而不是one-hot编码 # 将one-hot编码转换为类别索引 _, labels_indices = torch.max(labels_New, dim=1) loss = functional.cross_entropy(outputs_New, labels_indices) else: # 如果维度匹配,始终使用binary_cross_entropy_with_logits # 它会自动应用sigmoid函数,避免输入值超出[0,1]范围 loss = functional.binary_cross_entropy_with_logits(outputs_New, labels_New) return torch.as_tensor(loss, dtype = torch.float32)