315 lines
12 KiB
Python
315 lines
12 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import math
|
||
|
||
class CIOULoss(nn.Module):
|
||
"""
|
||
Complete Intersection over Union (CIOU) Loss
|
||
適用於目標檢測中的邊界框回歸任務
|
||
|
||
CIOU Loss 考慮了三個幾何因子:
|
||
1. 重疊面積 (Overlap area)
|
||
2. 中心點距離 (Central point distance)
|
||
3. 長寬比一致性 (Aspect ratio consistency)
|
||
"""
|
||
|
||
def __init__(self, eps=1e-7):
|
||
super(CIOULoss, self).__init__()
|
||
self.eps = eps
|
||
|
||
def forward(self, pred_boxes, target_boxes):
|
||
"""
|
||
計算 CIOU Loss
|
||
|
||
Args:
|
||
pred_boxes: 預測邊界框 [N, 4] (x1, y1, x2, y2) 或 [N, 4] (cx, cy, w, h) 或分割掩碼 [B, 1, H, W]
|
||
target_boxes: 真實邊界框 [N, 4] (x1, y1, x2, y2) 或 [N, 4] (cx, cy, w, h) 或分割掩碼 [B, 1, H, W]
|
||
|
||
Returns:
|
||
CIOU loss value
|
||
"""
|
||
# 檢查輸入是否為分割掩碼格式
|
||
if len(pred_boxes.shape) == 4 and pred_boxes.shape[1] == 1:
|
||
# 將分割掩碼轉換為邊界框格式
|
||
pred_boxes = self._mask_to_boxes(pred_boxes)
|
||
target_boxes = self._mask_to_boxes(target_boxes)
|
||
|
||
# 如果無法從掩碼中提取有效的邊界框,則返回一個小的損失值
|
||
if pred_boxes is None or target_boxes is None:
|
||
return torch.tensor(0.01, device=pred_boxes.device if pred_boxes is not None else target_boxes.device)
|
||
|
||
# 確保輸入為浮點數
|
||
pred_boxes = pred_boxes.float()
|
||
target_boxes = target_boxes.float()
|
||
|
||
# 檢查邊界框維度是否正確
|
||
if pred_boxes.dim() == 1:
|
||
# 如果是單個邊界框,擴展為批次大小為1的張量
|
||
pred_boxes = pred_boxes.unsqueeze(0)
|
||
if target_boxes.dim() == 1:
|
||
target_boxes = target_boxes.unsqueeze(0)
|
||
|
||
# 確保邊界框有4個坐標
|
||
if pred_boxes.shape[1] != 4 or target_boxes.shape[1] != 4:
|
||
# 如果坐標數量不正確,返回一個小的損失值
|
||
return torch.tensor(0.01, device=pred_boxes.device)
|
||
|
||
# 如果輸入是 (cx, cy, w, h) 格式,轉換為 (x1, y1, x2, y2)
|
||
if self._is_center_format(pred_boxes, target_boxes):
|
||
pred_boxes = self._center_to_corner(pred_boxes)
|
||
target_boxes = self._center_to_corner(target_boxes)
|
||
|
||
# 計算交集區域
|
||
intersection = self._calculate_intersection(pred_boxes, target_boxes)
|
||
|
||
# 計算各自的面積
|
||
pred_area = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1])
|
||
target_area = (target_boxes[:, 2] - target_boxes[:, 0]) * (target_boxes[:, 3] - target_boxes[:, 1])
|
||
|
||
# 計算聯集面積
|
||
union = pred_area + target_area - intersection + self.eps
|
||
|
||
# 計算 IoU
|
||
iou = intersection / union
|
||
|
||
# 計算最小外接矩形
|
||
enclose_x1 = torch.min(pred_boxes[:, 0], target_boxes[:, 0])
|
||
enclose_y1 = torch.min(pred_boxes[:, 1], target_boxes[:, 1])
|
||
enclose_x2 = torch.max(pred_boxes[:, 2], target_boxes[:, 2])
|
||
enclose_y2 = torch.max(pred_boxes[:, 3], target_boxes[:, 3])
|
||
|
||
# 計算最小外接矩形的對角線距離平方
|
||
enclose_diagonal_sq = (enclose_x2 - enclose_x1) ** 2 + (enclose_y2 - enclose_y1) ** 2 + self.eps
|
||
|
||
# 計算兩個邊界框中心點之間的距離平方
|
||
pred_center_x = (pred_boxes[:, 0] + pred_boxes[:, 2]) / 2
|
||
pred_center_y = (pred_boxes[:, 1] + pred_boxes[:, 3]) / 2
|
||
target_center_x = (target_boxes[:, 0] + target_boxes[:, 2]) / 2
|
||
target_center_y = (target_boxes[:, 1] + target_boxes[:, 3]) / 2
|
||
|
||
center_distance_sq = (pred_center_x - target_center_x) ** 2 + (pred_center_y - target_center_y) ** 2
|
||
|
||
# 計算長寬比一致性項
|
||
pred_w = pred_boxes[:, 2] - pred_boxes[:, 0]
|
||
pred_h = pred_boxes[:, 3] - pred_boxes[:, 1]
|
||
target_w = target_boxes[:, 2] - target_boxes[:, 0]
|
||
target_h = target_boxes[:, 3] - target_boxes[:, 1]
|
||
|
||
# 避免除零
|
||
pred_w = torch.clamp(pred_w, min=self.eps)
|
||
pred_h = torch.clamp(pred_h, min=self.eps)
|
||
target_w = torch.clamp(target_w, min=self.eps)
|
||
target_h = torch.clamp(target_h, min=self.eps)
|
||
|
||
v = (4 / (math.pi ** 2)) * torch.pow(torch.atan(target_w / target_h) - torch.atan(pred_w / pred_h), 2)
|
||
|
||
# 計算 alpha 參數
|
||
with torch.no_grad():
|
||
alpha = v / (1 - iou + v + self.eps)
|
||
|
||
# 計算 CIOU
|
||
ciou = iou - (center_distance_sq / enclose_diagonal_sq) - alpha * v
|
||
|
||
# 返回 CIOU Loss (1 - CIOU)
|
||
ciou_loss = 1 - ciou
|
||
|
||
return ciou_loss.mean()
|
||
|
||
def _is_center_format(self, pred_boxes, target_boxes):
|
||
"""
|
||
判斷輸入格式是否為中心點格式 (cx, cy, w, h)
|
||
簡單的啟發式判斷:如果第三、四列的值都是正數且相對較小,可能是寬高
|
||
"""
|
||
# 這裡使用簡單的判斷邏輯,實際使用時可能需要更精確的判斷
|
||
return False # 預設假設輸入為 (x1, y1, x2, y2) 格式
|
||
|
||
def _center_to_corner(self, boxes):
|
||
"""
|
||
將中心點格式 (cx, cy, w, h) 轉換為角點格式 (x1, y1, x2, y2)
|
||
"""
|
||
cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
|
||
x1 = cx - w / 2
|
||
y1 = cy - h / 2
|
||
x2 = cx + w / 2
|
||
y2 = cy + h / 2
|
||
return torch.stack([x1, y1, x2, y2], dim=1)
|
||
|
||
def _mask_to_boxes(self, masks):
|
||
"""
|
||
將分割掩碼轉換為邊界框格式 [N, 4] (x1, y1, x2, y2)
|
||
|
||
Args:
|
||
masks: 分割掩碼 [B, 1, H, W]
|
||
|
||
Returns:
|
||
boxes: 邊界框 [B, 4] (x1, y1, x2, y2)
|
||
"""
|
||
batch_size = masks.size(0)
|
||
device = masks.device
|
||
|
||
# 將掩碼轉換為二值掩碼
|
||
binary_masks = (torch.sigmoid(masks) > 0.5).float()
|
||
|
||
# 初始化邊界框張量
|
||
boxes = torch.zeros(batch_size, 4, device=device)
|
||
|
||
# 對每個批次處理
|
||
for b in range(batch_size):
|
||
mask = binary_masks[b, 0] # [H, W]
|
||
|
||
# 找出非零元素的索引
|
||
non_zero_indices = torch.nonzero(mask, as_tuple=True)
|
||
|
||
# 如果掩碼中沒有非零元素,則使用默認的小邊界框
|
||
if len(non_zero_indices[0]) == 0:
|
||
# 返回一個默認的小邊界框
|
||
boxes[b] = torch.tensor([0, 0, 1, 1], device=device)
|
||
continue
|
||
|
||
# 計算邊界框坐標
|
||
y_min = torch.min(non_zero_indices[0])
|
||
y_max = torch.max(non_zero_indices[0])
|
||
x_min = torch.min(non_zero_indices[1])
|
||
x_max = torch.max(non_zero_indices[1])
|
||
|
||
# 存儲邊界框 [x1, y1, x2, y2]
|
||
boxes[b] = torch.tensor([x_min, y_min, x_max, y_max], device=device)
|
||
|
||
return boxes
|
||
|
||
def _calculate_intersection(self, pred_boxes, target_boxes):
|
||
"""
|
||
計算兩個邊界框的交集面積
|
||
"""
|
||
x1 = torch.max(pred_boxes[:, 0], target_boxes[:, 0])
|
||
y1 = torch.max(pred_boxes[:, 1], target_boxes[:, 1])
|
||
x2 = torch.min(pred_boxes[:, 2], target_boxes[:, 2])
|
||
y2 = torch.min(pred_boxes[:, 3], target_boxes[:, 3])
|
||
|
||
# 計算交集的寬度和高度
|
||
intersection_w = torch.clamp(x2 - x1, min=0)
|
||
intersection_h = torch.clamp(y2 - y1, min=0)
|
||
|
||
return intersection_w * intersection_h
|
||
|
||
|
||
class DIoULoss(nn.Module):
|
||
"""
|
||
Distance Intersection over Union (DIoU) Loss
|
||
CIOU 的簡化版本,只考慮重疊面積和中心點距離
|
||
"""
|
||
|
||
def __init__(self, eps=1e-7):
|
||
super(DIoULoss, self).__init__()
|
||
self.eps = eps
|
||
|
||
def forward(self, pred_boxes, target_boxes):
|
||
# 確保輸入為浮點數
|
||
pred_boxes = pred_boxes.float()
|
||
target_boxes = target_boxes.float()
|
||
|
||
# 計算交集區域
|
||
intersection = self._calculate_intersection(pred_boxes, target_boxes)
|
||
|
||
# 計算各自的面積
|
||
pred_area = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1])
|
||
target_area = (target_boxes[:, 2] - target_boxes[:, 0]) * (target_boxes[:, 3] - target_boxes[:, 1])
|
||
|
||
# 計算聯集面積
|
||
union = pred_area + target_area - intersection + self.eps
|
||
|
||
# 計算 IoU
|
||
iou = intersection / union
|
||
|
||
# 計算最小外接矩形的對角線距離平方
|
||
enclose_x1 = torch.min(pred_boxes[:, 0], target_boxes[:, 0])
|
||
enclose_y1 = torch.min(pred_boxes[:, 1], target_boxes[:, 1])
|
||
enclose_x2 = torch.max(pred_boxes[:, 2], target_boxes[:, 2])
|
||
enclose_y2 = torch.max(pred_boxes[:, 3], target_boxes[:, 3])
|
||
|
||
enclose_diagonal_sq = (enclose_x2 - enclose_x1) ** 2 + (enclose_y2 - enclose_y1) ** 2 + self.eps
|
||
|
||
# 計算中心點距離平方
|
||
pred_center_x = (pred_boxes[:, 0] + pred_boxes[:, 2]) / 2
|
||
pred_center_y = (pred_boxes[:, 1] + pred_boxes[:, 3]) / 2
|
||
target_center_x = (target_boxes[:, 0] + target_boxes[:, 2]) / 2
|
||
target_center_y = (target_boxes[:, 1] + target_boxes[:, 3]) / 2
|
||
|
||
center_distance_sq = (pred_center_x - target_center_x) ** 2 + (pred_center_y - target_center_y) ** 2
|
||
|
||
# 計算 DIoU
|
||
diou = iou - (center_distance_sq / enclose_diagonal_sq)
|
||
|
||
# 返回 DIoU Loss
|
||
diou_loss = 1 - diou
|
||
|
||
return diou_loss.mean()
|
||
|
||
def _calculate_intersection(self, pred_boxes, target_boxes):
|
||
"""計算交集面積"""
|
||
x1 = torch.max(pred_boxes[:, 0], target_boxes[:, 0])
|
||
y1 = torch.max(pred_boxes[:, 1], target_boxes[:, 1])
|
||
x2 = torch.min(pred_boxes[:, 2], target_boxes[:, 2])
|
||
y2 = torch.min(pred_boxes[:, 3], target_boxes[:, 3])
|
||
|
||
intersection_w = torch.clamp(x2 - x1, min=0)
|
||
intersection_h = torch.clamp(y2 - y1, min=0)
|
||
|
||
return intersection_w * intersection_h
|
||
|
||
|
||
class GIoULoss(nn.Module):
|
||
"""
|
||
Generalized Intersection over Union (GIoU) Loss
|
||
IoU 的泛化版本,考慮了最小外接矩形
|
||
"""
|
||
|
||
def __init__(self, eps=1e-7):
|
||
super(GIoULoss, self).__init__()
|
||
self.eps = eps
|
||
|
||
def forward(self, pred_boxes, target_boxes):
|
||
# 確保輸入為浮點數
|
||
pred_boxes = pred_boxes.float()
|
||
target_boxes = target_boxes.float()
|
||
|
||
# 計算交集
|
||
intersection = self._calculate_intersection(pred_boxes, target_boxes)
|
||
|
||
# 計算各自面積
|
||
pred_area = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1])
|
||
target_area = (target_boxes[:, 2] - target_boxes[:, 0]) * (target_boxes[:, 3] - target_boxes[:, 1])
|
||
|
||
# 計算聯集
|
||
union = pred_area + target_area - intersection + self.eps
|
||
|
||
# 計算 IoU
|
||
iou = intersection / union
|
||
|
||
# 計算最小外接矩形面積
|
||
enclose_x1 = torch.min(pred_boxes[:, 0], target_boxes[:, 0])
|
||
enclose_y1 = torch.min(pred_boxes[:, 1], target_boxes[:, 1])
|
||
enclose_x2 = torch.max(pred_boxes[:, 2], target_boxes[:, 2])
|
||
enclose_y2 = torch.max(pred_boxes[:, 3], target_boxes[:, 3])
|
||
|
||
enclose_area = (enclose_x2 - enclose_x1) * (enclose_y2 - enclose_y1) + self.eps
|
||
|
||
# 計算 GIoU
|
||
giou = iou - (enclose_area - union) / enclose_area
|
||
|
||
# 返回 GIoU Loss
|
||
giou_loss = 1 - giou
|
||
|
||
return giou_loss.mean()
|
||
|
||
def _calculate_intersection(self, pred_boxes, target_boxes):
|
||
"""計算交集面積"""
|
||
x1 = torch.max(pred_boxes[:, 0], target_boxes[:, 0])
|
||
y1 = torch.max(pred_boxes[:, 1], target_boxes[:, 1])
|
||
x2 = torch.min(pred_boxes[:, 2], target_boxes[:, 2])
|
||
y2 = torch.min(pred_boxes[:, 3], target_boxes[:, 3])
|
||
|
||
intersection_w = torch.clamp(x2 - x1, min=0)
|
||
intersection_h = torch.clamp(y2 - y1, min=0)
|
||
|
||
return intersection_w * intersection_h |