Files
Stomach_Cancer_Pytorch/test_loss_fix.py

166 lines
5.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
測試腳本:驗證損失函數修復
檢查數據類型處理和交叉熵損失計算
"""
import torch
import numpy as np
import sys
import os
# 添加項目路徑
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from Model_Loss.Loss import Entropy_Loss
def test_loss_function():
"""測試損失函數的數據類型處理"""
print("🧪 測試損失函數...")
# 創建測試數據
batch_size = 4
num_classes = 3
# 模擬模型輸出logits
outputs = torch.randn(batch_size, num_classes, requires_grad=True)
print(f"✅ 模型輸出形狀: {outputs.shape}, 類型: {outputs.dtype}, requires_grad: {outputs.requires_grad}")
# 測試案例1one-hot編碼標籤
labels_onehot = torch.zeros(batch_size, num_classes)
labels_onehot[0, 0] = 1 # 類別0
labels_onehot[1, 1] = 1 # 類別1
labels_onehot[2, 2] = 1 # 類別2
labels_onehot[3, 0] = 1 # 類別0
print(f"✅ One-hot標籤形狀: {labels_onehot.shape}, 類型: {labels_onehot.dtype}")
# 測試案例2類別索引標籤
labels_indices = torch.tensor([0, 1, 2, 0], dtype=torch.long)
print(f"✅ 索引標籤形狀: {labels_indices.shape}, 類型: {labels_indices.dtype}")
# 測試案例3numpy數組標籤
labels_numpy = np.array([0, 1, 2, 0])
print(f"✅ Numpy標籤形狀: {labels_numpy.shape}, 類型: {labels_numpy.dtype}")
# 創建損失函數
criterion = Entropy_Loss()
try:
# 測試one-hot編碼標籤
print("\n📊 測試one-hot編碼標籤...")
loss1 = criterion(outputs, labels_onehot)
print(f"✅ One-hot標籤損失: {loss1.item():.4f}, requires_grad: {loss1.requires_grad}")
# 測試梯度計算
loss1.backward(retain_graph=True)
print(f"✅ 梯度計算成功,輸出梯度範數: {outputs.grad.norm().item():.4f}")
outputs.grad.zero_() # 清零梯度
# 測試類別索引標籤
print("\n📊 測試類別索引標籤...")
loss2 = criterion(outputs, labels_indices)
print(f"✅ 索引標籤損失: {loss2.item():.4f}, requires_grad: {loss2.requires_grad}")
# 測試梯度計算
loss2.backward(retain_graph=True)
print(f"✅ 梯度計算成功,輸出梯度範數: {outputs.grad.norm().item():.4f}")
outputs.grad.zero_() # 清零梯度
# 測試numpy數組標籤
print("\n📊 測試numpy數組標籤...")
loss3 = criterion(outputs, labels_numpy)
print(f"✅ Numpy標籤損失: {loss3.item():.4f}, requires_grad: {loss3.requires_grad}")
# 測試梯度計算
loss3.backward()
print(f"✅ 梯度計算成功,輸出梯度範數: {outputs.grad.norm().item():.4f}")
print("\n🎉 所有損失函數測試通過!")
return True
except Exception as e:
print(f"❌ 損失函數測試失敗: {e}")
return False
def test_cuda_compatibility():
"""測試CUDA兼容性如果可用"""
if not torch.cuda.is_available():
print("⚠️ CUDA不可用跳過CUDA測試")
return True
print("\n🚀 測試CUDA兼容性...")
try:
device = torch.device('cuda')
# 創建CUDA張量
outputs = torch.randn(2, 3, device=device, requires_grad=True)
labels = torch.tensor([0, 1], device=device, dtype=torch.long)
criterion = Entropy_Loss()
loss = criterion(outputs, labels)
print(f"✅ CUDA損失計算: {loss.item():.4f}")
# 測試梯度
loss.backward()
print(f"✅ CUDA梯度計算成功梯度範數: {outputs.grad.norm().item():.4f}")
print("🎉 CUDA測試通過")
return True
except Exception as e:
print(f"❌ CUDA測試失敗: {e}")
return False
def test_edge_cases():
"""測試邊界情況"""
print("\n🔍 測試邊界情況...")
criterion = Entropy_Loss()
try:
# 測試單個樣本
outputs_single = torch.randn(1, 3, requires_grad=True)
labels_single = torch.tensor([1], dtype=torch.long)
loss = criterion(outputs_single, labels_single)
print(f"✅ 單樣本測試: {loss.item():.4f}")
loss.backward()
print(f"✅ 單樣本梯度計算成功,梯度範數: {outputs_single.grad.norm().item():.4f}")
# 測試大批次
outputs_large = torch.randn(100, 5, requires_grad=True)
labels_large = torch.randint(0, 5, (100,), dtype=torch.long)
loss = criterion(outputs_large, labels_large)
print(f"✅ 大批次測試: {loss.item():.4f}")
print("🎉 邊界情況測試通過!")
return True
except Exception as e:
print(f"❌ 邊界情況測試失敗: {e}")
return False
if __name__ == "__main__":
print("🔧 開始損失函數修復驗證...")
success = True
success &= test_loss_function()
success &= test_cuda_compatibility()
success &= test_edge_cases()
if success:
print("\n🎉 所有測試通過!損失函數修復成功。")
print("✅ RuntimeError: Expected floating point type for target with class probabilities, got Long 已修復")
else:
print("\n❌ 部分測試失敗,需要進一步檢查。")
print("\n📋 修復摘要:")
print("1. ✅ 修復了 Loss.py 中的 cross_entropy 調用")
print("2. ✅ 正確處理 one-hot 編碼和類別索引標籤")
print("3. ✅ 修復了 Model_Branch 中的參數傳遞")
print("4. ✅ 確保了梯度計算的連續性")