Files
Stomach_Cancer_Pytorch/Image_Process/Image_Generator.py

166 lines
9.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from model_data_processing.processing import make_label_list
from _validation.ValidationTheEnterData import validation_the_enter_data
from Load_process.file_processing import Process_File
from Load_process.LoadData import Load_Data_Prepare, Load_Data_Tools
from Training_Tools.PreProcess import Training_Precesses
from torchvision import transforms
class Image_generator():
'''製作資料強化'''
def __init__(self, Training_Root, Generator_Root, Labels, Image_Size, Class_Count) -> None:
self._validation = validation_the_enter_data()
self.stop = 0
self.Labels = Labels
self.Training_Root = Training_Root
self.Generator_Root = Generator_Root
self.Image_Size = Image_Size
self.Class_Count = Class_Count
pass
def Processing_Main(self):
data_size = 2712
File = Process_File()
Prepare = Load_Data_Prepare()
Load_Tool = Load_Data_Tools()
if not File.Judge_File_Exist(self.Generator_Root): # 檔案若不存在
# 確定我要多少個List
Prepare.Set_Data_Content([], len(self.Labels))
# 製作讀檔字典並回傳檔案路徑
Prepare.Set_Label_List(self.Labels)
Prepare.Set_Data_Dictionary(Prepare.Get_Label_List(), Prepare.Get_Data_Content(), len(self.Labels))
Original_Dict_Data_Root = Prepare.Get_Data_Dict()
get_all_original_image_data = Load_Tool.get_data_root(self.Training_Root, Original_Dict_Data_Root, Prepare.Get_Label_List())
# 儲存資料強化後資料
# 製作標準資料增強
'''
這裡我想要做的是依照paper上的資料強化IMAGE DATA COLLECTION AND IMPLEMENTATION OF DEEP LEARNING-BASED MODEL IN DETECTING MONKEYPOX DISEASE USING MODIFIED VGG16
產生出資料強化後的影像
'''
for i in range(1, 5, 1):
print(f"\nAugmentation {i} Generator image")
data_size = self.get_processing_Augmentation(get_all_original_image_data, i, data_size)
self.stop += data_size
else: # 若檔案存在
print("standard data and myself data are exist\n")
def get_processing_Augmentation(self, original_image_root : dict, Augment_choose, data_size):
Prepaer = Load_Data_Prepare()
self.get_data_roots = original_image_root # 要處理的影像路徑
Prepaer.Set_Label_List(self.Labels)
data_size = self.Generator_main(self.Generator_Root, Augment_choose, data_size) # 執行
return data_size
def Generator_main(self, save_roots, stardand, data_size):
'''
Parameter:
labels = 取得資料的標籤
save_root = 要儲存資料的地方
strardand = 要使用哪種Image Augmentation
'''
File = Process_File()
tool = Training_Precesses(self.Image_Size)
Classes = []
Transform = self.Generator_Content(stardand)
for label in self.Labels: # 分別對所有類別進行資料強化
Image_Roots = self.get_data_roots[label]
save_root = File.Make_Save_Root(label, save_roots) # 合併路徑
Classes = make_label_list(len(Image_Roots), "1")
Training_Dataset = tool.Setting_DataSet(Image_Roots, Classes, "Generator")
Training_DataLoader = tool.Dataloader_Sampler(Training_Dataset, 1, False)
if File.JudgeRoot_MakeDir(save_root): # 判斷要存的資料夾存不存在,不存在則創立
print("The file is exist.This Script is not creating new fold.")
for i in range(1, int(self.Class_Count / len(Image_Roots)) + 1, 1):
for batch_idx, (images, labels, File_Name, File_Classes) in enumerate(Training_DataLoader):
for j, img in enumerate(images):
# if i == self.stop:
# break
img = img.permute(2, 0, 1)
img = Transform(img)
# 轉換為 NumPy 陣列並從 BGR 轉為 RGB
img_np = img.numpy().transpose(1, 2, 0) # 轉回 HWC 格式
img_pil = transforms.ToPILImage()(img_np)
File.Save_PIL_File("image_" + label + str(data_size) + ".png", save_root, img_pil) # 存檔
data_size += 1
return data_size
def Generator_Content(self, judge): # 影像資料增強
'''
## Parameters:
<b>featurewise_center</b> : 布爾值。將輸入數據的均值設置為0逐特徵進行。<br/>
<b>samplewise_center</b> : 布爾值。將每個樣本的均值設置為0。<br/>
<b>featurewise_std_normalization</b> : Boolean. 布爾值。將輸入除以數據標準差,逐特徵進行。<br/>
<b>samplewise_std_normalization</b> : 布爾值。將每個輸入除以其標準差。<br/>
<b>zca_epsilon</b> : ZCA 白化的epsilon 值默認為1e-6。<br/>
<b>zca_whitening</b> : 布爾值。是否應用ZCA 白化。<br/>
<b>rotation_range</b> : 整數。隨機旋轉的度數範圍。<br/>
<b>width_shift_range</b> : 浮點數、一維數組或整數<br/>
float: 如果<1則是除以總寬度的值或者如果>=1則為像素值。
1-D 數組: 數組中的隨機元素。
int: 來自間隔 (-width_shift_range, +width_shift_range) 之間的整數個像素。
width_shift_range=2時可能值是整數[-1, 0, +1],與 width_shift_range=[-1, 0, +1] 相同;而 width_shift_range=1.0 時,可能值是 [-1.0, +1.0) 之間的浮點數。
<b>height_shift_range</b> : 浮點數、一維數組或整數<br/>
float: 如果<1則是除以總寬度的值或者如果>=1則為像素值。
1-D array-like: 數組中的隨機元素。
int: 來自間隔 (-height_shift_range, +height_shift_range) 之間的整數個像素。
height_shift_range=2時可能值是整數[-1, 0, +1],與 height_shift_range=[-1, 0, +1] 相同;而 height_shift_range=1.0 時,可能值是 [-1.0, +1.0) 之間的浮點數。
<b>shear_range</b> : 浮點數。剪切強度(以弧度逆時針方向剪切角度)。<br/>
<b>zoom_range</b> : 浮點數或[lower, upper]。隨機縮放範圍。如果是浮點數,[lower, upper] = [1-zoom_range, 1+zoom_range]。<br/>
<b>channel_shift_range</b> : 浮點數。隨機通道轉換的範圍。<br/>
<b>fill_mode</b> : {"constant", "nearest", "reflect" or "wrap"} 之一。默認為'nearest'。輸入邊界以外的點根據給定的模式填充:<br/>
'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
'nearest': aaaaaaaa|abcd|dddddddd
'reflect': abcddcba|abcd|dcbaabcd
'wrap': abcdabcd|abcd|abcdabcd
<b>cval</b> : 浮點數或整數。用於邊界之外的點的值,當 fill_mode = "constant" 時。<br/>
<b>horizontal_flip</b> : 布爾值。隨機水平翻轉。<br/>
<b>vertical_flip</b> : 布爾值。隨機垂直翻轉。<br/>
<b>rescale</b> : 重縮放因子。默認為None。如果是None 或0不進行縮放否則將數據乘以所提供的值在應用任何其他轉換之前。<br/>
<b>preprocessing_function</b> : 應用於每個輸入的函數。這個函數會在任何其他改變之前運行。這個函數需要一個參數一張圖像秩為3 的Numpy 張量並且應該輸出一個同尺寸的Numpy 張量。<br/>
<b>data_format</b> : 圖像數據格式,{"channels_first", "channels_last"} 之一。"channels_last" 模式表示圖像輸入尺寸應該為(samples, height, width, channels)"channels_first" 模式表示輸入尺寸應該為(samples, channels, height, width)。默認為在Keras 配置文件 ~/.keras/keras.json 中的 image_data_format 值。如果你從未設置它,那它就是"channels_last"。<br/>
<b>validation_split</b> : 浮點數。Float. 保留用於驗證的圖像的比例嚴格在0和1之間。<br/>
<b>dtype</b> : 生成數組使用的數據類型。<br/>
'''
if judge == 1:
return transforms.Compose([
transforms.RandomRotation(30),
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
])
elif judge == 2:
return transforms.Compose([
transforms.RandomRotation(180),
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
])
elif judge == 3:
return transforms.Compose([
transforms.RandomRotation(45),
transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
transforms.RandomAffine(degrees=20, shear=0.2),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.RandomHorizontalFlip(),
])
elif judge == 4:
return transforms.Compose([
transforms.RandomRotation(50),
transforms.RandomResizedCrop(224, scale=(0.75, 1.0)),
transforms.RandomAffine(degrees=30, shear=0.25),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
])