# vit_branch.py import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from typing import Tuple, Optional class PatchEmbed(nn.Module): """將 feature map 切成 patches""" def __init__(self, in_chs: int = 728, embed_dim: int = 728, patch_size: int = 4): super().__init__() self.proj = nn.Conv2d(in_chs, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: B, C, H, W = x.shape x = self.proj(x) # (B, D, H/p, W/p) x = x.flatten(2).transpose(1, 2) # (B, N, D) x = self.norm(x) return x, (H, W) # vit_branch.py class PatchReconstruction(nn.Module): def __init__(self, embed_dim=728, out_chs=728, patch_size=4): super().__init__() self.patch_size = patch_size self.out_chs = out_chs self.proj = nn.Linear(embed_dim, patch_size * patch_size * out_chs) def forward(self, x, orig_size): B, N, _ = x.shape H, W = orig_size p = self.patch_size h = H // p w = W // p assert h * w == N, f"N mismatch: {h*w} vs {N}" x = self.proj(x) # (B, N, p*p*C) x = x.view(B, h, w, p*p, self.out_chs) # (B, h, w, p*p, C) x = rearrange( x, 'b h w (p1 p2) c -> b c (h p1) (w p2)', p1=p, p2=p ) return x class TransformerBlock(nn.Module): def __init__(self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, drop: float = 0., attn_drop: float = 0.): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention(dim, num_heads, dropout=attn_drop, bias=qkv_bias, batch_first=True) self.norm2 = nn.LayerNorm(dim) hidden_dim = int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(drop), nn.Linear(hidden_dim, dim), nn.Dropout(drop) ) def forward(self, x: torch.Tensor) -> torch.Tensor: y, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x)) x = x + y x = x + self.mlp(self.norm2(x)) return x class ViTBranch(nn.Module): """ ViT Branch: 接收 CNN feature map,輸出同尺寸增強特徵 用途:與 Xception Middle Flow 融合 """ def __init__( self, in_chs: int = 728, embed_dim: int = 728, patch_size: int = 4, depth: int = 3, num_heads: int = 8, mlp_ratio: float = 4.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): super().__init__() factory_kwargs = {'device': device, 'dtype': dtype} self.patch_embed = PatchEmbed(in_chs, embed_dim, patch_size) self.patch_recon = PatchReconstruction(embed_dim, in_chs, patch_size) # Learnable pos embed (will be resized dynamically) num_patches_approx = ((299 // 16) ** 2) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches_approx, embed_dim, **factory_kwargs)) nn.init.trunc_normal_(self.pos_embed, std=0.02) self.blocks = nn.ModuleList([ TransformerBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate ) for _ in range(depth) ]) self.norm = nn.LayerNorm(embed_dim) def _resize_pos_embed(self, pos_embed: torch.Tensor, target_hw: Tuple[int, int]) -> torch.Tensor: """動態調整位置編碼尺寸""" B = pos_embed.shape[0] sqrt_N = int(pos_embed.shape[1] ** 0.5) pos_embed = pos_embed.reshape(1, sqrt_N, sqrt_N, -1).permute(0, 3, 1, 2) pos_embed = F.interpolate(pos_embed, size=target_hw, mode='bilinear', align_corners=False) pos_embed = pos_embed.permute(0, 2, 3, 1).flatten(1, 2) return pos_embed.expand(B, -1, -1) def forward(self, x: torch.Tensor) -> torch.Tensor: B, C, H, W = x.shape x, orig_size = self.patch_embed(x) # (B, N, D), orig_size=(H,W) N = x.shape[1] # 動態調整 pos_embed target_grid = (H // self.patch_embed.proj.stride[0], W // self.patch_embed.proj.stride[1]) if self.pos_embed.shape[1] != N: pos_embed = self._resize_pos_embed(self.pos_embed, target_grid) else: pos_embed = self.pos_embed.expand(B, -1, -1) x = x + pos_embed for blk in self.blocks: x = blk(x) x = self.norm(x) x = self.patch_recon(x, orig_size) return x