Files
Stomach_Cancer_Pytorch/test_segmentation.py

48 lines
1.6 KiB
Python

import torch
import sys
import os
# 添加項目根目錄到路徑
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# 導入必要的模塊
from experiments.Models.GastroSegNet_Model import GastroSegNet
from experiments.Training.Segmentation_Block_Training import Segmentation_Block_Training_Step
def test_process_segmentation_output():
# 初始化模型
try:
model_step = Segmentation_Block_Training_Step()
print("初始化成功")
# 創建測試數據
batch_size = 2
channels = 3
height = 256
width = 256
# 獲取設備
device = model_step.device
print(f"使用設備: {device}")
# 創建隨機輸入圖像並移到正確的設備上
input_images = torch.rand(batch_size, channels, height, width).to(device)
# 創建隨機分割輸出 (模擬模型輸出)並移到正確的設備上
segmentation_output = torch.rand(batch_size, 1, height, width).to(device)
# 測試處理方法
try:
processed_images = model_step.process_segmentation_output(input_images, segmentation_output)
print(f"處理後圖像形狀: {processed_images.shape}")
print("處理成功")
return True
except Exception as e:
print(f"處理過程中出錯: {str(e)}")
return False
except Exception as e:
print(f"初始化過程中出錯: {str(e)}")
return False
if __name__ == "__main__":
test_process_segmentation_output()