提交 6f0a40fa 编写于 作者: T Tao Luo

Fix conv_shift_op infershape

test=develop
上级 82cff5ec
...@@ -36,11 +36,14 @@ class ConvShiftOp : public framework::OperatorWithKernel { ...@@ -36,11 +36,14 @@ class ConvShiftOp : public framework::OperatorWithKernel {
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2."); PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2.");
if (ctx->IsRuntime() || (x_dims[0] > 0 && y_dims[0] > 0))
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"The 1st dimension of Input(X) and Input(Y) should " "The 1st dimension of Input(X) and Input(Y) should "
"be equal."); "be equal.");
if (ctx->IsRuntime() || y_dims[1] > 0)
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1, PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
"The 2nd dimension of Input(Y) should be odd."); "The 2nd dimension of Input(Y) should be odd.");
if (ctx->IsRuntime() || (x_dims[1] > 0 && y_dims[1] > 0))
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
"The 2nd dimension of Input(Y) should be less than or " "The 2nd dimension of Input(Y) should be less than or "
"equal to the 2nd dimension of Input(X)."); "equal to the 2nd dimension of Input(X).");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册