import torch import sys import os import time # 添加當前目錄到系統路徑 sys.path.append(os.path.dirname(os.path.abspath(__file__))) from experiments.experiment import experiments from utils.Stomach_Config import Training_Config, Loading_Config, Save_Result_File_Config from Training_Tools.Tools import Tool from merge_class.merge import merge def test_main(): # 測試GPU是否可用 flag = torch.cuda.is_available() if not flag: print("CUDA不可用\n") else: print(f"CUDA可用,數量為{torch.cuda.device_count()}\n") # 测试GPU是否可用 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"使用设备: {device}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") # 創建一個簡單的測試數據集 class SimpleDataset(torch.utils.data.Dataset): def __init__(self, size=10): self.size = size self.data = torch.randn(size, 3, 256, 256) # 假設輸入是3通道256x256圖像 self.mask = torch.randint(0, 2, (size, 1, 256, 256)).float() # 二值掩碼 self.labels = torch.zeros(size, 3) # 假設有3個類別 self.labels[:, 0] = 1 # 所有樣本都是第一個類別 print(f"創建了測試數據集,大小: {size},數據形狀: {self.data.shape}") def __len__(self): return self.size def __getitem__(self, idx): return self.data[idx], self.mask[idx], self.labels[idx] # 創建測試數據 print("創建測試數據...") test_dataset = SimpleDataset() training_data = [test_dataset.data[i] for i in range(test_dataset.size)] training_label = [test_dataset.labels[i] for i in range(test_dataset.size)] training_mask = [test_dataset.mask[i] for i in range(test_dataset.size)] # 初始化工具和配置 tool = Tool() Status = 1 Label_Length = 3 # 模擬main.py中的數據處理 print("\n模擬main.py中的數據處理...") Merge = merge() # 確保Loading_Config中包含必要的配置項 required_configs = [ "Test_Data_Root", "Training_Labels", "Annotation_Root", "Label_Image_Labels", "Image enhance processing save root" ] for config in required_configs: if config not in Loading_Config: print(f"添加缺少的配置項: {config}") if config == "Test_Data_Root": Loading_Config[config] = "./test_data" elif config == "Training_Labels": Loading_Config[config] = ["class1", "class2", "class3"] elif config == "Annotation_Root": Loading_Config[config] = "./annotations" elif config == "Label_Image_Labels": Loading_Config[config] = ["CA", "Normal"] elif config == "Image enhance processing save root": Loading_Config[config] = "./enhanced_images" # 測試experiments類的初始化 try: print("\n測試experiments類的初始化...") experiment = experiments( Xception_Training_Data=training_data, Xception_Training_Label=training_label, GastroSegNet_Training_Data=training_data, GastroSegNet_Training_Label=training_label, GastroSegNet_Training_Mask=training_mask, Training_Config=Training_Config, Loading_Config=Loading_Config, tools=tool, Number_Of_Classes=Label_Length, status=Status ) print("experiments類初始化成功!") # 測試processing_main方法(僅執行部分代碼) print("\n測試processing_main方法...") # 這裡我們不實際調用processing_main,因為它會執行完整的訓練流程 # 而是檢查關鍵屬性是否正確設置 print(f"模型名稱: {experiment.model_name}") print(f"實驗名稱: {experiment.experiment_name}") print(f"訓練批次大小: {experiment.train_batch_size}") print(f"使用設備: {experiment.device}") # 模擬processing_main的部分功能 print("\n模擬processing_main的部分功能...") # 這裡我們只模擬一些基本操作,不執行實際的訓練 print("模擬讀取測試數據...") experiment.test = training_data experiment.test_label = training_label print("\n測試成功!") except Exception as e: print(f"\n測試失敗: {str(e)}") import traceback traceback.print_exc() if __name__ == "__main__": test_main()