提交 6f0a40fa 编写于 作者: T Tao Luo

Fix conv_shift_op infershape

test=develop
上级 82cff5ec
...@@ -36,14 +36,17 @@ class ConvShiftOp : public framework::OperatorWithKernel { ...@@ -36,14 +36,17 @@ class ConvShiftOp : public framework::OperatorWithKernel {
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2."); PADDLE_ENFORCE_EQ(y_dims.size(), 2, "Input(Y)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0], if (ctx->IsRuntime() || (x_dims[0] > 0 && y_dims[0] > 0))
"The 1st dimension of Input(X) and Input(Y) should " PADDLE_ENFORCE_EQ(x_dims[0], y_dims[0],
"be equal."); "The 1st dimension of Input(X) and Input(Y) should "
PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1, "be equal.");
"The 2nd dimension of Input(Y) should be odd."); if (ctx->IsRuntime() || y_dims[1] > 0)
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1], PADDLE_ENFORCE_EQ(y_dims[1] % 2, 1,
"The 2nd dimension of Input(Y) should be less than or " "The 2nd dimension of Input(Y) should be odd.");
"equal to the 2nd dimension of Input(X)."); if (ctx->IsRuntime() || (x_dims[1] > 0 && y_dims[1] > 0))
PADDLE_ENFORCE_LE(y_dims[1], x_dims[1],
"The 2nd dimension of Input(Y) should be less than or "
"equal to the 2nd dimension of Input(X).");
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册