Files
Stomach_Cancer_Keras/experiments/experiment.py
2024-12-07 02:00:39 +08:00

200 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from all_models_tools.all_model_tools import call_back
from Read_and_process_image.ReadAndProcess import Read_image_and_Process_image
from draw_tools.draw import plot_history, Confusion_Matrix_of_Two_Classification
from keras import regularizers
from Load_process.Load_Indepentend import Load_Indepentend_Data
from _validation.ValidationTheEnterData import validation_the_enter_data
from keras.layers import GlobalAveragePooling2D, Dense, Dropout
from keras.applications import Xception
from Load_process.file_processing import Process_File
from merge_class.merge import merge
from draw_tools.Grad_cam import Grad_CAM
from sklearn.metrics import confusion_matrix
from keras.models import Model
from keras.optimizers import SGD
from Image_Process.Image_Generator import Image_generator
import pandas as pd
import keras
import numpy as np
import time
class experiments():
def __init__(self, tools, status):
'''
parmeter:
* validation_obj : 驗證物件
* cut_image : 切割影像物件
* image_processing : 讀檔與讀檔資料處理物件
* merge : 合併物件
* model_name: 取名,告訴我我是用哪個模型(可能是預處理模型/自己設計的模型)
* generator_batch_size: 每一批次要讀多少檔出來
* epoch: 訓練次數
* train_batch_size: 訓練時要多少批次的影像為1組
* generator_batch_size: 減少圖片數量對GPU記憶體的用量, 減少張數用的
* experiment_name : 本次實驗名稱
* convolution_name : Pre-train model 的最後一層Convolotion的名稱
'''
self.Topic_Tool = tools
self.validation_obj = validation_the_enter_data() # 呼叫驗證物件
self.cut_image = Load_Indepentend_Data(self.Topic_Tool.Get_Data_Label(), self.Topic_Tool.Get_OneHot_Encording_Label()) # 呼叫切割影像物件
self.image_processing = Read_image_and_Process_image()
self.merge = merge()
self.model_name = "Xception" # 取名,告訴我我是用哪個模型(可能是預處理模型/自己設計的模型)
self.experiment_name = "Xception Skin to train Normal stomach cancer"
# self.file_name = "Remove background of Chickenpox with normal image"
self.generator_batch_size = 50
self.epoch = 10000
self.train_batch_size = 128
self.layers = 1
self.convolution_name = self.get_layer_name(self.model_name)
self.Grad = ""
self.Status = status
pass
def processing_main(self, train, train_label, counter):
Train, Test, Validation = self.Topic_Tool.Get_Save_Roots(self.Status) # 要換不同資料集就要改
start = time.time()
self.cut_image.process_main(Test, Validation) # 呼叫處理test Data與Validation Data的function
end = time.time()
print("讀取testing與validation資料(154)執行時間:%f\n" % (end - start))
Generator = Image_generator("", "")
# 將處理好的test Data 與 Validation Data 丟給這個物件的變數
self.test, self.test_label = self.cut_image.test, self.cut_image.test_label
self.validation, self.validation_label = self.cut_image.validation, self.cut_image.validation_label
self.Grad = Grad_CAM(self.Topic_Tool.Get_Data_Label(), self.test_label, self.experiment_name, self.convolution_name)
cnn_model = self.construct_model() # 呼叫讀取模型的function
# model_dir = '../save_the_best_model/Topic/Remove background with Normal image/best_model( 2023-10-17 )-2.h5' # 這是一個儲存模型權重的路徑,每一個模型都有一個自己權重儲存的檔
# if os.path.exists(model_dir): # 如果這個檔案存在
# cnn_model.load_weights(model_dir) # 將模型權重讀出來
# print("讀出權重\n")
Optimizer = SGD(learning_rate = 0.045, momentum = 0.9) # 決定優化器與學習率
cnn_model.compile(
loss = "binary_crossentropy",
optimizer = Optimizer,
metrics=
[
'accuracy',
keras.metrics.Precision(name='precision'),
keras.metrics.Recall(name='recall'),
keras.metrics.AUC(name = 'AUC'),
]
)
train_data = Generator.Generator_Content(5) # 叫入ImageGeneratorr的物件為了要讓我訓練時可以分批次讀取資料GPU記憶體才不會爆
cnn_model.summary() # 顯示模型架構
print("\n\n\n讀取訓練資料(70000)執行時間:%f\n\n" % (end - start))
history = cnn_model.fit(
train_data.flow(train, train_label, batch_size = self.generator_batch_size),
# x = train,
# y = train_label,
epochs = self.epoch,
batch_size = self.train_batch_size,
validation_data = (self.validation, self.validation_label),
callbacks = call_back(self.experiment_name, counter) # 呼叫 call back list
# callbacks = call_back("best_model", self.counter) # 呼叫 call back list
)
Matrix = self.record_matrix_image(cnn_model, self.experiment_name, counter) # 紀錄混淆矩陣的function
loss, accuracy, precision, recall, AUC = cnn_model.evaluate(self.test, self.test_label) # 預測結果
# 防分母為0的時候
if recall == 0 or precision == 0:
f = 0
else:
f = (1 + 0.5 * 0.5) * ((recall * precision) / (0.5 * 0.5 * recall + precision))
print(self.record_everyTime_test_result(loss, accuracy, precision, recall, AUC, f, counter, self.experiment_name, Matrix)) # 紀錄當前訓練完之後的預測結果並輸出成csv檔
plot_history(history, "train" + str(counter), self.experiment_name) # 將訓練結果化成圖,並將化出來的圖丟出去儲存
self.Grad.process_main(cnn_model, counter, self.test)
return loss, accuracy, precision, recall, AUC, f
def construct_model(self):
'''決定我這次訓練要用哪個model'''
xception = Xception(include_top = False, weights = "imagenet", input_shape = (512, 512, 3))
Flatten = GlobalAveragePooling2D()(xception.output)
output = Dense(units = 1370, activation = "relu", kernel_regularizer = regularizers.L2())(Flatten)
output = Dropout(0.6)(output)
output = Dense(units = 2, activation = "softmax")(output)
cnn_model = Model(inputs = xception.input, outputs = output)
return cnn_model
def record_matrix_image(self, cnn_model : Model, model_name, index):
'''劃出混淆矩陣(熱力圖)'''
result = cnn_model.predict(self.test) # 利用predict function來預測結果
result = np.argmax(result, axis = 1) # 將預測出來的結果從one-hot encoding轉成label-encoding
y_test = np.argmax(self.test_label, axis = 1)
matrix = confusion_matrix(result, y_test, labels = [0, 1]) # 丟入confusion matrix的function中以形成混淆矩陣
Confusion_Matrix_of_Two_Classification(model_name, matrix, index) # 呼叫畫出confusion matrix的function
return matrix.real
def record_everyTime_test_result(self, loss, accuracy, precision, recall, auc, f, indexs, model_name, Matrix):
'''記錄我單次的訓練結果並將它輸出到檔案中'''
File = Process_File()
Dataframe = pd.DataFrame(
{
"model_name" : str(model_name),
"loss" : "{:.2f}".format(loss),
"precision" : "{:.2f}%".format(precision * 100),
"recall" : "{:.2f}%".format(recall * 100),
"accuracy" : "{:.2f}%".format(accuracy * 100),
"f" : "{:.2f}%".format(f * 100),
"AUC" : "{:.2f}%".format(auc * 100)
}, index = [indexs])
File.Save_CSV_File("train_result", Dataframe)
# File.Save_TXT_File("Matrix_Result : " + str(Matrix), model_name + "_train" + str(indexs))
return Dataframe
def get_layer_name(self, model_name):
if(self.validation_obj.validation_string(model_name, "VGG19")):
return "block5_conv4"
if(self.validation_obj.validation_string(model_name, "ResNet50")):
return "conv5_block3_3_conv"
if(self.validation_obj.validation_string(model_name, "Xception")):
return "block14_sepconv2"
if(self.validation_obj.validation_string(model_name, "DenseNet121")):
return "conv5_block16_concat"
if(self.validation_obj.validation_string(model_name, "InceptionResNetV2")):
return "conv_7b"
if(self.validation_obj.validation_string(model_name, "InceptionV3")):
return "conv2d_93"
if(self.validation_obj.validation_string(model_name, "MobileNet")):
return "conv_pw_13"
if(self.validation_obj.validation_string(model_name, "MobileNetV2")):
return "Conv_1"
if(self.validation_obj.validation_string(model_name, "NASNetLarge")):
return "separable_conv_2_normal_left5_18"
if(self.validation_obj.validation_string(model_name, "ResNet101")):
return "conv5_block3_3_conv"
if(self.validation_obj.validation_string(model_name, "ResNet101V2")):
return "conv5_block3_3_conv"
if(self.validation_obj.validation_string(model_name, "ResNet152")):
return "conv5_block3_3_conv"
if(self.validation_obj.validation_string(model_name, "ResNet152V2")):
return "conv5_block3_out"
if(self.validation_obj.validation_string(model_name, "ResNet50v2")):
return "conv5_block3_out"
if(self.validation_obj.validation_string(model_name, "VGG16")):
return "block5_conv3"