Files
Stomach_Cancer_Pytorch/test_processing_main.py

117 lines
4.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import torch.nn.functional as F
import sys
import os
import time
# 添加項目根目錄到路徑
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# 導入Loading_Config
from utils.Stomach_Config import Loading_Config
# 設置Segmentation_Plot_Image配置項
Loading_Config["Segmentation_Plot_Image"] = "./test_plot.png"
# 導入必要的模塊
from experiments.Models.GastroSegNet_Model import GastroSegNet
from experiments.Training.Segmentation_Block_Training import Segmentation_Block_Training_Step
def test_processing_main():
# 初始化模型
try:
model_step = Segmentation_Block_Training_Step()
print("初始化成功")
# 獲取設備
device = model_step.device
print(f"使用設備: {device}")
# 創建一個簡單的測試數據集
class SimpleDataset(torch.utils.data.Dataset):
def __init__(self, size=10):
self.size = size
# 將數據放在CPU上避免CUDA轉NumPy的問題
self.data = torch.rand(size, 3, 256, 256)
self.masks = torch.rand(size, 1, 256, 256) > 0.5
self.masks = self.masks.float()
self.labels = torch.zeros(size, 3) # 假設有3個類別
self.labels[:, 0] = 1 # 所有樣本都是第一個類別
print(f"創建了測試數據集,大小: {size},數據形狀: {self.data.shape},掩碼形狀: {self.masks.shape},標籤形狀: {self.labels.shape}")
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx], self.masks[idx], f"sample_{idx}", self.labels[idx]
# 創建測試數據集
print("創建測試數據集和數據加載器...")
test_dataset = SimpleDataset(size=10) # 使用10個樣本以滿足KFold的要求
# 創建測試數據加載器
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False)
print("測試數據集和數據加載器創建完成")
# 創建一個自定義的損失計算函數,用於捕獲輸入形狀
def mock_losses(Segmentation_Output_Image, Segmentation_Mask_GroundTruth_Image):
print(f"Losses方法輸入形狀 - 預測: {Segmentation_Output_Image.shape}, 目標: {Segmentation_Mask_GroundTruth_Image.shape}")
# 創建一個可以進行反向傳播的損失張量
dummy_loss = Segmentation_Output_Image.mean()
print(f"創建了虛擬損失,值: {dummy_loss.item()}")
return dummy_loss
# 創建一個簡化版的Calculate_Progress_And_Timing函數避免遞歸調用
def mock_calculate_progress_and_timing(*args):
# 直接返回最後一個參數epoch_iterator
return args[-1]
# 保存原始的方法
original_losses = model_step.Losses
original_calculate_progress = model_step.Calculate_Progress_And_Timing
# 測試Processing_Main方法
try:
# 替換方法
model_step.Losses = mock_losses
model_step.Calculate_Progress_And_Timing = mock_calculate_progress_and_timing
# Loading_Config已在全局設置
# 調用Processing_Main方法設置return_processed_images=True並提供test_dataloader
print("\n開始測試Processing_Main方法...")
result = model_step.Processing_Main(test_dataset, return_processed_images=True, test_dataloader=test_loader)
# 恢復原始的方法
model_step.Losses = original_losses
model_step.Calculate_Progress_And_Timing = original_calculate_progress
# 檢查結果
print("Processing_Main方法測試成功")
print(f"結果類型: {type(result)}")
# 如果結果是元組,打印每個元素
if isinstance(result, tuple):
for i, item in enumerate(result):
if isinstance(item, torch.Tensor):
print(f"結果[{i}]: {type(item)} - 形狀: {item.shape}")
else:
print(f"結果[{i}]: {type(item)}")
return True
except Exception as e:
print(f"Processing_Main方法測試失敗: {str(e)}")
import traceback
traceback.print_exc()
return False
except Exception as e:
print(f"初始化過程中出錯: {str(e)}")
return False
if __name__ == "__main__":
success = test_processing_main()
if success:
print("\n所有測試通過!")
else:
print("\n測試失敗!")