diff --git a/paddle/operators/sequence_reshape_op.cc b/paddle/operators/sequence_reshape_op.cc index 57cca13105537d88fe942b850cae10650d3096e2..d89a46a712c9c84a142e1e347219ed171556d761 100644 --- a/paddle/operators/sequence_reshape_op.cc +++ b/paddle/operators/sequence_reshape_op.cc @@ -30,8 +30,13 @@ class SequenceReshapeOp : public framework::OperatorWithKernel { auto x_numel = product(x_dims); PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2."); int new_dim = ctx->Attrs().Get("new_dim"); - ctx->SetOutputDim("Out", - {x_numel / new_dim, static_cast(new_dim)}); + if (ctx->IsRuntime()) { + ctx->SetOutputDim("Out", + {x_numel / new_dim, static_cast(new_dim)}); + } else { + // when compiling, the batch size is undetermined, just set to -1 + ctx->SetOutputDim("Out", {-1, static_cast(new_dim)}); + } } };