121 lines
4.6 KiB
Python
121 lines
4.6 KiB
Python
import torch
|
||
import sys
|
||
import os
|
||
import time
|
||
|
||
# 添加當前目錄到系統路徑
|
||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
from experiments.experiment import experiments
|
||
from utils.Stomach_Config import Training_Config, Loading_Config, Save_Result_File_Config
|
||
from Training_Tools.Tools import Tool
|
||
from merge_class.merge import merge
|
||
|
||
def test_main():
|
||
# 測試GPU是否可用
|
||
flag = torch.cuda.is_available()
|
||
if not flag:
|
||
print("CUDA不可用\n")
|
||
else:
|
||
print(f"CUDA可用,數量為{torch.cuda.device_count()}\n")
|
||
|
||
# 测试GPU是否可用
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
print(f"使用设备: {device}")
|
||
if torch.cuda.is_available():
|
||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||
|
||
# 創建一個簡單的測試數據集
|
||
class SimpleDataset(torch.utils.data.Dataset):
|
||
def __init__(self, size=10):
|
||
self.size = size
|
||
self.data = torch.randn(size, 3, 256, 256) # 假設輸入是3通道256x256圖像
|
||
self.mask = torch.randint(0, 2, (size, 1, 256, 256)).float() # 二值掩碼
|
||
self.labels = torch.zeros(size, 3) # 假設有3個類別
|
||
self.labels[:, 0] = 1 # 所有樣本都是第一個類別
|
||
print(f"創建了測試數據集,大小: {size},數據形狀: {self.data.shape}")
|
||
|
||
def __len__(self):
|
||
return self.size
|
||
|
||
def __getitem__(self, idx):
|
||
return self.data[idx], self.mask[idx], self.labels[idx]
|
||
|
||
# 創建測試數據
|
||
print("創建測試數據...")
|
||
test_dataset = SimpleDataset()
|
||
training_data = [test_dataset.data[i] for i in range(test_dataset.size)]
|
||
training_label = [test_dataset.labels[i] for i in range(test_dataset.size)]
|
||
training_mask = [test_dataset.mask[i] for i in range(test_dataset.size)]
|
||
|
||
# 初始化工具和配置
|
||
tool = Tool()
|
||
Status = 1
|
||
Label_Length = 3
|
||
|
||
# 模擬main.py中的數據處理
|
||
print("\n模擬main.py中的數據處理...")
|
||
Merge = merge()
|
||
|
||
# 確保Loading_Config中包含必要的配置項
|
||
required_configs = [
|
||
"Test_Data_Root", "Training_Labels", "Annotation_Root",
|
||
"Label_Image_Labels", "Image enhance processing save root"
|
||
]
|
||
|
||
for config in required_configs:
|
||
if config not in Loading_Config:
|
||
print(f"添加缺少的配置項: {config}")
|
||
if config == "Test_Data_Root":
|
||
Loading_Config[config] = "./test_data"
|
||
elif config == "Training_Labels":
|
||
Loading_Config[config] = ["class1", "class2", "class3"]
|
||
elif config == "Annotation_Root":
|
||
Loading_Config[config] = "./annotations"
|
||
elif config == "Label_Image_Labels":
|
||
Loading_Config[config] = ["CA", "Normal"]
|
||
elif config == "Image enhance processing save root":
|
||
Loading_Config[config] = "./enhanced_images"
|
||
|
||
# 測試experiments類的初始化
|
||
try:
|
||
print("\n測試experiments類的初始化...")
|
||
experiment = experiments(
|
||
Xception_Training_Data=training_data,
|
||
Xception_Training_Label=training_label,
|
||
GastroSegNet_Training_Data=training_data,
|
||
GastroSegNet_Training_Label=training_label,
|
||
GastroSegNet_Training_Mask=training_mask,
|
||
Training_Config=Training_Config,
|
||
Loading_Config=Loading_Config,
|
||
tools=tool,
|
||
Number_Of_Classes=Label_Length,
|
||
status=Status
|
||
)
|
||
print("experiments類初始化成功!")
|
||
|
||
# 測試processing_main方法(僅執行部分代碼)
|
||
print("\n測試processing_main方法...")
|
||
# 這裡我們不實際調用processing_main,因為它會執行完整的訓練流程
|
||
# 而是檢查關鍵屬性是否正確設置
|
||
print(f"模型名稱: {experiment.model_name}")
|
||
print(f"實驗名稱: {experiment.experiment_name}")
|
||
print(f"訓練批次大小: {experiment.train_batch_size}")
|
||
print(f"使用設備: {experiment.device}")
|
||
|
||
# 模擬processing_main的部分功能
|
||
print("\n模擬processing_main的部分功能...")
|
||
# 這裡我們只模擬一些基本操作,不執行實際的訓練
|
||
print("模擬讀取測試數據...")
|
||
experiment.test = training_data
|
||
experiment.test_label = training_label
|
||
|
||
print("\n測試成功!")
|
||
|
||
except Exception as e:
|
||
print(f"\n測試失敗: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
if __name__ == "__main__":
|
||
test_main() |