diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 912ec79910301b67bc520b1aa78d3fa1fd165d1f..aecd3d430231855fa29cf7716eb636cdb28182ce 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -64,12 +64,19 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); - PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); + } + if (ctx->HasInput("H0")) { auto h_dims = ctx->GetInputDim("H0"); - PADDLE_ENFORCE(h_dims == c_dims, - "The dimension of Input(H0) and Input(C0) " - "should be the same."); + PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2."); + if (ctx->IsRuntime() || + (framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) { + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } } auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); @@ -79,6 +86,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { "AttentionWeight shapes must be (%d + %d) * 1.", M, D); PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + if (ctx->HasInput("AttentionBias")) { auto atten_b_dims = ctx->GetInputDim("AttentionBias"); PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2,