117 lines
4.1 KiB
Python
117 lines
4.1 KiB
Python
from torch.utils.data import Dataset, DataLoader, RandomSampler, WeightedRandomSampler, SubsetRandomSampler, Subset
|
|
from torchvision.datasets import ImageFolder
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
import torch
|
|
import numpy as np
|
|
import cv2
|
|
|
|
class ListDataset(Dataset):
|
|
def __init__(self, data_list, labels_list, transform):
|
|
self.data = data_list
|
|
self.labels = labels_list
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
Image_Root = self.data[idx]
|
|
try:
|
|
with open(Image_Root, 'rb') as file:
|
|
Images = Image.open(file).convert("RGB")
|
|
# Image = cv2.imread(Image_Root, cv2.IMREAD_COLOR) # 讀檔(彩色)
|
|
# Image = cv2.cvtColor(Image, cv2.COLOR_BGR2RGB)
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
if self.transform is not "Generator":
|
|
Images = self.transform(Images)
|
|
|
|
Images = torch.tensor(np.array(Images))
|
|
label = self.labels[idx]
|
|
|
|
# print(f"Dataset_Data: \n{sample}\n")
|
|
return Images, label
|
|
|
|
class Training_Precesses:
|
|
def __init__(self, ImageSize):
|
|
seed = 42 # Set an arbitrary integer as the seed
|
|
self.ImageSize = ImageSize
|
|
self.generator = torch.Generator()
|
|
self.generator.manual_seed(seed)
|
|
|
|
def Dataloader_Sampler(self, SubDataSet, Batch_Size, Sampler=True):
|
|
if Sampler:
|
|
# Data_Loader = DataLoader(
|
|
# dataset=SubDataSet,
|
|
# batch_size=Batch_Size,
|
|
# num_workers=0,
|
|
# pin_memory=True,
|
|
# sampler=self.Setting_RandomSampler_Content(SubDataSet)
|
|
# )
|
|
Data_Loader = DataLoader(
|
|
dataset=SubDataSet,
|
|
batch_size=Batch_Size,
|
|
num_workers=0,
|
|
pin_memory=True,
|
|
sampler=self.Setting_RandomSampler_Content(SubDataSet)
|
|
)
|
|
else:
|
|
Data_Loader = DataLoader(
|
|
dataset=SubDataSet,
|
|
batch_size=Batch_Size,
|
|
num_workers=0,
|
|
pin_memory=True,
|
|
shuffle=True
|
|
)
|
|
return Data_Loader
|
|
|
|
def Setting_WeightedRandomSampler_Content(self, SubDataSet):
|
|
# Check if SubDataSet is a Subset or a full dataset
|
|
if isinstance(SubDataSet, Subset):
|
|
# Get the underlying dataset and subset indices
|
|
base_dataset = SubDataSet.dataset
|
|
subset_indices = SubDataSet.indices
|
|
# Extract labels for the subset
|
|
labels = [base_dataset.labels[i] for i in subset_indices]
|
|
else:
|
|
# Assume SubDataSet is a ListDataset or similar
|
|
labels = SubDataSet.labels
|
|
|
|
# Convert labels to class indices if they are one-hot encoded
|
|
labels = np.array(labels)
|
|
if labels.ndim > 1: # If one-hot encoded
|
|
labels = np.argmax(labels, axis=1)
|
|
|
|
# Count occurrences of each class
|
|
class_counts = np.bincount(labels)
|
|
class_weights = 1.0 / class_counts # Inverse frequency as weight
|
|
sample_weights = class_weights[labels] # Assign weight to each sample
|
|
|
|
return WeightedRandomSampler(
|
|
weights=sample_weights,
|
|
num_samples=len(sample_weights),
|
|
replacement=True
|
|
)
|
|
|
|
def Setting_RandomSampler_Content(self, Dataset):
|
|
return RandomSampler(Dataset, generator = self.generator)
|
|
|
|
def Setting_DataSet(self, Datas, Labels, transform = None):
|
|
# 資料預處理
|
|
if transform == None:
|
|
transform = transforms.Compose([
|
|
transforms.Resize((256, 256))
|
|
])
|
|
elif transform == "Transform":
|
|
transform = transforms.Compose([
|
|
transforms.Resize((256, 256)),
|
|
transforms.ToTensor()
|
|
])
|
|
elif transform == "Generator":
|
|
transform = "Generator"
|
|
|
|
# Create Dataset
|
|
list_dataset = ListDataset(Datas, Labels , transform)
|
|
return list_dataset |