Firsh Push at 20241207
This commit is contained in:
40
all_models_tools/all_model_tools.py
Normal file
40
all_models_tools/all_model_tools.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
|
||||
from keras.layers import GlobalAveragePooling2D, Dense, Reshape, Multiply
|
||||
from Load_process.file_processing import Process_File
|
||||
import datetime
|
||||
|
||||
def attention_block(input):
|
||||
channel = input.shape[-1]
|
||||
|
||||
GAP = GlobalAveragePooling2D()(input)
|
||||
|
||||
block = Dense(units = channel // 16, activation = "relu")(GAP)
|
||||
block = Dense(units = channel, activation = "sigmoid")(block)
|
||||
block = Reshape((1, 1, channel))(block)
|
||||
|
||||
block = Multiply()([input, block])
|
||||
|
||||
return block
|
||||
|
||||
def call_back(model_name, index):
|
||||
File = Process_File()
|
||||
|
||||
model_dir = '../Result/save_the_best_model/' + model_name
|
||||
File.JudgeRoot_MakeDir(model_dir)
|
||||
modelfiles = File.Make_Save_Root('best_model( ' + str(datetime.date.today()) + " )-" + str(index) + ".weights.h5", model_dir)
|
||||
|
||||
model_mckp = ModelCheckpoint(modelfiles, monitor='val_loss', save_best_only=True, save_weights_only = True, mode='auto')
|
||||
|
||||
earlystop = EarlyStopping(monitor='val_loss', patience=74, verbose=1) # 提早停止
|
||||
|
||||
reduce_lr = ReduceLROnPlateau(
|
||||
monitor = 'val_loss',
|
||||
factor = 0.94, # 學習率降低的量。 new_lr = lr * factor
|
||||
patience = 2, # 沒有改進的時期數,之後學習率將降低
|
||||
verbose = 0,
|
||||
mode = 'auto',
|
||||
min_lr = 0 # 學習率下限
|
||||
)
|
||||
|
||||
callbacks_list = [model_mckp, earlystop, reduce_lr]
|
||||
return callbacks_list
|
||||
Reference in New Issue
Block a user