未验证 提交 189b08dc 编写于 作者: W whs 提交者: GitHub

Make infer shape of pad2d support for input with negative dims in compile time. (#18695)

test=develop
上级 c457a69d
...@@ -495,13 +495,21 @@ class Pad2dOp : public framework::OperatorWithKernel { ...@@ -495,13 +495,21 @@ class Pad2dOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(paddings.size(), 4, PADDLE_ENFORCE_EQ(paddings.size(), 4,
"Size of paddings should be equal to 4."); "Size of paddings should be equal to 4.");
if (data_format == "NCHW") { if (data_format == "NCHW") {
out_dims[1] = x_dim[1]; out_dims[1] = x_dim[1]; // channel
out_dims[2] = x_dim[2] + paddings[0] + paddings[1]; // height out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0))
out_dims[3] = x_dim[3] + paddings[2] + paddings[3]; // width ? x_dim[2]
: (x_dim[2] + paddings[0] + paddings[1]); // height
out_dims[3] = ((!ctx->IsRuntime()) && (x_dim[3] < 0))
? x_dim[3]
: (x_dim[3] + paddings[2] + paddings[3]); // width
} else { // NHWC } else { // NHWC
out_dims[3] = x_dim[3]; out_dims[3] = x_dim[3]; // channel
out_dims[1] = x_dim[1] + paddings[0] + paddings[1]; out_dims[1] = ((!ctx->IsRuntime()) && (x_dim[1] < 0))
out_dims[2] = x_dim[2] + paddings[2] + paddings[3]; ? x_dim[1]
: (x_dim[1] + paddings[0] + paddings[1]); // height
out_dims[2] = ((!ctx->IsRuntime()) && (x_dim[2] < 0))
? x_dim[2]
: (x_dim[2] + paddings[2] + paddings[3]); // width
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册