diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 9c46832218db688547e019e5cbabd22f192d7b7a..7d599ffd6fc239dd36f1e7dcd4fea1af708beb80 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -54,18 +54,13 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { auto w_dims = ctx->GetInputDim("LSTMWeight"); const int D = w_dims[1] / 4; PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(w_dims[0], D + M, - "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D); - } + PADDLE_ENFORCE_EQ(w_dims[0], D + M, + "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D); auto b_dims = ctx->GetInputDim("LSTMBias"); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); - PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", - 4 * D); - } + PADDLE_ENFORCE_EQ(b_dims[0], 1, "LSTMBias dims should be 1 x %d.", 4 * D); + PADDLE_ENFORCE_EQ(b_dims[1], 4 * D, "LSTMBias dims should be 1 x %d.", 4 * D); auto c_dims = ctx->GetInputDim("C0"); PADDLE_ENFORCE_EQ(c_dims.size(), 2, "Input(C0)'s rank must be 2."); @@ -83,33 +78,27 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { auto atten_w_dims = ctx->GetInputDim("AttentionWeight"); PADDLE_ENFORCE_EQ(atten_w_dims.size(), 2, "Input(AttentionWeight)'s rank must be 2."); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, - "AttentionWeight shapes must be (%d + %d) * 1.", M, D); - PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, - "AttentionWeight shapes must be (%d + %d) * 1.", M, D); - } + PADDLE_ENFORCE_EQ(atten_w_dims[0], M + D, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); + PADDLE_ENFORCE_EQ(atten_w_dims[1], 1, + "AttentionWeight shapes must be (%d + %d) * 1.", M, D); if (ctx->HasInput("AttentionBias")) { auto atten_b_dims = ctx->GetInputDim("AttentionBias"); PADDLE_ENFORCE_EQ(atten_b_dims.size(), 2, "Input(AttentionBias)'s rank must be 2."); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, - "AttentionBias shapes must be 1 * 1."); - PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, - "AttentionBias shapes must be 1 * 1."); - } + PADDLE_ENFORCE_EQ(atten_b_dims[0], 1, + "AttentionBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(atten_b_dims[1], 1, + "AttentionBias shapes must be 1 * 1."); } if (ctx->HasInput("AttentionScalar")) { auto dims = ctx->GetInputDim("AttentionScalar"); PADDLE_ENFORCE_EQ(dims.size(), 2, "Input(AttentionScalar)'s rank must be 2."); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); - PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); - } + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalar shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalar shapes must be 1 * 1."); } if (ctx->HasInput("AttentionScalarBias")) { @@ -119,12 +108,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { "AttentionScalar should not be null when have AttentionScalarBias."); PADDLE_ENFORCE_EQ(dims.size(), 2, "Input(AttentionScalarBias)'s rank must be 2."); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(dims[0], 1, - "AttentionScalarBias shapes must be 1 * 1."); - PADDLE_ENFORCE_EQ(dims[1], 1, - "AttentionScalarBias shapes must be 1 * 1."); - } + PADDLE_ENFORCE_EQ(dims[0], 1, "AttentionScalarBias shapes must be 1 * 1."); + PADDLE_ENFORCE_EQ(dims[1], 1, "AttentionScalarBias shapes must be 1 * 1."); } framework::DDim out_dims({x_dims[0], D});