Files

101 lines
3.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from matplotlib import pyplot as plt
import seaborn as sns
import datetime
import matplotlib.figure as figure
import matplotlib.backends.backend_agg as agg
from Load_process.file_processing import Process_File
def plot_history(Losses, Accuracys, Save_Root, File_Name):
File = Process_File()
plt.figure(figsize=(16,4))
plt.subplot(1,2,1)
# 修正維度不匹配問題
train_losses = Losses[0]
val_losses = Losses[1]
# 分別繪製訓練損失和驗證損失
train_epochs = range(1, len(train_losses) + 1)
plt.plot(train_epochs, train_losses, label='Train')
val_epochs = range(1, len(val_losses) + 1)
plt.plot(val_epochs, val_losses, label='Validation')
plt.ylabel('Losses')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.title('Model Loss')
if Accuracys is not None:
plt.subplot(1,2,2)
train_acc = Accuracys[0]
val_acc = Accuracys[1]
# 分別繪製訓練準確率和驗證準確率
train_epochs_acc = range(1, len(train_acc) + 1)
plt.plot(train_epochs_acc, train_acc, label='Train')
val_epochs_acc = range(1, len(val_acc) + 1)
plt.plot(val_epochs_acc, val_acc, label='Validation')
plt.ylabel('Accuracies')
plt.xlabel('epoch')
plt.legend(loc='upper left')
plt.title('Model Accuracy')
File.JudgeRoot_MakeDir(Save_Root)
modelfiles = File.Make_Save_Root(f"{str(File_Name)}.png", Save_Root)
plt.savefig(modelfiles)
plt.close("all") # 關閉圖表
def draw_heatmap(matrix, Save_Root, File_Name, index): # 二分類以上混淆矩陣做法
File = Process_File()
# 创建热图
fig = figure.Figure(figsize=(6, 4))
canvas = agg.FigureCanvasAgg(fig)
Ax = fig.add_subplot(111)
sns.heatmap(matrix, square = True, annot = True, fmt = 'd', linecolor = 'white', cmap = "Purples", ax = Ax)#画热力图cmap表示设定的颜色集
File.JudgeRoot_MakeDir(Save_Root)
modelfiles = File.Make_Save_Root(f"{File_Name}-{str(index)}.png", Save_Root)
# confusion.figure.savefig(modelfiles)
# 设置图像参数
Ax.set_title(f"{File_Name} confusion matrix")
Ax.set_xlabel("X-Predict label of the model")
Ax.set_ylabel("Y-True label of the model")
# 保存图像到文件中
canvas.print_figure(modelfiles)
def Confusion_Matrix_of_Two_Classification(Matrix, Save_Root, File_Name, index):
File = Process_File()
fx = sns.heatmap(Matrix, annot=True, cmap='turbo')
# labels the title and x, y axis of plot
fx.set_title('Plotting Confusion Matrix using Seaborn\n\n')
fx.set_xlabel('answer Values ')
fx.set_ylabel('Predicted Values')
# 根据矩阵维度动态设置标签
n_classes = Matrix.shape[0]
# 如果是2类问题使用False/True标签
if n_classes == 2:
labels = ['False', 'True']
else:
# 对于多类问题,使用数字标签
labels = [str(i) for i in range(n_classes)]
fx.xaxis.set_ticklabels(labels)
fx.yaxis.set_ticklabels(labels)
File.JudgeRoot_MakeDir(Save_Root)
modelfiles = File.Make_Save_Root(f"{File_Name}-{str(index)}.png", Save_Root)
plt.savefig(modelfiles)
plt.close("all") # 關閉圖表
pass