import pandas as pd from torch.nn import functional import torch from torch.utils.data import Dataset, DataLoader, RandomSampler import torchvision.transforms as transforms class ListDataset(Dataset): def __init__(self, data_list, labels_list, status): self.data = data_list self.labels = labels_list self.status = status def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] if self.status: from Image_Process.Image_Generator import Image_generator ImageGenerator = Image_generator("", "", 12) Transform = ImageGenerator.Generator_Content(5) sample = Transform(sample) label = self.labels[idx] return sample, label class Tool: def __init__(self) -> None: self.__ICG_Training_Root = "" self.__Normal_Training_Root = "" self.__Comprehensive_Training_Root = "" self.__ICG_Test_Data_Root = "" self.__Normal_Test_Data_Root = "" self.__Comprehensive_Testing_Root = "" self.__ICG_Validation_Data_Root = "" self.__Normal_Validation_Data_Root = "" self.__Comprehensive_Validation_Root = "" self.__ICG_ImageGenerator_Data_Root = "" self.__Normal_ImageGenerator_Data_Root = "" self.__Comprehensive_Generator_Root = "" self.__Labels = [] self.__OneHot_Encording = [] pass def Set_Labels(self): self.__Labels = ["stomach_cancer_Crop", "Normal_Crop", "Have_Question_Crop"] def Set_Save_Roots(self): self.__ICG_Training_Root = "../Dataset/Training/CA_ICG" self.__Normal_Training_Root = "../Dataset/Training/CA" self.__Comprehensive_Training_Root = "../Dataset/Training/Mixed" self.__ICG_Test_Data_Root = "../Dataset/Training/CA_ICG_TestData" self.__Normal_Test_Data_Root = "../Dataset/Training/Normal_TestData" self.__Comprehensive_Testing_Root = "../Dataset/Training/Comprehensive_TestData" self.__ICG_Validation_Data_Root = "../Dataset/Training/CA_ICG_ValidationData" self.__Normal_Validation_Data_Root = "../Dataset/Training/Normal_ValidationData" self.__Comprehensive_Validation_Root = "../Dataset/Training/Comprehensive_ValidationData" self.__ICG_ImageGenerator_Data_Root = "../Dataset/Training/ICG_ImageGenerator" self.__Normal_ImageGenerator_Data_Root = "../Dataset/Training/Normal_ImageGenerator" self.__Comprehensive_Generator_Root = "../Dataset/Training/Comprehensive_ImageGenerator" def Set_OneHotEncording(self, content): Counter = [] for i in range(len(content)): Counter.append(i) Counter = torch.tensor(Counter) self.__OneHot_Encording = functional.one_hot(Counter, len(content)) pass def Get_Data_Label(self): ''' 取得所需資料的Labels ''' return self.__Labels def Get_Save_Roots(self, choose): '''回傳結果為Train, test, validation choose = 1 => 取ICG Label else => 取Normal Label 若choose != 1 || choose != 2 => 會回傳四個結果 ''' if choose == 1: return self.__ICG_Training_Root, self.__ICG_Test_Data_Root, self.__ICG_Validation_Data_Root if choose == 2: return self.__Normal_Training_Root, self.__Normal_Test_Data_Root, self.__Normal_Validation_Data_Root else: return self.__Comprehensive_Training_Root, self.__Comprehensive_Testing_Root, self.__Comprehensive_Validation_Root def Get_Generator_Save_Roots(self, choose): '''回傳結果為Train, test, validation''' if choose == 1: return self.__ICG_ImageGenerator_Data_Root if choose == 2: return self.__Normal_ImageGenerator_Data_Root else: return self.__Comprehensive_Generator_Root def Get_OneHot_Encording_Label(self): return self.__OneHot_Encording def Convert_Data_To_DataSet_And_Put_To_Dataloader(self, Datas : list, Labels : list, Batch_Size : int, status : bool = True): seed = 42 # 設定任意整數作為種子 # 產生隨機種子產生器 generator = torch.Generator() generator.manual_seed(seed) # 創建 Dataset list_dataset = ListDataset(Datas, Labels, status) # sampler = RandomSampler(list_dataset, generator = generator) # 創建Sampler return DataLoader(dataset = list_dataset, batch_size = Batch_Size, num_workers = 0, pin_memory=True, shuffle = True)