101 lines
3.3 KiB
Python
101 lines
3.3 KiB
Python
from matplotlib import pyplot as plt
|
||
import seaborn as sns
|
||
import datetime
|
||
import matplotlib.figure as figure
|
||
import matplotlib.backends.backend_agg as agg
|
||
from Load_process.file_processing import Process_File
|
||
|
||
def plot_history(Losses, Accuracys, Save_Root, File_Name):
|
||
File = Process_File()
|
||
|
||
plt.figure(figsize=(16,4))
|
||
plt.subplot(1,2,1)
|
||
|
||
# 修正維度不匹配問題
|
||
train_losses = Losses[0]
|
||
val_losses = Losses[1]
|
||
|
||
# 分別繪製訓練損失和驗證損失
|
||
train_epochs = range(1, len(train_losses) + 1)
|
||
plt.plot(train_epochs, train_losses, label='Train')
|
||
|
||
val_epochs = range(1, len(val_losses) + 1)
|
||
plt.plot(val_epochs, val_losses, label='Validation')
|
||
|
||
plt.ylabel('Losses')
|
||
plt.xlabel('epoch')
|
||
plt.legend(loc='upper left')
|
||
plt.title('Model Loss')
|
||
|
||
if Accuracys is not None:
|
||
plt.subplot(1,2,2)
|
||
train_acc = Accuracys[0]
|
||
val_acc = Accuracys[1]
|
||
|
||
# 分別繪製訓練準確率和驗證準確率
|
||
train_epochs_acc = range(1, len(train_acc) + 1)
|
||
plt.plot(train_epochs_acc, train_acc, label='Train')
|
||
|
||
val_epochs_acc = range(1, len(val_acc) + 1)
|
||
plt.plot(val_epochs_acc, val_acc, label='Validation')
|
||
|
||
plt.ylabel('Accuracies')
|
||
plt.xlabel('epoch')
|
||
plt.legend(loc='upper left')
|
||
plt.title('Model Accuracy')
|
||
|
||
File.JudgeRoot_MakeDir(Save_Root)
|
||
modelfiles = File.Make_Save_Root(f"{str(File_Name)}.png", Save_Root)
|
||
plt.savefig(modelfiles)
|
||
plt.close("all") # 關閉圖表
|
||
|
||
def draw_heatmap(matrix, Save_Root, File_Name, index): # 二分類以上混淆矩陣做法
|
||
File = Process_File()
|
||
|
||
# 创建热图
|
||
fig = figure.Figure(figsize=(6, 4))
|
||
canvas = agg.FigureCanvasAgg(fig)
|
||
Ax = fig.add_subplot(111)
|
||
sns.heatmap(matrix, square = True, annot = True, fmt = 'd', linecolor = 'white', cmap = "Purples", ax = Ax)#画热力图,cmap表示设定的颜色集
|
||
|
||
File.JudgeRoot_MakeDir(Save_Root)
|
||
modelfiles = File.Make_Save_Root(f"{File_Name}-{str(index)}.png", Save_Root)
|
||
|
||
# confusion.figure.savefig(modelfiles)
|
||
# 设置图像参数
|
||
Ax.set_title(f"{File_Name} confusion matrix")
|
||
Ax.set_xlabel("X-Predict label of the model")
|
||
Ax.set_ylabel("Y-True label of the model")
|
||
|
||
# 保存图像到文件中
|
||
canvas.print_figure(modelfiles)
|
||
|
||
def Confusion_Matrix_of_Two_Classification(Matrix, Save_Root, File_Name, index):
|
||
File = Process_File()
|
||
|
||
fx = sns.heatmap(Matrix, annot=True, cmap='turbo')
|
||
|
||
# labels the title and x, y axis of plot
|
||
fx.set_title('Plotting Confusion Matrix using Seaborn\n\n')
|
||
fx.set_xlabel('answer Values ')
|
||
fx.set_ylabel('Predicted Values')
|
||
|
||
# 根据矩阵维度动态设置标签
|
||
n_classes = Matrix.shape[0]
|
||
# 如果是2类问题,使用False/True标签
|
||
if n_classes == 2:
|
||
labels = ['False', 'True']
|
||
else:
|
||
# 对于多类问题,使用数字标签
|
||
labels = [str(i) for i in range(n_classes)]
|
||
|
||
fx.xaxis.set_ticklabels(labels)
|
||
fx.yaxis.set_ticklabels(labels)
|
||
|
||
File.JudgeRoot_MakeDir(Save_Root)
|
||
modelfiles = File.Make_Save_Root(f"{File_Name}-{str(index)}.png", Save_Root)
|
||
|
||
plt.savefig(modelfiles)
|
||
plt.close("all") # 關閉圖表
|
||
|
||
pass |