Files
Stomach_Cancer_Pytorch/test_gradient_fix.py

187 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
測試梯度計算修復的腳本
用於驗證 RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn 是否已修復
"""
import torch
import torch.nn as nn
import numpy as np
from experiments.Models.Xception_Model_Modification import Xception
from Model_Loss.Loss import Entropy_Loss
def test_gradient_computation():
"""測試梯度計算是否正常工作"""
print("=== 測試梯度計算修復 ===")
# 設置設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用設備: {device}")
# 創建模型
model = Xception().to(device)
model.train()
# 檢查模型參數
print("\n=== 模型參數檢查 ===")
total_params = 0
trainable_params = 0
for name, param in model.named_parameters():
total_params += param.numel()
if param.requires_grad:
trainable_params += param.numel()
else:
print(f"❌ 參數 {name} 不需要梯度!")
print(f"總參數數量: {total_params:,}")
print(f"可訓練參數數量: {trainable_params:,}")
print(f"可訓練參數比例: {trainable_params/total_params*100:.2f}%")
if trainable_params == 0:
print("❌ 錯誤:沒有可訓練的參數!")
return False
# 創建測試數據
batch_size = 2
input_images = torch.randn(batch_size, 3, 224, 224, device=device, requires_grad=True)
# 創建 one-hot 編碼的標籤
labels_onehot = torch.zeros(batch_size, 3, device=device)
labels_onehot[0, 1] = 1.0 # 第一個樣本屬於類別1
labels_onehot[1, 2] = 1.0 # 第二個樣本屬於類別2
print(f"\n=== 測試數據 ===")
print(f"輸入形狀: {input_images.shape}")
print(f"標籤形狀: {labels_onehot.shape}")
print(f"輸入 requires_grad: {input_images.requires_grad}")
print(f"標籤內容:\n{labels_onehot}")
try:
# 前向傳播
print("\n=== 前向傳播 ===")
outputs = model(input_images)
print(f"輸出形狀: {outputs.shape}")
print(f"輸出 requires_grad: {outputs.requires_grad}")
print(f"輸出 grad_fn: {outputs.grad_fn}")
if outputs.grad_fn is None:
print("❌ 錯誤:輸出沒有 grad_fn")
return False
# 計算損失
print("\n=== 損失計算 ===")
criterion = Entropy_Loss()
loss = criterion(outputs, labels_onehot)
print(f"損失值: {loss.item():.6f}")
print(f"損失 requires_grad: {loss.requires_grad}")
print(f"損失 grad_fn: {loss.grad_fn}")
if loss.grad_fn is None:
print("❌ 錯誤:損失沒有 grad_fn")
return False
# 反向傳播
print("\n=== 反向傳播 ===")
loss.backward()
print("✅ 反向傳播成功完成!")
# 檢查梯度
print("\n=== 梯度檢查 ===")
grad_count = 0
for name, param in model.named_parameters():
if param.grad is not None:
grad_count += 1
grad_norm = param.grad.norm().item()
if grad_norm > 0:
print(f"{name}: 梯度範數 = {grad_norm:.6f}")
else:
print(f"⚠️ {name}: 梯度為零")
else:
print(f"{name}: 沒有梯度")
print(f"\n有梯度的參數數量: {grad_count}")
if grad_count == 0:
print("❌ 錯誤:沒有參數有梯度!")
return False
print("\n✅ 所有測試通過!梯度計算修復成功!")
return True
except Exception as e:
print(f"\n❌ 測試失敗:{str(e)}")
print(f"錯誤類型: {type(e).__name__}")
return False
def test_losses_method():
"""測試 Losses 方法的張量處理"""
print("\n=== 測試 Losses 方法 ===")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 模擬 Losses 方法的邏輯
def test_losses(predicts, labels):
# 確保輸入是張量且在正確的設備上
if not isinstance(predicts, torch.Tensor):
predicts = torch.tensor(predicts, dtype=torch.float32, device=device, requires_grad=True)
if not isinstance(labels, torch.Tensor):
labels = torch.tensor(labels, dtype=torch.float32, device=device)
# 確保張量在同一設備上
predicts = predicts.to(device)
labels = labels.to(device)
print(f"Predicts: shape={predicts.shape}, requires_grad={predicts.requires_grad}, device={predicts.device}")
print(f"Labels: shape={labels.shape}, requires_grad={labels.requires_grad}, device={labels.device}")
criterion = Entropy_Loss()
loss = criterion(predicts, labels)
return loss
# 測試不同類型的輸入
batch_size = 2
num_classes = 3
# 測試1: 張量輸入
print("\n--- 測試1: 張量輸入 ---")
predicts_tensor = torch.randn(batch_size, num_classes, device=device, requires_grad=True)
labels_tensor = torch.zeros(batch_size, num_classes, device=device)
labels_tensor[0, 1] = 1.0
labels_tensor[1, 2] = 1.0
try:
loss1 = test_losses(predicts_tensor, labels_tensor)
print(f"✅ 張量輸入測試成功,損失: {loss1.item():.6f}")
except Exception as e:
print(f"❌ 張量輸入測試失敗: {e}")
# 測試2: NumPy 輸入
print("\n--- 測試2: NumPy 輸入 ---")
predicts_numpy = np.random.randn(batch_size, num_classes).astype(np.float32)
labels_numpy = np.zeros((batch_size, num_classes), dtype=np.float32)
labels_numpy[0, 1] = 1.0
labels_numpy[1, 2] = 1.0
try:
loss2 = test_losses(predicts_numpy, labels_numpy)
print(f"✅ NumPy 輸入測試成功,損失: {loss2.item():.6f}")
except Exception as e:
print(f"❌ NumPy 輸入測試失敗: {e}")
if __name__ == "__main__":
print("開始測試梯度計算修復...")
# 測試梯度計算
gradient_test_passed = test_gradient_computation()
# 測試 Losses 方法
test_losses_method()
print("\n" + "="*50)
if gradient_test_passed:
print("🎉 梯度計算修復驗證成功!可以開始訓練了。")
else:
print("❌ 梯度計算仍有問題,需要進一步調試。")
print("="*50)