Files
Stomach_Cancer_Pytorch/Training_Tools/PreProcess.py

117 lines
4.1 KiB
Python

from torch.utils.data import Dataset, DataLoader, RandomSampler, WeightedRandomSampler, SubsetRandomSampler, Subset
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from PIL import Image
import torch
import numpy as np
import cv2
class ListDataset(Dataset):
def __init__(self, data_list, labels_list, transform):
self.data = data_list
self.labels = labels_list
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
Image_Root = self.data[idx]
try:
with open(Image_Root, 'rb') as file:
Images = Image.open(file).convert("RGB")
# Image = cv2.imread(Image_Root, cv2.IMREAD_COLOR) # 讀檔(彩色)
# Image = cv2.cvtColor(Image, cv2.COLOR_BGR2RGB)
except Exception as e:
print(e)
if self.transform is not "Generator":
Images = self.transform(Images)
Images = torch.tensor(np.array(Images))
label = self.labels[idx]
# print(f"Dataset_Data: \n{sample}\n")
return Images, label
class Training_Precesses:
def __init__(self, ImageSize):
seed = 42 # Set an arbitrary integer as the seed
self.ImageSize = ImageSize
self.generator = torch.Generator()
self.generator.manual_seed(seed)
def Dataloader_Sampler(self, SubDataSet, Batch_Size, Sampler=True):
if Sampler:
# Data_Loader = DataLoader(
# dataset=SubDataSet,
# batch_size=Batch_Size,
# num_workers=0,
# pin_memory=True,
# sampler=self.Setting_RandomSampler_Content(SubDataSet)
# )
Data_Loader = DataLoader(
dataset=SubDataSet,
batch_size=Batch_Size,
num_workers=0,
pin_memory=True,
sampler=self.Setting_RandomSampler_Content(SubDataSet)
)
else:
Data_Loader = DataLoader(
dataset=SubDataSet,
batch_size=Batch_Size,
num_workers=0,
pin_memory=True,
shuffle=True
)
return Data_Loader
def Setting_WeightedRandomSampler_Content(self, SubDataSet):
# Check if SubDataSet is a Subset or a full dataset
if isinstance(SubDataSet, Subset):
# Get the underlying dataset and subset indices
base_dataset = SubDataSet.dataset
subset_indices = SubDataSet.indices
# Extract labels for the subset
labels = [base_dataset.labels[i] for i in subset_indices]
else:
# Assume SubDataSet is a ListDataset or similar
labels = SubDataSet.labels
# Convert labels to class indices if they are one-hot encoded
labels = np.array(labels)
if labels.ndim > 1: # If one-hot encoded
labels = np.argmax(labels, axis=1)
# Count occurrences of each class
class_counts = np.bincount(labels)
class_weights = 1.0 / class_counts # Inverse frequency as weight
sample_weights = class_weights[labels] # Assign weight to each sample
return WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
def Setting_RandomSampler_Content(self, Dataset):
return RandomSampler(Dataset, generator = self.generator)
def Setting_DataSet(self, Datas, Labels, transform = None):
# 資料預處理
if transform == None:
transform = transforms.Compose([
transforms.Resize((256, 256))
])
elif transform == "Transform":
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
elif transform == "Generator":
transform = "Generator"
# Create Dataset
list_dataset = ListDataset(Datas, Labels , transform)
return list_dataset