diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index af0665f4a12acfa9fd9c0642da671af57f9e3f89..073c7fe75613b03b573fcb1d37e51b1c874bf4eb 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -78,8 +78,12 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const { output_shape.insert(output_shape.end(), ksize.begin(), ksize.end()); } else { for (size_t i = 0; i < ksize.size(); ++i) { - output_shape.push_back(PoolOutputSize( - in_x_dims[i + 2], ksize[i], paddings[i], strides[i], ceil_mode)); + if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) { + output_shape.push_back(-1); + } else { + output_shape.push_back(PoolOutputSize( + in_x_dims[i + 2], ksize[i], paddings[i], strides[i], ceil_mode)); + } } } ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 664a295660f54376d96b0890022d611d34286b6d..3466a9b7829f597b8fc4b29ee56d79b7cfbf8f57 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1354,6 +1354,25 @@ class TestBook(LayerTest): return (layers.pool2d( x, pool_size=[5, 3], pool_stride=[1, 2], pool_padding=(2, 1))) + def make_pool2d_infershape(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + theta = self._get_data("theta", shape=[2, 3], dtype='float32') + x = fluid.layers.affine_grid(theta, out_shape=[2, 3, 244, 244]) + return (layers.pool2d( + x, pool_size=[5, 3], pool_stride=[1, 2], pool_padding=(2, 1))) + + def make_pool3d(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + x = self._get_data( + name='x', shape=[3, 244, 244, 244], dtype='float32') + return (layers.pool3d( + x, + pool_size=[5, 3, 2], + pool_stride=[1, 2, 3], + pool_padding=(2, 1, 1))) + def make_adaptive_pool2d(self): with program_guard(fluid.default_main_program(), fluid.default_startup_program()):