Files
Stomach_Cancer_Pytorch/Training_Tools/PreProcess.py

164 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from torch.utils.data import Dataset, DataLoader, RandomSampler, WeightedRandomSampler, SubsetRandomSampler, Subset
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from PIL import Image
import torch
import numpy as np
import cv2
class ListDataset(Dataset):
def __init__(self, data_list, labels_list, Mask_List, transform):
self.data = data_list
self.labels = labels_list
self.Mask_Truth_List = Mask_List
self.transform = transform
self.roots = []
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
Image_Root = self.data[idx]
# Mask_Ground_Truth = None
# if self.Mask_Truth_List is not None:
# mask_path = self.Mask_Truth_List[idx]
# if mask_path is not None: # 確保掩碼路徑不為None
# try:
# Mask_Ground_Truth = Image.open(mask_path).convert("RGB")
# # 先不轉換為 tensor等待 transform 處理完後再轉換
# except Exception as e:
# print(e)
Split_Roots = Image_Root.split("/")
Split_Roots = Split_Roots[-1].split("\\")
File_Name = Split_Roots[-1]
classes = Split_Roots[-2]
try:
Images = Image.open(Image_Root).convert("RGB")
except Exception as e:
assert e is not None, f"Error loading image {Image_Root}: {e}"
if self.transform != "Generator":
Images = self.transform(Images)
# if self.Mask_Truth_List is not None and Mask_Ground_Truth is not None and not isinstance(Mask_Ground_Truth, torch.Tensor):
# Mask_Ground_Truth = self.transform(Mask_Ground_Truth)
# # 確保 Images 是 tensor
# if not isinstance(Images, torch.Tensor):
# Images = torch.tensor(np.array(Images))
# # 確保 Mask_Ground_Truth 是 tensor
# if self.Mask_Truth_List is not None and Mask_Ground_Truth is not None and not isinstance(Mask_Ground_Truth, torch.Tensor):
# Mask_Ground_Truth = torch.tensor(np.array(Mask_Ground_Truth))
Images = torch.tensor(np.array(Images))
label = self.labels[idx]
# if self.Mask_Truth_List is not None:
# # 如果掩碼為None創建一個與圖像相同大小的空掩碼
# if Mask_Ground_Truth is None:
# if isinstance(Images, torch.Tensor):
# # 創建與圖像相同大小的空掩碼張量
# Mask_Ground_Truth = torch.zeros_like(Images)
# else:
# # 如果圖像不是張量創建一個空的PIL圖像
# Mask_Ground_Truth = Image.new('RGB', Images.size, (0, 0, 0))
# if self.transform != "Generator":
# Mask_Ground_Truth = self.transform(Mask_Ground_Truth)
# return Images, Mask_Ground_Truth, label, File_Name, classes
# print(f"Dataset_Data: \n{sample}\n")
return Images, label, File_Name, classes
class Training_Precesses:
def __init__(self, ImageSize):
seed = 42 # Set an arbitrary integer as the seed
self.ImageSize = ImageSize
self.generator = torch.Generator()
self.generator.manual_seed(seed)
def Dataloader_Sampler(self, SubDataSet, Batch_Size, Sampler=True):
if Sampler:
Data_Loader = DataLoader(
dataset=SubDataSet,
batch_size=Batch_Size,
num_workers=0,
pin_memory=True,
sampler=self.Setting_WeightedRandomSampler_Content(SubDataSet)
)
else:
Data_Loader = DataLoader(
dataset=SubDataSet,
batch_size=Batch_Size,
num_workers=0,
pin_memory=True
)
return Data_Loader
def Setting_WeightedRandomSampler_Content(self, SubDataSet):
# Check if SubDataSet is a Subset or a full dataset
if isinstance(SubDataSet, Subset):
# Get the underlying dataset and subset indices
base_dataset = SubDataSet.dataset
subset_indices = SubDataSet.indices
# Extract labels for the subset
labels = [base_dataset.labels[i] for i in subset_indices]
else:
# Assume SubDataSet is a ListDataset or similar
labels = SubDataSet.labels
# Convert labels to class indices if they are one-hot encoded
labels = np.array(labels)
if labels.ndim > 1: # If one-hot encoded
labels = np.argmax(labels, axis=1)
# 確保標籤是整數類型
try:
# 嘗試將標籤轉換為整數
labels = labels.astype(np.int64)
except ValueError:
# 如果標籤是字符串,先將其映射到整數
unique_labels = np.unique(labels)
label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
labels = np.array([label_to_idx[label] for label in labels])
# Count occurrences of each class
class_counts = np.bincount(labels)
class_weights = 1.0 / class_counts # Inverse frequency as weight
sample_weights = class_weights[labels] # Assign weight to each sample
return WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
def Setting_RandomSampler_Content(self, Dataset):
return RandomSampler(Dataset, generator = self.generator)
def Setting_DataSet(self, Datas, Labels, Mask_List, transform = None):
# 資料預處理
if transform == None:
transform = transforms.Compose([
transforms.Resize((self.ImageSize, self.ImageSize))
])
elif transform == "Transform":
transform = transforms.Compose([
transforms.Resize((self.ImageSize, self.ImageSize)),
transforms.ToTensor()
])
elif transform == "Generator":
transform = "Generator"
# Create Dataset
list_dataset = ListDataset(Datas, Labels, Mask_List, transform)
return list_dataset
def Setting_SubsetRandomSampler_Content(self, SubDataSet):
# Calculate subset indices (example: using a fraction of the dataset)
dataset_size = len(SubDataSet)
subset_size = int(0.8 * dataset_size) # Use 80% of the dataset as an example
subset_indices = torch.randperm(dataset_size, generator=self.generator)[:subset_size]
return SubsetRandomSampler(subset_indices, generator=self.generator)