import torch import torch.nn.functional as F import sys import os # 添加項目根目錄到路徑 sys.path.append(os.path.dirname(os.path.abspath(__file__))) # 導入必要的模塊 from experiments.Models.GastroSegNet_Model import GastroSegNet from experiments.Training.Segmentation_Block_Training import Segmentation_Block_Training_Step def test_model_branch(): # 初始化模型 try: model_step = Segmentation_Block_Training_Step() print("初始化成功") # 獲取設備 device = model_step.device print(f"使用設備: {device}") # 創建一個簡單的測試數據 batch_size = 2 input_images = torch.rand(batch_size, 3, 256, 256) # 創建符合Model_Branch方法輸入要求的mask_gt # 創建二值掩碼,形狀為[batch_size, 1, height, width] mask_gt = torch.zeros(batch_size, 1, 256, 256) # 在掩碼中央創建一個矩形區域(模擬分割目標) for i in range(batch_size): # 設置矩形區域為1(白色) mask_gt[i, 0, 50:200, 50:200] = 1.0 # 初始化模型 model_step.Model = model_step.Construct_Segment_Model_CUDA() print("模型初始化完成") # 測試Model_Branch方法 try: # 調用Model_Branch方法,設置return_processed_image=True print("\n開始測試Model_Branch方法...") processed_images, seg_outputs = model_step.Model_Branch(input_images, mask_gt, return_processed_image=True) # 檢查結果 print("Model_Branch方法測試成功") print(f"處理後的圖像形狀: {processed_images.shape}") print(f"分割輸出形狀: {seg_outputs.shape}") # 測試Model_Branch方法,設置return_processed_image=False print("\n開始測試Model_Branch方法 (return_processed_image=False)...") # 創建一個自定義的損失計算函數,用於捕獲輸入形狀 def mock_losses(Segmentation_Output_Image, Segmentation_Mask_GroundTruth_Image): print(f"Losses方法輸入形狀 - 預測: {Segmentation_Output_Image.shape}, 目標: {Segmentation_Mask_GroundTruth_Image.shape}") return torch.tensor(0.5, device=model_step.device) # 保存原始的Losses方法 original_losses = model_step.Losses try: # 替換Losses方法 model_step.Losses = mock_losses # 調用Model_Branch方法 loss = model_step.Model_Branch(input_images, mask_gt, return_processed_image=False) print("Model_Branch方法測試成功 (return_processed_image=False)") print(f"損失值: {loss}") finally: # 恢復原始的Losses方法 model_step.Losses = original_losses return True except Exception as e: print(f"Model_Branch方法測試失敗: {str(e)}") import traceback traceback.print_exc() return False except Exception as e: print(f"初始化過程中出錯: {str(e)}") return False if __name__ == "__main__": success = test_model_branch() if success: print("\n所有測試通過!") else: print("\n測試失敗!")