未验证 提交 28e5a3ab 编写于 作者: F Feng Ni 提交者: GitHub

fix RCNNBox, test=dygraph (#2215)

上级 839660fe
......@@ -318,7 +318,8 @@ class RCNNBox(object):
origin_shape = paddle.concat(origin_shape_list)
# [N, C*4]
# bbox_pred.shape: [N, C*4]
# C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head)
bbox = paddle.concat(roi)
if bbox.shape[0] == 0:
bbox = paddle.zeros([0, bbox_pred.shape[1]], dtype='float32')
......@@ -326,10 +327,9 @@ class RCNNBox(object):
bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var)
scores = cls_prob[:, :-1]
# [N*C, 4]
bbox_num_class = bbox.shape[1] // 4
bbox = paddle.reshape(bbox, [-1, bbox_num_class, 4])
# bbox.shape: [N, C, 4]
# bbox.shape[1] must be equal to scores.shape[1]
bbox_num_class = bbox.shape[1]
if bbox_num_class == 1:
bbox = paddle.tile(bbox, [1, self.num_classes, 1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册