From 9b30c51e3e9b42189e09112237136316be223f52 Mon Sep 17 00:00:00 2001 From: Hongyu Liu <43953930+phlrain@users.noreply.github.com> Date: Wed, 17 Apr 2019 09:54:26 +0800 Subject: [PATCH] Merge pull request #16861 from tensor-tang/refine/infershape separate runtime infershape --- paddle/fluid/operators/attention_lstm_op.cc | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 912ec7991..aecd3d430 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -64,12 +64,19 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); - PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); + } + if (ctx->HasInput("H0")) { auto h_dims = ctx->GetInputDim("H0"); - PADDLE_ENFORCE(h_dims == c_dims, - "The dimension of Input(H0) and Input(C0) " - "should be the same."); + PADDLE_ENFORCE_EQ(h_dims.size(), 2UL, "Input(H0)'s rank must be 2."); + if (ctx->IsRuntime() || + (framework::product(c_dims) > 0 && framework::product(h_dims) > 0)) { + PADDLE_ENFORCE(h_dims == c_dims, + "The dimension of Input(H0) and Input(C0) " + "should be the same."); + } } auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); @@ -79,6 +86,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { "AttentionWeight shapes must be (%d + %d) * 1.", M, D); PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + if (ctx->HasInput("AttentionBias")) { auto atten_b_dims = ctx->GetInputDim("AttentionBias"); PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, -- GitLab