diff --git a/draw_tools/Grad_cam.py b/draw_tools/Grad_cam.py index fc57df0..f22288f 100644 --- a/draw_tools/Grad_cam.py +++ b/draw_tools/Grad_cam.py @@ -11,13 +11,15 @@ from Load_process.file_processing import Process_File class GradCAM: def __init__(self, model, target_layer): self.model = model - self.target_layer = target_layer + # 若為 DataParallel,取出真正的 backbone + self.backbone = model.module if isinstance(model, nn.DataParallel) else model + self.target_layer = self._resolve_target_layer(target_layer) self.activations = None self.gradients = None self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model.to(self.device) # Ensure model is on the correct device - # Register hooks + # Register hooks on resolved module self.target_layer.register_forward_hook(self.save_activations) # Use full backward hook if available to avoid deprecation issues if hasattr(self.target_layer, "register_full_backward_hook"): @@ -25,6 +27,34 @@ class GradCAM: else: self.target_layer.register_backward_hook(self.save_gradients) + def _resolve_target_layer(self, target): + # 支援 nn.Module / nn.Parameter / 字串路徑 + if isinstance(target, nn.Module): + return target + if isinstance(target, torch.nn.Parameter): + # 先在 backbone 參數中找到該 Parameter 的名稱 + for name, param in self.backbone.named_parameters(): + if param is target: + # 去掉 .weight / .bias,取得父模組名稱 + module_name = name.rsplit('.', 1)[0] + # 先嘗試用 named_modules 快速匹配 + for mod_name, mod in self.backbone.named_modules(): + if mod_name == module_name: + return mod + # 回退為屬性遍歷 + obj = self.backbone + for attr in module_name.split('.'): + obj = getattr(obj, attr) + return obj + raise AttributeError("Target parameter not found in model parameters.") + if isinstance(target, str): + # 允許使用字串路徑指定層,例如 'conv4.pointwise' + obj = self.backbone + for attr in target.split('.'): + obj = getattr(obj, attr) + return obj + raise TypeError("target_layer must be nn.Module, nn.Parameter, or str") + def Processing_Main(self, Test_Dataloader, File_Path): File = Process_File() for batch_idx, (images, labels, File_Name, File_Classes) in enumerate(Test_Dataloader): diff --git a/draw_tools/__pycache__/Grad_cam.cpython-313.pyc b/draw_tools/__pycache__/Grad_cam.cpython-313.pyc index f7bb8dc..0e0ee4e 100644 Binary files a/draw_tools/__pycache__/Grad_cam.cpython-313.pyc and b/draw_tools/__pycache__/Grad_cam.cpython-313.pyc differ diff --git a/experiments/Models/Xception_Model_Modification.py b/experiments/Models/Xception_Model_Modification.py index af15832..9873b03 100644 --- a/experiments/Models/Xception_Model_Modification.py +++ b/experiments/Models/Xception_Model_Modification.py @@ -1,152 +1,379 @@ +""" +Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) + +@author: tstandley +Adapted by cadene + +Creates an Xception Model as defined in: + +Francois Chollet +Xception: Deep Learning with Depthwise Separable Convolutions +https://arxiv.org/pdf/1610.02357.pdf + +This weights ported from the Keras implementation. Achieves the following performance on the validation set: + +Loss:0.9173 Prec@1:78.892 Prec@5:94.292 + +REMEMBER to set your image size to 3x299x299 for both test and validation + +normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]) + +The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 +""" import torch.nn as nn import torch.nn.functional as F -import torch - from utils.Stomach_Config import Model_Config +from einops import rearrange + +from timm.layers import create_classifier + class SeparableConv2d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True): - super(SeparableConv2d, self).__init__() - self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, - padding=padding, groups=in_channels, bias=bias) - self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, - padding=0, bias=bias) - + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + + self.conv1 = nn.Conv2d( + in_channels, + in_channels, + kernel_size, + stride, + padding, + dilation, + groups=in_channels, + bias=False, + **dd, + ) + self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False, **dd) + def forward(self, x): - x = self.depthwise(x) + x = self.conv1(x) x = self.pointwise(x) return x -class EntryFlow(nn.Module): - def __init__(self, in_channels=3): - super(EntryFlow, self).__init__() - self.conv1 = nn.Conv2d(in_channels, 32, 3, stride=2, padding=1, bias=False, dilation = 2) - self.bn1 = nn.BatchNorm2d(32) - self.conv2 = nn.Conv2d(32, 64, 3, padding=1, bias=False, dilation = 2) - self.bn2 = nn.BatchNorm2d(64) - - self.conv3_residual = nn.Sequential( - SeparableConv2d(64, 128, 3, padding=1), - nn.BatchNorm2d(128), - nn.ReLU(inplace=False), # 修改這裡 - SeparableConv2d(128, 128, 3, padding=1), - nn.BatchNorm2d(128), - nn.MaxPool2d(3, stride=2, padding=1) - ) - self.conv3_shortcut = nn.Conv2d(64, 128, 1, stride=2, bias=False) - self.bn3 = nn.BatchNorm2d(128) - - self.conv4_residual = nn.Sequential( - nn.ReLU(inplace=False), # 修改這裡 - SeparableConv2d(128, 256, 3, padding=1), - nn.BatchNorm2d(256), - nn.ReLU(inplace=False), # 修改這裡 - SeparableConv2d(256, 256, 3, padding=1), - nn.BatchNorm2d(256), - nn.MaxPool2d(3, stride=2, padding=1) - ) - self.conv4_shortcut = nn.Conv2d(128, 256, 1, stride=2, bias=False) - self.bn4 = nn.BatchNorm2d(256) - - self.conv5_residual = nn.Sequential( - nn.ReLU(inplace=False), # 修改這裡 - SeparableConv2d(256, 728, 3, padding=1), - nn.BatchNorm2d(728), - nn.ReLU(inplace=False), # 修改這裡 - SeparableConv2d(728, 728, 3, padding=1), - nn.BatchNorm2d(728), - nn.MaxPool2d(3, stride=2, padding=1) - ) - self.conv5_shortcut = nn.Conv2d(256, 728, 1, stride=2, bias=False) - self.bn5 = nn.BatchNorm2d(728) - - def forward(self, x): - x = F.relu(self.bn1(self.conv1(x))) - x = F.relu(self.bn2(self.conv2(x))) - - residual = self.conv3_residual(x) - shortcut = self.conv3_shortcut(x) - x = F.relu(self.bn3(residual + shortcut)) - - residual = self.conv4_residual(x) - shortcut = self.conv4_shortcut(x) - x = F.relu(self.bn4(residual + shortcut)) - - residual = self.conv5_residual(x) - shortcut = self.conv5_shortcut(x) - x = F.relu(self.bn5(residual + shortcut)) + +class Block(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + reps: int, + strides: int = 1, + start_with_relu: bool = True, + grow_first: bool = True, + device=None, + dtype=None, + ): + dd = {'device': device, 'dtype': dtype} + super().__init__() + + if out_channels != in_channels or strides != 1: + self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False, **dd) + self.skipbn = nn.BatchNorm2d(out_channels, **dd) + else: + self.skip = None + + rep = [] + for i in range(reps): + if grow_first: + inc = in_channels if i == 0 else out_channels + outc = out_channels + else: + inc = in_channels + outc = in_channels if i < (reps - 1) else out_channels + rep.append(nn.ReLU(inplace=True)) + rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1, **dd)) + rep.append(nn.BatchNorm2d(outc, **dd)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3, strides, 1)) + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x += skip return x -class MiddleFlow(nn.Module): - def __init__(self): - super(MiddleFlow, self).__init__() - self.conv_residual = nn.Sequential( - nn.ReLU(inplace=False), # 修改這裡 - SeparableConv2d(728, 728, 3, padding=1), - nn.BatchNorm2d(728), - nn.ReLU(inplace=False), # 修改這裡 - SeparableConv2d(728, 728, 3, padding=1), - nn.BatchNorm2d(728), - nn.ReLU(inplace=False), # 修改這裡 - SeparableConv2d(728, 728, 3, padding=1), - nn.BatchNorm2d(728) - ) - - def forward(self, x): - return self.conv_residual(x) + x - -class ExitFlow(nn.Module): - def __init__(self): - super(ExitFlow, self).__init__() - self.conv1_residual = nn.Sequential( - nn.ReLU(inplace=False), # 修改這裡 - SeparableConv2d(728, 1024, 3, padding=1), - nn.BatchNorm2d(1024), - nn.ReLU(inplace=False), # 修改這裡 - SeparableConv2d(1024, 1024, 3, padding=1), - nn.BatchNorm2d(1024), - nn.MaxPool2d(3, stride=2, padding=1) - ) - self.conv1_shortcut = nn.Conv2d(728, 1024, 1, stride=2, bias=False) - self.bn1 = nn.BatchNorm2d(1024) - - self.conv2 = nn.Sequential( - SeparableConv2d(1024, 1536, 3, padding=1), - nn.BatchNorm2d(1536), - nn.ReLU(inplace=False), # 修改這裡 - SeparableConv2d(1536, 2048, 3, padding=1), - nn.BatchNorm2d(2048), - nn.ReLU(inplace=False) # 修改這裡 - ) - self.avgpool = nn.AdaptiveAvgPool1d(Model_Config["GPA Output Nodes"]) - self.Hidden = nn.Linear(Model_Config["GPA Output Nodes"], Model_Config["Linear Hidden Nodes"]) - self.fc = nn.Linear(Model_Config["Linear Hidden Nodes"], Model_Config["Output Linear Nodes"]) - self.dropout = nn.Dropout(Model_Config["Dropout Rate"]) - - def forward(self, x): - residual = self.conv1_residual(x) - shortcut = self.conv1_shortcut(x) - x = F.relu(self.bn1(residual + shortcut)) - - x = self.conv2(x) - x = x.view(x.size(0), -1) - x = self.avgpool(x) - x = F.relu(self.Hidden(x)) - x = self.dropout(x) - x = self.fc(x) - return x class Xception(nn.Module): - def __init__(self): - super(Xception, self).__init__() - self.entry_flow = EntryFlow(in_channels=3) # 默认输入通道为3 - self.middle_flow = nn.Sequential(*[MiddleFlow() for _ in range(8)]) - self.exit_flow = ExitFlow() - - def forward(self, x): - # 正常的前向傳播 - x = self.entry_flow(x) - x = self.middle_flow(x) - x = self.exit_flow(x) + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + + def __init__( + self, + num_classes: int = 1000, + in_chans: int = 3, + drop_rate: float = 0., + global_pool: str = 'avg', + device=None, + dtype=None, + ): + """ Constructor + Args: + num_classes: number of classes + """ + super().__init__() + dd = {'device': device, 'dtype': dtype} + self.drop_rate = drop_rate + self.global_pool = global_pool + self.num_classes = num_classes + self.num_features = self.head_hidden_size = 2048 + + self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False, **dd) + self.bn1 = nn.BatchNorm2d(32, **dd) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, 3, bias=False, **dd) + self.bn2 = nn.BatchNorm2d(64, **dd) + self.act2 = nn.ReLU(inplace=True) + + self.block1 = Block(64, 128, 2, 2, start_with_relu=False, **dd) + self.block2 = Block(128, 256, 2, 2, **dd) + self.block3 = Block(256, 728, 2, 2, **dd) + + self.block4 = Block(728, 728, 3, 1, **dd) + self.block5 = Block(728, 728, 3, 1, **dd) + self.block6 = Block(728, 728, 3, 1, **dd) + self.block7 = Block(728, 728, 3, 1, **dd) + + self.block8 = Block(728, 728, 3, 1, **dd) + self.block9 = Block(728, 728, 3, 1, **dd) + self.block10 = Block(728, 728, 3, 1, **dd) + self.block11 = Block(728, 728, 3, 1, **dd) + + self.block12 = Block(728, 1024, 2, 2, grow_first=False, **dd) + + self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1, **dd) + self.bn3 = nn.BatchNorm2d(1536, **dd) + self.act3 = nn.ReLU(inplace=True) + + self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1, **dd) + self.bn4 = nn.BatchNorm2d(self.num_features, **dd) + self.act4 = nn.ReLU(inplace=True) + self.feature_info = [ + dict(num_chs=64, reduction=2, module='act2'), + dict(num_chs=128, reduction=4, module='block2.rep.0'), + dict(num_chs=256, reduction=8, module='block3.rep.0'), + dict(num_chs=728, reduction=16, module='block12.rep.0'), + dict(num_chs=2048, reduction=32, module='act4'), + ] + + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool, **dd) + self.hidden_layer = nn.Linear(2048, Model_Config["Linear Hidden Nodes"]) # 隱藏層,輸入大小取決於 Xception 的輸出大小 + self.output_layer = nn.Linear(Model_Config["Linear Hidden Nodes"], Model_Config["Output Linear Nodes"]) # 輸出層,依據分類數目設定 - return x \ No newline at end of file + # 激活函數與 dropout + self.relu = nn.ReLU() + self.dropout = nn.Dropout(Model_Config["Dropout Rate"]) + + # #------- init weights -------- + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): + self.num_classes = num_classes + self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.act2(x) + + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + + x = self.conv3(x) + x = self.bn3(x) + x = self.act3(x) + + x = self.conv4(x) + x = self.bn4(x) + x = self.act4(x) + return x + + def forward_head(self, x): + x = self.global_pool(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + + x = self.dropout(x) # Dropout + x = self.hidden_layer(x) + x = self.relu(x) # 隱藏層 + ReLU + x = self.output_layer(x) # 輸出層 + return x + +# class Residual(nn.Module): +# def __init__(self, fn): +# super().__init__() +# self.fn = fn + +# def forward(self, x, **kwargs): +# return self.fn(x, **kwargs) + x + +# class PreNorm(nn.Module): +# def __init__(self, dim, fn): +# super().__init__() +# self.norm = nn.LayerNorm(dim) +# self.fn = fn + +# def forward(self, x, **kwargs): +# return self.fn(self.norm(x), **kwargs) + +# class FeedForward(nn.Module): +# def __init__(self, dim, hidden_dim, dropout=0.0): +# super().__init__() +# self.net = nn.Sequential( +# nn.Linear(dim, hidden_dim), +# nn.GELU(), +# nn.Dropout(dropout), +# nn.Linear(hidden_dim, dim), +# nn.Dropout(dropout) +# ) + +# def forward(self, x): +# return self.net(x) + +# class Attention(nn.Module): +# def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): +# super().__init__() +# inner_dim = dim_head * heads +# self.heads = heads +# self.scale = dim_head ** -0.5 + +# self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) +# self.to_out = nn.Sequential( +# nn.Linear(inner_dim, dim), +# nn.Dropout(dropout) +# ) + +# def forward(self, x, mask=None): +# b, n, _, h = *x.shape, self.heads +# qkv = self.to_qkv(x).chunk(3, dim=-1) +# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) + +# dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale + +# if mask is not None: +# mask = F.pad(mask.flatten(1), (1, 0), value=True) +# assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions' +# mask = mask[:, None, :] * mask[:, :, None] +# dots.masked_fill_(~mask, float('-inf')) +# del mask + +# attn = dots.softmax(dim=-1) + +# out = torch.einsum('bhij,bhjd->bhid', attn, v) +# out = rearrange(out, 'b h n d -> b n (h d)') +# out = self.to_out(out) +# return out + +# class Transformer(nn.Module): +# def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): +# super().__init__() +# self.layers = nn.ModuleList([]) +# for _ in range(depth): +# self.layers.append(nn.ModuleList([ +# Residual(PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout))), +# Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))) +# ])) + +# def forward(self, x, mask=None): +# for attn, ff in self.layers: +# x = attn(x, mask=mask) +# x = ff(x) +# return x + +# class ViT(nn.Module): +# def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.): +# super().__init__() +# image_height, image_width = image_size if isinstance(image_size, tuple) else (image_size, image_size) +# patch_height, patch_width = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size) + +# assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' + +# num_patches = (image_height // patch_height) * (image_width // patch_width) +# patch_dim = channels * patch_height * patch_width +# assert pool in {'cls', 'mean'}, 'pool type must be either cls (class token) or mean (mean pooling)' + +# self.to_patch_embedding = nn.Sequential( +# nn.Conv2d(channels, dim, kernel_size=(patch_height, patch_width), stride=(patch_height, patch_width)), +# nn.Flatten(2), +# nn.LayerNorm(dim), +# ) + +# self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) +# self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) +# self.dropout = nn.Dropout(emb_dropout) + +# self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + +# self.pool = pool +# self.to_latent = nn.Identity() + +# self.mlp_head = nn.Sequential( +# nn.LayerNorm(dim), +# nn.Linear(dim, num_classes) +# ) + +# def forward(self, img, mask=None): +# x = self.to_patch_embedding(img) +# x = x.permute(0, 2, 1) +# b, n, _ = x.shape + +# cls_tokens = self.cls_token.expand(b, -1, -1) +# x = torch.cat((cls_tokens, x), dim=1) +# x += self.pos_embedding[:, :(n + 1)] +# x = self.dropout(x) + +# x = self.transformer(x, mask) + +# x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] + +# x = self.to_latent(x) +# return self.mlp_head(x) diff --git a/experiments/Models/__pycache__/Xception_Model_Modification.cpython-313.pyc b/experiments/Models/__pycache__/Xception_Model_Modification.cpython-313.pyc index b2a6649..ca96391 100644 Binary files a/experiments/Models/__pycache__/Xception_Model_Modification.cpython-313.pyc and b/experiments/Models/__pycache__/Xception_Model_Modification.cpython-313.pyc differ diff --git a/experiments/Models/__pycache__/pytorch_Model.cpython-313.pyc b/experiments/Models/__pycache__/pytorch_Model.cpython-313.pyc index 1965f7c..6f5745f 100644 Binary files a/experiments/Models/__pycache__/pytorch_Model.cpython-313.pyc and b/experiments/Models/__pycache__/pytorch_Model.cpython-313.pyc differ diff --git a/experiments/Models/pytorch_Model.py b/experiments/Models/pytorch_Model.py index 4408bc7..ce32920 100644 --- a/experiments/Models/pytorch_Model.py +++ b/experiments/Models/pytorch_Model.py @@ -7,12 +7,13 @@ class ModifiedXception(nn.Module): super(ModifiedXception, self).__init__() # 加載 Xception 預訓練模型,去掉最後一層 (fc 層) - self.base_model = timm.create_model(Model_Config["Model Name"], pretrained=True, num_classes = 3) - self.base_model.fc = nn.Identity() # 移除原來的 fully connected 層 + self.base_model = timm.create_model(Model_Config["Model Name"], pretrained=True, num_classes = 0) + # self.base_model.fc = nn.Identity() # 移除原來的 fully connected 層 # 新增全局平均池化層、隱藏層和輸出層 + in_features = self.base_model.num_features # 自動取得 2048 self.global_avg_pool = nn.AdaptiveAvgPool1d(Model_Config["GPA Output Nodes"]) # 全局平均池化 - self.hidden_layer = nn.Linear(Model_Config["GPA Output Nodes"], Model_Config["Linear Hidden Nodes"]) # 隱藏層,輸入大小取決於 Xception 的輸出大小 + self.hidden_layer = nn.Linear(in_features, Model_Config["Linear Hidden Nodes"]) # 隱藏層,輸入大小取決於 Xception 的輸出大小 self.output_layer = nn.Linear(Model_Config["Linear Hidden Nodes"], Model_Config["Output Linear Nodes"]) # 輸出層,依據分類數目設定 # 激活函數與 dropout @@ -21,7 +22,7 @@ class ModifiedXception(nn.Module): def forward(self, x): x = self.base_model(x) # Xception 主體 - x = self.global_avg_pool(x) # 全局平均池化 + # x = self.global_avg_pool(x) # 全局平均池化 x = self.dropout(x) # Dropout x = self.hidden_layer(x) x = self.relu(x) # 隱藏層 + ReLU diff --git a/experiments/Training/Xception_Identification_Test.py b/experiments/Training/Xception_Identification_Test.py index da8ac59..48b05a7 100644 --- a/experiments/Training/Xception_Identification_Test.py +++ b/experiments/Training/Xception_Identification_Test.py @@ -14,7 +14,7 @@ from Model_Loss.Loss import Entropy_Loss from merge_class.merge import merge from draw_tools.Saliency_Map import SaliencyMap from utils.Stomach_Config import Training_Config, Loading_Config, Save_Result_File_Config -# from experiments.Models.Xception_Model_Modification import Xception +from experiments.Models.Xception_Model_Modification import Xception from experiments.Models.pytorch_Model import ModifiedXception from Load_process.LoadData import Loding_Data_Root from Training_Tools.PreProcess import Training_Precesses @@ -36,8 +36,6 @@ import os class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Precesses): def __init__(self, Experiment_Name, Best_Model_Save_Root): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - self.Model = self.Construct_Identification_Model_CUDA() # 模型變數 self.train_subset = None # Training Dataset 的子集 self.val_subset = None # Validation Dataset 的子集 self.train_loader = None # Training DataLoader 的讀檔器 @@ -51,7 +49,6 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre self.Experiment_Name = Experiment_Name self.Number_Of_Classes = len(Loading_Config["Training_Labels"]) self.Best_Model_Save_Root = Best_Model_Save_Root - self.Optimizer = optim.SGD(self.Model.parameters(), lr=0.045, momentum=0.9, weight_decay = Training_Config["weight_decay"]) # 初始化多個繼承物件 Training_Precesses.__init__(self, Training_Config["Image_Size"]) @@ -77,6 +74,8 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre # K-Fold loop kf = KFold(n_splits=5, shuffle=True, random_state=42) for fold, (train_idx, val_idx) in enumerate(kf.split(range(len(training_dataset)))): # K-Fold 交叉驗證迴圈 + Model = self.Construct_Identification_Model_CUDA() # 模型變數 + Optimizer = optim.SGD(Model.parameters(), lr=0.045, momentum=0.9, weight_decay = Training_Config["weight_decay"]) print(f"\nStarting Fold {fold + 1}/5") # Create training and validation subsets for this fold @@ -90,7 +89,7 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre self.train_loader = Image_Enhance_Training_Data(Training_Loader = self.train_loader, Save_Root = f"{Loading_Config['Image enhance processing save root']}/{str(fold)}") # 模型訓練與驗證 - model_path, Train_Losses, Validation_losses, Train_Accuracies, Validation_accuracies, best_val_loss = self.Training_And_Validation(fold) + model_path, Train_Losses, Validation_losses, Train_Accuracies, Validation_accuracies, best_val_loss = self.Training_And_Validation(Model, Optimizer, fold) # Store fold results all_fold_train_losses.append(Train_Losses) @@ -113,7 +112,7 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre plot_history(Losses, Accuracies, f"{Save_Result_File_Config['Identification_Plot_Image']}/{self.Experiment_Name}", f"train-{str(fold)}") # 將訓練結果化成圖,並將化出來的圖丟出去儲存 # 驗證結果 - True_Label, Predict_Label, loss, accuracy, precision, recall, f1 = self.Evaluate_Model(self.Model, Test_Dataloader, fold, model_path) + True_Label, Predict_Label, loss, accuracy, precision, recall, f1 = self.Evaluate_Model(Model, Test_Dataloader, fold, model_path) # 紀錄該次訓練結果 Calculate_Process.Append_numbers(loss, accuracy, precision, recall, f1) @@ -121,7 +120,7 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre print(self.record_everyTime_test_result(loss, accuracy, precision, recall, f1, fold, self.Experiment_Name)) # 紀錄當前訓練完之後的預測結果,並輸出成csv檔 # 使用識別模型進行各類別評估 - Calculate_Tool = self.Evaluate_Per_Class_Metrics(self.Model, Test_Dataloader, Loading_Config["Training_Labels"], Calculate_Tool, model_path) + Calculate_Tool = self.Evaluate_Per_Class_Metrics(Model, Test_Dataloader, Loading_Config["Training_Labels"], Calculate_Tool, model_path) if best_val_loss < Best_Validation_Loss: Best_Validation_Loss = best_val_loss @@ -151,11 +150,11 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre # 返回最後一個fold的模型路徑和平均指標 return Best_Model_Path - def Training_And_Validation(self, Fold): + def Training_And_Validation(self, Model, Optimizer, Fold): ''' 模型主要的訓練與驗證部分 ''' - model_path, early_stopping, scheduler = call_back(self.Best_Model_Save_Root, f"fold{Fold}", self.Optimizer) + model_path, early_stopping, scheduler = call_back(self.Best_Model_Save_Root, f"fold{Fold}", Optimizer) # Lists to store metrics for this fold train_losses = [] @@ -165,7 +164,7 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre # Epoch loop for epoch in range(self.Epoch): - self.Model.train() # Start training + Model.train() # Start training Training_Loss = 0.0 All_Predict_List, All_Label_List = [], [] @@ -175,11 +174,13 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre for inputs, labels, File_Name, File_Classes in epoch_iterator: Total_Losses, Training_Loss, All_Predict_List, All_Label_List, Predict_Indexs, Truth_Indexs = self.Model_Branch( + Model=Model, Input_Images=inputs, Labels=labels, All_Predict_List=All_Predict_List, All_Label_List=All_Label_List, running_loss=Training_Loss, + Optimizer=Optimizer, status="Training" ) self.Calculate_Progress_And_Timing(inputs, Predict_Indexs, Truth_Indexs, self.train_subset, Total_Losses, epoch_iterator, Start_Time) @@ -187,7 +188,7 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre train_losses, train_accuracies, Training_Loss, Train_accuracy = self.Calculate_Average_Scores(self.train_loader, Training_Loss, All_Predict_List, All_Label_List, train_losses, train_accuracies) # Validation step - self.Model.eval() + Model.eval() val_loss = 0.0 all_val_preds = [] all_val_labels = [] @@ -197,11 +198,13 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre with torch.no_grad(): for inputs, labels, File_Name, File_Classes in epoch_iterator: Validation_Total_Loss, val_loss, all_val_preds, all_val_labels, Predict_Indexs, Truth_Indexs = self.Model_Branch( + Model=Model, Input_Images=inputs, Labels=labels, All_Predict_List=all_val_preds, All_Label_List=all_val_labels, running_loss=val_loss, + Optimizer=Optimizer, status="Validation" ) self.Calculate_Progress_And_Timing(inputs, Predict_Indexs, Truth_Indexs, self.val_subset, Validation_Total_Loss, epoch_iterator, start_Validation_time) @@ -210,7 +213,7 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre print(f"Traini Loss: {Training_Loss:.4f}, Accuracy: {Train_accuracy:0.2f}, Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:0.2f}\n") if epoch % 5 == 0: - Grad = GradCAM(self.Model, self.TargetLayer) + Grad = GradCAM(Model, self.TargetLayer) Grad.Processing_Main(self.val_loader, f"{Save_Result_File_Config['GradCAM_Validation_Image_Save_Root']}/{self.Experiment_Name}/fold-{str(Fold)}/{str(epoch)}") # # 創建SaliencyMap實例 @@ -219,7 +222,7 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre # saliency_map.Processing_Main(self.val_loader, f"../Result/Saliency_Image/Validation/Saliency_Image({str(datetime.date.today())})/{self.Experiment_Name}/fold-{str(Fold)}/") # Early stopping - early_stopping(val_loss, self.Model, model_path) + early_stopping(val_loss, Model, model_path) best_val_loss = early_stopping.best_loss if early_stopping.early_stop: print(f"Early stopping triggered in Fold {Fold + 1} at epoch {epoch + 1}") @@ -233,20 +236,25 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre def Construct_Identification_Model_CUDA(self): # 从Model_Config中获取输出节点数量 - Model = ModifiedXception() + # Model = ModifiedXception() + Model = Xception(num_classes=0) print(summary(Model)) for name, parameters in Model.named_parameters(): print(f"Layer Name: {name}, Parameters: {parameters.size()}") - - self.TargetLayer = Model.base_model.conv4.pointwise # 注释掉summary调用,避免Mask参数问题 # 直接打印模型结构 - # print(f"Model structure: {Model}") + print(f"Model structure: {Model}") - # # 打印模型参数和梯度状态 - # for name, parameters in Model.named_parameters(): - # print(f"Layer Name: {name}, Parameters: {parameters.size()}, requires_grad: {parameters.requires_grad}") + # 打印模型参数和梯度状态 + for name, parameters in Model.named_parameters(): + print(f"Layer Name: {name}, Parameters: {parameters.size()}, requires_grad: {parameters.requires_grad}") + + self.TargetLayer = Model.conv4.pointwise.weight + # self.TargetLayer = Model.base_model.conv4.pointwise + + # if name == "exit_flow.conv2.3.pointwise.bias": + # self.TargetLayer = Model.exit_flow.conv2 return self.Convert_Model_To_CUDA(Model) @@ -256,21 +264,21 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre return model - def Model_Branch(self, Input_Images, Labels, All_Predict_List : list, All_Label_List : list, running_loss, status): + def Model_Branch(self, Model, Input_Images, Labels, All_Predict_List : list, All_Label_List : list, running_loss, Optimizer, status): if status == "Training": - self.Optimizer.zero_grad() # 清零梯度,防止梯度累積 + Optimizer.zero_grad() # 清零梯度,防止梯度累積 # 將張量移到設備上,但保持梯度計算能力 Input_Images, Labels = Input_Images.to(self.device), Labels.to(self.device) - Predicts_Data = self.Model(Input_Images) + Predicts_Data = Model(Input_Images) # 計算損失時使用原始的 Predict 張量和 Labels 張量(保持梯度) Losses = self.Losses(Predicts_Data, Labels) if status == "Training": Losses.backward() - self.Optimizer.step() + Optimizer.step() running_loss += Losses.item() @@ -311,11 +319,13 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre with torch.no_grad(): for images, labels, File_Name, File_Classes in Test_Dataloader: Total_Loss, Running_Loss, Predict_Label, True_Label, Output_Indexs, Truth_Index = self.Model_Branch( + Model=cnn_model, Input_Images=images, Labels=labels, All_Predict_List=Predict_Label, All_Label_List=True_Label, running_loss=0, + Optimizer=None, status="Testing" ) @@ -333,9 +343,8 @@ class Xception_Identification_Block_Training_Step(Loding_Data_Root, Training_Pre matrix = confusion_matrix(True_Label, Predict_Label) draw_heatmap(matrix, f"{Save_Result_File_Config['Identification_Marix_Image']}/{self.Experiment_Name}/Identification_Test_Marix_Image", f"confusion_matrix", index) # 呼叫畫出confusion matrix的function - TargetLayer = self.Model.base_model.conv4.pointwise - Grad = GradCAM(self.Model, TargetLayer) - Grad.Processing_Main(Test_Dataloader, f"{Save_Result_File_Config['GradCAM_Test_Image_Save_Root']}/{self.Experiment_Name}/fold-{str(fold)}/") + Grad = GradCAM(cnn_model, self.TargetLayer) + Grad.Processing_Main(Test_Dataloader, f"{Save_Result_File_Config['GradCAM_Test_Image_Save_Root']}/{self.Experiment_Name}/fold-{str(index)}/") return True_Label, Predict_Label, loss, accuracy, precision, recall, f1 diff --git a/experiments/Training/__pycache__/Xception_Identification_Test.cpython-313.pyc b/experiments/Training/__pycache__/Xception_Identification_Test.cpython-313.pyc index bd6455e..9b1a43d 100644 Binary files a/experiments/Training/__pycache__/Xception_Identification_Test.cpython-313.pyc and b/experiments/Training/__pycache__/Xception_Identification_Test.cpython-313.pyc differ diff --git a/utils/__pycache__/Stomach_Config.cpython-313.pyc b/utils/__pycache__/Stomach_Config.cpython-313.pyc index ee2ee6a..74cd711 100644 Binary files a/utils/__pycache__/Stomach_Config.cpython-313.pyc and b/utils/__pycache__/Stomach_Config.cpython-313.pyc differ