未验证 提交 e7a6567b 编写于 作者: K Kaipeng Deng 提交者: GitHub

polish pool infer shape (#20038)

* fix pool infershape. test=develop

* fix unittest converage. test=develop

* fix format. test=develop
上级 fb2a9cdf
...@@ -78,8 +78,12 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -78,8 +78,12 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
output_shape.insert(output_shape.end(), ksize.begin(), ksize.end()); output_shape.insert(output_shape.end(), ksize.begin(), ksize.end());
} else { } else {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(PoolOutputSize( if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) {
in_x_dims[i + 2], ksize[i], paddings[i], strides[i], ceil_mode)); output_shape.push_back(-1);
} else {
output_shape.push_back(PoolOutputSize(
in_x_dims[i + 2], ksize[i], paddings[i], strides[i], ceil_mode));
}
} }
} }
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
......
...@@ -1354,6 +1354,25 @@ class TestBook(LayerTest): ...@@ -1354,6 +1354,25 @@ class TestBook(LayerTest):
return (layers.pool2d( return (layers.pool2d(
x, pool_size=[5, 3], pool_stride=[1, 2], pool_padding=(2, 1))) x, pool_size=[5, 3], pool_stride=[1, 2], pool_padding=(2, 1)))
def make_pool2d_infershape(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
theta = self._get_data("theta", shape=[2, 3], dtype='float32')
x = fluid.layers.affine_grid(theta, out_shape=[2, 3, 244, 244])
return (layers.pool2d(
x, pool_size=[5, 3], pool_stride=[1, 2], pool_padding=(2, 1)))
def make_pool3d(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
x = self._get_data(
name='x', shape=[3, 244, 244, 244], dtype='float32')
return (layers.pool3d(
x,
pool_size=[5, 3, 2],
pool_stride=[1, 2, 3],
pool_padding=(2, 1, 1)))
def make_adaptive_pool2d(self): def make_adaptive_pool2d(self):
with program_guard(fluid.default_main_program(), with program_guard(fluid.default_main_program(),
fluid.default_startup_program()): fluid.default_startup_program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册