Files
Stomach_Cancer_Pytorch/Training_Tools/PreProcess.py

61 lines
2.4 KiB
Python

from torch.utils.data import Dataset, DataLoader, RandomSampler, WeightedRandomSampler
import torchvision.transforms as transforms
import torch
class ListDataset(Dataset):
def __init__(self, data_list, labels_list, status):
self.data = data_list
self.labels = labels_list
self.status = status
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
if self.status:
from Image_Process.Image_Generator import Image_generator
ImageGenerator = Image_generator("", "", 12)
Transform = ImageGenerator.Generator_Content(5)
sample = Transform(sample)
label = self.labels[idx]
return sample, label
class Training_Precesses:
def __init__(self, Training_Datas, Training_Labels, Testing_Datas, Testing_Labels):
self.Training_Datas = Training_Datas
self.Training_Labels = Training_Labels
self.Testing_Datas = Testing_Datas
self.Testing_Labels = Testing_Labels
seed = 42 # 設定任意整數作為種子
# 產生隨機種子產生器
self.generator = torch.Generator()
self.generator.manual_seed(seed)
pass
def Total_Data_Combine_To_DataLoader(self, Batch_Size):
Training_Dataset = self.Convert_Data_To_DataSet(self.Training_Datas, self.Training_Labels)
Testing_Dataset = self.Convert_Data_To_DataSet(self.Testing_Datas, self.Testing_Labels)
Training_DataLoader = DataLoader(dataset = Training_Dataset, batch_size = Batch_Size, num_workers = 0, pin_memory=True, shuffle = True)
Testing_DataLoader = DataLoader(dataset = Testing_Dataset, batch_size = 1, num_workers = 0, pin_memory=True, shuffle = True)
return Training_DataLoader, Testing_DataLoader
def Combine_Signal_Dataset_To_DataLoader(self, datas : list, Labels : list, Batch_Size, status : bool = True):
dataset = self.Convert_Data_To_DataSet(datas, Labels, status)
sampler = WeightedRandomSampler(dataset, generator = self.generator) # 創建Sampler
Dataloader = DataLoader(dataset = dataset, batch_size = Batch_Size, num_workers = 0, pin_memory=True, sampler = sampler)
return Dataloader
def Convert_Data_To_DataSet(self, Datas : list, Labels : list, status : bool = True):
# 創建 Dataset
list_dataset = ListDataset(Datas, Labels, status)
return list_dataset