未验证 提交 42d12ea7 编写于 作者: W wangguanzhong 提交者: GitHub

[cherry-pick] support batch_size=2 in RCNN (#4788)

上级 472c288f
......@@ -91,9 +91,23 @@ class BBoxPostProcess(nn.Layer):
including labels, scores and bboxes.
"""
if bboxes.shape[0] == 0:
bboxes = self.fake_bboxes
bbox_num = self.fake_bbox_num
bboxes_list = []
bbox_num_list = []
id_start = 0
# add fake bbox when output is empty for each batch
for i in range(bbox_num.shape[0]):
if bbox_num[i] == 0:
bboxes_i = self.fake_bboxes
bbox_num_i = self.fake_bbox_num
id_start += 1
else:
bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
bbox_num_i = bbox_num[i]
id_start += bbox_num[i]
bboxes_list.append(bboxes_i)
bbox_num_list.append(bbox_num_i)
bboxes = paddle.concat(bboxes_list)
bbox_num = paddle.concat(bbox_num_list)
origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
......@@ -156,6 +170,7 @@ class MaskPostProcess(object):
"""
Paste the mask prediction to the original image.
"""
x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
masks = paddle.unsqueeze(masks, [0, 1])
img_y = paddle.arange(0, im_h, dtype='float32') + 0.5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册