Stomach_Cancer_Pytorch/experiments/model_evaluation.py

371 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模型評估程式
用於載入最佳的分類模型權重和分割模型權重,
對三類資料(胃癌、正常、非胃癌有病)計算整體的Precision、Recall、Accuracy、F1-Score
"""
import torch
import torch.nn as nn
import numpy as np
import os
import glob
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
from tqdm import tqdm
import pandas as pd
from datetime import datetime
# 導入必要的模組
from experiments.Models.Xception_Model_Modification import Xception
from experiments.Models.GastroSegNet_Model import GastroSegNet
from Load_process.LoadData import Loding_Data_Root
from Training_Tools.PreProcess import Training_Precesses
from Load_process.LoadData import Load_Data_Prepare
from utils.Stomach_Config import Training_Config, Loading_Config
from model_data_processing.processing import make_label_list
from Training_Tools.Tools import Tool
from merge_class.merge import merge
class ModelEvaluator:
def __init__(self, Normal_Model_Path, CA_Model_Path):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用設備: {self.device}")
# 初始化數據加載器
self.data_loader = Loding_Data_Root(
Loading_Config["Training_Labels"],
Loading_Config["Test_Data_Root"],
Loading_Config["Test_Data_Root"]
)
# 初始化預處理器
self.preprocessor = Training_Precesses(Training_Config["Image_Size"])
# 類別標籤映射
self.class_labels = {
0: "胃癌 (stomach_cancer_Crop)",
1: "正常 (Normal_Crop)",
2: "非胃癌有病 (Have_Question_Crop)"
}
# 直接指定的模型路徑
self.normal_model_path = Normal_Model_Path
self.ca_model_path = CA_Model_Path
def load_classification_model(self, model_path, num_classes=2):
"""
載入分類模型
"""
try:
model = Xception()
# 如果是多GPU訓練的模型需要處理DataParallel
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model = model.to(self.device)
# 載入權重
checkpoint = torch.load(model_path, map_location=self.device)
model.load_state_dict(checkpoint)
model.eval()
print(f"成功載入分類模型: {model_path}")
return model
except Exception as e:
print(f"載入分類模型失敗: {e}")
return None
def load_segmentation_model(self, model_path):
"""
載入分割模型
"""
try:
model = GastroSegNet(in_channels=3, out_channels=3)
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model = model.to(self.device)
# 載入權重
checkpoint = torch.load(model_path, map_location=self.device)
model.load_state_dict(checkpoint)
model.eval()
print(f"成功載入分割模型: {model_path}")
return model
except Exception as e:
print(f"載入分割模型失敗: {e}")
return None
def predict_three_class(self, normal_model, ca_model, test_data):
"""
使用兩個二分類模型進行三分類預測
邏輯:
1. 先用Normal模型判斷是否為正常
2. 如果不是正常再用CA模型判斷是胃癌還是非胃癌有病
"""
all_predictions = []
all_true_labels = []
with torch.no_grad():
for images, labels, filename, Class in tqdm(test_data, desc="預測中"):
images = images.to(self.device)
batch_size = images.size(0)
# 獲取真實標籤
if isinstance(labels, torch.Tensor):
if labels.dim() > 1: # one-hot編碼
true_labels = torch.argmax(labels, dim=1).cpu().numpy()
else: # 已經是類別索引
true_labels = labels.cpu().numpy()
else:
# 如果是numpy數組或列表
labels = np.array(labels)
if labels.ndim > 1:
true_labels = np.argmax(labels, axis=1)
else:
true_labels = labels
all_true_labels.extend(true_labels)
# 第一步使用Normal模型判斷是否為正常
normal_outputs = normal_model(images)
# 第二步使用CA模型判斷胃癌vs非胃癌有病
ca_outputs = ca_model(images)
# 三分類邏輯
batch_predictions = []
for i in range(batch_size):
# 如果Normal模型認為是正常第二個類別概率高
if normal_outputs[i, 1] > normal_outputs[i, 0]:
prediction = 1 # 正常
else:
# 如果不是正常用CA模型判斷
if ca_outputs[i, 0] > ca_outputs[i, 1]: # 胃癌概率高
prediction = 0 # 胃癌
else:
prediction = 2 # 非胃癌有病
batch_predictions.append(prediction)
all_predictions.extend(batch_predictions)
return np.array(all_predictions), np.array(all_true_labels)
def calculate_metrics(self, y_true, y_pred):
"""
計算各種評估指標
"""
# 整體指標
accuracy = accuracy_score(y_true, y_pred)
precision_macro = precision_score(y_true, y_pred, average='macro', zero_division=0)
recall_macro = recall_score(y_true, y_pred, average='macro', zero_division=0)
f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
# 每類別指標
precision_per_class = precision_score(y_true, y_pred, average=None, zero_division=0)
recall_per_class = recall_score(y_true, y_pred, average=None, zero_division=0)
f1_per_class = f1_score(y_true, y_pred, average=None, zero_division=0)
# 混淆矩陣
cm = confusion_matrix(y_true, y_pred)
# 分類報告
report = classification_report(y_true, y_pred, target_names=list(self.class_labels.values()))
return {
'accuracy': accuracy,
'precision_macro': precision_macro,
'recall_macro': recall_macro,
'f1_macro': f1_macro,
'precision_per_class': precision_per_class,
'recall_per_class': recall_per_class,
'f1_per_class': f1_per_class,
'confusion_matrix': cm,
'classification_report': report
}
def save_results(self, metrics):
"""
保存評估結果
"""
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
results_dir = f"../Result/Model_Evaluation"
os.makedirs(results_dir, exist_ok=True)
# 保存整體指標
overall_results = {
'指標': ['Accuracy', 'Precision (Macro)', 'Recall (Macro)', 'F1-Score (Macro)'],
'數值': [
f"{metrics['accuracy']:.4f}",
f"{metrics['precision_macro']:.4f}",
f"{metrics['recall_macro']:.4f}",
f"{metrics['f1_macro']:.4f}"
]
}
overall_df = pd.DataFrame(overall_results)
overall_df.to_csv(os.path.join(results_dir, "overall_metrics.csv"), index=False, encoding='utf-8-sig')
# 保存每類別指標
per_class_results = {
'類別': list(self.class_labels.values()),
'Precision': [f"{p:.4f}" for p in metrics['precision_per_class']],
'Recall': [f"{r:.4f}" for r in metrics['recall_per_class']],
'F1-Score': [f"{f:.4f}" for f in metrics['f1_per_class']]
}
per_class_df = pd.DataFrame(per_class_results)
per_class_df.to_csv(os.path.join(results_dir, "per_class_metrics.csv"), index=False, encoding='utf-8-sig')
# 保存混淆矩陣
cm_df = pd.DataFrame(
metrics['confusion_matrix'],
index=list(self.class_labels.values()),
columns=list(self.class_labels.values())
)
cm_df.to_csv(os.path.join(results_dir, "confusion_matrix.csv"), encoding='utf-8-sig')
# 保存詳細分類報告
with open(os.path.join(results_dir, "classification_report.txt"), 'w', encoding='utf-8') as f:
f.write("三類胃癌數據分類評估報告\n")
f.write("=" * 50 + "\n\n")
f.write(f"評估時間: {timestamp}\n\n")
f.write("整體性能指標:\n")
f.write(f"Accuracy: {metrics['accuracy']:.4f}\n")
f.write(f"Precision (Macro): {metrics['precision_macro']:.4f}\n")
f.write(f"Recall (Macro): {metrics['recall_macro']:.4f}\n")
f.write(f"F1-Score (Macro): {metrics['f1_macro']:.4f}\n\n")
f.write("詳細分類報告:\n")
f.write(metrics['classification_report'])
print(f"評估結果已保存至: {results_dir}")
return results_dir
def run_evaluation(self):
"""
執行完整的模型評估流程
"""
print("開始模型評估...")
print("=" * 50)
# 1. 檢查模型路徑
print("1. 檢查模型路徑...")
if not self.normal_model_path or not self.ca_model_path:
print("錯誤: 必須指定Normal和CA分類模型權重路徑")
return
print(f"Normal分類模型: {self.normal_model_path}")
print(f"CA分類模型: {self.ca_model_path}")
# 2. 載入模型
print("\n2. 載入模型...")
normal_model = self.load_classification_model(self.normal_model_path, num_classes=2)
ca_model = self.load_classification_model(self.ca_model_path, num_classes=2)
if not normal_model or not ca_model:
print("錯誤: 模型載入失敗")
return
# 3. 準備測試數據
print("\n3. 準備測試數據...")
try:
# 載入測試數據
test_data_dict = self.data_loader.process_main(status=False) # 不使用ImageGenerator
Total_Size_List = []
Train_Size = 0
print("前處理後資料集總數")
for label in Loading_Config["Training_Labels"]:
Train_Size += len(test_data_dict[label])
Total_Size_List.append(len(test_data_dict[label]))
print(f"Labels: {label}, 總數為: {len(test_data_dict[label])}")
print("總共有 " + str(Train_Size) + " 筆資料")
# 做出跟資料相同數量的Label
Classes = []
i = 0
tool = Tool()
# 取得One-hot encording 的資料
tool.Set_OneHotEncording(Loading_Config["Training_Labels"])
Encording_Label = tool.Get_OneHot_Encording_Label()
for encording in Encording_Label:
Classes.append(make_label_list(Total_Size_List[i], encording))
i += 1
# 將資料做成Dict的資料型態
Prepare = Load_Data_Prepare()
Prepare.Set_Final_Dict_Data(Loading_Config["Training_Labels"], test_data_dict, Classes, len(Loading_Config["Training_Labels"]))
Final_Dict_Data = Prepare.Get_Final_Data_Dict()
keys = list(Final_Dict_Data.keys())
Merge = merge()
Training_Data = Merge.merge_all_image_data(Final_Dict_Data[keys[0]], Final_Dict_Data[keys[1]]) # 將訓練資料合併成一個list
for i in range(2, len(Loading_Config["Training_Labels"])):
Training_Data = Merge.merge_all_image_data(Training_Data, Final_Dict_Data[keys[i]]) # 將訓練資料合併成一個list
Training_Label = Merge.merge_all_image_data(Final_Dict_Data[keys[len(Loading_Config["Training_Labels"])]], Final_Dict_Data[keys[len(Loading_Config["Training_Labels"]) + 1]]) #將訓練資料的label合併成一個label的list
for i in range(len(Loading_Config["Training_Labels"]) + 2, 2 * len(Loading_Config["Training_Labels"])):
Training_Label = Merge.merge_all_image_data(Training_Label, Final_Dict_Data[keys[i]]) # 將訓練資料合併成一個list
# 使用Setting_DataSet創建ListDataset
test_dataset = self.preprocessor.Setting_DataSet(
Training_Data,
Training_Label,
None, # 不使用mask
"Transform"
)
test_loader = self.preprocessor.Dataloader_Sampler(test_dataset, Batch_Size=32, Sampler=False)
print(f"測試數據載入完成,共 {len(test_dataset)} 個樣本")
except Exception as e:
print(f"測試數據準備失敗: {e}")
import traceback
traceback.print_exc()
return
# 4. 進行預測
print("\n4. 進行三分類預測...")
predictions, true_labels = self.predict_three_class(
normal_model, ca_model, test_loader
)
# 5. 計算評估指標
print("\n5. 計算評估指標...")
metrics = self.calculate_metrics(true_labels, predictions)
# 6. 顯示結果
print("\n6. 評估結果:")
print("=" * 50)
print(f"整體準確率 (Accuracy): {metrics['accuracy']:.4f}")
print(f"整體精確率 (Precision): {metrics['precision_macro']:.4f}")
print(f"整體召回率 (Recall): {metrics['recall_macro']:.4f}")
print(f"整體F1分數 (F1-Score): {metrics['f1_macro']:.4f}")
print("\n各類別詳細指標:")
for i, (precision, recall, f1) in enumerate(zip(
metrics['precision_per_class'],
metrics['recall_per_class'],
metrics['f1_per_class']
)):
print(f"{self.class_labels[i]}:")
print(f" Precision: {precision:.4f}")
print(f" Recall: {recall:.4f}")
print(f" F1-Score: {f1:.4f}")
print(f"\n混淆矩陣:")
print(metrics['confusion_matrix'])
# 7. 保存結果
print("\n7. 保存評估結果...")
results_dir = self.save_results(metrics)
print("\n評估完成!")
return metrics, results_dir