未验证 提交 64298778 编写于 作者: W whs 提交者: GitHub

Fix infer_shape in pad2d_op (#16831)

test=develop
...@@ -483,8 +483,10 @@ class Pad2dOp : public framework::OperatorWithKernel { ...@@ -483,8 +483,10 @@ class Pad2dOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddings_dim.size(), 1, paddings_dim.size(), 1,
"Size of Input(Paddings)'s dimension should be equal to 1."); "Size of Input(Paddings)'s dimension should be equal to 1.");
PADDLE_ENFORCE_EQ(paddings_dim[0], 4, if (ctx->IsRuntime()) {
"Shape of Input(Paddings) should be equal to [4]."); PADDLE_ENFORCE_EQ(paddings_dim[0], 4,
"Shape of Input(Paddings) should be equal to [4].");
}
out_dims[1] = x_dim[1]; out_dims[1] = x_dim[1];
out_dims[2] = x_dim[2]; out_dims[2] = x_dim[2];
out_dims[3] = x_dim[3]; out_dims[3] = x_dim[3];
...@@ -504,11 +506,7 @@ class Pad2dOp : public framework::OperatorWithKernel { ...@@ -504,11 +506,7 @@ class Pad2dOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
if (out_dims[0] == x_dim[0]) { ctx->ShareLoD("X", /*->*/ "Out");
// Only pass LoD when the first dimension is equal between
// output and input.
ctx->ShareLoD("X", /*->*/ "Out");
}
} }
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册