diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 752d706cbfab8eb3027fe9610c25b7400ecfed1d..7437d7bd2092044b6634aa720fbee1a02b630bcd 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -47,8 +47,11 @@ class GRUOp : public framework::OperatorWithKernel { auto weight_dims = ctx->GetInputDim("Weight"); int input_size = input_dims[1]; int frame_size = weight_dims[0]; - PADDLE_ENFORCE_EQ(input_size, frame_size * 3, - "The input_size must be 3 times of frame_size in GRUOp."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ( + input_size, frame_size * 3, + "The input_size must be 3 times of frame_size in GRUOp."); + } PADDLE_ENFORCE_EQ( weight_dims[1], frame_size * 3, "The shape of Weight matrix must be [frame_size, frame_size * 3]."); diff --git a/paddle/fluid/operators/lstm_unit_op.cc b/paddle/fluid/operators/lstm_unit_op.cc index 0895c58f5f58afd444000ebeac7a92e3eb7778d3..47d695475c2e240d273fe873352cf5c213e2026e 100644 --- a/paddle/fluid/operators/lstm_unit_op.cc +++ b/paddle/fluid/operators/lstm_unit_op.cc @@ -34,10 +34,12 @@ class LstmUnitOp : public framework::OperatorWithKernel { auto c_prev_dims = ctx->GetInputDim("C_prev"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0], - "Batch size of inputs and states must be equal"); - PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4, - "Dimension of FC should equal to prev state * 4"); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0], + "Batch size of inputs and states must be equal"); + PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4, + "Dimension of FC should equal to prev state * 4"); + } int b_size = c_prev_dims[0]; // batch size int s_dim = c_prev_dims[1]; // state dim