import torch import torch.nn.functional as F import sys import os import time # 添加項目根目錄到路徑 sys.path.append(os.path.dirname(os.path.abspath(__file__))) # 導入Loading_Config from utils.Stomach_Config import Loading_Config # 設置Segmentation_Plot_Image配置項 Loading_Config["Segmentation_Plot_Image"] = "./test_plot.png" # 導入必要的模塊 from experiments.Models.GastroSegNet_Model import GastroSegNet from experiments.Training.Segmentation_Block_Training import Segmentation_Block_Training_Step def test_processing_main(): # 初始化模型 try: model_step = Segmentation_Block_Training_Step() print("初始化成功") # 獲取設備 device = model_step.device print(f"使用設備: {device}") # 創建一個簡單的測試數據集 class SimpleDataset(torch.utils.data.Dataset): def __init__(self, size=10): self.size = size # 將數據放在CPU上,避免CUDA轉NumPy的問題 self.data = torch.rand(size, 3, 256, 256) self.masks = torch.rand(size, 1, 256, 256) > 0.5 self.masks = self.masks.float() self.labels = torch.zeros(size, 3) # 假設有3個類別 self.labels[:, 0] = 1 # 所有樣本都是第一個類別 print(f"創建了測試數據集,大小: {size},數據形狀: {self.data.shape},掩碼形狀: {self.masks.shape},標籤形狀: {self.labels.shape}") def __len__(self): return self.size def __getitem__(self, idx): return self.data[idx], self.masks[idx], f"sample_{idx}", self.labels[idx] # 創建測試數據集 print("創建測試數據集和數據加載器...") test_dataset = SimpleDataset(size=10) # 使用10個樣本以滿足KFold的要求 # 創建測試數據加載器 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False) print("測試數據集和數據加載器創建完成") # 創建一個自定義的損失計算函數,用於捕獲輸入形狀 def mock_losses(Segmentation_Output_Image, Segmentation_Mask_GroundTruth_Image): print(f"Losses方法輸入形狀 - 預測: {Segmentation_Output_Image.shape}, 目標: {Segmentation_Mask_GroundTruth_Image.shape}") # 創建一個可以進行反向傳播的損失張量 dummy_loss = Segmentation_Output_Image.mean() print(f"創建了虛擬損失,值: {dummy_loss.item()}") return dummy_loss # 創建一個簡化版的Calculate_Progress_And_Timing函數,避免遞歸調用 def mock_calculate_progress_and_timing(*args): # 直接返回最後一個參數(epoch_iterator) return args[-1] # 保存原始的方法 original_losses = model_step.Losses original_calculate_progress = model_step.Calculate_Progress_And_Timing # 測試Processing_Main方法 try: # 替換方法 model_step.Losses = mock_losses model_step.Calculate_Progress_And_Timing = mock_calculate_progress_and_timing # Loading_Config已在全局設置 # 調用Processing_Main方法,設置return_processed_images=True並提供test_dataloader print("\n開始測試Processing_Main方法...") result = model_step.Processing_Main(test_dataset, return_processed_images=True, test_dataloader=test_loader) # 恢復原始的方法 model_step.Losses = original_losses model_step.Calculate_Progress_And_Timing = original_calculate_progress # 檢查結果 print("Processing_Main方法測試成功") print(f"結果類型: {type(result)}") # 如果結果是元組,打印每個元素 if isinstance(result, tuple): for i, item in enumerate(result): if isinstance(item, torch.Tensor): print(f"結果[{i}]: {type(item)} - 形狀: {item.shape}") else: print(f"結果[{i}]: {type(item)}") return True except Exception as e: print(f"Processing_Main方法測試失敗: {str(e)}") import traceback traceback.print_exc() return False except Exception as e: print(f"初始化過程中出錯: {str(e)}") return False if __name__ == "__main__": success = test_processing_main() if success: print("\n所有測試通過!") else: print("\n測試失敗!")