From 10879a3cae4548665377f3d0fe2f20d754f575d4 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 15 Apr 2019 05:34:57 +0000 Subject: [PATCH] separate runtime infershape test=develop --- paddle/fluid/operators/attention_lstm_op.cc | 53 ++++++++++++++------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 912ec7991..9c4683221 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -54,17 +54,25 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { auto w_dims = ctx->GetInputDim("LSTMWeight"); const int D = w_dims[1] / 4; PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); - PADDLE_ENFORCE_EQ(w_dims[0], D + M, - "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(w_dims[0], D + M, + "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D); + } auto b_dims = ctx->GetInputDim("LSTMBias"); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); - PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); - PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); + PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", + 4 * D); + } auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); - PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(c_dims[1], D, "C0 dims should be N x %d.", D); + } + if (ctx->HasInput("H0")) { auto h_dims = ctx->GetInputDim("H0"); PADDLE_ENFORCE(h_dims == c_dims, @@ -75,26 +83,33 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, "Input(AttentionWeight)'s rank must be 2."); - PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, - "AttentionWeight shapes must be (%d + %d) * 1.", M, D); - PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, - "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + } + if (ctx->HasInput("AttentionBias")) { auto atten_b_dims = ctx->GetInputDim("AttentionBias"); PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, "Input(AttentionBias)'s rank must be 2."); - PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, - "AttentionBias shapes must be 1 * 1."); - PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, - "AttentionBias shapes must be 1 * 1."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, + "AttentionBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, + "AttentionBias shapes must be 1 * 1."); + } } if (ctx->HasInput("AttentionScalar")) { auto dims = ctx->GetInputDim("AttentionScalar"); PADDLE_ENFORCE_EQ(dims.size(), 2, "Input(AttentionScalar)'s rank must be 2."); - PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); - PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); + } } if (ctx->HasInput("AttentionScalarBias")) { @@ -104,8 +119,12 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { "AttentionScalar should not be null when have AttentionScalarBias."); PADDLE_ENFORCE_EQ(dims.size(), 2, "Input(AttentionScalarBias)'s rank must be 2."); - PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1."); - PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(dims[0], 1, + "AttentionScalarBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, + "AttentionScalarBias shapes must be 1 * 1."); + } } framework::DDim out_dims({x_dims[0], D}); -- GitLab