diff --git a/paddle/fluid/operators/controlflow/logical_op.cc b/paddle/fluid/operators/controlflow/logical_op.cc index 2e7f3edd55c3353bacddec3dd4ffaba9e0208136..37a82a8067f84722fc37e2469c739faf25f7540b 100644 --- a/paddle/fluid/operators/controlflow/logical_op.cc +++ b/paddle/fluid/operators/controlflow/logical_op.cc @@ -71,8 +71,16 @@ class BinaryLogicalOpInferShape : public framework::InferShapeBase { "Input(Y) of %s operator must not be null", comment.type); auto dim_x = context->GetInputDim("X"); auto dim_y = context->GetInputDim("Y"); - PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y), - "The number of elements in X and Y should be same"); + + int product_x = framework::product(dim_x); + int product_y = framework::product(dim_y); + bool check = context->IsRuntime() || (product_x >= 0 && product_y >= 0); + if (check) { + PADDLE_ENFORCE_EQ( + product_x, product_y, + "The number of elements in X and Y should be same, %d != %d", + product_x, product_y); + } context->SetOutputDim("Out", context->GetInputDim("X")); context->ShareLoD("X", "Out"); diff --git a/paddle/fluid/operators/gru_op.cc b/paddle/fluid/operators/gru_op.cc index 752d706cbfab8eb3027fe9610c25b7400ecfed1d..7437d7bd2092044b6634aa720fbee1a02b630bcd 100644 --- a/paddle/fluid/operators/gru_op.cc +++ b/paddle/fluid/operators/gru_op.cc @@ -47,8 +47,11 @@ class GRUOp : public framework::OperatorWithKernel { auto weight_dims = ctx->GetInputDim("Weight"); int input_size = input_dims[1]; int frame_size = weight_dims[0]; - PADDLE_ENFORCE_EQ(input_size, frame_size * 3, - "The input_size must be 3 times of frame_size in GRUOp."); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ( + input_size, frame_size * 3, + "The input_size must be 3 times of frame_size in GRUOp."); + } PADDLE_ENFORCE_EQ( weight_dims[1], frame_size * 3, "The shape of Weight matrix must be [frame_size, frame_size * 3]."); diff --git a/paddle/fluid/operators/lstm_unit_op.cc b/paddle/fluid/operators/lstm_unit_op.cc index 0895c58f5f58afd444000ebeac7a92e3eb7778d3..47d695475c2e240d273fe873352cf5c213e2026e 100644 --- a/paddle/fluid/operators/lstm_unit_op.cc +++ b/paddle/fluid/operators/lstm_unit_op.cc @@ -34,10 +34,12 @@ class LstmUnitOp : public framework::OperatorWithKernel { auto c_prev_dims = ctx->GetInputDim("C_prev"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0], - "Batch size of inputs and states must be equal"); - PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4, - "Dimension of FC should equal to prev state * 4"); + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(x_dims[0], c_prev_dims[0], + "Batch size of inputs and states must be equal"); + PADDLE_ENFORCE_EQ(x_dims[1], c_prev_dims[1] * 4, + "Dimension of FC should equal to prev state * 4"); + } int b_size = c_prev_dims[0]; // batch size int s_dim = c_prev_dims[1]; // state dim