From 5c10c574d8e894063d626a961130d8044fac2481 Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Mon, 15 Apr 2019 17:42:33 +0800 Subject: [PATCH] Merge pull request #16854 from luotao1/conv_shift_infershape Fix conv_shift_op infershape --- paddle/fluid/operators/conv_shift_op.cc | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/conv_shift_op.cc b/paddle/fluid/operators/conv_shift_op.cc index 08506ddd1..fa4edb70b 100644 --- a/paddle/fluid/operators/conv_shift_op.cc +++ b/paddle/fluid/operators/conv_shift_op.cc @@ -36,14 +36,17 @@ class ConvShiftOp : public framework::OperatorWithKernel { auto y_dims = ctx->GetInputDim("Y"); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2."); - PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], - "The 1st dimension of Input(X) and Input(Y) should " - "be equal."); - PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1, - "The 2nd dimension of Input(Y) should be odd."); - PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], - "The 2nd dimension of Input(Y) should be less than or " - "equal to the 2nd dimension of Input(X)."); + if (ctx->IsRuntime() || (x_dims[0] > 0 && y_dims[0] > 0)) + PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], + "The 1st dimension of Input(X) and Input(Y) should " + "be equal."); + if (ctx->IsRuntime() || y_dims[1] > 0) + PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1, + "The 2nd dimension of Input(Y) should be odd."); + if (ctx->IsRuntime() || (x_dims[1] > 0 && y_dims[1] > 0)) + PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], + "The 2nd dimension of Input(Y) should be less than or " + "equal to the 2nd dimension of Input(X)."); ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); } -- GitLab