未验证 提交 de808d9c 编写于 作者: W wangguanzhong 提交者: GitHub

refine ppdet file name (#2527)

* refine ppdet file name

* unify bbox_utils
上级 e3ff4417
......@@ -21,7 +21,7 @@ import sys
import numpy as np
import itertools
from ppdet.metrics.post_process import get_det_res, get_seg_res, get_solov2_segm_res
from ppdet.metrics.json_results import get_det_res, get_seg_res, get_solov2_segm_res
from ppdet.metrics.map_utils import draw_pr_curve
from ppdet.utils.logger import setup_logger
......
......@@ -7,7 +7,6 @@ from . import losses
from . import architectures
from . import post_process
from . import layers
from . import utils
from .ops import *
from .backbones import *
......@@ -18,4 +17,3 @@ from .losses import *
from .architectures import *
from .post_process import *
from .layers import *
from .utils import *
......@@ -14,6 +14,8 @@
import math
import paddle
import paddle.nn.functional as F
import math
def bbox2delta(src_boxes, tgt_boxes, weights):
......@@ -111,6 +113,16 @@ def bbox_area(boxes):
def bbox_overlaps(boxes1, boxes2):
"""
Calculate overlaps between boxes1 and boxes2
Args:
boxes1 (Tensor): boxes with shape [M, 4]
boxes2 (Tensor): boxes with shape [N, 4]
Return:
overlaps (Tensor): overlaps between boxes1 and boxes2 with shape [M, N]
"""
area1 = bbox_area(boxes1)
area2 = bbox_area(boxes2)
......@@ -126,3 +138,125 @@ def bbox_overlaps(boxes1, boxes2):
(paddle.unsqueeze(area1, 1) + area2 - inter),
paddle.zeros_like(inter))
return overlaps
def xywh2xyxy(box):
x, y, w, h = box
x1 = x - w * 0.5
y1 = y - h * 0.5
x2 = x + w * 0.5
y2 = y + h * 0.5
return [x1, y1, x2, y2]
def make_grid(h, w, dtype):
yv, xv = paddle.meshgrid([paddle.arange(h), paddle.arange(w)])
return paddle.stack((xv, yv), 2).cast(dtype=dtype)
def decode_yolo(box, anchor, downsample_ratio):
"""decode yolo box
Args:
box (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
anchor (list): anchor with the shape [na, 2]
downsample_ratio (int): downsample ratio, default 32
scale (float): scale, default 1.
Return:
box (list): decoded box, [x, y, w, h], all have the shape [b, na, h, w, 1]
"""
x, y, w, h = box
na, grid_h, grid_w = x.shape[1:4]
grid = make_grid(grid_h, grid_w, x.dtype).reshape((1, 1, grid_h, grid_w, 2))
x1 = (x + grid[:, :, :, :, 0:1]) / grid_w
y1 = (y + grid[:, :, :, :, 1:2]) / grid_h
anchor = paddle.to_tensor(anchor)
anchor = paddle.cast(anchor, x.dtype)
anchor = anchor.reshape((1, na, 1, 1, 2))
w1 = paddle.exp(w) * anchor[:, :, :, :, 0:1] / (downsample_ratio * grid_w)
h1 = paddle.exp(h) * anchor[:, :, :, :, 1:2] / (downsample_ratio * grid_h)
return [x1, y1, w1, h1]
def iou_similarity(box1, box2, eps=1e-9):
"""Calculate iou of box1 and box2
Args:
box1 (Tensor): box with the shape [N, M1, 4]
box2 (Tensor): box with the shape [N, M2, 4]
Return:
iou (Tensor): iou between box1 and box2 with the shape [N, M1, M2]
"""
box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
x1y1 = paddle.maximum(px1y1, gx1y1)
x2y2 = paddle.minimum(px2y2, gx2y2)
overlap = (x2y2 - x1y1).clip(0).prod(-1)
area1 = (px2y2 - px1y1).clip(0).prod(-1)
area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
union = area1 + area2 - overlap + eps
return overlap / union
def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
"""calculate the iou of box1 and box2
Args:
box1 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
box2 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
giou (bool): whether use giou or not, default False
diou (bool): whether use diou or not, default False
ciou (bool): whether use ciou or not, default False
eps (float): epsilon to avoid divide by zero
Return:
iou (Tensor): iou of box1 and box1, with the shape [b, na, h, w, 1]
"""
px1, py1, px2, py2 = box1
gx1, gy1, gx2, gy2 = box2
x1 = paddle.maximum(px1, gx1)
y1 = paddle.maximum(py1, gy1)
x2 = paddle.minimum(px2, gx2)
y2 = paddle.minimum(py2, gy2)
overlap = ((x2 - x1).clip(0)) * ((y2 - y1).clip(0))
area1 = (px2 - px1) * (py2 - py1)
area1 = area1.clip(0)
area2 = (gx2 - gx1) * (gy2 - gy1)
area2 = area2.clip(0)
union = area1 + area2 - overlap + eps
iou = overlap / union
if giou or ciou or diou:
# convex w, h
cw = paddle.maximum(px2, gx2) - paddle.minimum(px1, gx1)
ch = paddle.maximum(py2, gy2) - paddle.minimum(py1, gy1)
if giou:
c_area = cw * ch + eps
return iou - (c_area - union) / c_area
else:
# convex diagonal squared
c2 = cw**2 + ch**2 + eps
# center distance
rho2 = ((px1 + px2 - gx1 - gx2)**2 + (py1 + py2 - gy1 - gy2)**2) / 4
if diou:
return iou - rho2 / c2
else:
w1, h1 = px2 - px1, py2 - py1 + eps
w2, h2 = gx2 - gx1, gy2 - gy1 + eps
delta = paddle.atan(w1 / h1) - paddle.atan(w2 / h2)
v = (4 / math.pi**2) * paddle.pow(delta, 2)
alpha = v / (1 + eps - iou + v)
alpha.stop_gradient = True
return iou - (rho2 / c2 + v * alpha)
else:
return iou
......@@ -20,7 +20,7 @@ import paddle
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from .iou_loss import IouLoss
from ..utils import xywh2xyxy, bbox_iou, decode_yolo
from ..bbox_utils import xywh2xyxy, bbox_iou
@register
......
......@@ -19,7 +19,7 @@ from __future__ import print_function
import paddle
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from ..utils import xywh2xyxy, bbox_iou, decode_yolo
from ..bbox_utils import xywh2xyxy, bbox_iou
__all__ = ['IouLoss', 'GIoULoss']
......
......@@ -21,7 +21,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ..utils import decode_yolo, xywh2xyxy, iou_similarity
from ..bbox_utils import decode_yolo, xywh2xyxy, iou_similarity
__all__ = ['YOLOv3Loss']
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import bbox_util
from .bbox_util import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn.functional as F
import math
def xywh2xyxy(box):
x, y, w, h = box
x1 = x - w * 0.5
y1 = y - h * 0.5
x2 = x + w * 0.5
y2 = y + h * 0.5
return [x1, y1, x2, y2]
def make_grid(h, w, dtype):
yv, xv = paddle.meshgrid([paddle.arange(h), paddle.arange(w)])
return paddle.stack((xv, yv), 2).cast(dtype=dtype)
def decode_yolo(box, anchor, downsample_ratio):
"""decode yolo box
Args:
box (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
anchor (list): anchor with the shape [na, 2]
downsample_ratio (int): downsample ratio, default 32
scale (float): scale, default 1.
Return:
box (list): decoded box, [x, y, w, h], all have the shape [b, na, h, w, 1]
"""
x, y, w, h = box
na, grid_h, grid_w = x.shape[1:4]
grid = make_grid(grid_h, grid_w, x.dtype).reshape((1, 1, grid_h, grid_w, 2))
x1 = (x + grid[:, :, :, :, 0:1]) / grid_w
y1 = (y + grid[:, :, :, :, 1:2]) / grid_h
anchor = paddle.to_tensor(anchor)
anchor = paddle.cast(anchor, x.dtype)
anchor = anchor.reshape((1, na, 1, 1, 2))
w1 = paddle.exp(w) * anchor[:, :, :, :, 0:1] / (downsample_ratio * grid_w)
h1 = paddle.exp(h) * anchor[:, :, :, :, 1:2] / (downsample_ratio * grid_h)
return [x1, y1, w1, h1]
def iou_similarity(box1, box2, eps=1e-9):
"""Calculate iou of box1 and box2
Args:
box1 (Tensor): box with the shape [N, M1, 4]
box2 (Tensor): box with the shape [N, M2, 4]
Return:
iou (Tensor): iou between box1 and box2 with the shape [N, M1, M2]
"""
box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
x1y1 = paddle.maximum(px1y1, gx1y1)
x2y2 = paddle.minimum(px2y2, gx2y2)
overlap = (x2y2 - x1y1).clip(0).prod(-1)
area1 = (px2y2 - px1y1).clip(0).prod(-1)
area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
union = area1 + area2 - overlap + eps
return overlap / union
def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
"""calculate the iou of box1 and box2
Args:
box1 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
box2 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
giou (bool): whether use giou or not, default False
diou (bool): whether use diou or not, default False
ciou (bool): whether use ciou or not, default False
eps (float): epsilon to avoid divide by zero
Return:
iou (Tensor): iou of box1 and box1, with the shape [b, na, h, w, 1]
"""
px1, py1, px2, py2 = box1
gx1, gy1, gx2, gy2 = box2
x1 = paddle.maximum(px1, gx1)
y1 = paddle.maximum(py1, gy1)
x2 = paddle.minimum(px2, gx2)
y2 = paddle.minimum(py2, gy2)
overlap = ((x2 - x1).clip(0)) * ((y2 - y1).clip(0))
area1 = (px2 - px1) * (py2 - py1)
area1 = area1.clip(0)
area2 = (gx2 - gx1) * (gy2 - gy1)
area2 = area2.clip(0)
union = area1 + area2 - overlap + eps
iou = overlap / union
if giou or ciou or diou:
# convex w, h
cw = paddle.maximum(px2, gx2) - paddle.minimum(px1, gx1)
ch = paddle.maximum(py2, gy2) - paddle.minimum(py1, gy1)
if giou:
c_area = cw * ch + eps
return iou - (c_area - union) / c_area
else:
# convex diagonal squared
c2 = cw**2 + ch**2 + eps
# center distance
rho2 = ((px1 + px2 - gx1 - gx2)**2 + (py1 + py2 - gy1 - gy2)**2) / 4
if diou:
return iou - rho2 / c2
else:
w1, h1 = px2 - px1, py2 - py1 + eps
w2, h2 = gx2 - gx1, gy2 - gy1 + eps
delta = paddle.atan(w1 / h1) - paddle.atan(w2 / h2)
v = (4 / math.pi**2) * paddle.pow(delta, 2)
alpha = v / (1 + eps - iou + v)
alpha.stop_gradient = True
return iou - (rho2 / c2 + v * alpha)
else:
return iou
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from .logger import setup_logger
logger = setup_logger(__name__)
__all__ = ["bbox_overlaps", "box_to_delta"]
def bbox_overlaps(boxes_1, boxes_2):
'''
bbox_overlaps
boxes_1: x1, y, x2, y2
boxes_2: x1, y, x2, y2
'''
assert boxes_1.shape[1] == 4 and boxes_2.shape[1] == 4
num_1 = boxes_1.shape[0]
num_2 = boxes_2.shape[0]
x1_1 = boxes_1[:, 0:1]
y1_1 = boxes_1[:, 1:2]
x2_1 = boxes_1[:, 2:3]
y2_1 = boxes_1[:, 3:4]
area_1 = (x2_1 - x1_1 + 1) * (y2_1 - y1_1 + 1)
x1_2 = boxes_2[:, 0].transpose()
y1_2 = boxes_2[:, 1].transpose()
x2_2 = boxes_2[:, 2].transpose()
y2_2 = boxes_2[:, 3].transpose()
area_2 = (x2_2 - x1_2 + 1) * (y2_2 - y1_2 + 1)
xx1 = np.maximum(x1_1, x1_2)
yy1 = np.maximum(y1_1, y1_2)
xx2 = np.minimum(x2_1, x2_2)
yy2 = np.minimum(y2_1, y2_2)
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (area_1 + area_2 - inter)
return ovr
def box_to_delta(ex_boxes, gt_boxes, weights):
""" box_to_delta """
ex_w = ex_boxes[:, 2] - ex_boxes[:, 0] + 1
ex_h = ex_boxes[:, 3] - ex_boxes[:, 1] + 1
ex_ctr_x = ex_boxes[:, 0] + 0.5 * ex_w
ex_ctr_y = ex_boxes[:, 1] + 0.5 * ex_h
gt_w = gt_boxes[:, 2] - gt_boxes[:, 0] + 1
gt_h = gt_boxes[:, 3] - gt_boxes[:, 1] + 1
gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_w
gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_h
dx = (gt_ctr_x - ex_ctr_x) / ex_w / weights[0]
dy = (gt_ctr_y - ex_ctr_y) / ex_h / weights[1]
dw = (np.log(gt_w / ex_w)) / weights[2]
dh = (np.log(gt_h / ex_h)) / weights[3]
targets = np.vstack([dx, dy, dw, dh]).transpose()
return targets
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import cv2
from .logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['nms']
def box_flip(boxes, im_shape):
im_width = im_shape[0][1]
flipped_boxes = boxes.copy()
flipped_boxes[:, 0::4] = im_width - boxes[:, 2::4] - 1
flipped_boxes[:, 2::4] = im_width - boxes[:, 0::4] - 1
return flipped_boxes
def nms(dets, thresh):
"""Apply classic DPM-style greedy NMS."""
if dets.shape[0] == 0:
return dets[[], :]
scores = dets[:, 0]
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int)
# nominal indices
# _i, _j
# sorted indices
# i, j
# temp variables for box i's (the box currently under consideration)
# ix1, iy1, ix2, iy2, iarea
# variables for computing overlap with box j (lower scoring box)
# xx1, yy1, xx2, yy2
# w, h
# inter, ovr
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= thresh:
suppressed[j] = 1
keep = np.where(suppressed == 0)[0]
dets = dets[keep, :]
return dets
def soft_nms(dets, sigma, thres):
dets_final = []
while len(dets) > 0:
maxpos = np.argmax(dets[:, 0])
dets_final.append(dets[maxpos].copy())
ts, tx1, ty1, tx2, ty2 = dets[maxpos]
scores = dets[:, 0]
# force remove bbox at maxpos
scores[maxpos] = -1
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
xx1 = np.maximum(tx1, x1)
yy1 = np.maximum(ty1, y1)
xx2 = np.minimum(tx2, x2)
yy2 = np.minimum(ty2, y2)
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas + areas[maxpos] - inter)
weight = np.exp(-(ovr * ovr) / sigma)
scores = scores * weight
idx_keep = np.where(scores >= thres)
dets[:, 0] = scores
dets = dets[idx_keep]
dets_final = np.array(dets_final).reshape(-1, 5)
return dets_final
def bbox_area(box):
w = box[2] - box[0] + 1
h = box[3] - box[1] + 1
return w * h
def bbox_overlaps(x, y):
N = x.shape[0]
K = y.shape[0]
overlaps = np.zeros((N, K), dtype=np.float32)
for k in range(K):
y_area = bbox_area(y[k])
for n in range(N):
iw = min(x[n, 2], y[k, 2]) - max(x[n, 0], y[k, 0]) + 1
if iw > 0:
ih = min(x[n, 3], y[k, 3]) - max(x[n, 1], y[k, 1]) + 1
if ih > 0:
x_area = bbox_area(x[n])
ua = x_area + y_area - iw * ih
overlaps[n, k] = iw * ih / ua
return overlaps
def box_voting(nms_dets, dets, vote_thresh):
top_dets = nms_dets.copy()
top_boxes = nms_dets[:, 1:]
all_boxes = dets[:, 1:]
all_scores = dets[:, 0]
top_to_all_overlaps = bbox_overlaps(top_boxes, all_boxes)
for k in range(nms_dets.shape[0]):
inds_to_vote = np.where(top_to_all_overlaps[k] >= vote_thresh)[0]
boxes_to_vote = all_boxes[inds_to_vote, :]
ws = all_scores[inds_to_vote]
top_dets[k, 1:] = np.average(boxes_to_vote, axis=0, weights=ws)
return top_dets
def get_nms_result(boxes,
scores,
config,
num_classes,
background_label=0,
labels=None):
has_labels = labels is not None
cls_boxes = [[] for _ in range(num_classes)]
start_idx = 1 if background_label == 0 else 0
for j in range(start_idx, num_classes):
inds = np.where(labels == j)[0] if has_labels else np.where(
scores[:, j] > config['score_thresh'])[0]
scores_j = scores[inds] if has_labels else scores[inds, j]
boxes_j = boxes[inds, :] if has_labels else boxes[inds, j * 4:(j + 1) *
4]
dets_j = np.hstack((scores_j[:, np.newaxis], boxes_j)).astype(
np.float32, copy=False)
if config.get('use_soft_nms', False):
nms_dets = soft_nms(dets_j, config['sigma'], config['nms_thresh'])
else:
nms_dets = nms(dets_j, config['nms_thresh'])
if config.get('enable_voting', False):
nms_dets = box_voting(nms_dets, dets_j, config['vote_thresh'])
#add labels
label = np.array([j for _ in range(len(nms_dets))])
nms_dets = np.hstack((label[:, np.newaxis], nms_dets)).astype(
np.float32, copy=False)
cls_boxes[j] = nms_dets
# Limit to max_per_image detections **over all classes**
image_scores = np.hstack(
[cls_boxes[j][:, 1] for j in range(start_idx, num_classes)])
if len(image_scores) > config['detections_per_im']:
image_thresh = np.sort(image_scores)[-config['detections_per_im']]
for j in range(start_idx, num_classes):
keep = np.where(cls_boxes[j][:, 1] >= image_thresh)[0]
cls_boxes[j] = cls_boxes[j][keep, :]
im_results = np.vstack(
[cls_boxes[j] for j in range(start_idx, num_classes)])
return im_results
def mstest_box_post_process(result, config, num_classes):
"""
Multi-scale Test
Only available for batch_size=1 now.
"""
post_bbox = {}
use_flip = False
ms_boxes = []
ms_scores = []
im_shape = result['im_shape'][0]
for k in result.keys():
if 'bbox' in k:
boxes = result[k][0]
boxes = np.reshape(boxes, (-1, 4 * num_classes))
scores = result['score' + k[4:]][0]
if 'flip' in k:
boxes = box_flip(boxes, im_shape)
use_flip = True
ms_boxes.append(boxes)
ms_scores.append(scores)
ms_boxes = np.concatenate(ms_boxes)
ms_scores = np.concatenate(ms_scores)
bbox_pred = get_nms_result(ms_boxes, ms_scores, config, num_classes)
post_bbox.update({'bbox': (bbox_pred, [[len(bbox_pred)]])})
if use_flip:
bbox = bbox_pred[:, 2:]
bbox_flip = np.append(
bbox_pred[:, :2], box_flip(bbox, im_shape), axis=1)
post_bbox.update({'bbox_flip': (bbox_flip, [[len(bbox_flip)]])})
return post_bbox
def mstest_mask_post_process(result, cfg):
mask_list = []
im_shape = result['im_shape'][0]
M = cfg.FPNRoIAlign['mask_resolution']
for k in result.keys():
if 'mask' in k:
masks = result[k][0]
if len(masks.shape) != 4:
masks = np.zeros((0, M, M))
mask_list.append(masks)
continue
if 'flip' in k:
masks = masks[:, :, :, ::-1]
mask_list.append(masks)
mask_pred = np.mean(mask_list, axis=0)
return {'mask': (mask_pred, [[len(mask_pred)]])}
def mask_encode(results, resolution, thresh_binarize=0.5):
import pycocotools.mask as mask_util
from ppdet.utils.coco_eval import expand_boxes
scale = (resolution + 2.0) / resolution
bboxes = results['bbox'][0]
masks = results['mask'][0]
lengths = results['mask'][1][0]
im_shapes = results['im_shape'][0]
segms = []
if bboxes.shape == (1, 1) or bboxes is None:
return segms
if len(bboxes.tolist()) == 0:
return segms
s = 0
# for each sample
for i in range(len(lengths)):
num = lengths[i]
im_shape = im_shapes[i]
bbox = bboxes[s:s + num][:, 2:]
clsid_scores = bboxes[s:s + num][:, 0:2]
mask = masks[s:s + num]
s += num
im_h = int(im_shape[0])
im_w = int(im_shape[1])
expand_bbox = expand_boxes(bbox, scale)
expand_bbox = expand_bbox.astype(np.int32)
padded_mask = np.zeros(
(resolution + 2, resolution + 2), dtype=np.float32)
for j in range(num):
xmin, ymin, xmax, ymax = expand_bbox[j].tolist()
clsid, score = clsid_scores[j].tolist()
clsid = int(clsid)
padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :]
w = xmax - xmin + 1
h = ymax - ymin + 1
w = np.maximum(w, 1)
h = np.maximum(h, 1)
resized_mask = cv2.resize(padded_mask, (w, h))
resized_mask = np.array(
resized_mask > thresh_binarize, dtype=np.uint8)
im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
x0 = min(max(xmin, 0), im_w)
x1 = min(max(xmax + 1, 0), im_w)
y0 = min(max(ymin, 0), im_h)
y1 = min(max(ymax + 1, 0), im_h)
im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (
x0 - xmin):(x1 - xmin)]
segm = mask_util.encode(
np.array(
im_mask[:, :, np.newaxis], order='F'))[0]
segms.append(segm)
return segms
def corner_post_process(results, config, num_classes):
detections = results['bbox'][0]
keep_inds = (detections[:, 1] > -1)
detections = detections[keep_inds]
labels = detections[:, 0]
scores = detections[:, 1]
boxes = detections[:, 2:6]
cls_boxes = get_nms_result(
boxes, scores, config, num_classes, background_label=-1, labels=labels)
results.update({'bbox': (cls_boxes, [[len(cls_boxes)]])})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册