import torch import numpy as np def validate_data_for_training(outputs, labels, num_classes): """ 驗證訓練數據是否會導致 CUDA device-side assert 錯誤 Args: outputs: 模型輸出 tensor labels: 標籤 tensor (可能是 one-hot 編碼) num_classes: 預期的類別數量 Returns: bool: True 如果數據有效,False 如果可能導致錯誤 """ print("=== 數據驗證開始 ===") # 檢查基本信息 print(f"輸出形狀: {outputs.shape}") print(f"標籤形狀: {labels.shape}") print(f"預期類別數: {num_classes}") # 檢查數據類型 print(f"輸出數據類型: {outputs.dtype}") print(f"標籤數據類型: {labels.dtype}") # 檢查設備 print(f"輸出設備: {outputs.device}") print(f"標籤設備: {labels.device}") # 檢查是否有 NaN 或 Inf if torch.isnan(outputs).any(): print("❌ 警告:輸出中包含 NaN 值!") return False if torch.isinf(outputs).any(): print("❌ 警告:輸出中包含 Inf 值!") return False if torch.isnan(labels).any(): print("❌ 警告:標籤中包含 NaN 值!") return False # 如果標籤是 one-hot 編碼,轉換為索引 if labels.dim() > 1 and labels.shape[1] > 1: print("檢測到 one-hot 編碼標籤,轉換為索引...") _, label_indices = torch.max(labels, dim=1) else: label_indices = labels.long() print(f"標籤索引範圍: {label_indices.min().item()} - {label_indices.max().item()}") # 檢查標籤索引是否在有效範圍內 if label_indices.min() < 0: print("❌ 錯誤:發現負數標籤索引!") return False if label_indices.max() >= num_classes: print(f"❌ 錯誤:標籤索引 {label_indices.max().item()} 超出類別數 {num_classes}!") print("這會導致 CUDA device-side assert 錯誤") return False # 檢查輸出維度是否與類別數匹配 if outputs.shape[1] != num_classes: print(f"❌ 警告:輸出維度 {outputs.shape[1]} 與預期類別數 {num_classes} 不匹配!") return False print("✅ 數據驗證通過!") print("=== 數據驗證結束 ===") return True def debug_loss_calculation(outputs, labels, num_classes): """ 調試損失計算過程 """ print("\n=== 損失計算調試 ===") try: # 模擬 Entropy_Loss 的計算過程 outputs_tensor = torch.as_tensor(outputs, dtype=torch.float32) labels_tensor = torch.as_tensor(labels, dtype=torch.float32) print(f"轉換後輸出形狀: {outputs_tensor.shape}") print(f"轉換後標籤形狀: {labels_tensor.shape}") # 轉換為標籤索引 _, labels_indices = torch.max(labels_tensor, dim=1) print(f"標籤索引: {labels_indices}") # 檢查是否會導致錯誤 if validate_data_for_training(outputs_tensor, labels_tensor, num_classes): print("✅ 損失計算應該不會出錯") else: print("❌ 損失計算可能會出錯") except Exception as e: print(f"❌ 損失計算調試過程中出現錯誤: {e}") # 使用示例 if __name__ == "__main__": # 示例:3 類分類問題 num_classes = 3 batch_size = 2 # 正常情況 print("測試正常情況:") outputs_normal = torch.randn(batch_size, num_classes) labels_normal = torch.tensor([[1, 0, 0], [0, 0, 1]], dtype=torch.float32) debug_loss_calculation(outputs_normal, labels_normal, num_classes) # 異常情況:標籤索引超出範圍 print("\n測試異常情況(標籤索引超出範圍):") outputs_error = torch.randn(batch_size, num_classes) labels_error = torch.tensor([[0, 0, 0, 1], [1, 0, 0, 0]], dtype=torch.float32) # 4 類標籤但只有 3 類輸出 debug_loss_calculation(outputs_error, labels_error, num_classes)