27 lines
1.2 KiB
Python
27 lines
1.2 KiB
Python
from torch import nn
|
||
from torch.nn import functional
|
||
import torch
|
||
|
||
|
||
class Entropy_Loss(nn.Module):
|
||
def __init__(self):
|
||
super(Entropy_Loss, self).__init__()
|
||
|
||
def forward(self, outputs, labels):
|
||
# 转换为张量
|
||
outputs_New = torch.as_tensor(outputs, dtype=torch.float32)
|
||
labels_New = torch.as_tensor(labels, dtype=torch.float32)
|
||
|
||
# 检查输出和标签的维度是否匹配
|
||
if outputs_New.shape[1] != labels_New.shape[1]:
|
||
# 如果维度不匹配,使用交叉熵损失函数
|
||
# 对于交叉熵损失,标签需要是类别索引而不是one-hot编码
|
||
# 将one-hot编码转换为类别索引
|
||
_, labels_indices = torch.max(labels_New, dim=1)
|
||
loss = functional.cross_entropy(outputs_New, labels_indices)
|
||
else:
|
||
# 如果维度匹配,始终使用binary_cross_entropy_with_logits
|
||
# 它会自动应用sigmoid函数,避免输入值超出[0,1]范围
|
||
loss = functional.binary_cross_entropy_with_logits(outputs_New, labels_New)
|
||
|
||
return torch.as_tensor(loss, dtype = torch.float32) |