未验证 提交 a7289a08 编写于 作者: W wangguanzhong 提交者: GitHub

fix no label training in bs2 (#3891)

上级 099d1276
...@@ -215,7 +215,8 @@ def generate_proposal_target(rpn_rois, ...@@ -215,7 +215,8 @@ def generate_proposal_target(rpn_rois,
if gt_bbox.shape[0] > 0: if gt_bbox.shape[0] > 0:
sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind) sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
else: else:
sampled_bbox = paddle.zeros([0, 4], dtype='float32') num = rois_per_image.shape[0]
sampled_bbox = paddle.zeros([num, 4], dtype='float32')
rois_per_image.stop_gradient = True rois_per_image.stop_gradient = True
sampled_gt_ind.stop_gradient = True sampled_gt_ind.stop_gradient = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册