Files
Stomach_Cancer_Pytorch/test_binary_cross_entropy.py

68 lines
2.5 KiB
Python

import torch
import torch.nn as nn
import sys
import os
# 添加當前目錄到系統路徑
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# 導入自定義的 BinaryCrossEntropy 類
from Model_Loss.binary_cross_entropy import BinaryCrossEntropy
def test_binary_cross_entropy():
print("開始測試 BinaryCrossEntropy 類...")
# 創建損失函數實例
criterion = BinaryCrossEntropy()
# 測試 1: 維度匹配的情況 (二元分類)
print("\n測試 1: 維度匹配的情況 (二元分類)")
predictions_match = torch.tensor([[0.7, 0.3], [0.2, 0.8]], dtype=torch.float32)
targets_match = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
try:
loss_match = criterion(predictions_match, targets_match)
print(f"維度匹配的損失值: {loss_match.item()}")
print("測試 1 通過!")
except Exception as e:
print(f"測試 1 失敗: {e}")
# 測試 2: 維度不匹配的情況 (多類別分類)
print("\n測試 2: 維度不匹配的情況 (多類別分類)")
predictions_mismatch = torch.tensor([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]], dtype=torch.float32)
targets_mismatch = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
try:
loss_mismatch = criterion(predictions_mismatch, targets_mismatch)
print(f"維度不匹配的損失值: {loss_mismatch.item()}")
print("測試 2 通過!")
except Exception as e:
print(f"測試 2 失敗: {e}")
# 測試 3: 與 PyTorch 內建函數比較
print("\n測試 3: 與 PyTorch 內建函數比較")
predictions_compare = torch.tensor([[0.7, 0.3], [0.2, 0.8]], dtype=torch.float32)
targets_compare = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32)
try:
# 自定義損失函數
loss_custom = criterion(predictions_compare, targets_compare)
# PyTorch 內建損失函數
loss_pytorch = nn.BCEWithLogitsLoss()(predictions_compare, targets_compare)
print(f"自定義損失值: {loss_custom.item()}")
print(f"PyTorch 損失值: {loss_pytorch.item()}")
print(f"差異: {abs(loss_custom.item() - loss_pytorch.item())}")
if abs(loss_custom.item() - loss_pytorch.item()) < 1e-5:
print("測試 3 通過!")
else:
print("測試 3 失敗: 損失值與 PyTorch 內建函數不一致")
except Exception as e:
print(f"測試 3 失敗: {e}")
print("\n測試完成!")
if __name__ == "__main__":
test_binary_cross_entropy()