166 lines
5.7 KiB
Python
166 lines
5.7 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
測試腳本:驗證損失函數修復
|
||
檢查數據類型處理和交叉熵損失計算
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
import sys
|
||
import os
|
||
|
||
# 添加項目路徑
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from Model_Loss.Loss import Entropy_Loss
|
||
|
||
def test_loss_function():
|
||
"""測試損失函數的數據類型處理"""
|
||
print("🧪 測試損失函數...")
|
||
|
||
# 創建測試數據
|
||
batch_size = 4
|
||
num_classes = 3
|
||
|
||
# 模擬模型輸出(logits)
|
||
outputs = torch.randn(batch_size, num_classes, requires_grad=True)
|
||
print(f"✅ 模型輸出形狀: {outputs.shape}, 類型: {outputs.dtype}, requires_grad: {outputs.requires_grad}")
|
||
|
||
# 測試案例1:one-hot編碼標籤
|
||
labels_onehot = torch.zeros(batch_size, num_classes)
|
||
labels_onehot[0, 0] = 1 # 類別0
|
||
labels_onehot[1, 1] = 1 # 類別1
|
||
labels_onehot[2, 2] = 1 # 類別2
|
||
labels_onehot[3, 0] = 1 # 類別0
|
||
|
||
print(f"✅ One-hot標籤形狀: {labels_onehot.shape}, 類型: {labels_onehot.dtype}")
|
||
|
||
# 測試案例2:類別索引標籤
|
||
labels_indices = torch.tensor([0, 1, 2, 0], dtype=torch.long)
|
||
print(f"✅ 索引標籤形狀: {labels_indices.shape}, 類型: {labels_indices.dtype}")
|
||
|
||
# 測試案例3:numpy數組標籤
|
||
labels_numpy = np.array([0, 1, 2, 0])
|
||
print(f"✅ Numpy標籤形狀: {labels_numpy.shape}, 類型: {labels_numpy.dtype}")
|
||
|
||
# 創建損失函數
|
||
criterion = Entropy_Loss()
|
||
|
||
try:
|
||
# 測試one-hot編碼標籤
|
||
print("\n📊 測試one-hot編碼標籤...")
|
||
loss1 = criterion(outputs, labels_onehot)
|
||
print(f"✅ One-hot標籤損失: {loss1.item():.4f}, requires_grad: {loss1.requires_grad}")
|
||
|
||
# 測試梯度計算
|
||
loss1.backward(retain_graph=True)
|
||
print(f"✅ 梯度計算成功,輸出梯度範數: {outputs.grad.norm().item():.4f}")
|
||
outputs.grad.zero_() # 清零梯度
|
||
|
||
# 測試類別索引標籤
|
||
print("\n📊 測試類別索引標籤...")
|
||
loss2 = criterion(outputs, labels_indices)
|
||
print(f"✅ 索引標籤損失: {loss2.item():.4f}, requires_grad: {loss2.requires_grad}")
|
||
|
||
# 測試梯度計算
|
||
loss2.backward(retain_graph=True)
|
||
print(f"✅ 梯度計算成功,輸出梯度範數: {outputs.grad.norm().item():.4f}")
|
||
outputs.grad.zero_() # 清零梯度
|
||
|
||
# 測試numpy數組標籤
|
||
print("\n📊 測試numpy數組標籤...")
|
||
loss3 = criterion(outputs, labels_numpy)
|
||
print(f"✅ Numpy標籤損失: {loss3.item():.4f}, requires_grad: {loss3.requires_grad}")
|
||
|
||
# 測試梯度計算
|
||
loss3.backward()
|
||
print(f"✅ 梯度計算成功,輸出梯度範數: {outputs.grad.norm().item():.4f}")
|
||
|
||
print("\n🎉 所有損失函數測試通過!")
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 損失函數測試失敗: {e}")
|
||
return False
|
||
|
||
def test_cuda_compatibility():
|
||
"""測試CUDA兼容性(如果可用)"""
|
||
if not torch.cuda.is_available():
|
||
print("⚠️ CUDA不可用,跳過CUDA測試")
|
||
return True
|
||
|
||
print("\n🚀 測試CUDA兼容性...")
|
||
|
||
try:
|
||
device = torch.device('cuda')
|
||
|
||
# 創建CUDA張量
|
||
outputs = torch.randn(2, 3, device=device, requires_grad=True)
|
||
labels = torch.tensor([0, 1], device=device, dtype=torch.long)
|
||
|
||
criterion = Entropy_Loss()
|
||
loss = criterion(outputs, labels)
|
||
|
||
print(f"✅ CUDA損失計算: {loss.item():.4f}")
|
||
|
||
# 測試梯度
|
||
loss.backward()
|
||
print(f"✅ CUDA梯度計算成功,梯度範數: {outputs.grad.norm().item():.4f}")
|
||
|
||
print("🎉 CUDA測試通過!")
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ CUDA測試失敗: {e}")
|
||
return False
|
||
|
||
def test_edge_cases():
|
||
"""測試邊界情況"""
|
||
print("\n🔍 測試邊界情況...")
|
||
|
||
criterion = Entropy_Loss()
|
||
|
||
try:
|
||
# 測試單個樣本
|
||
outputs_single = torch.randn(1, 3, requires_grad=True)
|
||
labels_single = torch.tensor([1], dtype=torch.long)
|
||
|
||
loss = criterion(outputs_single, labels_single)
|
||
print(f"✅ 單樣本測試: {loss.item():.4f}")
|
||
|
||
loss.backward()
|
||
print(f"✅ 單樣本梯度計算成功,梯度範數: {outputs_single.grad.norm().item():.4f}")
|
||
|
||
# 測試大批次
|
||
outputs_large = torch.randn(100, 5, requires_grad=True)
|
||
labels_large = torch.randint(0, 5, (100,), dtype=torch.long)
|
||
|
||
loss = criterion(outputs_large, labels_large)
|
||
print(f"✅ 大批次測試: {loss.item():.4f}")
|
||
|
||
print("🎉 邊界情況測試通過!")
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 邊界情況測試失敗: {e}")
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
print("🔧 開始損失函數修復驗證...")
|
||
|
||
success = True
|
||
success &= test_loss_function()
|
||
success &= test_cuda_compatibility()
|
||
success &= test_edge_cases()
|
||
|
||
if success:
|
||
print("\n🎉 所有測試通過!損失函數修復成功。")
|
||
print("✅ RuntimeError: Expected floating point type for target with class probabilities, got Long 已修復")
|
||
else:
|
||
print("\n❌ 部分測試失敗,需要進一步檢查。")
|
||
|
||
print("\n📋 修復摘要:")
|
||
print("1. ✅ 修復了 Loss.py 中的 cross_entropy 調用")
|
||
print("2. ✅ 正確處理 one-hot 編碼和類別索引標籤")
|
||
print("3. ✅ 修復了 Model_Branch 中的參數傳遞")
|
||
print("4. ✅ 確保了梯度計算的連續性") |