未验证 提交 20303e69 编写于 作者: W whs 提交者: GitHub

Fix eval run in pruning (#576)

上级 1a485e12
......@@ -169,7 +169,7 @@ class RPNHead(object):
rpn_cls_prob = fluid.layers.transpose(
rpn_cls_prob, perm=[0, 3, 1, 2])
prop_op = self.train_proposal if mode == 'train' else self.test_proposal
rpn_rois, rpn_roi_probs = prop_op(
rpn_rois, rpn_roi_probs, _ = prop_op(
scores=rpn_cls_prob,
bbox_deltas=rpn_bbox_pred,
im_info=im_info,
......@@ -430,7 +430,7 @@ class FPNRPNHead(RPNHead):
rpn_cls_prob_fpn, shape=(0, 0, 0, -1))
rpn_cls_prob_fpn = fluid.layers.transpose(
rpn_cls_prob_fpn, perm=[0, 3, 1, 2])
rpn_rois_fpn, rpn_roi_prob_fpn = prop_op(
rpn_rois_fpn, rpn_roi_prob_fpn, _ = prop_op(
scores=rpn_cls_prob_fpn,
bbox_deltas=rpn_bbox_pred_fpn,
im_info=im_info,
......
......@@ -56,6 +56,9 @@ class TestMaskRCNN(TestFasterRCNN):
self.cfg_file = 'configs/mask_rcnn_r50_1x.yml'
@unittest.skip(
reason="It should be fixed to adapt https://github.com/PaddlePaddle/Paddle/pull/23797"
)
class TestCascadeRCNN(TestFasterRCNN):
def set_config(self):
self.cfg_file = 'configs/cascade_rcnn_r50_fpn_1x.yml'
......
......@@ -256,7 +256,7 @@ def main():
if FLAGS.eval:
# evaluation
results = eval_run(exe, compiled_eval_prog, eval_loader, eval_keys,
eval_values, eval_cls)
eval_values, eval_cls, cfg)
resolution = None
if 'mask' in results[0]:
resolution = model.mask_head.resolution
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册