Files
Stomach_Cancer_Pytorch/experiments/experiment.py

98 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from torchinfo import summary
from Training_Tools.PreProcess import Training_Precesses
from experiments.pytorch_Model import ModifiedXception
from experiments.Model_All_Step import All_Step
from Load_process.Load_Indepentend import Load_Indepentend_Data
from _validation.ValidationTheEnterData import validation_the_enter_data
import numpy as np
import torch
import torch.nn as nn
import time
class experiments():
def __init__(self, Image_Size, Model_Name, Experiment_Name, Epoch, Train_Batch_Size, tools, Number_Of_Classes, status):
'''
# 實驗物件
## 說明:
* 用於開始訓練pytorch的物件裡面分為數個方法負責處理實驗過程的種種
## parmeter:
* Topic_Tool: 讀取訓練、驗證、測試的資料集與Label等等的內容
* cut_image: 呼叫切割影像物件
* merge: 合併的物件
* model_name: 模型名稱,告訴我我是用哪個模型(可能是預處理模型/自己設計的模型)
* experiment_name: 實驗名稱
* epoch: 訓練次數
* train_batch_size: 訓練資料的batch
* convolution_name: Grad-CAM的最後一層的名稱
* Number_Of_Classes: Label的類別
* Status: 選擇現在資料集的狀態
* device: 決定使用GPU或CPU
## Method:
* processing_main: 實驗物件的進入點
* construct_model: 決定實驗用的Model
* Training_Step: 訓練步驟,開始進行訓練驗證的部分
* Evaluate_Model: 驗證模型的準確度
* record_matrix_image: 劃出混淆矩陣(熱力圖)
* record_everyTime_test_result: 記錄我單次的訓練結果並將它輸出到檔案中
'''
self.Topic_Tool = tools
self.validation_obj = validation_the_enter_data() # 呼叫驗證物件
self.cut_image = Load_Indepentend_Data(self.Topic_Tool.Get_Data_Label(), self.Topic_Tool.Get_OneHot_Encording_Label()) # 呼叫切割影像物件
self.model_name = Model_Name # 取名,告訴我我是用哪個模型(可能是預處理模型/自己設計的模型)
self.experiment_name = Experiment_Name
self.epoch = Epoch
self.train_batch_size = Train_Batch_Size
self.layers = 1
self.Number_Of_Classes = Number_Of_Classes
self.Image_Size = Image_Size
self.Grad = ""
self.Status = status
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pass
def processing_main(self, Training_Data, Training_Label, counter):
Train, Test = self.Topic_Tool.Get_Save_Roots(self.Status) # 要換不同資料集就要改
start = time.time()
self.cut_image.process_main(Test) # 呼叫處理test Data與Validation Data的function
end = time.time()
print("讀取testing與validation資料(154)執行時間:%f\n" % (end - start))
# 將處理好的test Data 與 Validation Data 丟給這個物件的變數
self.test, self.test_label = self.cut_image.test, self.cut_image.test_label
PreProcess = Training_Precesses(Training_Data, Training_Label, self.test, self.test_label)
cnn_model = self.construct_model() # 呼叫讀取模型的function
print(summary(cnn_model, input_size=(int(self.train_batch_size / 2), 3, self.Image_Size, self.Image_Size)))
for name, parameters in cnn_model.named_parameters():
print(f"Layer Name: {name}, Parameters: {parameters.size()}")
step = All_Step(PreProcess, self.train_batch_size, cnn_model, self.epoch, self.Number_Of_Classes, self.model_name, self.experiment_name)
print("\n\n\n讀取訓練資料(70000)執行時間:%f\n\n" % (end - start))
step.Training_Step(self.model_name, counter)
step.Evaluate_Model(cnn_model, self.model_name, counter)
# self.Grad.process_main(cnn_model, counter, Testing_Dataset)
pass
def construct_model(self):
'''決定我這次訓練要用哪個model'''
cnn_model = ModifiedXception(self.Number_Of_Classes)
if torch.cuda.device_count() > 1:
cnn_model = nn.DataParallel(cnn_model)
cnn_model = cnn_model.to(self.device)
return cnn_model