From f2d1cd119a5f64302c09d0ec4e1945836c770de4 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Thu, 27 Feb 2020 16:25:13 +0800 Subject: [PATCH] fix lod level, test=develop (#22755) --- .../operators/detection/distribute_fpn_proposals_op.cc | 5 +++++ paddle/fluid/operators/detection/multiclass_nms_op.cc | 6 ++++++ .../unittests/white_list/compile_vs_runtime_white_list.py | 3 --- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc index 2976c3ff4c8..7df03990b73 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cc @@ -41,6 +41,11 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel { } ctx->SetOutputsDim("MultiFpnRois", outs_dims); ctx->SetOutputDim("RestoreIndex", {-1, 1}); + if (!ctx->IsRuntime()) { + for (size_t i = 0; i < num_out_rois; ++i) { + ctx->SetLoDLevel("MultiFpnRois", ctx->GetLoDLevel("FpnRois"), i); + } + } } protected: diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index 9cdc46b4a26..0cfb79b358b 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -74,6 +74,9 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { } else { ctx->SetOutputDim("Out", {-1, box_dims[2] + 2}); } + if (!ctx->IsRuntime()) { + ctx->SetLoDLevel("Out", std::max(ctx->GetLoDLevel("BBoxes"), 1)); + } } protected: @@ -493,6 +496,9 @@ class MultiClassNMS2Op : public MultiClassNMSOp { } else { ctx->SetOutputDim("Index", {-1, 1}); } + if (!ctx->IsRuntime()) { + ctx->SetLoDLevel("Index", std::max(ctx->GetLoDLevel("BBoxes"), 1)); + } } }; diff --git a/python/paddle/fluid/tests/unittests/white_list/compile_vs_runtime_white_list.py b/python/paddle/fluid/tests/unittests/white_list/compile_vs_runtime_white_list.py index 39db9f5476b..ee8202aa9f3 100644 --- a/python/paddle/fluid/tests/unittests/white_list/compile_vs_runtime_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/compile_vs_runtime_white_list.py @@ -30,11 +30,8 @@ COMPILE_RUN_OP_WHITE_LIST = [ 'rpn_target_assign', \ 'retinanet_target_assign', \ 'filter_by_instag', \ - 'multiclass_nms', \ - 'multiclass_nms2', \ 'im2sequence', \ 'generate_proposal_labels', \ - 'distribute_fpn_proposals', \ 'detection_map', \ 'locality_aware_nms', \ 'var_conv_2d' -- GitLab