115 lines
4.0 KiB
Python
115 lines
4.0 KiB
Python
import torch
|
||
import numpy as np
|
||
|
||
def validate_data_for_training(outputs, labels, num_classes):
|
||
"""
|
||
驗證訓練數據是否會導致 CUDA device-side assert 錯誤
|
||
|
||
Args:
|
||
outputs: 模型輸出 tensor
|
||
labels: 標籤 tensor (可能是 one-hot 編碼)
|
||
num_classes: 預期的類別數量
|
||
|
||
Returns:
|
||
bool: True 如果數據有效,False 如果可能導致錯誤
|
||
"""
|
||
print("=== 數據驗證開始 ===")
|
||
|
||
# 檢查基本信息
|
||
print(f"輸出形狀: {outputs.shape}")
|
||
print(f"標籤形狀: {labels.shape}")
|
||
print(f"預期類別數: {num_classes}")
|
||
|
||
# 檢查數據類型
|
||
print(f"輸出數據類型: {outputs.dtype}")
|
||
print(f"標籤數據類型: {labels.dtype}")
|
||
|
||
# 檢查設備
|
||
print(f"輸出設備: {outputs.device}")
|
||
print(f"標籤設備: {labels.device}")
|
||
|
||
# 檢查是否有 NaN 或 Inf
|
||
if torch.isnan(outputs).any():
|
||
print("❌ 警告:輸出中包含 NaN 值!")
|
||
return False
|
||
|
||
if torch.isinf(outputs).any():
|
||
print("❌ 警告:輸出中包含 Inf 值!")
|
||
return False
|
||
|
||
if torch.isnan(labels).any():
|
||
print("❌ 警告:標籤中包含 NaN 值!")
|
||
return False
|
||
|
||
# 如果標籤是 one-hot 編碼,轉換為索引
|
||
if labels.dim() > 1 and labels.shape[1] > 1:
|
||
print("檢測到 one-hot 編碼標籤,轉換為索引...")
|
||
_, label_indices = torch.max(labels, dim=1)
|
||
else:
|
||
label_indices = labels.long()
|
||
|
||
print(f"標籤索引範圍: {label_indices.min().item()} - {label_indices.max().item()}")
|
||
|
||
# 檢查標籤索引是否在有效範圍內
|
||
if label_indices.min() < 0:
|
||
print("❌ 錯誤:發現負數標籤索引!")
|
||
return False
|
||
|
||
if label_indices.max() >= num_classes:
|
||
print(f"❌ 錯誤:標籤索引 {label_indices.max().item()} 超出類別數 {num_classes}!")
|
||
print("這會導致 CUDA device-side assert 錯誤")
|
||
return False
|
||
|
||
# 檢查輸出維度是否與類別數匹配
|
||
if outputs.shape[1] != num_classes:
|
||
print(f"❌ 警告:輸出維度 {outputs.shape[1]} 與預期類別數 {num_classes} 不匹配!")
|
||
return False
|
||
|
||
print("✅ 數據驗證通過!")
|
||
print("=== 數據驗證結束 ===")
|
||
return True
|
||
|
||
def debug_loss_calculation(outputs, labels, num_classes):
|
||
"""
|
||
調試損失計算過程
|
||
"""
|
||
print("\n=== 損失計算調試 ===")
|
||
|
||
try:
|
||
# 模擬 Entropy_Loss 的計算過程
|
||
outputs_tensor = torch.as_tensor(outputs, dtype=torch.float32)
|
||
labels_tensor = torch.as_tensor(labels, dtype=torch.float32)
|
||
|
||
print(f"轉換後輸出形狀: {outputs_tensor.shape}")
|
||
print(f"轉換後標籤形狀: {labels_tensor.shape}")
|
||
|
||
# 轉換為標籤索引
|
||
_, labels_indices = torch.max(labels_tensor, dim=1)
|
||
print(f"標籤索引: {labels_indices}")
|
||
|
||
# 檢查是否會導致錯誤
|
||
if validate_data_for_training(outputs_tensor, labels_tensor, num_classes):
|
||
print("✅ 損失計算應該不會出錯")
|
||
else:
|
||
print("❌ 損失計算可能會出錯")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 損失計算調試過程中出現錯誤: {e}")
|
||
|
||
# 使用示例
|
||
if __name__ == "__main__":
|
||
# 示例:3 類分類問題
|
||
num_classes = 3
|
||
batch_size = 2
|
||
|
||
# 正常情況
|
||
print("測試正常情況:")
|
||
outputs_normal = torch.randn(batch_size, num_classes)
|
||
labels_normal = torch.tensor([[1, 0, 0], [0, 0, 1]], dtype=torch.float32)
|
||
debug_loss_calculation(outputs_normal, labels_normal, num_classes)
|
||
|
||
# 異常情況:標籤索引超出範圍
|
||
print("\n測試異常情況(標籤索引超出範圍):")
|
||
outputs_error = torch.randn(batch_size, num_classes)
|
||
labels_error = torch.tensor([[0, 0, 0, 1], [1, 0, 0, 0]], dtype=torch.float32) # 4 類標籤但只有 3 類輸出
|
||
debug_loss_calculation(outputs_error, labels_error, num_classes) |