From 320ce6c7cb8f633021b96be30559587211bb36bc Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Thu, 2 Dec 2021 21:18:48 +0800 Subject: [PATCH] support batch_size=2 in RCNN (#4787) --- ppdet/modeling/post_process.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index f485abaf6..0c4717cb6 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -91,9 +91,23 @@ class BBoxPostProcess(nn.Layer): including labels, scores and bboxes. """ - if bboxes.shape[0] == 0: - bboxes = self.fake_bboxes - bbox_num = self.fake_bbox_num + bboxes_list = [] + bbox_num_list = [] + id_start = 0 + # add fake bbox when output is empty for each batch + for i in range(bbox_num.shape[0]): + if bbox_num[i] == 0: + bboxes_i = self.fake_bboxes + bbox_num_i = self.fake_bbox_num + id_start += 1 + else: + bboxes_i = bboxes[id_start:id_start + bbox_num[i], :] + bbox_num_i = bbox_num[i] + id_start += bbox_num[i] + bboxes_list.append(bboxes_i) + bbox_num_list.append(bbox_num_i) + bboxes = paddle.concat(bboxes_list) + bbox_num = paddle.concat(bbox_num_list) origin_shape = paddle.floor(im_shape / scale_factor + 0.5) @@ -156,6 +170,7 @@ class MaskPostProcess(object): """ Paste the mask prediction to the original image. """ + x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1) masks = paddle.unsqueeze(masks, [0, 1]) img_y = paddle.arange(0, im_h, dtype='float32') + 0.5 -- GitLab