From f641a47bb1d09d227cca7bac5eb3233ff15e070a Mon Sep 17 00:00:00 2001 From: guosheng Date: Tue, 16 Apr 2019 13:43:40 +0800 Subject: [PATCH] Refine ENFORCE in infer_shape of gru_op and lstm_unit_op. test=develop --- paddle/fluid/operators/gru_op.cc | 7 +++++-- paddle/fluid/operators/lstm_unit_op.cc | 10 ++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 752d706cb..7437d7bd2 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -47,8 +47,11 @@ class GRUOp : public framework::OperatorWithKernel { auto weight_dims = ctx->GetInputDim("Weight"); int input_size = input_dims[1]; int frame_size = weight_dims[0]; - PADDLE_ENFORCE_EQ(input_size, frame_size * 3, - "The input_size must be 3 times of frame_size in GRUOp."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ( + input_size, frame_size * 3, + "The input_size must be 3 times of frame_size in GRUOp."); + } PADDLE_ENFORCE_EQ( weight_dims[1], frame_size * 3, "The shape of Weight matrix must be [frame_size, frame_size * 3]."); diff --git a/paddle/fluid/operators/lstm_unit_op.cc b/paddle/fluid/operators/lstm_unit_op.cc index 0895c58f5..47d695475 100644 --- a/paddle/fluid/operators/lstm_unit_op.cc +++ b/paddle/fluid/operators/lstm_unit_op.cc @@ -34,10 +34,12 @@ class LstmUnitOp : public framework::OperatorWithKernel { auto c_prev_dims = ctx->GetInputDim("C_prev"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0], - "Batch size of inputs and states must be equal"); - PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4, - "Dimension of FC should equal to prev state * 4"); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0], + "Batch size of inputs and states must be equal"); + PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4, + "Dimension of FC should equal to prev state * 4"); + } int b_size = c_prev_dims[0]; // batch size int s_dim = c_prev_dims[1]; // state dim -- GitLab