import torch import torch.nn as nn import sys import os # 添加當前目錄到系統路徑 sys.path.append(os.path.dirname(os.path.abspath(__file__))) # 導入自定義的 BinaryCrossEntropy 類 from Model_Loss.binary_cross_entropy import BinaryCrossEntropy def test_binary_cross_entropy(): print("開始測試 BinaryCrossEntropy 類...") # 創建損失函數實例 criterion = BinaryCrossEntropy() # 測試 1: 維度匹配的情況 (二元分類) print("\n測試 1: 維度匹配的情況 (二元分類)") predictions_match = torch.tensor([[0.7, 0.3], [0.2, 0.8]], dtype=torch.float32) targets_match = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) try: loss_match = criterion(predictions_match, targets_match) print(f"維度匹配的損失值: {loss_match.item()}") print("測試 1 通過!") except Exception as e: print(f"測試 1 失敗: {e}") # 測試 2: 維度不匹配的情況 (多類別分類) print("\n測試 2: 維度不匹配的情況 (多類別分類)") predictions_mismatch = torch.tensor([[0.7, 0.2, 0.1], [0.1, 0.8, 0.1]], dtype=torch.float32) targets_mismatch = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) try: loss_mismatch = criterion(predictions_mismatch, targets_mismatch) print(f"維度不匹配的損失值: {loss_mismatch.item()}") print("測試 2 通過!") except Exception as e: print(f"測試 2 失敗: {e}") # 測試 3: 與 PyTorch 內建函數比較 print("\n測試 3: 與 PyTorch 內建函數比較") predictions_compare = torch.tensor([[0.7, 0.3], [0.2, 0.8]], dtype=torch.float32) targets_compare = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) try: # 自定義損失函數 loss_custom = criterion(predictions_compare, targets_compare) # PyTorch 內建損失函數 loss_pytorch = nn.BCEWithLogitsLoss()(predictions_compare, targets_compare) print(f"自定義損失值: {loss_custom.item()}") print(f"PyTorch 損失值: {loss_pytorch.item()}") print(f"差異: {abs(loss_custom.item() - loss_pytorch.item())}") if abs(loss_custom.item() - loss_pytorch.item()) < 1e-5: print("測試 3 通過!") else: print("測試 3 失敗: 損失值與 PyTorch 內建函數不一致") except Exception as e: print(f"測試 3 失敗: {e}") print("\n測試完成!") if __name__ == "__main__": test_binary_cross_entropy()