164 lines
6.6 KiB
Python
164 lines
6.6 KiB
Python
from torch.utils.data import Dataset, DataLoader, RandomSampler, WeightedRandomSampler, SubsetRandomSampler, Subset
|
||
from torchvision.datasets import ImageFolder
|
||
import torchvision.transforms as transforms
|
||
from PIL import Image
|
||
import torch
|
||
import numpy as np
|
||
import cv2
|
||
|
||
class ListDataset(Dataset):
|
||
def __init__(self, data_list, labels_list, Mask_List, transform):
|
||
self.data = data_list
|
||
self.labels = labels_list
|
||
self.Mask_Truth_List = Mask_List
|
||
self.transform = transform
|
||
self.roots = []
|
||
|
||
def __len__(self):
|
||
return len(self.data)
|
||
|
||
def __getitem__(self, idx):
|
||
Image_Root = self.data[idx]
|
||
|
||
# Mask_Ground_Truth = None
|
||
# if self.Mask_Truth_List is not None:
|
||
# mask_path = self.Mask_Truth_List[idx]
|
||
# if mask_path is not None: # 確保掩碼路徑不為None
|
||
# try:
|
||
# Mask_Ground_Truth = Image.open(mask_path).convert("RGB")
|
||
# # 先不轉換為 tensor,等待 transform 處理完後再轉換
|
||
# except Exception as e:
|
||
# print(e)
|
||
|
||
Split_Roots = Image_Root.split("/")
|
||
Split_Roots = Split_Roots[-1].split("\\")
|
||
File_Name = Split_Roots[-1]
|
||
classes = Split_Roots[-2]
|
||
|
||
try:
|
||
Images = Image.open(Image_Root).convert("RGB")
|
||
except Exception as e:
|
||
assert e is not None, f"Error loading image {Image_Root}: {e}"
|
||
|
||
if self.transform != "Generator":
|
||
Images = self.transform(Images)
|
||
# if self.Mask_Truth_List is not None and Mask_Ground_Truth is not None and not isinstance(Mask_Ground_Truth, torch.Tensor):
|
||
# Mask_Ground_Truth = self.transform(Mask_Ground_Truth)
|
||
|
||
# # 確保 Images 是 tensor
|
||
# if not isinstance(Images, torch.Tensor):
|
||
# Images = torch.tensor(np.array(Images))
|
||
|
||
# # 確保 Mask_Ground_Truth 是 tensor
|
||
# if self.Mask_Truth_List is not None and Mask_Ground_Truth is not None and not isinstance(Mask_Ground_Truth, torch.Tensor):
|
||
# Mask_Ground_Truth = torch.tensor(np.array(Mask_Ground_Truth))
|
||
|
||
Images = torch.tensor(np.array(Images))
|
||
label = self.labels[idx]
|
||
|
||
# if self.Mask_Truth_List is not None:
|
||
# # 如果掩碼為None,創建一個與圖像相同大小的空掩碼
|
||
# if Mask_Ground_Truth is None:
|
||
# if isinstance(Images, torch.Tensor):
|
||
# # 創建與圖像相同大小的空掩碼張量
|
||
# Mask_Ground_Truth = torch.zeros_like(Images)
|
||
# else:
|
||
# # 如果圖像不是張量,創建一個空的PIL圖像
|
||
# Mask_Ground_Truth = Image.new('RGB', Images.size, (0, 0, 0))
|
||
# if self.transform != "Generator":
|
||
# Mask_Ground_Truth = self.transform(Mask_Ground_Truth)
|
||
# return Images, Mask_Ground_Truth, label, File_Name, classes
|
||
# print(f"Dataset_Data: \n{sample}\n")
|
||
return Images, label, File_Name, classes
|
||
|
||
class Training_Precesses:
|
||
def __init__(self, ImageSize):
|
||
seed = 42 # Set an arbitrary integer as the seed
|
||
self.ImageSize = ImageSize
|
||
self.generator = torch.Generator()
|
||
self.generator.manual_seed(seed)
|
||
|
||
def Dataloader_Sampler(self, SubDataSet, Batch_Size, Sampler=True):
|
||
if Sampler:
|
||
Data_Loader = DataLoader(
|
||
dataset=SubDataSet,
|
||
batch_size=Batch_Size,
|
||
num_workers=0,
|
||
pin_memory=True,
|
||
sampler=self.Setting_WeightedRandomSampler_Content(SubDataSet)
|
||
)
|
||
else:
|
||
Data_Loader = DataLoader(
|
||
dataset=SubDataSet,
|
||
batch_size=Batch_Size,
|
||
num_workers=0,
|
||
pin_memory=True
|
||
)
|
||
return Data_Loader
|
||
|
||
def Setting_WeightedRandomSampler_Content(self, SubDataSet):
|
||
# Check if SubDataSet is a Subset or a full dataset
|
||
if isinstance(SubDataSet, Subset):
|
||
# Get the underlying dataset and subset indices
|
||
base_dataset = SubDataSet.dataset
|
||
subset_indices = SubDataSet.indices
|
||
# Extract labels for the subset
|
||
labels = [base_dataset.labels[i] for i in subset_indices]
|
||
else:
|
||
# Assume SubDataSet is a ListDataset or similar
|
||
labels = SubDataSet.labels
|
||
|
||
# Convert labels to class indices if they are one-hot encoded
|
||
labels = np.array(labels)
|
||
if labels.ndim > 1: # If one-hot encoded
|
||
labels = np.argmax(labels, axis=1)
|
||
|
||
# 確保標籤是整數類型
|
||
try:
|
||
# 嘗試將標籤轉換為整數
|
||
labels = labels.astype(np.int64)
|
||
except ValueError:
|
||
# 如果標籤是字符串,先將其映射到整數
|
||
unique_labels = np.unique(labels)
|
||
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
|
||
labels = np.array([label_to_idx[label] for label in labels])
|
||
|
||
# Count occurrences of each class
|
||
class_counts = np.bincount(labels)
|
||
class_weights = 1.0 / class_counts # Inverse frequency as weight
|
||
sample_weights = class_weights[labels] # Assign weight to each sample
|
||
|
||
return WeightedRandomSampler(
|
||
weights=sample_weights,
|
||
num_samples=len(sample_weights),
|
||
replacement=True
|
||
)
|
||
|
||
def Setting_RandomSampler_Content(self, Dataset):
|
||
return RandomSampler(Dataset, generator = self.generator)
|
||
|
||
def Setting_DataSet(self, Datas, Labels, Mask_List, transform = None):
|
||
# 資料預處理
|
||
if transform == None:
|
||
transform = transforms.Compose([
|
||
transforms.Resize((self.ImageSize, self.ImageSize))
|
||
])
|
||
elif transform == "Transform":
|
||
transform = transforms.Compose([
|
||
transforms.Resize((self.ImageSize, self.ImageSize)),
|
||
transforms.ToTensor()
|
||
])
|
||
elif transform == "Generator":
|
||
transform = "Generator"
|
||
|
||
# Create Dataset
|
||
list_dataset = ListDataset(Datas, Labels, Mask_List, transform)
|
||
return list_dataset
|
||
|
||
def Setting_SubsetRandomSampler_Content(self, SubDataSet):
|
||
# Calculate subset indices (example: using a fraction of the dataset)
|
||
dataset_size = len(SubDataSet)
|
||
subset_size = int(0.8 * dataset_size) # Use 80% of the dataset as an example
|
||
subset_indices = torch.randperm(dataset_size, generator=self.generator)[:subset_size]
|
||
|
||
return SubsetRandomSampler(subset_indices, generator=self.generator) |