import torch import torch.nn as nn import math class CIOULoss(nn.Module): """ Complete Intersection over Union (CIOU) Loss 適用於目標檢測中的邊界框回歸任務 CIOU Loss 考慮了三個幾何因子: 1. 重疊面積 (Overlap area) 2. 中心點距離 (Central point distance) 3. 長寬比一致性 (Aspect ratio consistency) """ def __init__(self, eps=1e-7): super(CIOULoss, self).__init__() self.eps = eps def forward(self, pred_boxes, target_boxes): """ 計算 CIOU Loss Args: pred_boxes: 預測邊界框 [N, 4] (x1, y1, x2, y2) 或 [N, 4] (cx, cy, w, h) 或分割掩碼 [B, 1, H, W] target_boxes: 真實邊界框 [N, 4] (x1, y1, x2, y2) 或 [N, 4] (cx, cy, w, h) 或分割掩碼 [B, 1, H, W] Returns: CIOU loss value """ # 檢查輸入是否為分割掩碼格式 if len(pred_boxes.shape) == 4 and pred_boxes.shape[1] == 1: # 將分割掩碼轉換為邊界框格式 pred_boxes = self._mask_to_boxes(pred_boxes) target_boxes = self._mask_to_boxes(target_boxes) # 如果無法從掩碼中提取有效的邊界框,則返回一個小的損失值 if pred_boxes is None or target_boxes is None: return torch.tensor(0.01, device=pred_boxes.device if pred_boxes is not None else target_boxes.device) # 確保輸入為浮點數 pred_boxes = pred_boxes.float() target_boxes = target_boxes.float() # 檢查邊界框維度是否正確 if pred_boxes.dim() == 1: # 如果是單個邊界框,擴展為批次大小為1的張量 pred_boxes = pred_boxes.unsqueeze(0) if target_boxes.dim() == 1: target_boxes = target_boxes.unsqueeze(0) # 確保邊界框有4個坐標 if pred_boxes.shape[1] != 4 or target_boxes.shape[1] != 4: # 如果坐標數量不正確,返回一個小的損失值 return torch.tensor(0.01, device=pred_boxes.device) # 如果輸入是 (cx, cy, w, h) 格式,轉換為 (x1, y1, x2, y2) if self._is_center_format(pred_boxes, target_boxes): pred_boxes = self._center_to_corner(pred_boxes) target_boxes = self._center_to_corner(target_boxes) # 計算交集區域 intersection = self._calculate_intersection(pred_boxes, target_boxes) # 計算各自的面積 pred_area = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1]) target_area = (target_boxes[:, 2] - target_boxes[:, 0]) * (target_boxes[:, 3] - target_boxes[:, 1]) # 計算聯集面積 union = pred_area + target_area - intersection + self.eps # 計算 IoU iou = intersection / union # 計算最小外接矩形 enclose_x1 = torch.min(pred_boxes[:, 0], target_boxes[:, 0]) enclose_y1 = torch.min(pred_boxes[:, 1], target_boxes[:, 1]) enclose_x2 = torch.max(pred_boxes[:, 2], target_boxes[:, 2]) enclose_y2 = torch.max(pred_boxes[:, 3], target_boxes[:, 3]) # 計算最小外接矩形的對角線距離平方 enclose_diagonal_sq = (enclose_x2 - enclose_x1) ** 2 + (enclose_y2 - enclose_y1) ** 2 + self.eps # 計算兩個邊界框中心點之間的距離平方 pred_center_x = (pred_boxes[:, 0] + pred_boxes[:, 2]) / 2 pred_center_y = (pred_boxes[:, 1] + pred_boxes[:, 3]) / 2 target_center_x = (target_boxes[:, 0] + target_boxes[:, 2]) / 2 target_center_y = (target_boxes[:, 1] + target_boxes[:, 3]) / 2 center_distance_sq = (pred_center_x - target_center_x) ** 2 + (pred_center_y - target_center_y) ** 2 # 計算長寬比一致性項 pred_w = pred_boxes[:, 2] - pred_boxes[:, 0] pred_h = pred_boxes[:, 3] - pred_boxes[:, 1] target_w = target_boxes[:, 2] - target_boxes[:, 0] target_h = target_boxes[:, 3] - target_boxes[:, 1] # 避免除零 pred_w = torch.clamp(pred_w, min=self.eps) pred_h = torch.clamp(pred_h, min=self.eps) target_w = torch.clamp(target_w, min=self.eps) target_h = torch.clamp(target_h, min=self.eps) v = (4 / (math.pi ** 2)) * torch.pow(torch.atan(target_w / target_h) - torch.atan(pred_w / pred_h), 2) # 計算 alpha 參數 with torch.no_grad(): alpha = v / (1 - iou + v + self.eps) # 計算 CIOU ciou = iou - (center_distance_sq / enclose_diagonal_sq) - alpha * v # 返回 CIOU Loss (1 - CIOU) ciou_loss = 1 - ciou return ciou_loss.mean() def _is_center_format(self, pred_boxes, target_boxes): """ 判斷輸入格式是否為中心點格式 (cx, cy, w, h) 簡單的啟發式判斷:如果第三、四列的值都是正數且相對較小,可能是寬高 """ # 這裡使用簡單的判斷邏輯,實際使用時可能需要更精確的判斷 return False # 預設假設輸入為 (x1, y1, x2, y2) 格式 def _center_to_corner(self, boxes): """ 將中心點格式 (cx, cy, w, h) 轉換為角點格式 (x1, y1, x2, y2) """ cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] x1 = cx - w / 2 y1 = cy - h / 2 x2 = cx + w / 2 y2 = cy + h / 2 return torch.stack([x1, y1, x2, y2], dim=1) def _mask_to_boxes(self, masks): """ 將分割掩碼轉換為邊界框格式 [N, 4] (x1, y1, x2, y2) Args: masks: 分割掩碼 [B, 1, H, W] Returns: boxes: 邊界框 [B, 4] (x1, y1, x2, y2) """ batch_size = masks.size(0) device = masks.device # 將掩碼轉換為二值掩碼 binary_masks = (torch.sigmoid(masks) > 0.5).float() # 初始化邊界框張量 boxes = torch.zeros(batch_size, 4, device=device) # 對每個批次處理 for b in range(batch_size): mask = binary_masks[b, 0] # [H, W] # 找出非零元素的索引 non_zero_indices = torch.nonzero(mask, as_tuple=True) # 如果掩碼中沒有非零元素,則使用默認的小邊界框 if len(non_zero_indices[0]) == 0: # 返回一個默認的小邊界框 boxes[b] = torch.tensor([0, 0, 1, 1], device=device) continue # 計算邊界框坐標 y_min = torch.min(non_zero_indices[0]) y_max = torch.max(non_zero_indices[0]) x_min = torch.min(non_zero_indices[1]) x_max = torch.max(non_zero_indices[1]) # 存儲邊界框 [x1, y1, x2, y2] boxes[b] = torch.tensor([x_min, y_min, x_max, y_max], device=device) return boxes def _calculate_intersection(self, pred_boxes, target_boxes): """ 計算兩個邊界框的交集面積 """ x1 = torch.max(pred_boxes[:, 0], target_boxes[:, 0]) y1 = torch.max(pred_boxes[:, 1], target_boxes[:, 1]) x2 = torch.min(pred_boxes[:, 2], target_boxes[:, 2]) y2 = torch.min(pred_boxes[:, 3], target_boxes[:, 3]) # 計算交集的寬度和高度 intersection_w = torch.clamp(x2 - x1, min=0) intersection_h = torch.clamp(y2 - y1, min=0) return intersection_w * intersection_h class DIoULoss(nn.Module): """ Distance Intersection over Union (DIoU) Loss CIOU 的簡化版本,只考慮重疊面積和中心點距離 """ def __init__(self, eps=1e-7): super(DIoULoss, self).__init__() self.eps = eps def forward(self, pred_boxes, target_boxes): # 確保輸入為浮點數 pred_boxes = pred_boxes.float() target_boxes = target_boxes.float() # 計算交集區域 intersection = self._calculate_intersection(pred_boxes, target_boxes) # 計算各自的面積 pred_area = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1]) target_area = (target_boxes[:, 2] - target_boxes[:, 0]) * (target_boxes[:, 3] - target_boxes[:, 1]) # 計算聯集面積 union = pred_area + target_area - intersection + self.eps # 計算 IoU iou = intersection / union # 計算最小外接矩形的對角線距離平方 enclose_x1 = torch.min(pred_boxes[:, 0], target_boxes[:, 0]) enclose_y1 = torch.min(pred_boxes[:, 1], target_boxes[:, 1]) enclose_x2 = torch.max(pred_boxes[:, 2], target_boxes[:, 2]) enclose_y2 = torch.max(pred_boxes[:, 3], target_boxes[:, 3]) enclose_diagonal_sq = (enclose_x2 - enclose_x1) ** 2 + (enclose_y2 - enclose_y1) ** 2 + self.eps # 計算中心點距離平方 pred_center_x = (pred_boxes[:, 0] + pred_boxes[:, 2]) / 2 pred_center_y = (pred_boxes[:, 1] + pred_boxes[:, 3]) / 2 target_center_x = (target_boxes[:, 0] + target_boxes[:, 2]) / 2 target_center_y = (target_boxes[:, 1] + target_boxes[:, 3]) / 2 center_distance_sq = (pred_center_x - target_center_x) ** 2 + (pred_center_y - target_center_y) ** 2 # 計算 DIoU diou = iou - (center_distance_sq / enclose_diagonal_sq) # 返回 DIoU Loss diou_loss = 1 - diou return diou_loss.mean() def _calculate_intersection(self, pred_boxes, target_boxes): """計算交集面積""" x1 = torch.max(pred_boxes[:, 0], target_boxes[:, 0]) y1 = torch.max(pred_boxes[:, 1], target_boxes[:, 1]) x2 = torch.min(pred_boxes[:, 2], target_boxes[:, 2]) y2 = torch.min(pred_boxes[:, 3], target_boxes[:, 3]) intersection_w = torch.clamp(x2 - x1, min=0) intersection_h = torch.clamp(y2 - y1, min=0) return intersection_w * intersection_h class GIoULoss(nn.Module): """ Generalized Intersection over Union (GIoU) Loss IoU 的泛化版本,考慮了最小外接矩形 """ def __init__(self, eps=1e-7): super(GIoULoss, self).__init__() self.eps = eps def forward(self, pred_boxes, target_boxes): # 確保輸入為浮點數 pred_boxes = pred_boxes.float() target_boxes = target_boxes.float() # 計算交集 intersection = self._calculate_intersection(pred_boxes, target_boxes) # 計算各自面積 pred_area = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1]) target_area = (target_boxes[:, 2] - target_boxes[:, 0]) * (target_boxes[:, 3] - target_boxes[:, 1]) # 計算聯集 union = pred_area + target_area - intersection + self.eps # 計算 IoU iou = intersection / union # 計算最小外接矩形面積 enclose_x1 = torch.min(pred_boxes[:, 0], target_boxes[:, 0]) enclose_y1 = torch.min(pred_boxes[:, 1], target_boxes[:, 1]) enclose_x2 = torch.max(pred_boxes[:, 2], target_boxes[:, 2]) enclose_y2 = torch.max(pred_boxes[:, 3], target_boxes[:, 3]) enclose_area = (enclose_x2 - enclose_x1) * (enclose_y2 - enclose_y1) + self.eps # 計算 GIoU giou = iou - (enclose_area - union) / enclose_area # 返回 GIoU Loss giou_loss = 1 - giou return giou_loss.mean() def _calculate_intersection(self, pred_boxes, target_boxes): """計算交集面積""" x1 = torch.max(pred_boxes[:, 0], target_boxes[:, 0]) y1 = torch.max(pred_boxes[:, 1], target_boxes[:, 1]) x2 = torch.min(pred_boxes[:, 2], target_boxes[:, 2]) y2 = torch.min(pred_boxes[:, 3], target_boxes[:, 3]) intersection_w = torch.clamp(x2 - x1, min=0) intersection_h = torch.clamp(y2 - y1, min=0) return intersection_w * intersection_h