210 lines
9.7 KiB
Python
210 lines
9.7 KiB
Python
from tqdm import tqdm
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
|
|
from torchmetrics.functional import auroc
|
|
from torch.nn import functional
|
|
|
|
from all_models_tools.all_model_tools import call_back
|
|
from Model_Loss.Loss import Entropy_Loss
|
|
from merge_class.merge import merge
|
|
from draw_tools.Grad_cam import GradCAM
|
|
|
|
import time
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
import torch
|
|
import pandas as pd
|
|
import datetime
|
|
|
|
|
|
class All_Step:
|
|
def __init__(self, Model, Epoch, Number_Of_Classes, Model_Name, Experiment_Name):
|
|
self.Model = Model
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
self.Epoch = Epoch
|
|
self.Number_Of_Classes = Number_Of_Classes
|
|
|
|
self.Model_Name = Model_Name
|
|
self.Experiment_Name = Experiment_Name
|
|
|
|
def Training_Step(self, train_subset, val_subset, train_loader, val_loader, model_name, fold, TargetLayer):
|
|
# Reinitialize model and optimizer for each fold
|
|
# self.Model = self.Model.__class__(self.Number_Of_Classes).to(self.device) # Reinitialize model
|
|
Optimizer = optim.SGD(self.Model.parameters(), lr=0.045, momentum=0.9, weight_decay=0.01)
|
|
model_path, early_stopping, scheduler = call_back(model_name, f"_fold{fold}", Optimizer)
|
|
|
|
criterion = Entropy_Loss() # Custom loss function
|
|
Merge_Function = merge()
|
|
|
|
# Lists to store metrics for this fold
|
|
train_losses = []
|
|
val_losses = []
|
|
train_accuracies = []
|
|
val_accuracies = []
|
|
epoch = 0
|
|
|
|
# Epoch loop
|
|
for epoch in range(self.Epoch):
|
|
self.Model.train() # Start training
|
|
running_loss = 0.0
|
|
all_train_preds = []
|
|
all_train_labels = []
|
|
processed_samples = 0
|
|
|
|
# Calculate epoch start time
|
|
start_time = time.time()
|
|
total_samples = len(train_subset) # Total samples in subset, not DataLoader
|
|
total_Validation_samples = len(val_subset)
|
|
|
|
# Progress bar for training batches
|
|
epoch_iterator = tqdm(train_loader, desc=f"Fold {fold + 1}/5, Epoch [{epoch + 1}/{self.Epoch}]")
|
|
|
|
for inputs, labels in epoch_iterator:
|
|
inputs, labels = inputs.to(self.device), labels.to(self.device) # Already tensors from DataLoader
|
|
|
|
Optimizer.zero_grad()
|
|
outputs = self.Model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
loss.backward()
|
|
Optimizer.step()
|
|
running_loss += loss.item()
|
|
|
|
# Collect training predictions and labels
|
|
Output_Values, Output_Indexs = torch.max(outputs, dim=1)
|
|
True_Indexs = np.argmax(labels.cpu().numpy(), axis=1)
|
|
|
|
all_train_preds.append(Output_Indexs.cpu().numpy())
|
|
all_train_labels.append(True_Indexs)
|
|
|
|
processed_samples += inputs.size(0) # Use size(0) for batch size
|
|
|
|
# Calculate progress and timing
|
|
progress = (processed_samples / total_samples) * 100
|
|
elapsed_time = time.time() - start_time
|
|
iterations_per_second = processed_samples / elapsed_time if elapsed_time > 0 else 0
|
|
eta = (total_samples - processed_samples) / iterations_per_second if iterations_per_second > 0 else 0
|
|
time_str = f"{int(elapsed_time//60):02d}:{int(elapsed_time%60):02d}<{int(eta//60):02d}:{int(eta%60):02d}"
|
|
|
|
# Calculate batch accuracy(正確label數量 / 該batch總共的label數量)
|
|
batch_accuracy = (Output_Indexs.cpu().numpy() == True_Indexs).mean()
|
|
|
|
# Update progress bar
|
|
epoch_iterator.set_postfix_str(
|
|
f"{processed_samples}/{total_samples} [{time_str}, {iterations_per_second:.2f}it/s, "
|
|
f"acc={batch_accuracy:.3f}, loss={loss.item():.3f}]"
|
|
)
|
|
|
|
epoch_iterator.close()
|
|
# Merge predictions and labels
|
|
all_train_preds = Merge_Function.merge_data_main(all_train_preds, 0, len(all_train_preds))
|
|
all_train_labels = Merge_Function.merge_data_main(all_train_labels, 0, len(all_train_labels))
|
|
|
|
Training_Loss = running_loss / len(train_loader)
|
|
train_accuracy = accuracy_score(all_train_labels, all_train_preds)
|
|
|
|
train_losses.append(Training_Loss)
|
|
train_accuracies.append(train_accuracy)
|
|
|
|
# Validation step
|
|
self.Model.eval()
|
|
val_loss = 0.0
|
|
all_val_preds = []
|
|
all_val_labels = []
|
|
|
|
start_Validation_time = time.time()
|
|
epoch_iterator = tqdm(val_loader, desc=f"\tValidation-Fold {fold + 1}/5, Epoch [{epoch + 1}/{self.Epoch}]")
|
|
with torch.no_grad():
|
|
for inputs, labels in epoch_iterator:
|
|
inputs, labels = inputs.to(self.device), labels.to(self.device)
|
|
outputs = self.Model(inputs)
|
|
loss = criterion(outputs, labels)
|
|
val_loss += loss.item()
|
|
|
|
# Collect validation predictions and labels
|
|
Output_Values, Output_Indexs = torch.max(outputs, dim=1)
|
|
True_Indexs = np.argmax(labels.cpu().numpy(), axis=1)
|
|
|
|
all_val_preds.append(Output_Indexs.cpu().numpy())
|
|
all_val_labels.append(True_Indexs)
|
|
|
|
processed_samples += inputs.size(0) # Use size(0) for batch size
|
|
|
|
# Calculate progress and timing
|
|
progress = (processed_samples / total_Validation_samples) * 100
|
|
elapsed_time = time.time() - start_Validation_time
|
|
iterations_per_second = processed_samples / elapsed_time if elapsed_time > 0 else 0
|
|
eta = (total_Validation_samples - processed_samples) / iterations_per_second if iterations_per_second > 0 else 0
|
|
time_str = f"{int(elapsed_time//60):02d}:{int(elapsed_time%60):02d}<{int(eta//60):02d}:{int(eta%60):02d}"
|
|
|
|
# Calculate batch accuracy
|
|
batch_accuracy = (Output_Indexs.cpu().numpy() == True_Indexs).mean()
|
|
|
|
# Update progress bar
|
|
epoch_iterator.set_postfix_str(
|
|
f"{processed_samples}/{total_Validation_samples} [{time_str}, {iterations_per_second:.2f}it/s, "
|
|
f"acc={batch_accuracy:.3f}, loss={loss.item():.3f}]"
|
|
)
|
|
|
|
epoch_iterator.close()
|
|
print("\n")
|
|
|
|
# Merge predictions and labels
|
|
all_val_preds = Merge_Function.merge_data_main(all_val_preds, 0, len(all_val_preds))
|
|
all_val_labels = Merge_Function.merge_data_main(all_val_labels, 0, len(all_val_labels))
|
|
|
|
val_loss /= len(val_loader)
|
|
val_accuracy = accuracy_score(all_val_labels, all_val_preds)
|
|
|
|
val_losses.append(val_loss)
|
|
val_accuracies.append(val_accuracy)
|
|
|
|
print(f"Traini Loss: {Training_Loss:.4f}, Accuracy: {train_accuracy:0.2f}, Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:0.2f}\n")
|
|
|
|
if epoch % 10 == 0:
|
|
Grad = GradCAM(self.Model, TargetLayer)
|
|
Grad.Processing_Main(val_loader, f"../Result/GradCAM_Image/Validation/GradCAM_Image({str(datetime.date.today())})/fold-{str(fold)}/")
|
|
|
|
# Early stopping
|
|
early_stopping(val_loss, self.Model, model_path)
|
|
if early_stopping.early_stop:
|
|
print(f"Early stopping triggered in Fold {fold + 1} at epoch {epoch + 1}")
|
|
break
|
|
|
|
# Learning rate adjustment
|
|
scheduler.step(val_loss)
|
|
|
|
Total_Epoch = epoch + 1
|
|
return self.Model, model_path, train_losses, val_losses, train_accuracies, val_accuracies, Total_Epoch
|
|
|
|
def Evaluate_Model(self, cnn_model, Test_Dataloader):
|
|
# (Unchanged Evaluate_Model method)
|
|
cnn_model.eval()
|
|
True_Label, Predict_Label = [], []
|
|
True_Label_OneHot, Predict_Label_OneHot = [], []
|
|
loss = 0.0
|
|
|
|
with torch.no_grad():
|
|
for images, labels in Test_Dataloader:
|
|
images, labels = torch.as_tensor(images).to(self.device), torch.as_tensor(labels).to(self.device)
|
|
outputs = cnn_model(images)
|
|
Output_Values, Output_Indexs = torch.max(outputs, 1)
|
|
True_Indexs = np.argmax(labels.cpu().numpy(), 1)
|
|
|
|
True_Label.append(Output_Indexs.cpu().numpy())
|
|
Predict_Label.append(True_Indexs)
|
|
|
|
Predict_Label_OneHot.append(torch.tensor(functional.one_hot(Output_Indexs, self.Number_Of_Classes), dtype=torch.float32).cpu().numpy()[0])
|
|
True_Label_OneHot.append(torch.tensor(labels, dtype=torch.int).cpu().numpy()[0])
|
|
|
|
loss /= len(Test_Dataloader)
|
|
|
|
True_Label_OneHot = torch.as_tensor(True_Label_OneHot, dtype=torch.int)
|
|
Predict_Label_OneHot = torch.as_tensor(Predict_Label_OneHot, dtype=torch.float32)
|
|
|
|
accuracy = accuracy_score(True_Label, Predict_Label)
|
|
precision = precision_score(True_Label, Predict_Label, average="macro")
|
|
recall = recall_score(True_Label, Predict_Label, average="macro")
|
|
AUC = auroc(Predict_Label_OneHot, True_Label_OneHot, num_labels=self.Number_Of_Classes, task="multilabel", average="macro")
|
|
f1 = f1_score(True_Label, Predict_Label, average="macro")
|
|
|
|
return True_Label, Predict_Label, loss, accuracy, precision, recall, AUC, f1 |