117 lines
4.8 KiB
Python
117 lines
4.8 KiB
Python
import torch
|
||
import torch.nn.functional as F
|
||
import sys
|
||
import os
|
||
import time
|
||
|
||
# 添加項目根目錄到路徑
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
# 導入Loading_Config
|
||
from utils.Stomach_Config import Loading_Config
|
||
|
||
# 設置Segmentation_Plot_Image配置項
|
||
Loading_Config["Segmentation_Plot_Image"] = "./test_plot.png"
|
||
|
||
# 導入必要的模塊
|
||
from experiments.Models.GastroSegNet_Model import GastroSegNet
|
||
from experiments.Training.Segmentation_Block_Training import Segmentation_Block_Training_Step
|
||
|
||
def test_processing_main():
|
||
# 初始化模型
|
||
try:
|
||
model_step = Segmentation_Block_Training_Step()
|
||
print("初始化成功")
|
||
|
||
# 獲取設備
|
||
device = model_step.device
|
||
print(f"使用設備: {device}")
|
||
|
||
# 創建一個簡單的測試數據集
|
||
class SimpleDataset(torch.utils.data.Dataset):
|
||
def __init__(self, size=10):
|
||
self.size = size
|
||
# 將數據放在CPU上,避免CUDA轉NumPy的問題
|
||
self.data = torch.rand(size, 3, 256, 256)
|
||
self.masks = torch.rand(size, 1, 256, 256) > 0.5
|
||
self.masks = self.masks.float()
|
||
self.labels = torch.zeros(size, 3) # 假設有3個類別
|
||
self.labels[:, 0] = 1 # 所有樣本都是第一個類別
|
||
print(f"創建了測試數據集,大小: {size},數據形狀: {self.data.shape},掩碼形狀: {self.masks.shape},標籤形狀: {self.labels.shape}")
|
||
|
||
def __len__(self):
|
||
return self.size
|
||
|
||
def __getitem__(self, idx):
|
||
return self.data[idx], self.masks[idx], f"sample_{idx}", self.labels[idx]
|
||
|
||
# 創建測試數據集
|
||
print("創建測試數據集和數據加載器...")
|
||
test_dataset = SimpleDataset(size=10) # 使用10個樣本以滿足KFold的要求
|
||
|
||
# 創建測試數據加載器
|
||
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False)
|
||
print("測試數據集和數據加載器創建完成")
|
||
|
||
# 創建一個自定義的損失計算函數,用於捕獲輸入形狀
|
||
def mock_losses(Segmentation_Output_Image, Segmentation_Mask_GroundTruth_Image):
|
||
print(f"Losses方法輸入形狀 - 預測: {Segmentation_Output_Image.shape}, 目標: {Segmentation_Mask_GroundTruth_Image.shape}")
|
||
# 創建一個可以進行反向傳播的損失張量
|
||
dummy_loss = Segmentation_Output_Image.mean()
|
||
print(f"創建了虛擬損失,值: {dummy_loss.item()}")
|
||
return dummy_loss
|
||
|
||
# 創建一個簡化版的Calculate_Progress_And_Timing函數,避免遞歸調用
|
||
def mock_calculate_progress_and_timing(*args):
|
||
# 直接返回最後一個參數(epoch_iterator)
|
||
return args[-1]
|
||
|
||
# 保存原始的方法
|
||
original_losses = model_step.Losses
|
||
original_calculate_progress = model_step.Calculate_Progress_And_Timing
|
||
|
||
# 測試Processing_Main方法
|
||
try:
|
||
# 替換方法
|
||
model_step.Losses = mock_losses
|
||
model_step.Calculate_Progress_And_Timing = mock_calculate_progress_and_timing
|
||
|
||
# Loading_Config已在全局設置
|
||
|
||
# 調用Processing_Main方法,設置return_processed_images=True並提供test_dataloader
|
||
print("\n開始測試Processing_Main方法...")
|
||
result = model_step.Processing_Main(test_dataset, return_processed_images=True, test_dataloader=test_loader)
|
||
|
||
# 恢復原始的方法
|
||
model_step.Losses = original_losses
|
||
model_step.Calculate_Progress_And_Timing = original_calculate_progress
|
||
|
||
# 檢查結果
|
||
print("Processing_Main方法測試成功")
|
||
print(f"結果類型: {type(result)}")
|
||
|
||
# 如果結果是元組,打印每個元素
|
||
if isinstance(result, tuple):
|
||
for i, item in enumerate(result):
|
||
if isinstance(item, torch.Tensor):
|
||
print(f"結果[{i}]: {type(item)} - 形狀: {item.shape}")
|
||
else:
|
||
print(f"結果[{i}]: {type(item)}")
|
||
|
||
return True
|
||
except Exception as e:
|
||
print(f"Processing_Main方法測試失敗: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
except Exception as e:
|
||
print(f"初始化過程中出錯: {str(e)}")
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
success = test_processing_main()
|
||
if success:
|
||
print("\n所有測試通過!")
|
||
else:
|
||
print("\n測試失敗!") |