未验证 提交 5c08d233 编写于 作者: W whs 提交者: GitHub

Fix infer_shape in pad2d_op (#16911)

test=develop
上级 391649e0
...@@ -480,8 +480,10 @@ class Pad2dOp : public framework::OperatorWithKernel { ...@@ -480,8 +480,10 @@ class Pad2dOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddings_dim.size(), 1, paddings_dim.size(), 1,
"Size of Input(Paddings)'s dimension should be equal to 1."); "Size of Input(Paddings)'s dimension should be equal to 1.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(paddings_dim[0], 4, PADDLE_ENFORCE_EQ(paddings_dim[0], 4,
"Shape of Input(Paddings) should be equal to [4]."); "Shape of Input(Paddings) should be equal to [4].");
}
out_dims[1] = x_dim[1]; out_dims[1] = x_dim[1];
out_dims[2] = x_dim[2]; out_dims[2] = x_dim[2];
out_dims[3] = x_dim[3]; out_dims[3] = x_dim[3];
...@@ -501,12 +503,8 @@ class Pad2dOp : public framework::OperatorWithKernel { ...@@ -501,12 +503,8 @@ class Pad2dOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
if (out_dims[0] == x_dim[0]) {
// Only pass LoD when the first dimension is equal between
// output and input.
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册