91 lines
3.5 KiB
Python
91 lines
3.5 KiB
Python
import torch
|
||
import torch.nn.functional as F
|
||
import sys
|
||
import os
|
||
|
||
# 添加項目根目錄到路徑
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
# 導入必要的模塊
|
||
from experiments.Models.GastroSegNet_Model import GastroSegNet
|
||
from experiments.Training.Segmentation_Block_Training import Segmentation_Block_Training_Step
|
||
|
||
def test_model_branch():
|
||
# 初始化模型
|
||
try:
|
||
model_step = Segmentation_Block_Training_Step()
|
||
print("初始化成功")
|
||
|
||
# 獲取設備
|
||
device = model_step.device
|
||
print(f"使用設備: {device}")
|
||
|
||
# 創建一個簡單的測試數據
|
||
batch_size = 2
|
||
input_images = torch.rand(batch_size, 3, 256, 256)
|
||
|
||
# 創建符合Model_Branch方法輸入要求的mask_gt
|
||
# 創建二值掩碼,形狀為[batch_size, 1, height, width]
|
||
mask_gt = torch.zeros(batch_size, 1, 256, 256)
|
||
|
||
# 在掩碼中央創建一個矩形區域(模擬分割目標)
|
||
for i in range(batch_size):
|
||
# 設置矩形區域為1(白色)
|
||
mask_gt[i, 0, 50:200, 50:200] = 1.0
|
||
|
||
# 初始化模型
|
||
model_step.Model = model_step.Construct_Segment_Model_CUDA()
|
||
print("模型初始化完成")
|
||
|
||
# 測試Model_Branch方法
|
||
try:
|
||
# 調用Model_Branch方法,設置return_processed_image=True
|
||
print("\n開始測試Model_Branch方法...")
|
||
processed_images, seg_outputs = model_step.Model_Branch(input_images, mask_gt, return_processed_image=True)
|
||
|
||
# 檢查結果
|
||
print("Model_Branch方法測試成功")
|
||
print(f"處理後的圖像形狀: {processed_images.shape}")
|
||
print(f"分割輸出形狀: {seg_outputs.shape}")
|
||
|
||
# 測試Model_Branch方法,設置return_processed_image=False
|
||
print("\n開始測試Model_Branch方法 (return_processed_image=False)...")
|
||
|
||
# 創建一個自定義的損失計算函數,用於捕獲輸入形狀
|
||
def mock_losses(Segmentation_Output_Image, Segmentation_Mask_GroundTruth_Image):
|
||
print(f"Losses方法輸入形狀 - 預測: {Segmentation_Output_Image.shape}, 目標: {Segmentation_Mask_GroundTruth_Image.shape}")
|
||
return torch.tensor(0.5, device=model_step.device)
|
||
|
||
# 保存原始的Losses方法
|
||
original_losses = model_step.Losses
|
||
|
||
try:
|
||
# 替換Losses方法
|
||
model_step.Losses = mock_losses
|
||
|
||
# 調用Model_Branch方法
|
||
loss = model_step.Model_Branch(input_images, mask_gt, return_processed_image=False)
|
||
|
||
print("Model_Branch方法測試成功 (return_processed_image=False)")
|
||
print(f"損失值: {loss}")
|
||
finally:
|
||
# 恢復原始的Losses方法
|
||
model_step.Losses = original_losses
|
||
|
||
return True
|
||
except Exception as e:
|
||
print(f"Model_Branch方法測試失敗: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
except Exception as e:
|
||
print(f"初始化過程中出錯: {str(e)}")
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
success = test_model_branch()
|
||
if success:
|
||
print("\n所有測試通過!")
|
||
else:
|
||
print("\n測試失敗!") |