You need to sign in or sign up before continuing.
未验证 提交 42d12ea7 编写于 作者: W wangguanzhong 提交者: GitHub

[cherry-pick] support batch_size=2 in RCNN (#4788)

上级 472c288f
...@@ -91,9 +91,23 @@ class BBoxPostProcess(nn.Layer): ...@@ -91,9 +91,23 @@ class BBoxPostProcess(nn.Layer):
including labels, scores and bboxes. including labels, scores and bboxes.
""" """
if bboxes.shape[0] == 0: bboxes_list = []
bboxes = self.fake_bboxes bbox_num_list = []
bbox_num = self.fake_bbox_num id_start = 0
# add fake bbox when output is empty for each batch
for i in range(bbox_num.shape[0]):
if bbox_num[i] == 0:
bboxes_i = self.fake_bboxes
bbox_num_i = self.fake_bbox_num
id_start += 1
else:
bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
bbox_num_i = bbox_num[i]
id_start += bbox_num[i]
bboxes_list.append(bboxes_i)
bbox_num_list.append(bbox_num_i)
bboxes = paddle.concat(bboxes_list)
bbox_num = paddle.concat(bbox_num_list)
origin_shape = paddle.floor(im_shape / scale_factor + 0.5) origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
...@@ -156,6 +170,7 @@ class MaskPostProcess(object): ...@@ -156,6 +170,7 @@ class MaskPostProcess(object):
""" """
Paste the mask prediction to the original image. Paste the mask prediction to the original image.
""" """
x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1) x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
masks = paddle.unsqueeze(masks, [0, 1]) masks = paddle.unsqueeze(masks, [0, 1])
img_y = paddle.arange(0, im_h, dtype='float32') + 0.5 img_y = paddle.arange(0, im_h, dtype='float32') + 0.5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册