Files
Stomach_Cancer_Pytorch/main.py

97 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from experiments.experiment import experiments
from Image_Process.load_and_ImageGenerator import Load_ImageGenerator
from Read_and_process_image.ReadAndProcess import Read_image_and_Process_image
from Training_Tools.Tools import Tool
from model_data_processing.processing import Balance_Process
from Load_process.LoadData import Load_Data_Prepare
from Calculate_Process.Calculate import Calculate
from merge_class.merge import merge
import time
import torch
if __name__ == "__main__":
# 測試GPU是否可用
flag = torch.cuda.is_available()
if not flag:
print("CUDA不可用\n")
else:
print(f"CUDA可用數量為{torch.cuda.device_count()}\n")
# 參數設定
tool = Tool()
tool.Set_Labels()
tool.Set_Save_Roots()
Status = 1 # 決定要使用什麼資料集
Labels = tool.Get_Data_Label()
Trainig_Root, Testing_Root = tool.Get_Save_Roots(Status) # 一般的
Generator_Root = tool.Get_Generator_Save_Roots(Status)
# 取得One-hot encording 的資料
tool.Set_OneHotEncording(Labels)
Encording_Label = tool.Get_OneHot_Encording_Label()
Label_Length = len(Labels)
Classification = 3 # 分類數量
Model_Name = "Xception" # 取名,告訴我我是用哪個模型(可能是預處理模型/自己設計的模型)
Experiment_Name = "Xception Skin to train Normal stomach cancer"
Epoch = 10000
Train_Batch_Size = 64
Image_Size = 256
Prepare = Load_Data_Prepare()
loading_data = Load_ImageGenerator(Trainig_Root, Testing_Root, Generator_Root, Labels, Image_Size)
experiment = experiments(Image_Size, Model_Name, Experiment_Name, Epoch, Train_Batch_Size, tool, Classification, Status)
image_processing = Read_image_and_Process_image(Image_Size)
Merge = merge()
Calculate_Tool = Calculate()
counter = 1
Batch_Size = 128
Train_Size = 0
for Run_Range in range(0, counter, 1): # 做規定次數的訓練
# 讀取資料
Data_Dict_Data = loading_data.process_main(Label_Length)
# Data_Dict_Data, Train_Size = Balance_Process(Data_Dict_Data, Labels)
for label in Labels:
Train_Size += len(Data_Dict_Data[label])
print("總共有 " + str(Train_Size) + " 筆資料")
# 做出跟資料相同數量的Label
Classes = []
i = 0
for encording in Encording_Label:
Classes.append(image_processing.make_label_list(Train_Size, encording))
i += 1
# 將資料做成Dict的資料型態
Prepare.Set_Final_Dict_Data(Labels, Data_Dict_Data, Classes, Label_Length)
Final_Dict_Data = Prepare.Get_Final_Data_Dict()
keys = list(Final_Dict_Data.keys())
training_data = Merge.merge_all_image_data(Final_Dict_Data[keys[0]], Final_Dict_Data[keys[1]]) # 將訓練資料合併成一個list
for i in range(2, Label_Length):
training_data = Merge.merge_all_image_data(training_data, Final_Dict_Data[keys[i]]) # 將訓練資料合併成一個list
training_label = Merge.merge_all_image_data(Final_Dict_Data[keys[Label_Length]], Final_Dict_Data[keys[Label_Length + 1]]) #將訓練資料的label合併成一個label的list
for i in range(Label_Length + 2, 2 * Label_Length):
training_label = Merge.merge_all_image_data(training_label, Final_Dict_Data[keys[i]]) # 將訓練資料合併成一個list
start = time.time()
trains_Data_Image = image_processing.Data_Augmentation_Image(training_data) # 讀檔
Training_Data, Training_Label = image_processing.image_data_processing(trains_Data_Image, training_label) # 將讀出來的檔做正規化。降label轉成numpy array 格式
# training_data = image_processing.normalization(training_data)
# training_data = training_data.permute(0, 3, 1, 2)
end = time.time()
print("\n\n\n讀取訓練資料(70000)執行時間:%f\n\n" % (end - start))
experiment.processing_main(Training_Data, Training_Label, Run_Range) # 執行訓練方法