19 lines
621 B
Python
19 lines
621 B
Python
from torch import nn
|
|
from torch.nn import functional
|
|
import torch
|
|
import numpy as np
|
|
|
|
|
|
class Entropy_Loss(nn.Module):
|
|
def __init__(self):
|
|
super(Entropy_Loss, self).__init__()
|
|
|
|
def forward(self, outputs, labels):
|
|
# 转换为张量并确保在同一設備上
|
|
outputs_New = torch.as_tensor(outputs, dtype=torch.float32)
|
|
labels_New = torch.as_tensor(labels, dtype=torch.float32) # 標籤應該是 long 類型用於索引
|
|
|
|
loss = functional.cross_entropy(outputs_New, labels_New)
|
|
|
|
return torch.as_tensor(loss, dtype=torch.float32)
|