From a4b6d5cb72185742916766ade2c8530180029f4d Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Mon, 15 Feb 2021 15:04:35 +0800 Subject: [PATCH] fix RCNNBox, test=dygraph (#2216) --- dygraph/ppdet/modeling/layers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dygraph/ppdet/modeling/layers.py b/dygraph/ppdet/modeling/layers.py index 967d6283b..04dcbd3ff 100644 --- a/dygraph/ppdet/modeling/layers.py +++ b/dygraph/ppdet/modeling/layers.py @@ -318,7 +318,8 @@ class RCNNBox(object): origin_shape = paddle.concat(origin_shape_list) - # [N, C*4] + # bbox_pred.shape: [N, C*4] + # C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head) bbox = paddle.concat(roi) if bbox.shape[0] == 0: bbox = paddle.zeros([0, bbox_pred.shape[1]], dtype='float32') @@ -326,10 +327,9 @@ class RCNNBox(object): bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var) scores = cls_prob[:, :-1] - # [N*C, 4] - - bbox_num_class = bbox.shape[1] // 4 - bbox = paddle.reshape(bbox, [-1, bbox_num_class, 4]) + # bbox.shape: [N, C, 4] + # bbox.shape[1] must be equal to scores.shape[1] + bbox_num_class = bbox.shape[1] if bbox_num_class == 1: bbox = paddle.tile(bbox, [1, self.num_classes, 1]) -- GitLab