from torch.utils.data import Dataset, DataLoader, RandomSampler, WeightedRandomSampler, SubsetRandomSampler, Subset from torchvision.datasets import ImageFolder import torchvision.transforms as transforms from PIL import Image import torch import numpy as np import cv2 class ListDataset(Dataset): def __init__(self, data_list, labels_list, Mask_List, transform): self.data = data_list self.labels = labels_list self.Mask_Truth_List = Mask_List self.transform = transform self.roots = [] def __len__(self): return len(self.data) def __getitem__(self, idx): Image_Root = self.data[idx] # Mask_Ground_Truth = None # if self.Mask_Truth_List is not None: # mask_path = self.Mask_Truth_List[idx] # if mask_path is not None: # 確保掩碼路徑不為None # try: # Mask_Ground_Truth = Image.open(mask_path).convert("RGB") # # 先不轉換為 tensor,等待 transform 處理完後再轉換 # except Exception as e: # print(e) Split_Roots = Image_Root.split("/") Split_Roots = Split_Roots[-1].split("\\") File_Name = Split_Roots[-1] classes = Split_Roots[-2] try: Images = Image.open(Image_Root).convert("RGB") except Exception as e: assert e is not None, f"Error loading image {Image_Root}: {e}" if self.transform != "Generator": Images = self.transform(Images) # if self.Mask_Truth_List is not None and Mask_Ground_Truth is not None and not isinstance(Mask_Ground_Truth, torch.Tensor): # Mask_Ground_Truth = self.transform(Mask_Ground_Truth) # # 確保 Images 是 tensor # if not isinstance(Images, torch.Tensor): # Images = torch.tensor(np.array(Images)) # # 確保 Mask_Ground_Truth 是 tensor # if self.Mask_Truth_List is not None and Mask_Ground_Truth is not None and not isinstance(Mask_Ground_Truth, torch.Tensor): # Mask_Ground_Truth = torch.tensor(np.array(Mask_Ground_Truth)) Images = torch.tensor(np.array(Images)) label = self.labels[idx] # if self.Mask_Truth_List is not None: # # 如果掩碼為None,創建一個與圖像相同大小的空掩碼 # if Mask_Ground_Truth is None: # if isinstance(Images, torch.Tensor): # # 創建與圖像相同大小的空掩碼張量 # Mask_Ground_Truth = torch.zeros_like(Images) # else: # # 如果圖像不是張量,創建一個空的PIL圖像 # Mask_Ground_Truth = Image.new('RGB', Images.size, (0, 0, 0)) # if self.transform != "Generator": # Mask_Ground_Truth = self.transform(Mask_Ground_Truth) # return Images, Mask_Ground_Truth, label, File_Name, classes # print(f"Dataset_Data: \n{sample}\n") return Images, label, File_Name, classes class Training_Precesses: def __init__(self, ImageSize): seed = 42 # Set an arbitrary integer as the seed self.ImageSize = ImageSize self.generator = torch.Generator() self.generator.manual_seed(seed) def Dataloader_Sampler(self, SubDataSet, Batch_Size, Sampler=True): if Sampler: Data_Loader = DataLoader( dataset=SubDataSet, batch_size=Batch_Size, num_workers=0, pin_memory=True, sampler=self.Setting_WeightedRandomSampler_Content(SubDataSet) ) else: Data_Loader = DataLoader( dataset=SubDataSet, batch_size=Batch_Size, num_workers=0, pin_memory=True ) return Data_Loader def Setting_WeightedRandomSampler_Content(self, SubDataSet): # Check if SubDataSet is a Subset or a full dataset if isinstance(SubDataSet, Subset): # Get the underlying dataset and subset indices base_dataset = SubDataSet.dataset subset_indices = SubDataSet.indices # Extract labels for the subset labels = [base_dataset.labels[i] for i in subset_indices] else: # Assume SubDataSet is a ListDataset or similar labels = SubDataSet.labels # Convert labels to class indices if they are one-hot encoded labels = np.array(labels) if labels.ndim > 1: # If one-hot encoded labels = np.argmax(labels, axis=1) # 確保標籤是整數類型 try: # 嘗試將標籤轉換為整數 labels = labels.astype(np.int64) except ValueError: # 如果標籤是字符串,先將其映射到整數 unique_labels = np.unique(labels) label_to_idx = {label: idx for idx, label in enumerate(unique_labels)} labels = np.array([label_to_idx[label] for label in labels]) # Count occurrences of each class class_counts = np.bincount(labels) class_weights = 1.0 / class_counts # Inverse frequency as weight sample_weights = class_weights[labels] # Assign weight to each sample return WeightedRandomSampler( weights=sample_weights, num_samples=len(sample_weights), replacement=True ) def Setting_RandomSampler_Content(self, Dataset): return RandomSampler(Dataset, generator = self.generator) def Setting_DataSet(self, Datas, Labels, Mask_List, transform = None): # 資料預處理 if transform == None: transform = transforms.Compose([ transforms.Resize((self.ImageSize, self.ImageSize)) ]) elif transform == "Transform": transform = transforms.Compose([ transforms.Resize((self.ImageSize, self.ImageSize)), transforms.ToTensor() ]) elif transform == "Generator": transform = "Generator" # Create Dataset list_dataset = ListDataset(Datas, Labels, Mask_List, transform) return list_dataset def Setting_SubsetRandomSampler_Content(self, SubDataSet): # Calculate subset indices (example: using a fraction of the dataset) dataset_size = len(SubDataSet) subset_size = int(0.8 * dataset_size) # Use 80% of the dataset as an example subset_indices = torch.randperm(dataset_size, generator=self.generator)[:subset_size] return SubsetRandomSampler(subset_indices, generator=self.generator)