Files
Stomach_Cancer_Pytorch/test_model_branch.py

91 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn.functional as F
import sys
import os
# 添加項目根目錄到路徑
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# 導入必要的模塊
from experiments.Models.GastroSegNet_Model import GastroSegNet
from experiments.Training.Segmentation_Block_Training import Segmentation_Block_Training_Step
def test_model_branch():
# 初始化模型
try:
model_step = Segmentation_Block_Training_Step()
print("初始化成功")
# 獲取設備
device = model_step.device
print(f"使用設備: {device}")
# 創建一個簡單的測試數據
batch_size = 2
input_images = torch.rand(batch_size, 3, 256, 256)
# 創建符合Model_Branch方法輸入要求的mask_gt
# 創建二值掩碼,形狀為[batch_size, 1, height, width]
mask_gt = torch.zeros(batch_size, 1, 256, 256)
# 在掩碼中央創建一個矩形區域(模擬分割目標)
for i in range(batch_size):
# 設置矩形區域為1白色
mask_gt[i, 0, 50:200, 50:200] = 1.0
# 初始化模型
model_step.Model = model_step.Construct_Segment_Model_CUDA()
print("模型初始化完成")
# 測試Model_Branch方法
try:
# 調用Model_Branch方法設置return_processed_image=True
print("\n開始測試Model_Branch方法...")
processed_images, seg_outputs = model_step.Model_Branch(input_images, mask_gt, return_processed_image=True)
# 檢查結果
print("Model_Branch方法測試成功")
print(f"處理後的圖像形狀: {processed_images.shape}")
print(f"分割輸出形狀: {seg_outputs.shape}")
# 測試Model_Branch方法設置return_processed_image=False
print("\n開始測試Model_Branch方法 (return_processed_image=False)...")
# 創建一個自定義的損失計算函數,用於捕獲輸入形狀
def mock_losses(Segmentation_Output_Image, Segmentation_Mask_GroundTruth_Image):
print(f"Losses方法輸入形狀 - 預測: {Segmentation_Output_Image.shape}, 目標: {Segmentation_Mask_GroundTruth_Image.shape}")
return torch.tensor(0.5, device=model_step.device)
# 保存原始的Losses方法
original_losses = model_step.Losses
try:
# 替換Losses方法
model_step.Losses = mock_losses
# 調用Model_Branch方法
loss = model_step.Model_Branch(input_images, mask_gt, return_processed_image=False)
print("Model_Branch方法測試成功 (return_processed_image=False)")
print(f"損失值: {loss}")
finally:
# 恢復原始的Losses方法
model_step.Losses = original_losses
return True
except Exception as e:
print(f"Model_Branch方法測試失敗: {str(e)}")
import traceback
traceback.print_exc()
return False
except Exception as e:
print(f"初始化過程中出錯: {str(e)}")
return False
if __name__ == "__main__":
success = test_model_branch()
if success:
print("\n所有測試通過!")
else:
print("\n測試失敗!")