未验证 提交 ff423730 编写于 作者: Z Zeng Jinle 提交者: GitHub

Merge pull request #14720 from sneaxiy/fix_seq_mask_op_infershape

Fix sequence_mask_op InferShape
...@@ -36,12 +36,10 @@ class SequenceMaskOp : public framework::OperatorWithKernel { ...@@ -36,12 +36,10 @@ class SequenceMaskOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist"); PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist");
auto maxlen = ctx->Attrs().Get<int>("maxlen"); int maxlen = ctx->Attrs().Get<int>("maxlen");
if (maxlen > 0) { // We can only infershape when maxlen > 0 auto dim = framework::vectorize2int(ctx->GetInputDim("X"));
auto dim = framework::vectorize2int(ctx->GetInputDim("X")); dim.push_back(maxlen > 0 ? maxlen : -1);
dim.push_back(maxlen); ctx->SetOutputDim("Y", framework::make_ddim(dim));
ctx->SetOutputDim("Y", framework::make_ddim(dim));
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册