diff --git a/paddle/fluid/operators/sequence_ops/sequence_mask_op.h b/paddle/fluid/operators/sequence_ops/sequence_mask_op.h index 18acb735cecabd1e01f7821c880fd8ed5e52971f..8fceed3558b4357b7863368c18add329ea9922b3 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_mask_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.h @@ -36,12 +36,10 @@ class SequenceMaskOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist"); PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist"); - auto maxlen = ctx->Attrs().Get("maxlen"); - if (maxlen > 0) { // We can only infershape when maxlen > 0 - auto dim = framework::vectorize2int(ctx->GetInputDim("X")); - dim.push_back(maxlen); - ctx->SetOutputDim("Y", framework::make_ddim(dim)); - } + int maxlen = ctx->Attrs().Get("maxlen"); + auto dim = framework::vectorize2int(ctx->GetInputDim("X")); + dim.push_back(maxlen > 0 ? maxlen : -1); + ctx->SetOutputDim("Y", framework::make_ddim(dim)); } };