未验证 提交 756af9ff 编写于 作者: W wangxinxin08 提交者: GitHub

modify infershape of multiclass nms (#40059)

* modify infershape of multiclass nms
上级 831b69d9
...@@ -93,7 +93,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { ...@@ -93,7 +93,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
// Here the box_dims[0] is not the real dimension of output. // Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel. // It will be rewritten in the computing kernel.
if (score_size == 3) { if (score_size == 3) {
ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2}); ctx->SetOutputDim("Out", {-1, box_dims[2] + 2});
} else { } else {
ctx->SetOutputDim("Out", {-1, box_dims[2] + 2}); ctx->SetOutputDim("Out", {-1, box_dims[2] + 2});
} }
...@@ -545,11 +545,10 @@ class MultiClassNMS2Op : public MultiClassNMSOp { ...@@ -545,11 +545,10 @@ class MultiClassNMS2Op : public MultiClassNMSOp {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
MultiClassNMSOp::InferShape(ctx); MultiClassNMSOp::InferShape(ctx);
auto box_dims = ctx->GetInputDim("BBoxes");
auto score_dims = ctx->GetInputDim("Scores"); auto score_dims = ctx->GetInputDim("Scores");
auto score_size = score_dims.size(); auto score_size = score_dims.size();
if (score_size == 3) { if (score_size == 3) {
ctx->SetOutputDim("Index", {box_dims[1], 1}); ctx->SetOutputDim("Index", {-1, 1});
} else { } else {
ctx->SetOutputDim("Index", {-1, 1}); ctx->SetOutputDim("Index", {-1, 1});
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册