diff --git a/paddle/fluid/operators/conv_shift_op.cc b/paddle/fluid/operators/conv_shift_op.cc index 08506ddd18ed35831702814e70962cb36ec958b1..fa4edb70b48e529102f11a1b0b9cac2110a33966 100644 --- a/paddle/fluid/operators/conv_shift_op.cc +++ b/paddle/fluid/operators/conv_shift_op.cc @@ -36,14 +36,17 @@ class ConvShiftOp : public framework::OperatorWithKernel { auto y_dims = ctx->GetInputDim("Y"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], - "The 1st dimension of Input(X) and Input(Y) should " - "be equal."); - PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1, - "The 2nd dimension of Input(Y) should be odd."); - PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], - "The 2nd dimension of Input(Y) should be less than or " - "equal to the 2nd dimension of Input(X)."); + if (ctx->IsRuntime() || (x_dims[0] > 0 && y_dims[0] > 0)) + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The 1st dimension of Input(X) and Input(Y) should " + "be equal."); + if (ctx->IsRuntime() || y_dims[1] > 0) + PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1, + "The 2nd dimension of Input(Y) should be odd."); + if (ctx->IsRuntime() || (x_dims[1] > 0 && y_dims[1] > 0)) + PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], + "The 2nd dimension of Input(Y) should be less than or " + "equal to the 2nd dimension of Input(X)."); ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); }