68 lines
2.5 KiB
Python
68 lines
2.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import sys
|
|
import os
|
|
|
|
# 添加當前目錄到系統路徑
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
# 導入自定義的 BinaryCrossEntropy 類
|
|
from Model_Loss.binary_cross_entropy import BinaryCrossEntropy
|
|
|
|
def test_binary_cross_entropy():
|
|
print("開始測試 BinaryCrossEntropy 類...")
|
|
|
|
# 創建損失函數實例
|
|
criterion = BinaryCrossEntropy()
|
|
|
|
# 測試 1: 維度匹配的情況 (二元分類)
|
|
print("\n測試 1: 維度匹配的情況 (二元分類)")
|
|
predictions_match = torch.tensor([[0.7, 0.3], [0.2, 0.8]], dtype=torch.float32)
|
|
targets_match = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
|
|
|
|
try:
|
|
loss_match = criterion(predictions_match, targets_match)
|
|
print(f"維度匹配的損失值: {loss_match.item()}")
|
|
print("測試 1 通過!")
|
|
except Exception as e:
|
|
print(f"測試 1 失敗: {e}")
|
|
|
|
# 測試 2: 維度不匹配的情況 (多類別分類)
|
|
print("\n測試 2: 維度不匹配的情況 (多類別分類)")
|
|
predictions_mismatch = torch.tensor([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]], dtype=torch.float32)
|
|
targets_mismatch = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
|
|
|
|
try:
|
|
loss_mismatch = criterion(predictions_mismatch, targets_mismatch)
|
|
print(f"維度不匹配的損失值: {loss_mismatch.item()}")
|
|
print("測試 2 通過!")
|
|
except Exception as e:
|
|
print(f"測試 2 失敗: {e}")
|
|
|
|
# 測試 3: 與 PyTorch 內建函數比較
|
|
print("\n測試 3: 與 PyTorch 內建函數比較")
|
|
predictions_compare = torch.tensor([[0.7, 0.3], [0.2, 0.8]], dtype=torch.float32)
|
|
targets_compare = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
|
|
|
|
try:
|
|
# 自定義損失函數
|
|
loss_custom = criterion(predictions_compare, targets_compare)
|
|
|
|
# PyTorch 內建損失函數
|
|
loss_pytorch = nn.BCEWithLogitsLoss()(predictions_compare, targets_compare)
|
|
|
|
print(f"自定義損失值: {loss_custom.item()}")
|
|
print(f"PyTorch 損失值: {loss_pytorch.item()}")
|
|
print(f"差異: {abs(loss_custom.item() - loss_pytorch.item())}")
|
|
|
|
if abs(loss_custom.item() - loss_pytorch.item()) < 1e-5:
|
|
print("測試 3 通過!")
|
|
else:
|
|
print("測試 3 失敗: 損失值與 PyTorch 內建函數不一致")
|
|
except Exception as e:
|
|
print(f"測試 3 失敗: {e}")
|
|
|
|
print("\n測試完成!")
|
|
|
|
if __name__ == "__main__":
|
|
test_binary_cross_entropy() |