未验证 提交 756af9ff 编写于 作者: W wangxinxin08 提交者: GitHub

modify infershape of multiclass nms (#40059)

* modify infershape of multiclass nms
上级 831b69d9
......@@ -93,7 +93,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
// Here the box_dims[0] is not the real dimension of output.
// It will be rewritten in the computing kernel.
if (score_size == 3) {
ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2});
ctx->SetOutputDim("Out", {-1, box_dims[2] + 2});
} else {
ctx->SetOutputDim("Out", {-1, box_dims[2] + 2});
}
......@@ -545,11 +545,10 @@ class MultiClassNMS2Op : public MultiClassNMSOp {
void InferShape(framework::InferShapeContext* ctx) const override {
MultiClassNMSOp::InferShape(ctx);
auto box_dims = ctx->GetInputDim("BBoxes");
auto score_dims = ctx->GetInputDim("Scores");
auto score_size = score_dims.size();
if (score_size == 3) {
ctx->SetOutputDim("Index", {box_dims[1], 1});
ctx->SetOutputDim("Index", {-1, 1});
} else {
ctx->SetOutputDim("Index", {-1, 1});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册