131 lines
5.3 KiB
Python
131 lines
5.3 KiB
Python
from tqdm import tqdm
|
|
from torch.nn import functional
|
|
import torch
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
|
from torchmetrics.functional import auroc
|
|
import torch.optim as optim
|
|
|
|
from all_models_tools.all_model_tools import call_back
|
|
from Model_Loss.Loss import Entropy_Loss
|
|
|
|
|
|
class All_Step:
|
|
def __init__(self, Training_Data_And_Label, Test_Data_And_Label, Validation_Data_And_Label, Model, Epoch, Number_Of_Classes):
|
|
self.Training_Data_And_Label = Training_Data_And_Label
|
|
self.Test_Data_And_Label = Test_Data_And_Label
|
|
self.Validation_Data_And_Label = Validation_Data_And_Label
|
|
|
|
self.Model = Model
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
self.Epoch = Epoch
|
|
self.Number_Of_Classes = Number_Of_Classes
|
|
|
|
pass
|
|
|
|
def Training_Step(self, model_name, counter):
|
|
# 定義優化器,並設定 weight_decay 參數來加入 L2 正則化
|
|
Optimizer = optim.SGD(self.Model.parameters(), lr=0.045, momentum = 0.9, weight_decay=0.1)
|
|
model_path, early_stopping, scheduler = call_back(model_name, counter, Optimizer)
|
|
|
|
criterion = Entropy_Loss() # 使用自定義的損失函數
|
|
train_losses = []
|
|
val_losses = []
|
|
train_accuracies = []
|
|
val_accuracies = []
|
|
|
|
for epoch in range(self.Epoch):
|
|
self.Model.train()
|
|
running_loss = 0.0
|
|
all_train_preds = []
|
|
all_train_labels = []
|
|
|
|
epoch_iterator = tqdm(self.Training_Data_And_Label, desc= "Training (Epoch %d)" % epoch)
|
|
|
|
|
|
for inputs, labels in epoch_iterator:
|
|
# labels = np.reshape(labels, (int(labels.shape[0]), 1))
|
|
inputs, OneHot_labels = inputs.to(self.device), OneHot_labels.to(self.device)
|
|
# inputs, labels = inputs.cuda(), labels.cuda()
|
|
|
|
Optimizer.zero_grad()
|
|
outputs = self.Model(inputs)
|
|
loss = criterion(outputs, OneHot_labels)
|
|
loss.backward()
|
|
Optimizer.step()
|
|
running_loss += loss.item()
|
|
|
|
# 收集訓練預測和標籤
|
|
_, preds = torch.max(outputs, 1)
|
|
all_train_preds.extend(preds.cpu().numpy())
|
|
all_train_labels.extend(labels.cpu().numpy())
|
|
|
|
Training_Loss = running_loss/len(self.Training_Data_And_Label)
|
|
|
|
# all_train_labels = torch.FloatTensor(all_train_labels)
|
|
# all_train_labels = torch.argmax(all_train_labels, 1)
|
|
train_accuracy = accuracy_score(all_train_labels, all_train_preds)
|
|
|
|
train_losses.append(Training_Loss)
|
|
train_accuracies.append(train_accuracy)
|
|
|
|
print(f"Epoch [{epoch+1}/{self.epoch}], Loss: {Training_Loss:.4f}, Accuracy: {train_accuracy:0.2f}", end = ' ')
|
|
|
|
self.Model.eval()
|
|
val_loss = 0.0
|
|
all_val_preds = []
|
|
all_val_labels = []
|
|
|
|
with torch.no_grad():
|
|
for inputs, labels in self.Validation_Data_And_Label:
|
|
inputs, OneHot_labels = inputs.to(self.device), labels.to(self.device)
|
|
|
|
outputs = self.Model(inputs)
|
|
loss = criterion(outputs, OneHot_labels)
|
|
val_loss += loss.item()
|
|
|
|
# 驗證預測與標籤
|
|
_, preds = torch.max(outputs, 1)
|
|
all_val_preds.extend(preds.cpu().numpy())
|
|
all_val_labels.extend(labels.cpu().numpy())
|
|
|
|
# 計算驗證損失與準確率
|
|
val_loss /= len(list(self.Validation_Data_And_Label))
|
|
val_accuracy = accuracy_score(all_val_labels, all_val_preds)
|
|
|
|
val_losses.append(val_loss)
|
|
val_accuracies.append(val_accuracy)
|
|
print(f"Epoch [{epoch+1}/{self.epoch}], Loss: {val_loss:.4f}, Accuracy: {val_accuracy:0.2f}")
|
|
|
|
early_stopping(val_loss, self.Model, model_path)
|
|
if early_stopping.early_stop:
|
|
print("Early stopping triggered. Training stopped.")
|
|
break
|
|
|
|
# 學習率調整
|
|
scheduler.step(val_loss)
|
|
|
|
return train_losses, val_losses, train_accuracies, val_accuracies
|
|
|
|
def Evaluate_Model(self, cnn_model):
|
|
# 測試模型
|
|
cnn_model.eval()
|
|
True_Label, Predict_Label = [], []
|
|
loss = 0.0
|
|
with torch.no_grad():
|
|
for images, labels in self.Test_Data_And_Label:
|
|
images, OneHot_labels = images.to(self.device), OneHot_labels.to(self.device)
|
|
|
|
outputs = cnn_model(images)
|
|
_, predicted = torch.max(outputs, 1)
|
|
Predict_Label.extend(predicted.cpu().numpy())
|
|
True_Label.extend(labels.cpu().numpy())
|
|
|
|
loss /= len(self.Test_Data_And_Label)
|
|
|
|
accuracy = accuracy_score(True_Label, Predict_Label)
|
|
precision = precision_score(True_Label, Predict_Label)
|
|
recall = recall_score(True_Label, Predict_Label)
|
|
AUC = auroc(True_Label, Predict_Label, task = ["Stomatch_Cancer", "Normal"])
|
|
f1 = f1_score(True_Label, Predict_Label)
|
|
return loss, accuracy, precision, recall, AUC, f1, True_Label, Predict_Label |