From 756af9fff53245d264b7cc550e88e4360b9750e9 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Thu, 3 Mar 2022 14:11:42 +0800 Subject: [PATCH] modify infershape of multiclass nms (#40059) * modify infershape of multiclass nms --- paddle/fluid/operators/detection/multiclass_nms_op.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/detection/multiclass_nms_op.cc b/paddle/fluid/operators/detection/multiclass_nms_op.cc index 7927410ef37..83cf6e5fd30 100644 --- a/paddle/fluid/operators/detection/multiclass_nms_op.cc +++ b/paddle/fluid/operators/detection/multiclass_nms_op.cc @@ -93,7 +93,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { // Here the box_dims[0] is not the real dimension of output. // It will be rewritten in the computing kernel. if (score_size == 3) { - ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2}); + ctx->SetOutputDim("Out", {-1, box_dims[2] + 2}); } else { ctx->SetOutputDim("Out", {-1, box_dims[2] + 2}); } @@ -545,11 +545,10 @@ class MultiClassNMS2Op : public MultiClassNMSOp { void InferShape(framework::InferShapeContext* ctx) const override { MultiClassNMSOp::InferShape(ctx); - auto box_dims = ctx->GetInputDim("BBoxes"); auto score_dims = ctx->GetInputDim("Scores"); auto score_size = score_dims.size(); if (score_size == 3) { - ctx->SetOutputDim("Index", {box_dims[1], 1}); + ctx->SetOutputDim("Index", {-1, 1}); } else { ctx->SetOutputDim("Index", {-1, 1}); } -- GitLab