diff --git a/ppdet/metrics/coco_utils.py b/ppdet/metrics/coco_utils.py index b2ae487e1507c430831cedaef68f28f2caa7f0d5..82d5cc85628a4d1a98507bbee0d1b3c5b6c13d2e 100644 --- a/ppdet/metrics/coco_utils.py +++ b/ppdet/metrics/coco_utils.py @@ -21,7 +21,7 @@ import sys import numpy as np import itertools -from ppdet.metrics.post_process import get_det_res, get_seg_res, get_solov2_segm_res +from ppdet.metrics.json_results import get_det_res, get_seg_res, get_solov2_segm_res from ppdet.metrics.map_utils import draw_pr_curve from ppdet.utils.logger import setup_logger diff --git a/ppdet/metrics/post_process.py b/ppdet/metrics/json_results.py similarity index 100% rename from ppdet/metrics/post_process.py rename to ppdet/metrics/json_results.py diff --git a/ppdet/modeling/__init__.py b/ppdet/modeling/__init__.py index 5171d205cf3992f70c3187eea595504215560ef2..8e8d41fa49f50e3bc256579fc27840ba6bda4c81 100644 --- a/ppdet/modeling/__init__.py +++ b/ppdet/modeling/__init__.py @@ -7,7 +7,6 @@ from . import losses from . import architectures from . import post_process from . import layers -from . import utils from .ops import * from .backbones import * @@ -18,4 +17,3 @@ from .losses import * from .architectures import * from .post_process import * from .layers import * -from .utils import * diff --git a/ppdet/modeling/bbox_utils.py b/ppdet/modeling/bbox_utils.py index 9d90831b6fcdcfccae89a0d30a8e3248f00b58bc..3308684aa6c483d407e70852e2fa4fdb7a8609f8 100644 --- a/ppdet/modeling/bbox_utils.py +++ b/ppdet/modeling/bbox_utils.py @@ -14,6 +14,8 @@ import math import paddle +import paddle.nn.functional as F +import math def bbox2delta(src_boxes, tgt_boxes, weights): @@ -111,6 +113,16 @@ def bbox_area(boxes): def bbox_overlaps(boxes1, boxes2): + """ + Calculate overlaps between boxes1 and boxes2 + + Args: + boxes1 (Tensor): boxes with shape [M, 4] + boxes2 (Tensor): boxes with shape [N, 4] + + Return: + overlaps (Tensor): overlaps between boxes1 and boxes2 with shape [M, N] + """ area1 = bbox_area(boxes1) area2 = bbox_area(boxes2) @@ -126,3 +138,125 @@ def bbox_overlaps(boxes1, boxes2): (paddle.unsqueeze(area1, 1) + area2 - inter), paddle.zeros_like(inter)) return overlaps + + +def xywh2xyxy(box): + x, y, w, h = box + x1 = x - w * 0.5 + y1 = y - h * 0.5 + x2 = x + w * 0.5 + y2 = y + h * 0.5 + return [x1, y1, x2, y2] + + +def make_grid(h, w, dtype): + yv, xv = paddle.meshgrid([paddle.arange(h), paddle.arange(w)]) + return paddle.stack((xv, yv), 2).cast(dtype=dtype) + + +def decode_yolo(box, anchor, downsample_ratio): + """decode yolo box + + Args: + box (list): [x, y, w, h], all have the shape [b, na, h, w, 1] + anchor (list): anchor with the shape [na, 2] + downsample_ratio (int): downsample ratio, default 32 + scale (float): scale, default 1. + + Return: + box (list): decoded box, [x, y, w, h], all have the shape [b, na, h, w, 1] + """ + x, y, w, h = box + na, grid_h, grid_w = x.shape[1:4] + grid = make_grid(grid_h, grid_w, x.dtype).reshape((1, 1, grid_h, grid_w, 2)) + x1 = (x + grid[:, :, :, :, 0:1]) / grid_w + y1 = (y + grid[:, :, :, :, 1:2]) / grid_h + + anchor = paddle.to_tensor(anchor) + anchor = paddle.cast(anchor, x.dtype) + anchor = anchor.reshape((1, na, 1, 1, 2)) + w1 = paddle.exp(w) * anchor[:, :, :, :, 0:1] / (downsample_ratio * grid_w) + h1 = paddle.exp(h) * anchor[:, :, :, :, 1:2] / (downsample_ratio * grid_h) + + return [x1, y1, w1, h1] + + +def iou_similarity(box1, box2, eps=1e-9): + """Calculate iou of box1 and box2 + + Args: + box1 (Tensor): box with the shape [N, M1, 4] + box2 (Tensor): box with the shape [N, M2, 4] + + Return: + iou (Tensor): iou between box1 and box2 with the shape [N, M1, M2] + """ + box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4] + box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4] + px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4] + gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4] + x1y1 = paddle.maximum(px1y1, gx1y1) + x2y2 = paddle.minimum(px2y2, gx2y2) + overlap = (x2y2 - x1y1).clip(0).prod(-1) + area1 = (px2y2 - px1y1).clip(0).prod(-1) + area2 = (gx2y2 - gx1y1).clip(0).prod(-1) + union = area1 + area2 - overlap + eps + return overlap / union + + +def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9): + """calculate the iou of box1 and box2 + + Args: + box1 (list): [x, y, w, h], all have the shape [b, na, h, w, 1] + box2 (list): [x, y, w, h], all have the shape [b, na, h, w, 1] + giou (bool): whether use giou or not, default False + diou (bool): whether use diou or not, default False + ciou (bool): whether use ciou or not, default False + eps (float): epsilon to avoid divide by zero + + Return: + iou (Tensor): iou of box1 and box1, with the shape [b, na, h, w, 1] + """ + px1, py1, px2, py2 = box1 + gx1, gy1, gx2, gy2 = box2 + x1 = paddle.maximum(px1, gx1) + y1 = paddle.maximum(py1, gy1) + x2 = paddle.minimum(px2, gx2) + y2 = paddle.minimum(py2, gy2) + + overlap = ((x2 - x1).clip(0)) * ((y2 - y1).clip(0)) + + area1 = (px2 - px1) * (py2 - py1) + area1 = area1.clip(0) + + area2 = (gx2 - gx1) * (gy2 - gy1) + area2 = area2.clip(0) + + union = area1 + area2 - overlap + eps + iou = overlap / union + + if giou or ciou or diou: + # convex w, h + cw = paddle.maximum(px2, gx2) - paddle.minimum(px1, gx1) + ch = paddle.maximum(py2, gy2) - paddle.minimum(py1, gy1) + if giou: + c_area = cw * ch + eps + return iou - (c_area - union) / c_area + else: + # convex diagonal squared + c2 = cw**2 + ch**2 + eps + # center distance + rho2 = ((px1 + px2 - gx1 - gx2)**2 + (py1 + py2 - gy1 - gy2)**2) / 4 + if diou: + return iou - rho2 / c2 + else: + w1, h1 = px2 - px1, py2 - py1 + eps + w2, h2 = gx2 - gx1, gy2 - gy1 + eps + delta = paddle.atan(w1 / h1) - paddle.atan(w2 / h2) + v = (4 / math.pi**2) * paddle.pow(delta, 2) + alpha = v / (1 + eps - iou + v) + alpha.stop_gradient = True + return iou - (rho2 / c2 + v * alpha) + else: + return iou diff --git a/ppdet/modeling/losses/iou_aware_loss.py b/ppdet/modeling/losses/iou_aware_loss.py index 2cc6f2a2c4077558c93ec55baa11f3d12a2e8476..f5599588c8815c5b5582ad6a9e180fff80c0ac49 100644 --- a/ppdet/modeling/losses/iou_aware_loss.py +++ b/ppdet/modeling/losses/iou_aware_loss.py @@ -20,7 +20,7 @@ import paddle import paddle.nn.functional as F from ppdet.core.workspace import register, serializable from .iou_loss import IouLoss -from ..utils import xywh2xyxy, bbox_iou, decode_yolo +from ..bbox_utils import xywh2xyxy, bbox_iou @register diff --git a/ppdet/modeling/losses/iou_loss.py b/ppdet/modeling/losses/iou_loss.py index 72613297d5df262519be019375bb2f9cf91aee0e..df1ef216012cd34d3eb16e900c05790eade26fe8 100644 --- a/ppdet/modeling/losses/iou_loss.py +++ b/ppdet/modeling/losses/iou_loss.py @@ -19,7 +19,7 @@ from __future__ import print_function import paddle import paddle.nn.functional as F from ppdet.core.workspace import register, serializable -from ..utils import xywh2xyxy, bbox_iou, decode_yolo +from ..bbox_utils import xywh2xyxy, bbox_iou __all__ = ['IouLoss', 'GIoULoss'] diff --git a/ppdet/modeling/losses/yolo_loss.py b/ppdet/modeling/losses/yolo_loss.py index 149139989a425fad61648c4ee8de43e2fbe7f798..e460d2e2866d26afa5ca49a4823f2a696ba2e947 100644 --- a/ppdet/modeling/losses/yolo_loss.py +++ b/ppdet/modeling/losses/yolo_loss.py @@ -21,7 +21,7 @@ import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register -from ..utils import decode_yolo, xywh2xyxy, iou_similarity +from ..bbox_utils import decode_yolo, xywh2xyxy, iou_similarity __all__ = ['YOLOv3Loss'] diff --git a/ppdet/modeling/utils/__init__.py b/ppdet/modeling/utils/__init__.py deleted file mode 100644 index e27f26a6f1254241a760af2b41a7eb26eb463ad6..0000000000000000000000000000000000000000 --- a/ppdet/modeling/utils/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from . import bbox_util - -from .bbox_util import * diff --git a/ppdet/modeling/utils/bbox_util.py b/ppdet/modeling/utils/bbox_util.py deleted file mode 100644 index 6ea3682b40ab3a48a04bdd78f64ca529dd2c9587..0000000000000000000000000000000000000000 --- a/ppdet/modeling/utils/bbox_util.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import paddle -import paddle.nn.functional as F -import math - - -def xywh2xyxy(box): - x, y, w, h = box - x1 = x - w * 0.5 - y1 = y - h * 0.5 - x2 = x + w * 0.5 - y2 = y + h * 0.5 - return [x1, y1, x2, y2] - - -def make_grid(h, w, dtype): - yv, xv = paddle.meshgrid([paddle.arange(h), paddle.arange(w)]) - return paddle.stack((xv, yv), 2).cast(dtype=dtype) - - -def decode_yolo(box, anchor, downsample_ratio): - """decode yolo box - - Args: - box (list): [x, y, w, h], all have the shape [b, na, h, w, 1] - anchor (list): anchor with the shape [na, 2] - downsample_ratio (int): downsample ratio, default 32 - scale (float): scale, default 1. - - Return: - box (list): decoded box, [x, y, w, h], all have the shape [b, na, h, w, 1] - """ - x, y, w, h = box - na, grid_h, grid_w = x.shape[1:4] - grid = make_grid(grid_h, grid_w, x.dtype).reshape((1, 1, grid_h, grid_w, 2)) - x1 = (x + grid[:, :, :, :, 0:1]) / grid_w - y1 = (y + grid[:, :, :, :, 1:2]) / grid_h - - anchor = paddle.to_tensor(anchor) - anchor = paddle.cast(anchor, x.dtype) - anchor = anchor.reshape((1, na, 1, 1, 2)) - w1 = paddle.exp(w) * anchor[:, :, :, :, 0:1] / (downsample_ratio * grid_w) - h1 = paddle.exp(h) * anchor[:, :, :, :, 1:2] / (downsample_ratio * grid_h) - - return [x1, y1, w1, h1] - - -def iou_similarity(box1, box2, eps=1e-9): - """Calculate iou of box1 and box2 - - Args: - box1 (Tensor): box with the shape [N, M1, 4] - box2 (Tensor): box with the shape [N, M2, 4] - - Return: - iou (Tensor): iou between box1 and box2 with the shape [N, M1, M2] - """ - box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4] - box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4] - px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4] - gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4] - x1y1 = paddle.maximum(px1y1, gx1y1) - x2y2 = paddle.minimum(px2y2, gx2y2) - overlap = (x2y2 - x1y1).clip(0).prod(-1) - area1 = (px2y2 - px1y1).clip(0).prod(-1) - area2 = (gx2y2 - gx1y1).clip(0).prod(-1) - union = area1 + area2 - overlap + eps - return overlap / union - - -def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9): - """calculate the iou of box1 and box2 - - Args: - box1 (list): [x, y, w, h], all have the shape [b, na, h, w, 1] - box2 (list): [x, y, w, h], all have the shape [b, na, h, w, 1] - giou (bool): whether use giou or not, default False - diou (bool): whether use diou or not, default False - ciou (bool): whether use ciou or not, default False - eps (float): epsilon to avoid divide by zero - - Return: - iou (Tensor): iou of box1 and box1, with the shape [b, na, h, w, 1] - """ - px1, py1, px2, py2 = box1 - gx1, gy1, gx2, gy2 = box2 - x1 = paddle.maximum(px1, gx1) - y1 = paddle.maximum(py1, gy1) - x2 = paddle.minimum(px2, gx2) - y2 = paddle.minimum(py2, gy2) - - overlap = ((x2 - x1).clip(0)) * ((y2 - y1).clip(0)) - - area1 = (px2 - px1) * (py2 - py1) - area1 = area1.clip(0) - - area2 = (gx2 - gx1) * (gy2 - gy1) - area2 = area2.clip(0) - - union = area1 + area2 - overlap + eps - iou = overlap / union - - if giou or ciou or diou: - # convex w, h - cw = paddle.maximum(px2, gx2) - paddle.minimum(px1, gx1) - ch = paddle.maximum(py2, gy2) - paddle.minimum(py1, gy1) - if giou: - c_area = cw * ch + eps - return iou - (c_area - union) / c_area - else: - # convex diagonal squared - c2 = cw**2 + ch**2 + eps - # center distance - rho2 = ((px1 + px2 - gx1 - gx2)**2 + (py1 + py2 - gy1 - gy2)**2) / 4 - if diou: - return iou - rho2 / c2 - else: - w1, h1 = px2 - px1, py2 - py1 + eps - w2, h2 = gx2 - gx1, gy2 - gy1 + eps - delta = paddle.atan(w1 / h1) - paddle.atan(w2 / h2) - v = (4 / math.pi**2) * paddle.pow(delta, 2) - alpha = v / (1 + eps - iou + v) - alpha.stop_gradient = True - return iou - (rho2 / c2 + v * alpha) - else: - return iou diff --git a/ppdet/utils/bbox_utils.py b/ppdet/utils/bbox_utils.py deleted file mode 100644 index 63c93976c1c63cf85e35a6ddc79ee32f0eb3b716..0000000000000000000000000000000000000000 --- a/ppdet/utils/bbox_utils.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from .logger import setup_logger -logger = setup_logger(__name__) - -__all__ = ["bbox_overlaps", "box_to_delta"] - - -def bbox_overlaps(boxes_1, boxes_2): - ''' - bbox_overlaps - boxes_1: x1, y, x2, y2 - boxes_2: x1, y, x2, y2 - ''' - assert boxes_1.shape[1] == 4 and boxes_2.shape[1] == 4 - - num_1 = boxes_1.shape[0] - num_2 = boxes_2.shape[0] - - x1_1 = boxes_1[:, 0:1] - y1_1 = boxes_1[:, 1:2] - x2_1 = boxes_1[:, 2:3] - y2_1 = boxes_1[:, 3:4] - area_1 = (x2_1 - x1_1 + 1) * (y2_1 - y1_1 + 1) - - x1_2 = boxes_2[:, 0].transpose() - y1_2 = boxes_2[:, 1].transpose() - x2_2 = boxes_2[:, 2].transpose() - y2_2 = boxes_2[:, 3].transpose() - area_2 = (x2_2 - x1_2 + 1) * (y2_2 - y1_2 + 1) - - xx1 = np.maximum(x1_1, x1_2) - yy1 = np.maximum(y1_1, y1_2) - xx2 = np.minimum(x2_1, x2_2) - yy2 = np.minimum(y2_1, y2_2) - - w = np.maximum(0.0, xx2 - xx1 + 1) - h = np.maximum(0.0, yy2 - yy1 + 1) - inter = w * h - - ovr = inter / (area_1 + area_2 - inter) - return ovr - - -def box_to_delta(ex_boxes, gt_boxes, weights): - """ box_to_delta """ - ex_w = ex_boxes[:, 2] - ex_boxes[:, 0] + 1 - ex_h = ex_boxes[:, 3] - ex_boxes[:, 1] + 1 - ex_ctr_x = ex_boxes[:, 0] + 0.5 * ex_w - ex_ctr_y = ex_boxes[:, 1] + 0.5 * ex_h - - gt_w = gt_boxes[:, 2] - gt_boxes[:, 0] + 1 - gt_h = gt_boxes[:, 3] - gt_boxes[:, 1] + 1 - gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_w - gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_h - - dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0] - dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1] - dw = (np.log(gt_w / ex_w)) / weights[2] - dh = (np.log(gt_h / ex_h)) / weights[3] - - targets = np.vstack([dx, dy, dw, dh]).transpose() - return targets diff --git a/ppdet/utils/post_process.py b/ppdet/utils/post_process.py deleted file mode 100644 index 45f9f9908af32f816c8965587d4e428b1b34ad5e..0000000000000000000000000000000000000000 --- a/ppdet/utils/post_process.py +++ /dev/null @@ -1,326 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import cv2 - -from .logger import setup_logger -logger = setup_logger(__name__) - -__all__ = ['nms'] - - -def box_flip(boxes, im_shape): - im_width = im_shape[0][1] - flipped_boxes = boxes.copy() - - flipped_boxes[:, 0::4] = im_width - boxes[:, 2::4] - 1 - flipped_boxes[:, 2::4] = im_width - boxes[:, 0::4] - 1 - return flipped_boxes - - -def nms(dets, thresh): - """Apply classic DPM-style greedy NMS.""" - if dets.shape[0] == 0: - return dets[[], :] - scores = dets[:, 0] - x1 = dets[:, 1] - y1 = dets[:, 2] - x2 = dets[:, 3] - y2 = dets[:, 4] - - areas = (x2 - x1 + 1) * (y2 - y1 + 1) - order = scores.argsort()[::-1] - - ndets = dets.shape[0] - suppressed = np.zeros((ndets), dtype=np.int) - - # nominal indices - # _i, _j - # sorted indices - # i, j - # temp variables for box i's (the box currently under consideration) - # ix1, iy1, ix2, iy2, iarea - - # variables for computing overlap with box j (lower scoring box) - # xx1, yy1, xx2, yy2 - # w, h - # inter, ovr - - for _i in range(ndets): - i = order[_i] - if suppressed[i] == 1: - continue - ix1 = x1[i] - iy1 = y1[i] - ix2 = x2[i] - iy2 = y2[i] - iarea = areas[i] - for _j in range(_i + 1, ndets): - j = order[_j] - if suppressed[j] == 1: - continue - xx1 = max(ix1, x1[j]) - yy1 = max(iy1, y1[j]) - xx2 = min(ix2, x2[j]) - yy2 = min(iy2, y2[j]) - w = max(0.0, xx2 - xx1 + 1) - h = max(0.0, yy2 - yy1 + 1) - inter = w * h - ovr = inter / (iarea + areas[j] - inter) - if ovr >= thresh: - suppressed[j] = 1 - keep = np.where(suppressed == 0)[0] - dets = dets[keep, :] - return dets - - -def soft_nms(dets, sigma, thres): - dets_final = [] - while len(dets) > 0: - maxpos = np.argmax(dets[:, 0]) - dets_final.append(dets[maxpos].copy()) - ts, tx1, ty1, tx2, ty2 = dets[maxpos] - scores = dets[:, 0] - # force remove bbox at maxpos - scores[maxpos] = -1 - x1 = dets[:, 1] - y1 = dets[:, 2] - x2 = dets[:, 3] - y2 = dets[:, 4] - areas = (x2 - x1 + 1) * (y2 - y1 + 1) - xx1 = np.maximum(tx1, x1) - yy1 = np.maximum(ty1, y1) - xx2 = np.minimum(tx2, x2) - yy2 = np.minimum(ty2, y2) - w = np.maximum(0.0, xx2 - xx1 + 1) - h = np.maximum(0.0, yy2 - yy1 + 1) - inter = w * h - ovr = inter / (areas + areas[maxpos] - inter) - weight = np.exp(-(ovr * ovr) / sigma) - scores = scores * weight - idx_keep = np.where(scores >= thres) - dets[:, 0] = scores - dets = dets[idx_keep] - dets_final = np.array(dets_final).reshape(-1, 5) - return dets_final - - -def bbox_area(box): - w = box[2] - box[0] + 1 - h = box[3] - box[1] + 1 - return w * h - - -def bbox_overlaps(x, y): - N = x.shape[0] - K = y.shape[0] - overlaps = np.zeros((N, K), dtype=np.float32) - for k in range(K): - y_area = bbox_area(y[k]) - for n in range(N): - iw = min(x[n, 2], y[k, 2]) - max(x[n, 0], y[k, 0]) + 1 - if iw > 0: - ih = min(x[n, 3], y[k, 3]) - max(x[n, 1], y[k, 1]) + 1 - if ih > 0: - x_area = bbox_area(x[n]) - ua = x_area + y_area - iw * ih - overlaps[n, k] = iw * ih / ua - return overlaps - - -def box_voting(nms_dets, dets, vote_thresh): - top_dets = nms_dets.copy() - top_boxes = nms_dets[:, 1:] - all_boxes = dets[:, 1:] - all_scores = dets[:, 0] - top_to_all_overlaps = bbox_overlaps(top_boxes, all_boxes) - for k in range(nms_dets.shape[0]): - inds_to_vote = np.where(top_to_all_overlaps[k] >= vote_thresh)[0] - boxes_to_vote = all_boxes[inds_to_vote, :] - ws = all_scores[inds_to_vote] - top_dets[k, 1:] = np.average(boxes_to_vote, axis=0, weights=ws) - - return top_dets - - -def get_nms_result(boxes, - scores, - config, - num_classes, - background_label=0, - labels=None): - has_labels = labels is not None - cls_boxes = [[] for _ in range(num_classes)] - start_idx = 1 if background_label == 0 else 0 - for j in range(start_idx, num_classes): - inds = np.where(labels == j)[0] if has_labels else np.where( - scores[:, j] > config['score_thresh'])[0] - scores_j = scores[inds] if has_labels else scores[inds, j] - boxes_j = boxes[inds, :] if has_labels else boxes[inds, j * 4:(j + 1) * - 4] - dets_j = np.hstack((scores_j[:, np.newaxis], boxes_j)).astype( - np.float32, copy=False) - if config.get('use_soft_nms', False): - nms_dets = soft_nms(dets_j, config['sigma'], config['nms_thresh']) - else: - nms_dets = nms(dets_j, config['nms_thresh']) - if config.get('enable_voting', False): - nms_dets = box_voting(nms_dets, dets_j, config['vote_thresh']) - #add labels - label = np.array([j for _ in range(len(nms_dets))]) - nms_dets = np.hstack((label[:, np.newaxis], nms_dets)).astype( - np.float32, copy=False) - cls_boxes[j] = nms_dets - # Limit to max_per_image detections **over all classes** - image_scores = np.hstack( - [cls_boxes[j][:, 1] for j in range(start_idx, num_classes)]) - if len(image_scores) > config['detections_per_im']: - image_thresh = np.sort(image_scores)[-config['detections_per_im']] - for j in range(start_idx, num_classes): - keep = np.where(cls_boxes[j][:, 1] >= image_thresh)[0] - cls_boxes[j] = cls_boxes[j][keep, :] - - im_results = np.vstack( - [cls_boxes[j] for j in range(start_idx, num_classes)]) - return im_results - - -def mstest_box_post_process(result, config, num_classes): - """ - Multi-scale Test - Only available for batch_size=1 now. - """ - post_bbox = {} - use_flip = False - ms_boxes = [] - ms_scores = [] - im_shape = result['im_shape'][0] - for k in result.keys(): - if 'bbox' in k: - boxes = result[k][0] - boxes = np.reshape(boxes, (-1, 4 * num_classes)) - scores = result['score' + k[4:]][0] - if 'flip' in k: - boxes = box_flip(boxes, im_shape) - use_flip = True - ms_boxes.append(boxes) - ms_scores.append(scores) - - ms_boxes = np.concatenate(ms_boxes) - ms_scores = np.concatenate(ms_scores) - bbox_pred = get_nms_result(ms_boxes, ms_scores, config, num_classes) - post_bbox.update({'bbox': (bbox_pred, [[len(bbox_pred)]])}) - if use_flip: - bbox = bbox_pred[:, 2:] - bbox_flip = np.append( - bbox_pred[:, :2], box_flip(bbox, im_shape), axis=1) - post_bbox.update({'bbox_flip': (bbox_flip, [[len(bbox_flip)]])}) - return post_bbox - - -def mstest_mask_post_process(result, cfg): - mask_list = [] - im_shape = result['im_shape'][0] - M = cfg.FPNRoIAlign['mask_resolution'] - for k in result.keys(): - if 'mask' in k: - masks = result[k][0] - if len(masks.shape) != 4: - masks = np.zeros((0, M, M)) - mask_list.append(masks) - continue - if 'flip' in k: - masks = masks[:, :, :, ::-1] - mask_list.append(masks) - - mask_pred = np.mean(mask_list, axis=0) - return {'mask': (mask_pred, [[len(mask_pred)]])} - - -def mask_encode(results, resolution, thresh_binarize=0.5): - import pycocotools.mask as mask_util - from ppdet.utils.coco_eval import expand_boxes - scale = (resolution + 2.0) / resolution - bboxes = results['bbox'][0] - masks = results['mask'][0] - lengths = results['mask'][1][0] - im_shapes = results['im_shape'][0] - segms = [] - if bboxes.shape == (1, 1) or bboxes is None: - return segms - if len(bboxes.tolist()) == 0: - return segms - - s = 0 - # for each sample - for i in range(len(lengths)): - num = lengths[i] - im_shape = im_shapes[i] - - bbox = bboxes[s:s + num][:, 2:] - clsid_scores = bboxes[s:s + num][:, 0:2] - mask = masks[s:s + num] - s += num - - im_h = int(im_shape[0]) - im_w = int(im_shape[1]) - expand_bbox = expand_boxes(bbox, scale) - expand_bbox = expand_bbox.astype(np.int32) - padded_mask = np.zeros( - (resolution + 2, resolution + 2), dtype=np.float32) - - for j in range(num): - xmin, ymin, xmax, ymax = expand_bbox[j].tolist() - clsid, score = clsid_scores[j].tolist() - clsid = int(clsid) - padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :] - - w = xmax - xmin + 1 - h = ymax - ymin + 1 - w = np.maximum(w, 1) - h = np.maximum(h, 1) - resized_mask = cv2.resize(padded_mask, (w, h)) - resized_mask = np.array( - resized_mask > thresh_binarize, dtype=np.uint8) - im_mask = np.zeros((im_h, im_w), dtype=np.uint8) - - x0 = min(max(xmin, 0), im_w) - x1 = min(max(xmax + 1, 0), im_w) - y0 = min(max(ymin, 0), im_h) - y1 = min(max(ymax + 1, 0), im_h) - - im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), ( - x0 - xmin):(x1 - xmin)] - segm = mask_util.encode( - np.array( - im_mask[:, :, np.newaxis], order='F'))[0] - segms.append(segm) - return segms - - -def corner_post_process(results, config, num_classes): - detections = results['bbox'][0] - keep_inds = (detections[:, 1] > -1) - detections = detections[keep_inds] - labels = detections[:, 0] - scores = detections[:, 1] - boxes = detections[:, 2:6] - cls_boxes = get_nms_result( - boxes, scores, config, num_classes, background_label=-1, labels=labels) - results.update({'bbox': (cls_boxes, [[len(cls_boxes)]])})