Files
Stomach_Cancer_Pytorch/debug_data_validation.py

115 lines
4.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import numpy as np
def validate_data_for_training(outputs, labels, num_classes):
"""
驗證訓練數據是否會導致 CUDA device-side assert 錯誤
Args:
outputs: 模型輸出 tensor
labels: 標籤 tensor (可能是 one-hot 編碼)
num_classes: 預期的類別數量
Returns:
bool: True 如果數據有效False 如果可能導致錯誤
"""
print("=== 數據驗證開始 ===")
# 檢查基本信息
print(f"輸出形狀: {outputs.shape}")
print(f"標籤形狀: {labels.shape}")
print(f"預期類別數: {num_classes}")
# 檢查數據類型
print(f"輸出數據類型: {outputs.dtype}")
print(f"標籤數據類型: {labels.dtype}")
# 檢查設備
print(f"輸出設備: {outputs.device}")
print(f"標籤設備: {labels.device}")
# 檢查是否有 NaN 或 Inf
if torch.isnan(outputs).any():
print("❌ 警告:輸出中包含 NaN 值!")
return False
if torch.isinf(outputs).any():
print("❌ 警告:輸出中包含 Inf 值!")
return False
if torch.isnan(labels).any():
print("❌ 警告:標籤中包含 NaN 值!")
return False
# 如果標籤是 one-hot 編碼,轉換為索引
if labels.dim() > 1 and labels.shape[1] > 1:
print("檢測到 one-hot 編碼標籤,轉換為索引...")
_, label_indices = torch.max(labels, dim=1)
else:
label_indices = labels.long()
print(f"標籤索引範圍: {label_indices.min().item()} - {label_indices.max().item()}")
# 檢查標籤索引是否在有效範圍內
if label_indices.min() < 0:
print("❌ 錯誤:發現負數標籤索引!")
return False
if label_indices.max() >= num_classes:
print(f"❌ 錯誤:標籤索引 {label_indices.max().item()} 超出類別數 {num_classes}")
print("這會導致 CUDA device-side assert 錯誤")
return False
# 檢查輸出維度是否與類別數匹配
if outputs.shape[1] != num_classes:
print(f"❌ 警告:輸出維度 {outputs.shape[1]} 與預期類別數 {num_classes} 不匹配!")
return False
print("✅ 數據驗證通過!")
print("=== 數據驗證結束 ===")
return True
def debug_loss_calculation(outputs, labels, num_classes):
"""
調試損失計算過程
"""
print("\n=== 損失計算調試 ===")
try:
# 模擬 Entropy_Loss 的計算過程
outputs_tensor = torch.as_tensor(outputs, dtype=torch.float32)
labels_tensor = torch.as_tensor(labels, dtype=torch.float32)
print(f"轉換後輸出形狀: {outputs_tensor.shape}")
print(f"轉換後標籤形狀: {labels_tensor.shape}")
# 轉換為標籤索引
_, labels_indices = torch.max(labels_tensor, dim=1)
print(f"標籤索引: {labels_indices}")
# 檢查是否會導致錯誤
if validate_data_for_training(outputs_tensor, labels_tensor, num_classes):
print("✅ 損失計算應該不會出錯")
else:
print("❌ 損失計算可能會出錯")
except Exception as e:
print(f"❌ 損失計算調試過程中出現錯誤: {e}")
# 使用示例
if __name__ == "__main__":
# 示例3 類分類問題
num_classes = 3
batch_size = 2
# 正常情況
print("測試正常情況:")
outputs_normal = torch.randn(batch_size, num_classes)
labels_normal = torch.tensor([[1, 0, 0], [0, 0, 1]], dtype=torch.float32)
debug_loss_calculation(outputs_normal, labels_normal, num_classes)
# 異常情況:標籤索引超出範圍
print("\n測試異常情況(標籤索引超出範圍):")
outputs_error = torch.randn(batch_size, num_classes)
labels_error = torch.tensor([[0, 0, 0, 1], [1, 0, 0, 0]], dtype=torch.float32) # 4 類標籤但只有 3 類輸出
debug_loss_calculation(outputs_error, labels_error, num_classes)