371 lines
15 KiB
Python
371 lines
15 KiB
Python
"""
|
||
模型評估程式
|
||
用於載入最佳的分類模型權重和分割模型權重,
|
||
對三類資料(胃癌、正常、非胃癌有病)計算整體的Precision、Recall、Accuracy、F1-Score
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import numpy as np
|
||
import os
|
||
import glob
|
||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
|
||
from tqdm import tqdm
|
||
import pandas as pd
|
||
from datetime import datetime
|
||
|
||
# 導入必要的模組
|
||
from experiments.Models.Xception_Model_Modification import Xception
|
||
from experiments.Models.GastroSegNet_Model import GastroSegNet
|
||
from Load_process.LoadData import Loding_Data_Root
|
||
from Training_Tools.PreProcess import Training_Precesses
|
||
from Load_process.LoadData import Load_Data_Prepare
|
||
from utils.Stomach_Config import Training_Config, Loading_Config
|
||
from model_data_processing.processing import make_label_list
|
||
from Training_Tools.Tools import Tool
|
||
from merge_class.merge import merge
|
||
|
||
class ModelEvaluator:
|
||
def __init__(self, Normal_Model_Path, CA_Model_Path):
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
print(f"使用設備: {self.device}")
|
||
|
||
# 初始化數據加載器
|
||
self.data_loader = Loding_Data_Root(
|
||
Loading_Config["Training_Labels"],
|
||
Loading_Config["Test_Data_Root"],
|
||
Loading_Config["Test_Data_Root"]
|
||
)
|
||
|
||
# 初始化預處理器
|
||
self.preprocessor = Training_Precesses(Training_Config["Image_Size"])
|
||
|
||
# 類別標籤映射
|
||
self.class_labels = {
|
||
0: "胃癌 (stomach_cancer_Crop)",
|
||
1: "正常 (Normal_Crop)",
|
||
2: "非胃癌有病 (Have_Question_Crop)"
|
||
}
|
||
|
||
# 直接指定的模型路徑
|
||
self.normal_model_path = Normal_Model_Path
|
||
self.ca_model_path = CA_Model_Path
|
||
|
||
def load_classification_model(self, model_path, num_classes=2):
|
||
"""
|
||
載入分類模型
|
||
"""
|
||
try:
|
||
model = Xception()
|
||
|
||
# 如果是多GPU訓練的模型,需要處理DataParallel
|
||
if torch.cuda.device_count() > 1:
|
||
model = nn.DataParallel(model)
|
||
|
||
model = model.to(self.device)
|
||
|
||
# 載入權重
|
||
checkpoint = torch.load(model_path, map_location=self.device)
|
||
model.load_state_dict(checkpoint)
|
||
model.eval()
|
||
|
||
print(f"成功載入分類模型: {model_path}")
|
||
return model
|
||
|
||
except Exception as e:
|
||
print(f"載入分類模型失敗: {e}")
|
||
return None
|
||
|
||
def load_segmentation_model(self, model_path):
|
||
"""
|
||
載入分割模型
|
||
"""
|
||
try:
|
||
model = GastroSegNet(in_channels=3, out_channels=3)
|
||
|
||
if torch.cuda.device_count() > 1:
|
||
model = nn.DataParallel(model)
|
||
|
||
model = model.to(self.device)
|
||
|
||
# 載入權重
|
||
checkpoint = torch.load(model_path, map_location=self.device)
|
||
model.load_state_dict(checkpoint)
|
||
model.eval()
|
||
|
||
print(f"成功載入分割模型: {model_path}")
|
||
return model
|
||
|
||
except Exception as e:
|
||
print(f"載入分割模型失敗: {e}")
|
||
return None
|
||
|
||
def predict_three_class(self, normal_model, ca_model, test_data):
|
||
"""
|
||
使用兩個二分類模型進行三分類預測
|
||
邏輯:
|
||
1. 先用Normal模型判斷是否為正常
|
||
2. 如果不是正常,再用CA模型判斷是胃癌還是非胃癌有病
|
||
"""
|
||
all_predictions = []
|
||
all_true_labels = []
|
||
|
||
with torch.no_grad():
|
||
for images, labels, filename, Class in tqdm(test_data, desc="預測中"):
|
||
|
||
images = images.to(self.device)
|
||
batch_size = images.size(0)
|
||
|
||
# 獲取真實標籤
|
||
if isinstance(labels, torch.Tensor):
|
||
if labels.dim() > 1: # one-hot編碼
|
||
true_labels = torch.argmax(labels, dim=1).cpu().numpy()
|
||
else: # 已經是類別索引
|
||
true_labels = labels.cpu().numpy()
|
||
else:
|
||
# 如果是numpy數組或列表
|
||
labels = np.array(labels)
|
||
if labels.ndim > 1:
|
||
true_labels = np.argmax(labels, axis=1)
|
||
else:
|
||
true_labels = labels
|
||
|
||
all_true_labels.extend(true_labels)
|
||
|
||
# 第一步:使用Normal模型判斷是否為正常
|
||
normal_outputs = normal_model(images)
|
||
|
||
# 第二步:使用CA模型判斷胃癌vs非胃癌有病
|
||
ca_outputs = ca_model(images)
|
||
|
||
# 三分類邏輯
|
||
batch_predictions = []
|
||
for i in range(batch_size):
|
||
# 如果Normal模型認為是正常(第二個類別概率高)
|
||
|
||
if normal_outputs[i, 1] > normal_outputs[i, 0]:
|
||
prediction = 1 # 正常
|
||
else:
|
||
# 如果不是正常,用CA模型判斷
|
||
if ca_outputs[i, 0] > ca_outputs[i, 1]: # 胃癌概率高
|
||
prediction = 0 # 胃癌
|
||
else:
|
||
prediction = 2 # 非胃癌有病
|
||
|
||
batch_predictions.append(prediction)
|
||
|
||
all_predictions.extend(batch_predictions)
|
||
return np.array(all_predictions), np.array(all_true_labels)
|
||
|
||
def calculate_metrics(self, y_true, y_pred):
|
||
"""
|
||
計算各種評估指標
|
||
"""
|
||
# 整體指標
|
||
accuracy = accuracy_score(y_true, y_pred)
|
||
precision_macro = precision_score(y_true, y_pred, average='macro', zero_division=0)
|
||
recall_macro = recall_score(y_true, y_pred, average='macro', zero_division=0)
|
||
f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
|
||
|
||
# 每類別指標
|
||
precision_per_class = precision_score(y_true, y_pred, average=None, zero_division=0)
|
||
recall_per_class = recall_score(y_true, y_pred, average=None, zero_division=0)
|
||
f1_per_class = f1_score(y_true, y_pred, average=None, zero_division=0)
|
||
|
||
# 混淆矩陣
|
||
cm = confusion_matrix(y_true, y_pred)
|
||
|
||
# 分類報告
|
||
report = classification_report(y_true, y_pred, target_names=list(self.class_labels.values()))
|
||
|
||
return {
|
||
'accuracy': accuracy,
|
||
'precision_macro': precision_macro,
|
||
'recall_macro': recall_macro,
|
||
'f1_macro': f1_macro,
|
||
'precision_per_class': precision_per_class,
|
||
'recall_per_class': recall_per_class,
|
||
'f1_per_class': f1_per_class,
|
||
'confusion_matrix': cm,
|
||
'classification_report': report
|
||
}
|
||
|
||
def save_results(self, metrics):
|
||
"""
|
||
保存評估結果
|
||
"""
|
||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||
results_dir = f"../Result/Model_Evaluation"
|
||
os.makedirs(results_dir, exist_ok=True)
|
||
|
||
# 保存整體指標
|
||
overall_results = {
|
||
'指標': ['Accuracy', 'Precision (Macro)', 'Recall (Macro)', 'F1-Score (Macro)'],
|
||
'數值': [
|
||
f"{metrics['accuracy']:.4f}",
|
||
f"{metrics['precision_macro']:.4f}",
|
||
f"{metrics['recall_macro']:.4f}",
|
||
f"{metrics['f1_macro']:.4f}"
|
||
]
|
||
}
|
||
|
||
overall_df = pd.DataFrame(overall_results)
|
||
overall_df.to_csv(os.path.join(results_dir, "overall_metrics.csv"), index=False, encoding='utf-8-sig')
|
||
|
||
# 保存每類別指標
|
||
per_class_results = {
|
||
'類別': list(self.class_labels.values()),
|
||
'Precision': [f"{p:.4f}" for p in metrics['precision_per_class']],
|
||
'Recall': [f"{r:.4f}" for r in metrics['recall_per_class']],
|
||
'F1-Score': [f"{f:.4f}" for f in metrics['f1_per_class']]
|
||
}
|
||
|
||
per_class_df = pd.DataFrame(per_class_results)
|
||
per_class_df.to_csv(os.path.join(results_dir, "per_class_metrics.csv"), index=False, encoding='utf-8-sig')
|
||
|
||
# 保存混淆矩陣
|
||
cm_df = pd.DataFrame(
|
||
metrics['confusion_matrix'],
|
||
index=list(self.class_labels.values()),
|
||
columns=list(self.class_labels.values())
|
||
)
|
||
cm_df.to_csv(os.path.join(results_dir, "confusion_matrix.csv"), encoding='utf-8-sig')
|
||
|
||
# 保存詳細分類報告
|
||
with open(os.path.join(results_dir, "classification_report.txt"), 'w', encoding='utf-8') as f:
|
||
f.write("三類胃癌數據分類評估報告\n")
|
||
f.write("=" * 50 + "\n\n")
|
||
f.write(f"評估時間: {timestamp}\n\n")
|
||
f.write("整體性能指標:\n")
|
||
f.write(f"Accuracy: {metrics['accuracy']:.4f}\n")
|
||
f.write(f"Precision (Macro): {metrics['precision_macro']:.4f}\n")
|
||
f.write(f"Recall (Macro): {metrics['recall_macro']:.4f}\n")
|
||
f.write(f"F1-Score (Macro): {metrics['f1_macro']:.4f}\n\n")
|
||
f.write("詳細分類報告:\n")
|
||
f.write(metrics['classification_report'])
|
||
|
||
print(f"評估結果已保存至: {results_dir}")
|
||
return results_dir
|
||
|
||
def run_evaluation(self):
|
||
"""
|
||
執行完整的模型評估流程
|
||
"""
|
||
print("開始模型評估...")
|
||
print("=" * 50)
|
||
|
||
# 1. 檢查模型路徑
|
||
print("1. 檢查模型路徑...")
|
||
if not self.normal_model_path or not self.ca_model_path:
|
||
print("錯誤: 必須指定Normal和CA分類模型權重路徑")
|
||
return
|
||
|
||
print(f"Normal分類模型: {self.normal_model_path}")
|
||
print(f"CA分類模型: {self.ca_model_path}")
|
||
|
||
# 2. 載入模型
|
||
print("\n2. 載入模型...")
|
||
normal_model = self.load_classification_model(self.normal_model_path, num_classes=2)
|
||
ca_model = self.load_classification_model(self.ca_model_path, num_classes=2)
|
||
|
||
if not normal_model or not ca_model:
|
||
print("錯誤: 模型載入失敗")
|
||
return
|
||
|
||
# 3. 準備測試數據
|
||
print("\n3. 準備測試數據...")
|
||
try:
|
||
# 載入測試數據
|
||
test_data_dict = self.data_loader.process_main(status=False) # 不使用ImageGenerator
|
||
|
||
Total_Size_List = []
|
||
Train_Size = 0
|
||
print("前處理後資料集總數")
|
||
for label in Loading_Config["Training_Labels"]:
|
||
Train_Size += len(test_data_dict[label])
|
||
Total_Size_List.append(len(test_data_dict[label]))
|
||
print(f"Labels: {label}, 總數為: {len(test_data_dict[label])}")
|
||
|
||
print("總共有 " + str(Train_Size) + " 筆資料")
|
||
|
||
# 做出跟資料相同數量的Label
|
||
Classes = []
|
||
i = 0
|
||
tool = Tool()
|
||
# 取得One-hot encording 的資料
|
||
tool.Set_OneHotEncording(Loading_Config["Training_Labels"])
|
||
Encording_Label = tool.Get_OneHot_Encording_Label()
|
||
for encording in Encording_Label:
|
||
Classes.append(make_label_list(Total_Size_List[i], encording))
|
||
i += 1
|
||
|
||
# 將資料做成Dict的資料型態
|
||
Prepare = Load_Data_Prepare()
|
||
Prepare.Set_Final_Dict_Data(Loading_Config["Training_Labels"], test_data_dict, Classes, len(Loading_Config["Training_Labels"]))
|
||
Final_Dict_Data = Prepare.Get_Final_Data_Dict()
|
||
keys = list(Final_Dict_Data.keys())
|
||
|
||
Merge = merge()
|
||
Training_Data = Merge.merge_all_image_data(Final_Dict_Data[keys[0]], Final_Dict_Data[keys[1]]) # 將訓練資料合併成一個list
|
||
for i in range(2, len(Loading_Config["Training_Labels"])):
|
||
Training_Data = Merge.merge_all_image_data(Training_Data, Final_Dict_Data[keys[i]]) # 將訓練資料合併成一個list
|
||
|
||
Training_Label = Merge.merge_all_image_data(Final_Dict_Data[keys[len(Loading_Config["Training_Labels"])]], Final_Dict_Data[keys[len(Loading_Config["Training_Labels"]) + 1]]) #將訓練資料的label合併成一個label的list
|
||
for i in range(len(Loading_Config["Training_Labels"]) + 2, 2 * len(Loading_Config["Training_Labels"])):
|
||
Training_Label = Merge.merge_all_image_data(Training_Label, Final_Dict_Data[keys[i]]) # 將訓練資料合併成一個list
|
||
|
||
# 使用Setting_DataSet創建ListDataset
|
||
test_dataset = self.preprocessor.Setting_DataSet(
|
||
Training_Data,
|
||
Training_Label,
|
||
None, # 不使用mask
|
||
"Transform"
|
||
)
|
||
test_loader = self.preprocessor.Dataloader_Sampler(test_dataset, Batch_Size=32, Sampler=False)
|
||
|
||
print(f"測試數據載入完成,共 {len(test_dataset)} 個樣本")
|
||
|
||
except Exception as e:
|
||
print(f"測試數據準備失敗: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return
|
||
|
||
# 4. 進行預測
|
||
print("\n4. 進行三分類預測...")
|
||
predictions, true_labels = self.predict_three_class(
|
||
normal_model, ca_model, test_loader
|
||
)
|
||
|
||
# 5. 計算評估指標
|
||
print("\n5. 計算評估指標...")
|
||
metrics = self.calculate_metrics(true_labels, predictions)
|
||
|
||
# 6. 顯示結果
|
||
print("\n6. 評估結果:")
|
||
print("=" * 50)
|
||
print(f"整體準確率 (Accuracy): {metrics['accuracy']:.4f}")
|
||
print(f"整體精確率 (Precision): {metrics['precision_macro']:.4f}")
|
||
print(f"整體召回率 (Recall): {metrics['recall_macro']:.4f}")
|
||
print(f"整體F1分數 (F1-Score): {metrics['f1_macro']:.4f}")
|
||
|
||
print("\n各類別詳細指標:")
|
||
for i, (precision, recall, f1) in enumerate(zip(
|
||
metrics['precision_per_class'],
|
||
metrics['recall_per_class'],
|
||
metrics['f1_per_class']
|
||
)):
|
||
print(f"{self.class_labels[i]}:")
|
||
print(f" Precision: {precision:.4f}")
|
||
print(f" Recall: {recall:.4f}")
|
||
print(f" F1-Score: {f1:.4f}")
|
||
|
||
print(f"\n混淆矩陣:")
|
||
print(metrics['confusion_matrix'])
|
||
|
||
# 7. 保存結果
|
||
print("\n7. 保存評估結果...")
|
||
results_dir = self.save_results(metrics)
|
||
|
||
print("\n評估完成!")
|
||
return metrics, results_dir |