Stomach_Cancer_Pytorch/Training_Tools/Tools.py

121 lines
4.6 KiB
Python

import pandas as pd
from torch.nn import functional
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
import torchvision.transforms as transforms
class ListDataset(Dataset):
def __init__(self, data_list, labels_list, status):
self.data = data_list
self.labels = labels_list
self.status = status
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
if self.status:
from Image_Process.Image_Generator import Image_generator
ImageGenerator = Image_generator("", "", 12)
Transform = ImageGenerator.Generator_Content(5)
sample = Transform(sample)
label = self.labels[idx]
return sample, label
class Tool:
def __init__(self) -> None:
self.__ICG_Training_Root = ""
self.__Normal_Training_Root = ""
self.__Comprehensive_Training_Root = ""
self.__ICG_Test_Data_Root = ""
self.__Normal_Test_Data_Root = ""
self.__Comprehensive_Testing_Root = ""
self.__ICG_Validation_Data_Root = ""
self.__Normal_Validation_Data_Root = ""
self.__Comprehensive_Validation_Root = ""
self.__ICG_ImageGenerator_Data_Root = ""
self.__Normal_ImageGenerator_Data_Root = ""
self.__Comprehensive_Generator_Root = ""
self.__Labels = []
self.__OneHot_Encording = []
pass
def Set_Labels(self):
self.__Labels = ["stomach_cancer_Crop", "Normal_Crop", "Have_Question_Crop"]
def Set_Save_Roots(self):
self.__ICG_Training_Root = "../Dataset/Training/CA_ICG"
self.__Normal_Training_Root = "../Dataset/Training/CA"
self.__Comprehensive_Training_Root = "../Dataset/Training/Mixed"
self.__ICG_Test_Data_Root = "../Dataset/Training/CA_ICG_TestData"
self.__Normal_Test_Data_Root = "../Dataset/Training/Normal_TestData"
self.__Comprehensive_Testing_Root = "../Dataset/Training/Comprehensive_TestData"
self.__ICG_Validation_Data_Root = "../Dataset/Training/CA_ICG_ValidationData"
self.__Normal_Validation_Data_Root = "../Dataset/Training/Normal_ValidationData"
self.__Comprehensive_Validation_Root = "../Dataset/Training/Comprehensive_ValidationData"
self.__ICG_ImageGenerator_Data_Root = "../Dataset/Training/ICG_ImageGenerator"
self.__Normal_ImageGenerator_Data_Root = "../Dataset/Training/Normal_ImageGenerator"
self.__Comprehensive_Generator_Root = "../Dataset/Training/Comprehensive_ImageGenerator"
def Set_OneHotEncording(self, content):
Counter = []
for i in range(len(content)):
Counter.append(i)
Counter = torch.tensor(Counter)
self.__OneHot_Encording = functional.one_hot(Counter, len(content))
pass
def Get_Data_Label(self):
'''
取得所需資料的Labels
'''
return self.__Labels
def Get_Save_Roots(self, choose):
'''回傳結果為Train, test, validation
choose = 1 => 取ICG Label
else => 取Normal Label
若choose != 1 || choose != 2 => 會回傳四個結果
'''
if choose == 1:
return self.__ICG_Training_Root, self.__ICG_Test_Data_Root, self.__ICG_Validation_Data_Root
if choose == 2:
return self.__Normal_Training_Root, self.__Normal_Test_Data_Root, self.__Normal_Validation_Data_Root
else:
return self.__Comprehensive_Training_Root, self.__Comprehensive_Testing_Root, self.__Comprehensive_Validation_Root
def Get_Generator_Save_Roots(self, choose):
'''回傳結果為Train, test, validation'''
if choose == 1:
return self.__ICG_ImageGenerator_Data_Root
if choose == 2:
return self.__Normal_ImageGenerator_Data_Root
else:
return self.__Comprehensive_Generator_Root
def Get_OneHot_Encording_Label(self):
return self.__OneHot_Encording
def Convert_Data_To_DataSet_And_Put_To_Dataloader(self, Datas : list, Labels : list, Batch_Size : int, status : bool = True):
seed = 42 # 設定任意整數作為種子
# 產生隨機種子產生器
generator = torch.Generator()
generator.manual_seed(seed)
# 創建 Dataset
list_dataset = ListDataset(Datas, Labels, status)
# sampler = RandomSampler(list_dataset, generator = generator) # 創建Sampler
return DataLoader(dataset = list_dataset, batch_size = Batch_Size, num_workers = 0, pin_memory=True, shuffle = True)