Firsh Push at 20241207
This commit is contained in:
199
experiments/experiment.py
Normal file
199
experiments/experiment.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from all_models_tools.all_model_tools import call_back
|
||||
from Read_and_process_image.ReadAndProcess import Read_image_and_Process_image
|
||||
from draw_tools.draw import plot_history, Confusion_Matrix_of_Two_Classification
|
||||
from keras import regularizers
|
||||
from Load_process.Load_Indepentend import Load_Indepentend_Data
|
||||
from _validation.ValidationTheEnterData import validation_the_enter_data
|
||||
from keras.layers import GlobalAveragePooling2D, Dense, Dropout
|
||||
from keras.applications import Xception
|
||||
from Load_process.file_processing import Process_File
|
||||
from merge_class.merge import merge
|
||||
from draw_tools.Grad_cam import Grad_CAM
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from keras.models import Model
|
||||
from keras.optimizers import SGD
|
||||
from Image_Process.Image_Generator import Image_generator
|
||||
import pandas as pd
|
||||
import keras
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
class experiments():
|
||||
def __init__(self, tools, status):
|
||||
'''
|
||||
parmeter:
|
||||
* validation_obj : 驗證物件
|
||||
* cut_image : 切割影像物件
|
||||
* image_processing : 讀檔與讀檔資料處理物件
|
||||
* merge : 合併物件
|
||||
* model_name: 取名,告訴我我是用哪個模型(可能是預處理模型/自己設計的模型)
|
||||
* generator_batch_size: 每一批次要讀多少檔出來
|
||||
* epoch: 訓練次數
|
||||
* train_batch_size: 訓練時要多少批次的影像為1組
|
||||
* generator_batch_size: 減少圖片數量對GPU記憶體的用量, 減少張數用的
|
||||
* experiment_name : 本次實驗名稱
|
||||
* convolution_name : Pre-train model 的最後一層Convolotion的名稱
|
||||
'''
|
||||
|
||||
self.Topic_Tool = tools
|
||||
|
||||
self.validation_obj = validation_the_enter_data() # 呼叫驗證物件
|
||||
self.cut_image = Load_Indepentend_Data(self.Topic_Tool.Get_Data_Label(), self.Topic_Tool.Get_OneHot_Encording_Label()) # 呼叫切割影像物件
|
||||
self.image_processing = Read_image_and_Process_image()
|
||||
self.merge = merge()
|
||||
|
||||
self.model_name = "Xception" # 取名,告訴我我是用哪個模型(可能是預處理模型/自己設計的模型)
|
||||
self.experiment_name = "Xception Skin to train Normal stomach cancer"
|
||||
# self.file_name = "Remove background of Chickenpox with normal image"
|
||||
self.generator_batch_size = 50
|
||||
self.epoch = 10000
|
||||
self.train_batch_size = 128
|
||||
self.layers = 1
|
||||
self.convolution_name = self.get_layer_name(self.model_name)
|
||||
|
||||
self.Grad = ""
|
||||
self.Status = status
|
||||
|
||||
pass
|
||||
|
||||
def processing_main(self, train, train_label, counter):
|
||||
Train, Test, Validation = self.Topic_Tool.Get_Save_Roots(self.Status) # 要換不同資料集就要改
|
||||
|
||||
|
||||
start = time.time()
|
||||
self.cut_image.process_main(Test, Validation) # 呼叫處理test Data與Validation Data的function
|
||||
end = time.time()
|
||||
print("讀取testing與validation資料(154)執行時間:%f 秒\n" % (end - start))
|
||||
|
||||
Generator = Image_generator("", "")
|
||||
|
||||
# 將處理好的test Data 與 Validation Data 丟給這個物件的變數
|
||||
self.test, self.test_label = self.cut_image.test, self.cut_image.test_label
|
||||
self.validation, self.validation_label = self.cut_image.validation, self.cut_image.validation_label
|
||||
|
||||
self.Grad = Grad_CAM(self.Topic_Tool.Get_Data_Label(), self.test_label, self.experiment_name, self.convolution_name)
|
||||
|
||||
cnn_model = self.construct_model() # 呼叫讀取模型的function
|
||||
|
||||
# model_dir = '../save_the_best_model/Topic/Remove background with Normal image/best_model( 2023-10-17 )-2.h5' # 這是一個儲存模型權重的路徑,每一個模型都有一個自己權重儲存的檔
|
||||
# if os.path.exists(model_dir): # 如果這個檔案存在
|
||||
# cnn_model.load_weights(model_dir) # 將模型權重讀出來
|
||||
# print("讀出權重\n")
|
||||
|
||||
Optimizer = SGD(learning_rate = 0.045, momentum = 0.9) # 決定優化器與學習率
|
||||
|
||||
cnn_model.compile(
|
||||
loss = "binary_crossentropy",
|
||||
optimizer = Optimizer,
|
||||
metrics=
|
||||
[
|
||||
'accuracy',
|
||||
keras.metrics.Precision(name='precision'),
|
||||
keras.metrics.Recall(name='recall'),
|
||||
keras.metrics.AUC(name = 'AUC'),
|
||||
]
|
||||
)
|
||||
|
||||
train_data = Generator.Generator_Content(5) # 叫入ImageGeneratorr的物件,為了要讓我訓練時可以分批次讀取資料,GPU記憶體才不會爆
|
||||
cnn_model.summary() # 顯示模型架構
|
||||
print("\n\n\n讀取訓練資料(70000)執行時間:%f 秒\n\n" % (end - start))
|
||||
history = cnn_model.fit(
|
||||
train_data.flow(train, train_label, batch_size = self.generator_batch_size),
|
||||
# x = train,
|
||||
# y = train_label,
|
||||
epochs = self.epoch,
|
||||
batch_size = self.train_batch_size,
|
||||
validation_data = (self.validation, self.validation_label),
|
||||
callbacks = call_back(self.experiment_name, counter) # 呼叫 call back list
|
||||
# callbacks = call_back("best_model", self.counter) # 呼叫 call back list
|
||||
)
|
||||
|
||||
Matrix = self.record_matrix_image(cnn_model, self.experiment_name, counter) # 紀錄混淆矩陣的function
|
||||
loss, accuracy, precision, recall, AUC = cnn_model.evaluate(self.test, self.test_label) # 預測結果
|
||||
|
||||
# 防分母為0的時候
|
||||
if recall == 0 or precision == 0:
|
||||
f = 0
|
||||
else:
|
||||
f = (1 + 0.5 * 0.5) * ((recall * precision) / (0.5 * 0.5 * recall + precision))
|
||||
|
||||
print(self.record_everyTime_test_result(loss, accuracy, precision, recall, AUC, f, counter, self.experiment_name, Matrix)) # 紀錄當前訓練完之後的預測結果,並輸出成csv檔
|
||||
|
||||
plot_history(history, "train" + str(counter), self.experiment_name) # 將訓練結果化成圖,並將化出來的圖丟出去儲存
|
||||
self.Grad.process_main(cnn_model, counter, self.test)
|
||||
|
||||
return loss, accuracy, precision, recall, AUC, f
|
||||
|
||||
def construct_model(self):
|
||||
'''決定我這次訓練要用哪個model'''
|
||||
xception = Xception(include_top = False, weights = "imagenet", input_shape = (512, 512, 3))
|
||||
Flatten = GlobalAveragePooling2D()(xception.output)
|
||||
output = Dense(units = 1370, activation = "relu", kernel_regularizer = regularizers.L2())(Flatten)
|
||||
output = Dropout(0.6)(output)
|
||||
output = Dense(units = 2, activation = "softmax")(output)
|
||||
|
||||
cnn_model = Model(inputs = xception.input, outputs = output)
|
||||
|
||||
return cnn_model
|
||||
|
||||
def record_matrix_image(self, cnn_model : Model, model_name, index):
|
||||
'''劃出混淆矩陣(熱力圖)'''
|
||||
result = cnn_model.predict(self.test) # 利用predict function來預測結果
|
||||
result = np.argmax(result, axis = 1) # 將預測出來的結果從one-hot encoding轉成label-encoding
|
||||
y_test = np.argmax(self.test_label, axis = 1)
|
||||
matrix = confusion_matrix(result, y_test, labels = [0, 1]) # 丟入confusion matrix的function中,以形成混淆矩陣
|
||||
Confusion_Matrix_of_Two_Classification(model_name, matrix, index) # 呼叫畫出confusion matrix的function
|
||||
|
||||
return matrix.real
|
||||
|
||||
def record_everyTime_test_result(self, loss, accuracy, precision, recall, auc, f, indexs, model_name, Matrix):
|
||||
'''記錄我單次的訓練結果並將它輸出到檔案中'''
|
||||
File = Process_File()
|
||||
|
||||
Dataframe = pd.DataFrame(
|
||||
{
|
||||
"model_name" : str(model_name),
|
||||
"loss" : "{:.2f}".format(loss),
|
||||
"precision" : "{:.2f}%".format(precision * 100),
|
||||
"recall" : "{:.2f}%".format(recall * 100),
|
||||
"accuracy" : "{:.2f}%".format(accuracy * 100),
|
||||
"f" : "{:.2f}%".format(f * 100),
|
||||
"AUC" : "{:.2f}%".format(auc * 100)
|
||||
}, index = [indexs])
|
||||
File.Save_CSV_File("train_result", Dataframe)
|
||||
# File.Save_TXT_File("Matrix_Result : " + str(Matrix), model_name + "_train" + str(indexs))
|
||||
|
||||
return Dataframe
|
||||
|
||||
def get_layer_name(self, model_name):
|
||||
if(self.validation_obj.validation_string(model_name, "VGG19")):
|
||||
return "block5_conv4"
|
||||
if(self.validation_obj.validation_string(model_name, "ResNet50")):
|
||||
return "conv5_block3_3_conv"
|
||||
if(self.validation_obj.validation_string(model_name, "Xception")):
|
||||
return "block14_sepconv2"
|
||||
if(self.validation_obj.validation_string(model_name, "DenseNet121")):
|
||||
return "conv5_block16_concat"
|
||||
if(self.validation_obj.validation_string(model_name, "InceptionResNetV2")):
|
||||
return "conv_7b"
|
||||
if(self.validation_obj.validation_string(model_name, "InceptionV3")):
|
||||
return "conv2d_93"
|
||||
if(self.validation_obj.validation_string(model_name, "MobileNet")):
|
||||
return "conv_pw_13"
|
||||
if(self.validation_obj.validation_string(model_name, "MobileNetV2")):
|
||||
return "Conv_1"
|
||||
if(self.validation_obj.validation_string(model_name, "NASNetLarge")):
|
||||
return "separable_conv_2_normal_left5_18"
|
||||
if(self.validation_obj.validation_string(model_name, "ResNet101")):
|
||||
return "conv5_block3_3_conv"
|
||||
if(self.validation_obj.validation_string(model_name, "ResNet101V2")):
|
||||
return "conv5_block3_3_conv"
|
||||
if(self.validation_obj.validation_string(model_name, "ResNet152")):
|
||||
return "conv5_block3_3_conv"
|
||||
if(self.validation_obj.validation_string(model_name, "ResNet152V2")):
|
||||
return "conv5_block3_out"
|
||||
if(self.validation_obj.validation_string(model_name, "ResNet50v2")):
|
||||
return "conv5_block3_out"
|
||||
if(self.validation_obj.validation_string(model_name, "VGG16")):
|
||||
return "block5_conv3"
|
||||
|
||||
Reference in New Issue
Block a user