97 lines
4.1 KiB
Python
97 lines
4.1 KiB
Python
from experiments.experiment import experiments
|
||
from Image_Process.load_and_ImageGenerator import Load_ImageGenerator
|
||
from Read_and_process_image.ReadAndProcess import Read_image_and_Process_image
|
||
from Training_Tools.Tools import Tool
|
||
from model_data_processing.processing import Balance_Process
|
||
from Load_process.LoadData import Load_Data_Prepare
|
||
from Calculate_Process.Calculate import Calculate
|
||
from merge_class.merge import merge
|
||
|
||
import time
|
||
import torch
|
||
|
||
if __name__ == "__main__":
|
||
# 測試GPU是否可用
|
||
flag = torch.cuda.is_available()
|
||
if not flag:
|
||
print("CUDA不可用\n")
|
||
else:
|
||
print(f"CUDA可用,數量為{torch.cuda.device_count()}\n")
|
||
|
||
# 參數設定
|
||
tool = Tool()
|
||
tool.Set_Labels()
|
||
tool.Set_Save_Roots()
|
||
|
||
Status = 1 # 決定要使用什麼資料集
|
||
Labels = tool.Get_Data_Label()
|
||
Trainig_Root, Testing_Root = tool.Get_Save_Roots(Status) # 一般的
|
||
Generator_Root = tool.Get_Generator_Save_Roots(Status)
|
||
|
||
# 取得One-hot encording 的資料
|
||
tool.Set_OneHotEncording(Labels)
|
||
Encording_Label = tool.Get_OneHot_Encording_Label()
|
||
|
||
Label_Length = len(Labels)
|
||
Classification = 3 # 分類數量
|
||
|
||
Model_Name = "Xception" # 取名,告訴我我是用哪個模型(可能是預處理模型/自己設計的模型)
|
||
Experiment_Name = "Xception Skin to train Normal stomach cancer"
|
||
Epoch = 10000
|
||
Train_Batch_Size = 64
|
||
Image_Size = 256
|
||
|
||
Prepare = Load_Data_Prepare()
|
||
loading_data = Load_ImageGenerator(Trainig_Root, Testing_Root, Generator_Root, Labels, Image_Size)
|
||
experiment = experiments(Image_Size, Model_Name, Experiment_Name, Epoch, Train_Batch_Size, tool, Classification, Status)
|
||
image_processing = Read_image_and_Process_image(Image_Size)
|
||
Merge = merge()
|
||
Calculate_Tool = Calculate()
|
||
|
||
counter = 1
|
||
Batch_Size = 128
|
||
Train_Size = 0
|
||
|
||
for Run_Range in range(0, counter, 1): # 做規定次數的訓練
|
||
# 讀取資料
|
||
Data_Dict_Data = loading_data.process_main(Label_Length)
|
||
# Data_Dict_Data, Train_Size = Balance_Process(Data_Dict_Data, Labels)
|
||
|
||
for label in Labels:
|
||
Train_Size += len(Data_Dict_Data[label])
|
||
|
||
print("總共有 " + str(Train_Size) + " 筆資料")
|
||
|
||
# 做出跟資料相同數量的Label
|
||
Classes = []
|
||
i = 0
|
||
for encording in Encording_Label:
|
||
Classes.append(image_processing.make_label_list(Train_Size, encording))
|
||
i += 1
|
||
|
||
# 將資料做成Dict的資料型態
|
||
Prepare.Set_Final_Dict_Data(Labels, Data_Dict_Data, Classes, Label_Length)
|
||
Final_Dict_Data = Prepare.Get_Final_Data_Dict()
|
||
keys = list(Final_Dict_Data.keys())
|
||
|
||
training_data = Merge.merge_all_image_data(Final_Dict_Data[keys[0]], Final_Dict_Data[keys[1]]) # 將訓練資料合併成一個list
|
||
for i in range(2, Label_Length):
|
||
training_data = Merge.merge_all_image_data(training_data, Final_Dict_Data[keys[i]]) # 將訓練資料合併成一個list
|
||
|
||
training_label = Merge.merge_all_image_data(Final_Dict_Data[keys[Label_Length]], Final_Dict_Data[keys[Label_Length + 1]]) #將訓練資料的label合併成一個label的list
|
||
for i in range(Label_Length + 2, 2 * Label_Length):
|
||
training_label = Merge.merge_all_image_data(training_label, Final_Dict_Data[keys[i]]) # 將訓練資料合併成一個list
|
||
|
||
start = time.time()
|
||
trains_Data_Image = image_processing.Data_Augmentation_Image(training_data) # 讀檔
|
||
Training_Data, Training_Label = image_processing.image_data_processing(trains_Data_Image, training_label) # 將讀出來的檔做正規化。降label轉成numpy array 格式
|
||
|
||
|
||
# training_data = image_processing.normalization(training_data)
|
||
# training_data = training_data.permute(0, 3, 1, 2)
|
||
|
||
end = time.time()
|
||
print("\n\n\n讀取訓練資料(70000)執行時間:%f 秒\n\n" % (end - start))
|
||
|
||
experiment.processing_main(Training_Data, Training_Label, Run_Range) # 執行訓練方法
|
||
|