未验证 提交 9d4722c8 编写于 作者: Z Zhang Zheng 提交者: GitHub

fix masked_select infer shape (#33167)

上级 47774d9c
...@@ -26,8 +26,9 @@ class MaskedSelectOp : public framework::OperatorWithKernel { ...@@ -26,8 +26,9 @@ class MaskedSelectOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "Input", "MaskedSelect"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "Input", "MaskedSelect");
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "MaskedSelect"); OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "MaskedSelect");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Out", "MaskedSelect"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Out", "MaskedSelect");
framework::DDim output_dims(ctx->GetInputDim("X"));
ctx->SetOutputDim("Y", output_dims); // output will only be a 1-D Tensor
ctx->SetOutputDim("Y", framework::make_ddim({-1}));
ctx->ShareLoD("X", /*->*/ "Y"); ctx->ShareLoD("X", /*->*/ "Y");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册