import torch import sys import os # 添加項目根目錄到路徑 sys.path.append(os.path.dirname(os.path.abspath(__file__))) # 導入必要的模塊 from experiments.Models.GastroSegNet_Model import GastroSegNet from experiments.Training.Segmentation_Block_Training import Segmentation_Block_Training_Step def test_process_segmentation_output(): # 初始化模型 try: model_step = Segmentation_Block_Training_Step() print("初始化成功") # 創建測試數據 batch_size = 2 channels = 3 height = 256 width = 256 # 獲取設備 device = model_step.device print(f"使用設備: {device}") # 創建隨機輸入圖像並移到正確的設備上 input_images = torch.rand(batch_size, channels, height, width).to(device) # 創建隨機分割輸出 (模擬模型輸出)並移到正確的設備上 segmentation_output = torch.rand(batch_size, 1, height, width).to(device) # 測試處理方法 try: processed_images = model_step.process_segmentation_output(input_images, segmentation_output) print(f"處理後圖像形狀: {processed_images.shape}") print("處理成功") return True except Exception as e: print(f"處理過程中出錯: {str(e)}") return False except Exception as e: print(f"初始化過程中出錯: {str(e)}") return False if __name__ == "__main__": test_process_segmentation_output()