166 lines
9.7 KiB
Python
166 lines
9.7 KiB
Python
from model_data_processing.processing import make_label_list
|
||
from _validation.ValidationTheEnterData import validation_the_enter_data
|
||
from Load_process.file_processing import Process_File
|
||
from Load_process.LoadData import Load_Data_Prepare, Load_Data_Tools
|
||
from Training_Tools.PreProcess import Training_Precesses
|
||
|
||
from torchvision import transforms
|
||
|
||
class Image_generator():
|
||
'''製作資料強化'''
|
||
def __init__(self, Training_Root, Generator_Root, Labels, Image_Size, Class_Count) -> None:
|
||
self._validation = validation_the_enter_data()
|
||
self.stop = 0
|
||
self.Labels = Labels
|
||
self.Training_Root = Training_Root
|
||
self.Generator_Root = Generator_Root
|
||
self.Image_Size = Image_Size
|
||
self.Class_Count = Class_Count
|
||
pass
|
||
|
||
def Processing_Main(self):
|
||
data_size = 2712
|
||
File = Process_File()
|
||
Prepare = Load_Data_Prepare()
|
||
Load_Tool = Load_Data_Tools()
|
||
|
||
if not File.Judge_File_Exist(self.Generator_Root): # 檔案若不存在
|
||
# 確定我要多少個List
|
||
Prepare.Set_Data_Content([], len(self.Labels))
|
||
|
||
# 製作讀檔字典並回傳檔案路徑
|
||
Prepare.Set_Label_List(self.Labels)
|
||
Prepare.Set_Data_Dictionary(Prepare.Get_Label_List(), Prepare.Get_Data_Content(), len(self.Labels))
|
||
Original_Dict_Data_Root = Prepare.Get_Data_Dict()
|
||
get_all_original_image_data = Load_Tool.get_data_root(self.Training_Root, Original_Dict_Data_Root, Prepare.Get_Label_List())
|
||
|
||
# 儲存資料強化後資料
|
||
# 製作標準資料增強
|
||
'''
|
||
這裡我想要做的是依照paper上的資料強化IMAGE DATA COLLECTION AND IMPLEMENTATION OF DEEP LEARNING-BASED MODEL IN DETECTING MONKEYPOX DISEASE USING MODIFIED VGG16
|
||
產生出資料強化後的影像
|
||
'''
|
||
for i in range(1, 5, 1):
|
||
print(f"\nAugmentation {i} Generator image")
|
||
data_size = self.get_processing_Augmentation(get_all_original_image_data, i, data_size)
|
||
self.stop += data_size
|
||
else: # 若檔案存在
|
||
print("standard data and myself data are exist\n")
|
||
|
||
def get_processing_Augmentation(self, original_image_root : dict, Augment_choose, data_size):
|
||
Prepaer = Load_Data_Prepare()
|
||
|
||
self.get_data_roots = original_image_root # 要處理的影像路徑
|
||
Prepaer.Set_Label_List(self.Labels)
|
||
data_size = self.Generator_main(self.Generator_Root, Augment_choose, data_size) # 執行
|
||
return data_size
|
||
|
||
def Generator_main(self, save_roots, stardand, data_size):
|
||
'''
|
||
Parameter:
|
||
labels = 取得資料的標籤
|
||
save_root = 要儲存資料的地方
|
||
strardand = 要使用哪種Image Augmentation
|
||
'''
|
||
File = Process_File()
|
||
tool = Training_Precesses(self.Image_Size)
|
||
Classes = []
|
||
Transform = self.Generator_Content(stardand)
|
||
|
||
for label in self.Labels: # 分別對所有類別進行資料強化
|
||
Image_Roots = self.get_data_roots[label]
|
||
save_root = File.Make_Save_Root(label, save_roots) # 合併路徑
|
||
|
||
Classes = make_label_list(len(Image_Roots), "1")
|
||
Training_Dataset = tool.Setting_DataSet(Image_Roots, Classes, "Generator")
|
||
Training_DataLoader = tool.Dataloader_Sampler(Training_Dataset, 1, False)
|
||
|
||
if File.JudgeRoot_MakeDir(save_root): # 判斷要存的資料夾存不存在,不存在則創立
|
||
print("The file is exist.This Script is not creating new fold.")
|
||
|
||
for i in range(1, int(self.Class_Count / len(Image_Roots)) + 1, 1):
|
||
for batch_idx, (images, labels, File_Name, File_Classes) in enumerate(Training_DataLoader):
|
||
for j, img in enumerate(images):
|
||
# if i == self.stop:
|
||
# break
|
||
|
||
img = img.permute(2, 0, 1)
|
||
img = Transform(img)
|
||
|
||
# 轉換為 NumPy 陣列並從 BGR 轉為 RGB
|
||
img_np = img.numpy().transpose(1, 2, 0) # 轉回 HWC 格式
|
||
|
||
img_pil = transforms.ToPILImage()(img_np)
|
||
File.Save_PIL_File("image_" + label + str(data_size) + ".png", save_root, img_pil) # 存檔
|
||
data_size += 1
|
||
|
||
return data_size
|
||
|
||
def Generator_Content(self, judge): # 影像資料增強
|
||
'''
|
||
## Parameters:
|
||
<b>featurewise_center</b> : 布爾值。將輸入數據的均值設置為0,逐特徵進行。<br/>
|
||
<b>samplewise_center</b> : 布爾值。將每個樣本的均值設置為0。<br/>
|
||
<b>featurewise_std_normalization</b> : Boolean. 布爾值。將輸入除以數據標準差,逐特徵進行。<br/>
|
||
<b>samplewise_std_normalization</b> : 布爾值。將每個輸入除以其標準差。<br/>
|
||
<b>zca_epsilon</b> : ZCA 白化的epsilon 值,默認為1e-6。<br/>
|
||
<b>zca_whitening</b> : 布爾值。是否應用ZCA 白化。<br/>
|
||
<b>rotation_range</b> : 整數。隨機旋轉的度數範圍。<br/>
|
||
<b>width_shift_range</b> : 浮點數、一維數組或整數<br/>
|
||
float: 如果<1,則是除以總寬度的值,或者如果>=1,則為像素值。
|
||
1-D 數組: 數組中的隨機元素。
|
||
int: 來自間隔 (-width_shift_range, +width_shift_range) 之間的整數個像素。
|
||
width_shift_range=2時,可能值是整數[-1, 0, +1],與 width_shift_range=[-1, 0, +1] 相同;而 width_shift_range=1.0 時,可能值是 [-1.0, +1.0) 之間的浮點數。
|
||
<b>height_shift_range</b> : 浮點數、一維數組或整數<br/>
|
||
float: 如果<1,則是除以總寬度的值,或者如果>=1,則為像素值。
|
||
1-D array-like: 數組中的隨機元素。
|
||
int: 來自間隔 (-height_shift_range, +height_shift_range) 之間的整數個像素。
|
||
height_shift_range=2時,可能值是整數[-1, 0, +1],與 height_shift_range=[-1, 0, +1] 相同;而 height_shift_range=1.0 時,可能值是 [-1.0, +1.0) 之間的浮點數。
|
||
<b>shear_range</b> : 浮點數。剪切強度(以弧度逆時針方向剪切角度)。<br/>
|
||
<b>zoom_range</b> : 浮點數或[lower, upper]。隨機縮放範圍。如果是浮點數,[lower, upper] = [1-zoom_range, 1+zoom_range]。<br/>
|
||
<b>channel_shift_range</b> : 浮點數。隨機通道轉換的範圍。<br/>
|
||
<b>fill_mode</b> : {"constant", "nearest", "reflect" or "wrap"} 之一。默認為'nearest'。輸入邊界以外的點根據給定的模式填充:<br/>
|
||
'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
|
||
'nearest': aaaaaaaa|abcd|dddddddd
|
||
'reflect': abcddcba|abcd|dcbaabcd
|
||
'wrap': abcdabcd|abcd|abcdabcd
|
||
<b>cval</b> : 浮點數或整數。用於邊界之外的點的值,當 fill_mode = "constant" 時。<br/>
|
||
<b>horizontal_flip</b> : 布爾值。隨機水平翻轉。<br/>
|
||
<b>vertical_flip</b> : 布爾值。隨機垂直翻轉。<br/>
|
||
<b>rescale</b> : 重縮放因子。默認為None。如果是None 或0,不進行縮放,否則將數據乘以所提供的值(在應用任何其他轉換之前)。<br/>
|
||
<b>preprocessing_function</b> : 應用於每個輸入的函數。這個函數會在任何其他改變之前運行。這個函數需要一個參數:一張圖像(秩為3 的Numpy 張量),並且應該輸出一個同尺寸的Numpy 張量。<br/>
|
||
<b>data_format</b> : 圖像數據格式,{"channels_first", "channels_last"} 之一。"channels_last" 模式表示圖像輸入尺寸應該為(samples, height, width, channels),"channels_first" 模式表示輸入尺寸應該為(samples, channels, height, width)。默認為在Keras 配置文件 ~/.keras/keras.json 中的 image_data_format 值。如果你從未設置它,那它就是"channels_last"。<br/>
|
||
<b>validation_split</b> : 浮點數。Float. 保留用於驗證的圖像的比例(嚴格在0和1之間)。<br/>
|
||
<b>dtype</b> : 生成數組使用的數據類型。<br/>
|
||
'''
|
||
if judge == 1:
|
||
return transforms.Compose([
|
||
transforms.RandomRotation(30),
|
||
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
|
||
transforms.RandomHorizontalFlip(),
|
||
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
||
])
|
||
elif judge == 2:
|
||
return transforms.Compose([
|
||
transforms.RandomRotation(180),
|
||
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
|
||
transforms.RandomHorizontalFlip(),
|
||
transforms.RandomVerticalFlip(),
|
||
])
|
||
elif judge == 3:
|
||
return transforms.Compose([
|
||
transforms.RandomRotation(45),
|
||
transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
|
||
transforms.RandomAffine(degrees=20, shear=0.2),
|
||
transforms.ColorJitter(brightness=0.2, contrast=0.2),
|
||
transforms.RandomHorizontalFlip(),
|
||
])
|
||
elif judge == 4:
|
||
return transforms.Compose([
|
||
transforms.RandomRotation(50),
|
||
transforms.RandomResizedCrop(224, scale=(0.75, 1.0)),
|
||
transforms.RandomAffine(degrees=30, shear=0.25),
|
||
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),
|
||
transforms.RandomHorizontalFlip(),
|
||
transforms.RandomVerticalFlip(),
|
||
]) |