Files
Stomach_Cancer_Pytorch/test_main.py

121 lines
4.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import sys
import os
import time
# 添加當前目錄到系統路徑
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from experiments.experiment import experiments
from utils.Stomach_Config import Training_Config, Loading_Config, Save_Result_File_Config
from Training_Tools.Tools import Tool
from merge_class.merge import merge
def test_main():
# 測試GPU是否可用
flag = torch.cuda.is_available()
if not flag:
print("CUDA不可用\n")
else:
print(f"CUDA可用數量為{torch.cuda.device_count()}\n")
# 测试GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
# 創建一個簡單的測試數據集
class SimpleDataset(torch.utils.data.Dataset):
def __init__(self, size=10):
self.size = size
self.data = torch.randn(size, 3, 256, 256) # 假設輸入是3通道256x256圖像
self.mask = torch.randint(0, 2, (size, 1, 256, 256)).float() # 二值掩碼
self.labels = torch.zeros(size, 3) # 假設有3個類別
self.labels[:, 0] = 1 # 所有樣本都是第一個類別
print(f"創建了測試數據集,大小: {size},數據形狀: {self.data.shape}")
def __len__(self):
return self.size
def __getitem__(self, idx):
return self.data[idx], self.mask[idx], self.labels[idx]
# 創建測試數據
print("創建測試數據...")
test_dataset = SimpleDataset()
training_data = [test_dataset.data[i] for i in range(test_dataset.size)]
training_label = [test_dataset.labels[i] for i in range(test_dataset.size)]
training_mask = [test_dataset.mask[i] for i in range(test_dataset.size)]
# 初始化工具和配置
tool = Tool()
Status = 1
Label_Length = 3
# 模擬main.py中的數據處理
print("\n模擬main.py中的數據處理...")
Merge = merge()
# 確保Loading_Config中包含必要的配置項
required_configs = [
"Test_Data_Root", "Training_Labels", "Annotation_Root",
"Label_Image_Labels", "Image enhance processing save root"
]
for config in required_configs:
if config not in Loading_Config:
print(f"添加缺少的配置項: {config}")
if config == "Test_Data_Root":
Loading_Config[config] = "./test_data"
elif config == "Training_Labels":
Loading_Config[config] = ["class1", "class2", "class3"]
elif config == "Annotation_Root":
Loading_Config[config] = "./annotations"
elif config == "Label_Image_Labels":
Loading_Config[config] = ["CA", "Normal"]
elif config == "Image enhance processing save root":
Loading_Config[config] = "./enhanced_images"
# 測試experiments類的初始化
try:
print("\n測試experiments類的初始化...")
experiment = experiments(
Xception_Training_Data=training_data,
Xception_Training_Label=training_label,
GastroSegNet_Training_Data=training_data,
GastroSegNet_Training_Label=training_label,
GastroSegNet_Training_Mask=training_mask,
Training_Config=Training_Config,
Loading_Config=Loading_Config,
tools=tool,
Number_Of_Classes=Label_Length,
status=Status
)
print("experiments類初始化成功!")
# 測試processing_main方法僅執行部分代碼
print("\n測試processing_main方法...")
# 這裡我們不實際調用processing_main因為它會執行完整的訓練流程
# 而是檢查關鍵屬性是否正確設置
print(f"模型名稱: {experiment.model_name}")
print(f"實驗名稱: {experiment.experiment_name}")
print(f"訓練批次大小: {experiment.train_batch_size}")
print(f"使用設備: {experiment.device}")
# 模擬processing_main的部分功能
print("\n模擬processing_main的部分功能...")
# 這裡我們只模擬一些基本操作,不執行實際的訓練
print("模擬讀取測試數據...")
experiment.test = training_data
experiment.test_label = training_label
print("\n測試成功!")
except Exception as e:
print(f"\n測試失敗: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_main()