48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
import torch
|
|
import sys
|
|
import os
|
|
|
|
# 添加項目根目錄到路徑
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
# 導入必要的模塊
|
|
from experiments.Models.GastroSegNet_Model import GastroSegNet
|
|
from experiments.Training.Segmentation_Block_Training import Segmentation_Block_Training_Step
|
|
|
|
def test_process_segmentation_output():
|
|
# 初始化模型
|
|
try:
|
|
model_step = Segmentation_Block_Training_Step()
|
|
print("初始化成功")
|
|
|
|
# 創建測試數據
|
|
batch_size = 2
|
|
channels = 3
|
|
height = 256
|
|
width = 256
|
|
|
|
# 獲取設備
|
|
device = model_step.device
|
|
print(f"使用設備: {device}")
|
|
|
|
# 創建隨機輸入圖像並移到正確的設備上
|
|
input_images = torch.rand(batch_size, channels, height, width).to(device)
|
|
|
|
# 創建隨機分割輸出 (模擬模型輸出)並移到正確的設備上
|
|
segmentation_output = torch.rand(batch_size, 1, height, width).to(device)
|
|
|
|
# 測試處理方法
|
|
try:
|
|
processed_images = model_step.process_segmentation_output(input_images, segmentation_output)
|
|
print(f"處理後圖像形狀: {processed_images.shape}")
|
|
print("處理成功")
|
|
return True
|
|
except Exception as e:
|
|
print(f"處理過程中出錯: {str(e)}")
|
|
return False
|
|
except Exception as e:
|
|
print(f"初始化過程中出錯: {str(e)}")
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
test_process_segmentation_output() |