diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index 7927410ef37862499aadf61d6e04c45af157f347..83cf6e5fd30f6bcad4870d1ebd18a50e21518dfe 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -93,7 +93,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { // Here the box_dims[0] is not the real dimension of output. // It will be rewritten in the computing kernel. if (score_size == 3) { - ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2}); + ctx->SetOutputDim("Out", {-1, box_dims[2] + 2}); } else { ctx->SetOutputDim("Out", {-1, box_dims[2] + 2}); } @@ -545,11 +545,10 @@ class MultiClassNMS2Op : public MultiClassNMSOp { void InferShape(framework::InferShapeContext* ctx) const override { MultiClassNMSOp::InferShape(ctx); - auto box_dims = ctx->GetInputDim("BBoxes"); auto score_dims = ctx->GetInputDim("Scores"); auto score_size = score_dims.size(); if (score_size == 3) { - ctx->SetOutputDim("Index", {box_dims[1], 1}); + ctx->SetOutputDim("Index", {-1, 1}); } else { ctx->SetOutputDim("Index", {-1, 1}); }