Files
Stomach_Cancer_Pytorch/main.py

157 lines
7.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from experiments.experiment import experiments
from Image_Process.Image_Generator import Image_generator
from Training_Tools.Tools import Tool
from model_data_processing.processing import make_label_list, Read_Image_Root_And_Image_Enhance
from Load_process.LoadData import Load_Data_Prepare
from Calculate_Process.Calculate import Calculate
from merge_class.merge import merge
from model_data_processing.processing_for_cut_image import Cut_Indepentend_Data
from Load_process.LoadData import Loding_Data_Root
from Load_process.file_processing import Process_File
from utils.Stomach_Config import Training_Config, Loading_Config
from Image_Process.Image_Mask_Ground_Truth_Processing import XMLAnnotationProcessor
import time
import torch
if __name__ == "__main__":
# 測試GPU是否可用
flag = torch.cuda.is_available()
if not flag:
print("CUDA不可用\n")
else:
print(f"CUDA可用數量為{torch.cuda.device_count()}\n")
# 测试GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
tool = Tool()
Status = 1
# 取得One-hot encording 的資料
tool.Set_OneHotEncording(Loading_Config["Training_Labels"])
Encording_Label = tool.Get_OneHot_Encording_Label()
Label_Length = len(Loading_Config["Training_Labels"])
Prepare = Load_Data_Prepare()
Indepentend = Cut_Indepentend_Data(Loading_Config["Train_Data_Root"], Loading_Config["Training_Labels"])
Merge = merge()
Calculate_Tool = Calculate()
file = Process_File()
Train_Size = 0
# 讀取資料
# 將測試資料獨立出來
test_size = 0.2
Indepentend.IndependentData_main(Loading_Config["Test_Data_Root"], test_size)
# # 創建處理切割的部分
# if not file.JudgeRoot_MakeDir(Loading_Config['Annotation_Training_Root']) and not file.JudgeRoot_MakeDir(Loading_Config['Annotation_Testing_Root']):
# processor_train = XMLAnnotationProcessor(
# dataset_root = Loading_Config["Train_Data_Root"],
# )
# processor_test = XMLAnnotationProcessor(
# dataset_root = Loading_Config["Test_Data_Root"],
# )
# # 設定自訂樣式(可選)
# processor_train.set_drawing_style(
# box_color=(0, 0, 255), # 紅色邊界框
# text_color=(255, 255, 255), # 白色文字
# box_thickness=3,
# font_scale=0.7
# )
# processor_test.set_drawing_style(
# box_color=(0, 0, 255), # 紅色邊界框
# text_color=(255, 255, 255), # 白色文字
# box_thickness=3,
# font_scale=0.7
# )
# print("XML標註處理器已準備就绪")
# for Label_Iamge_List in Loading_Config["Label_Image_Labels"]:
# if Label_Iamge_List == "CA":
# Label = "stomach_cancer_Crop"
# else:
# Label = "Have_Question_Crop"
# training_results = processor_train.process_multiple_xml(f"../Label_Image/{Label_Iamge_List}", Loading_Config['Annotation_Training_Root'], Label)
# testing_results = processor_test.process_multiple_xml(f"../Label_Image/{Label_Iamge_List}", Loading_Config['Annotation_Testing_Root'], Label)
# else:
# print("Training and Testing annoation is exist!!!!")
# # 讀取切割完成後的檔案
# print("Mask Ground truth is Finished\n")
# Mask_load = Loding_Data_Root(Loading_Config["XML_Loading_Label"], Loading_Config['Annotation_Training_Root'], "")
# Mask_Data_Dict_Data = Mask_load.process_main(False)
# Total_Size_Lists = []
# print("Mask資料集總數")
# for label in Loading_Config["XML_Loading_Label"]:
# Train_Size += len(Mask_Data_Dict_Data[label])
# Total_Size_Lists.append(len(Mask_Data_Dict_Data[label]))
# print(f"Labels: {label}, 總數為: {len(Mask_Data_Dict_Data[label])}")
# print("總共有 " + str(Train_Size) + " 筆資料")
# 讀取原始資料集
load = Loding_Data_Root(Loading_Config["Training_Labels"], Loading_Config["Train_Data_Root"], Loading_Config["ImageGenerator_Data_Root"])
Data_Dict_Data = load.process_main(False)
# # 製作資料增強資料
# if not file.Judge_File_Exist(Loading_Config['Image enhance processing save root']):
# for label in Loading_Config["Training_Labels"]:
# Read_Image_Root_And_Image_Enhance(Data_Dict_Data[label], f"{Loading_Config['Image enhance processing save root']}/{label}")
# tmp_load = Loding_Data_Root(Loading_Config["Training_Labels"], Loading_Config['Image enhance processing save root'], Loading_Config["ImageGenerator_Data_Root"])
# Data_Dict_Data = tmp_load.process_main(False)
Total_Size_List = []
print("前處理後資料集總數")
for label in Loading_Config["Training_Labels"]:
Train_Size += len(Data_Dict_Data[label])
Total_Size_List.append(len(Data_Dict_Data[label]))
print(f"Labels: {label}, 總數為: {len(Data_Dict_Data[label])}")
print("總共有 " + str(Train_Size) + " 筆資料")
# 做出跟資料相同數量的Label
Classes = []
i = 0
for encording in Encording_Label:
Classes.append(make_label_list(Total_Size_List[i], encording))
i += 1
# 將資料做成Dict的資料型態
Prepare.Set_Final_Dict_Data(Loading_Config["Training_Labels"], Data_Dict_Data, Classes, Label_Length)
Final_Dict_Data = Prepare.Get_Final_Data_Dict()
keys = list(Final_Dict_Data.keys())
# Mask_Keys = list(Mask_Data_Dict_Data.keys())
Training_Data = Merge.merge_all_image_data(Final_Dict_Data[keys[0]], Final_Dict_Data[keys[1]]) # 將訓練資料合併成一個list
for i in range(2, Label_Length):
Training_Data = Merge.merge_all_image_data(Training_Data, Final_Dict_Data[keys[i]]) # 將訓練資料合併成一個list
Training_Label = Merge.merge_all_image_data(Final_Dict_Data[keys[Label_Length]], Final_Dict_Data[keys[Label_Length + 1]]) #將訓練資料的label合併成一個label的list
for i in range(Label_Length + 2, 2 * Label_Length):
Training_Label = Merge.merge_all_image_data(Training_Label, Final_Dict_Data[keys[i]]) # 將訓練資料合併成一個list
# Training_Mask_Data = Merge.merge_all_image_data(Mask_Data_Dict_Data[Mask_Keys[0]], Mask_Data_Dict_Data[Mask_Keys[1]]) # 將訓練資料合併成一個list
# for i in range(2, len(Mask_Keys)):
# Training_Mask_Data = Merge.merge_all_image_data(Training_Mask_Data, Mask_Data_Dict_Data[Mask_Keys[i]]) # 將訓練資料合併成一個list
experiment = experiments(
Xception_Training_Data=Training_Data,
Xception_Training_Label=Training_Label,
Xception_Training_Mask_Data=None,
status=Status
)
start = time.time()
experiment.processing_main() # 執行訓練方法
end = time.time()
print(f"\n\n\n訓練時間:{end - start}\n\n")