From f3c6d19cd07782f7ca5d80de6d44841231194f48 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 5 Sep 2022 09:34:20 +0800 Subject: [PATCH] [OpAttr]ksize of pool2d support Tensor type of adaptive_avg_pool2d API (#45660) * [OpAttr]ksize of pool2d support Tensor type * fix unittest * add unittest --- paddle/fluid/operators/pool_op.cc | 17 ++-- paddle/phi/api/yaml/legacy_api.yaml | 8 +- paddle/phi/api/yaml/legacy_backward.yaml | 23 +++--- paddle/phi/infermeta/backward.cc | 17 ---- paddle/phi/infermeta/backward.h | 15 ---- paddle/phi/infermeta/unary.cc | 47 +++++++++++ paddle/phi/infermeta/unary.h | 14 ++++ paddle/phi/kernels/gpudnn/pool_grad_kernel.cu | 8 +- paddle/phi/kernels/gpudnn/pool_kernel.cu | 6 +- .../phi/kernels/impl/pool_grad_kernel_impl.h | 8 +- paddle/phi/kernels/impl/pool_kernel_impl.h | 6 +- paddle/phi/kernels/pool_grad_kernel.h | 9 ++- paddle/phi/kernels/pool_kernel.h | 5 +- paddle/phi/kernels/xpu/pool_kernel.cc | 5 +- python/paddle/fluid/layers/utils.py | 3 + .../unittests/test_adaptive_avg_pool2d.py | 79 +++++++++++++++++++ python/paddle/nn/functional/pooling.py | 19 +++-- 17 files changed, 210 insertions(+), 79 deletions(-) diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index ee494f7cb1..c5b1ce12f1 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -151,9 +151,8 @@ void Pool2dOpMaker::Make() { "(vector) The pooling window " "size(height, width) of the pooling operator. " "If global_pooling = true, ksize and paddings will " - "be ignored."); // TODO(Chengduo): Add checker. - // (Currently, - // TypedAttrChecker don't support vector type.) + "be ignored.") + .SupportTensor(); AddAttr( "global_pooling", "(bool) Whether to use the global pooling. " @@ -371,9 +370,7 @@ void Pool3dOpMaker::Make() { "(vector) The pooling window size(depth, height, " "width) of pooling operator. " "If global_pooling = true, ksize and paddings will " - "be ignored."); // TODO(Chengduo): Add checker. - // (Currently, - // TypedAttrChecker don't support vector type.) + "be ignored."); AddAttr( "global_pooling", "(bool) Whether to use the global pooling. " @@ -554,13 +551,13 @@ namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(pool2d, Pool2dInferShapeFunctor, - PD_INFER_META(phi::PoolInferMeta)); + PD_INFER_META(phi::Pool2DInferMeta)); DECLARE_INFER_SHAPE_FUNCTOR(pool2d_grad, Pool2dGradInferShapeFunctor, - PD_INFER_META(phi::PoolGradInferMeta)); + PD_INFER_META(phi::UnchangedInferMeta)); DECLARE_INFER_SHAPE_FUNCTOR(pool2d_double_grad, Pool2dDoubleGradInferShapeFunctor, - PD_INFER_META(phi::PoolInferMeta)); + PD_INFER_META(phi::Pool2DInferMeta)); REGISTER_OPERATOR( pool2d, @@ -584,7 +581,7 @@ DECLARE_INFER_SHAPE_FUNCTOR(pool3d, PD_INFER_META(phi::PoolInferMeta)); DECLARE_INFER_SHAPE_FUNCTOR(pool3d_grad, Pool3dGradInferShapeFunctor, - PD_INFER_META(phi::PoolGradInferMeta)); + PD_INFER_META(phi::UnchangedInferMeta)); REGISTER_OPERATOR( pool3d, diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index ff218b1756..4520c5ef37 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1978,10 +1978,10 @@ backward : pixel_shuffle_grad - api : pool2d - args : (Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) + args : (Tensor x, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) output : Tensor(out) infer_meta : - func : PoolInferMeta + func : Pool2DInferMeta kernel : func : pool2d use_gpudnn : true @@ -1989,10 +1989,10 @@ # Used in adaptive_avg_pool2d API - api : pool2d_gpudnn_unused - args : (Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) + args : (Tensor x, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) output : Tensor(out) infer_meta : - func : PoolInferMeta + func : Pool2DInferMeta kernel : func : pool2d use_gpudnn : false diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index bcaf99a88b..f7e1db86ec 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1748,32 +1748,34 @@ func : pixel_shuffle_grad - backward_api : pool2d_double_grad - forward : pool2d_grad(Tensor x, Tensor out, Tensor grad_out, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(grad_x) - args : (Tensor grad_x_grad, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) + forward : pool2d_grad(Tensor x, Tensor out, Tensor grad_out, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(grad_x) + args : (Tensor grad_x_grad, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) output : Tensor(grad_out_grad) infer_meta : - func : PoolInferMeta + func : Pool2DInferMeta kernel : func : pool2d_double_grad use_gpudnn : true - backward_api : pool2d_grad - forward : pool2d(Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(out) - args : (Tensor x, Tensor out, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) + forward : pool2d(Tensor x, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) output : Tensor(x_grad) infer_meta : - func : PoolGradInferMeta + func : UnchangedInferMeta + param: [x] kernel : func : pool2d_grad use_gpudnn : true backward : pool2d_double_grad - backward_api : pool2d_grad_gpudnn_unused - forward : pool2d_gpudnn_unused(Tensor x, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(out) - args : (Tensor x, Tensor out, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) + forward : pool2d_gpudnn_unused(Tensor x, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) output : Tensor(x_grad) infer_meta : - func : PoolGradInferMeta + func : UnchangedInferMeta + param: [x] kernel : func : pool2d_grad use_gpudnn : false @@ -1783,7 +1785,8 @@ args : (Tensor x, Tensor out, Tensor out_grad, int[] kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) output : Tensor(x_grad) infer_meta : - func : PoolGradInferMeta + func : UnchangedInferMeta + param: [x] kernel : func : pool3d_grad use_gpudnn : true diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a1c1a07861..82d83d4950 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -761,23 +761,6 @@ void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad, x_grad->set_dtype(out_grad.dtype()); } -void PoolGradInferMeta(const MetaTensor& x, - const MetaTensor& out, - const MetaTensor& dout, - const std::vector& kernel_size, - const std::vector& strides, - const std::vector& paddings, - bool ceil_mode, - bool exclusive, - const std::string& data_format, - const std::string& pooling_type, - bool global_pooling, - bool adaptive, - const std::string& padding_algorithm, - MetaTensor* dx) { - dx->share_meta(x); -} - void PsroiPoolGradInferMeta(const MetaTensor& x, const MetaTensor& rois, const MetaTensor& rois_num, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 2a11986a39..0930358ad8 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -319,21 +319,6 @@ void PsroiPoolGradInferMeta(const MetaTensor& x, float spatial_scale, MetaTensor* dx); -void PoolGradInferMeta(const MetaTensor& x, - const MetaTensor& out, - const MetaTensor& dout, - const std::vector& kernel_size, - const std::vector& strides, - const std::vector& paddings, - bool ceil_mode, - bool exclusive, - const std::string& data_format, - const std::string& pooling_type, - bool global_pooling, - bool adaptive, - const std::string& padding_algorithm, - MetaTensor* dx); - void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx); void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 5583ba37a2..fdbf2c7d5f 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2452,6 +2452,53 @@ void PNormInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void Pool2DInferMeta(const MetaTensor& x, + const IntArray& kernel_size, + const std::vector& strides, + const std::vector& paddings, + bool ceil_mode, + bool exclusive, + const std::string& data_format, + const std::string& pooling_type, + bool global_pooling, + bool adaptive, + const std::string& padding_algorithm, + MetaTensor* out, + MetaConfig config) { + const bool channel_last = (config.is_run_mkldnn_kernel == false) && + (data_format == "NHWC" || data_format == "NDHWC"); + if (!config.is_runtime && kernel_size.FromTensor()) { + auto x_dims = x.dims(); + std::vector output_shape = std::move(phi::vectorize(x_dims)); + // set dims of HW -1 + output_shape[x_dims.size() - 2] = -1; + if (channel_last) { // for NHWC, NDHWC + output_shape[x_dims.size() - 3] = -1; + } else { // for NCHW + output_shape[x_dims.size() - 1] = -1; + } + out->set_dims(make_ddim(output_shape)); + out->share_lod(x); + out->set_dtype(x.dtype()); + } else { + std::vector kernel_size_val(kernel_size.GetData().begin(), + kernel_size.GetData().end()); + PoolInferMeta(x, + kernel_size_val, + strides, + paddings, + ceil_mode, + exclusive, + data_format, + pooling_type, + global_pooling, + adaptive, + padding_algorithm, + out, + config); + } +} + void PoolInferMeta(const MetaTensor& x, const std::vector& kernel_size, const std::vector& strides, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 031411f925..3273a04373 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -354,6 +354,20 @@ void PoolInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void Pool2DInferMeta(const MetaTensor& x, + const IntArray& kernel_size, + const std::vector& strides, + const std::vector& paddings, + bool ceil_mode, + bool exclusive, + const std::string& data_format, + const std::string& pooling_type, + bool global_pooling, + bool adaptive, + const std::string& padding_algorithm, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void QrInferMeta(const MetaTensor& x, const std::string& mode, MetaTensor* q, diff --git a/paddle/phi/kernels/gpudnn/pool_grad_kernel.cu b/paddle/phi/kernels/gpudnn/pool_grad_kernel.cu index 4a7d417f2b..0798eab5e1 100644 --- a/paddle/phi/kernels/gpudnn/pool_grad_kernel.cu +++ b/paddle/phi/kernels/gpudnn/pool_grad_kernel.cu @@ -305,7 +305,7 @@ void Pool2dGradGPUDNNKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& dout, - const std::vector& kernel_size, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings, bool ceil_mode, @@ -316,11 +316,13 @@ void Pool2dGradGPUDNNKernel(const Context& ctx, bool adaptive, const std::string& padding_algorithm, DenseTensor* dx) { + std::vector kernel_size_val(kernel_size.GetData().begin(), + kernel_size.GetData().end()); PoolGradRawGPUDNNKernel(ctx, x, out, dout, - kernel_size, + kernel_size_val, strides, paddings, exclusive, @@ -335,7 +337,7 @@ void Pool2dGradGPUDNNKernel(const Context& ctx, template void Pool2dDoubleGradGPUDNNKernel(const Context& ctx, const DenseTensor& x, - const std::vector& kernel_size, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings, bool ceil_mode, diff --git a/paddle/phi/kernels/gpudnn/pool_kernel.cu b/paddle/phi/kernels/gpudnn/pool_kernel.cu index 9981ac5b53..16139d48b2 100644 --- a/paddle/phi/kernels/gpudnn/pool_kernel.cu +++ b/paddle/phi/kernels/gpudnn/pool_kernel.cu @@ -230,7 +230,7 @@ void PoolRawGPUDNNKernel(const Context& ctx, template void Pool2dGPUDNNKernel(const Context& ctx, const DenseTensor& x, - const std::vector& kernel_size, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings, bool ceil_mode, @@ -241,9 +241,11 @@ void Pool2dGPUDNNKernel(const Context& ctx, bool adaptive, const std::string& padding_algorithm, DenseTensor* out) { + std::vector kernel_size_val(kernel_size.GetData().begin(), + kernel_size.GetData().end()); PoolRawGPUDNNKernel(ctx, x, - kernel_size, + kernel_size_val, strides, paddings, exclusive, diff --git a/paddle/phi/kernels/impl/pool_grad_kernel_impl.h b/paddle/phi/kernels/impl/pool_grad_kernel_impl.h index a816deaeb0..e53018f229 100644 --- a/paddle/phi/kernels/impl/pool_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/pool_grad_kernel_impl.h @@ -189,7 +189,7 @@ void Pool2dGradKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& dout, - const std::vector& kernel_size, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings, bool ceil_mode, @@ -200,11 +200,13 @@ void Pool2dGradKernel(const Context& ctx, bool adaptive, const std::string& padding_algorithm, DenseTensor* dx) { + std::vector kernel_size_val(kernel_size.GetData().begin(), + kernel_size.GetData().end()); PoolGradRawKernel(ctx, x, out, dout, - kernel_size, + kernel_size_val, strides, paddings, exclusive, @@ -219,7 +221,7 @@ void Pool2dGradKernel(const Context& ctx, template void Pool2dDoubleGradKernel(const Context& ctx, const DenseTensor& x, - const std::vector& kernel_size, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings, bool ceil_mode, diff --git a/paddle/phi/kernels/impl/pool_kernel_impl.h b/paddle/phi/kernels/impl/pool_kernel_impl.h index fb93fc1ce6..931a14d9fd 100644 --- a/paddle/phi/kernels/impl/pool_kernel_impl.h +++ b/paddle/phi/kernels/impl/pool_kernel_impl.h @@ -223,7 +223,7 @@ void MaxPoolWithIndexRawKernel(const Context& ctx, template void Pool2dKernel(const Context& ctx, const DenseTensor& x, - const std::vector& kernel_size, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings, bool ceil_mode, @@ -234,9 +234,11 @@ void Pool2dKernel(const Context& ctx, bool adaptive, const std::string& padding_algorithm, DenseTensor* out) { + std::vector kernel_size_val(kernel_size.GetData().begin(), + kernel_size.GetData().end()); PoolRawKernel(ctx, x, - kernel_size, + kernel_size_val, strides, paddings, exclusive, diff --git a/paddle/phi/kernels/pool_grad_kernel.h b/paddle/phi/kernels/pool_grad_kernel.h index d26bee2eb2..64ad99a6d3 100644 --- a/paddle/phi/kernels/pool_grad_kernel.h +++ b/paddle/phi/kernels/pool_grad_kernel.h @@ -17,6 +17,7 @@ #include #include +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -26,7 +27,7 @@ void Pool2dGradKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& dout, - const std::vector& kernel_size, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings, bool ceil_mode, @@ -43,7 +44,7 @@ void Pool2dGradGPUDNNKernel(const Context& ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& dout, - const std::vector& kernel_size, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings, bool ceil_mode, @@ -58,7 +59,7 @@ void Pool2dGradGPUDNNKernel(const Context& ctx, template void Pool2dDoubleGradKernel(const Context& ctx, const DenseTensor& x, - const std::vector& kernel_size, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings, bool ceil_mode, @@ -73,7 +74,7 @@ void Pool2dDoubleGradKernel(const Context& ctx, template void Pool2dDoubleGradGPUDNNKernel(const Context& ctx, const DenseTensor& x, - const std::vector& kernel_size, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings, bool ceil_mode, diff --git a/paddle/phi/kernels/pool_kernel.h b/paddle/phi/kernels/pool_kernel.h index b9a4c830fa..c1a7dd471a 100644 --- a/paddle/phi/kernels/pool_kernel.h +++ b/paddle/phi/kernels/pool_kernel.h @@ -17,6 +17,7 @@ #include #include +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -24,7 +25,7 @@ namespace phi { template void Pool2dKernel(const Context& ctx, const DenseTensor& x, - const std::vector& kernel_size, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings, bool ceil_mode, @@ -39,7 +40,7 @@ void Pool2dKernel(const Context& ctx, template void Pool2dGPUDNNKernel(const Context& ctx, const DenseTensor& x, - const std::vector& kernel_size, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings, bool ceil_mode, diff --git a/paddle/phi/kernels/xpu/pool_kernel.cc b/paddle/phi/kernels/xpu/pool_kernel.cc index bf60deb28d..2eb850b9a7 100644 --- a/paddle/phi/kernels/xpu/pool_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_kernel.cc @@ -22,7 +22,7 @@ namespace phi { template void Pool2dKernel(const Context& ctx, const DenseTensor& x, - const std::vector& kernel_size_t, + const IntArray& kernel_size, const std::vector& strides, const std::vector& paddings_t, bool ceil_mode, @@ -36,7 +36,8 @@ void Pool2dKernel(const Context& ctx, using XPUType = typename XPUTypeTrait::Type; std::vector paddings(paddings_t); - std::vector kernel_size(kernel_size_t); + std::vector kernel_size_val(kernel_size.GetData().begin(), + kernel_size.GetData().end()); PADDLE_ENFORCE_EQ(kernel_size.size(), 2, diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py index be8045b7bb..ad68366e1e 100644 --- a/python/paddle/fluid/layers/utils.py +++ b/python/paddle/fluid/layers/utils.py @@ -62,6 +62,9 @@ def convert_to_list(value, n, name, dtype=int): raise ValueError("The " + name + "'s length must be " + str(n) + ". Received: " + str(value)) for single_value in value_list: + assert not isinstance( + single_value, Variable + ), "Required numerical type with '%s', but received Tensor." % dtype try: dtype(single_value) except (ValueError, TypeError): diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_avg_pool2d.py b/python/paddle/fluid/tests/unittests/test_adaptive_avg_pool2d.py index 2531834b21..82735b4006 100644 --- a/python/paddle/fluid/tests/unittests/test_adaptive_avg_pool2d.py +++ b/python/paddle/fluid/tests/unittests/test_adaptive_avg_pool2d.py @@ -15,6 +15,7 @@ from __future__ import print_function from __future__ import division +import os import unittest import numpy as np @@ -24,6 +25,8 @@ import paddle import paddle.fluid as fluid from paddle.fluid import Program, program_guard +from test_attribute_var import UnittestBase + def adaptive_start_index(index, input_size, output_size): return int(np.floor(index * input_size / output_size)) @@ -288,5 +291,81 @@ class TestAdaptiveAvgPool2DClassAPI(unittest.TestCase): assert np.allclose(out_5.numpy(), self.res_5_np) +class TestOutputSizeTensor(UnittestBase): + + def init_info(self): + self.shapes = [[1, 3, 6, 6]] + self.save_path = os.path.join(self.temp_dir.name, self.path_prefix()) + + def test_static(self): + paddle.enable_static() + main_prog = Program() + starup_prog = Program() + with program_guard(main_prog, starup_prog): + fc = paddle.nn.Linear(6, 6) + x = paddle.randn(self.shapes[0]) + x.stop_gradient = False + feat = fc(x) # [1,3,6,6] + + out1, out2 = self.call_func(feat) + + sgd = paddle.optimizer.SGD() + sgd.minimize(paddle.mean(out1 + out2)) + self.assertTrue(self.var_prefix() in str(main_prog)) + + exe = paddle.static.Executor() + exe.run(starup_prog) + res = exe.run(fetch_list=[out1, out2]) + np.testing.assert_allclose(res[0], res[1]) + paddle.static.save_inference_model(self.save_path, [x], + [out1, out2], exe) + # Test for Inference Predictor + infer_outs = self.infer_prog() + np.testing.assert_array_equal(infer_outs[0].shape, (1, 3, 3, 3)) + np.testing.assert_allclose(infer_outs[0], infer_outs[1]) + + def path_prefix(self): + return 'pool2d_tensor' + + def var_prefix(self): + return "Vars[" + + def call_func(self, x): + # list[Tensor] + output_size = [paddle.assign([3]), paddle.assign([3])] + out1 = paddle.nn.functional.adaptive_avg_pool2d(x=x, output_size=[3, 3]) + out2 = paddle.nn.functional.adaptive_avg_pool2d(x=x, + output_size=output_size) + return out1, out2 + + +class TestOutputSizeListTensor(TestOutputSizeTensor): + + def path_prefix(self): + return 'pool2d_tensors' + + def call_func(self, x): + # list[int, Tensor] + output_size = [paddle.assign([3]), 3] + out1 = paddle.nn.functional.adaptive_avg_pool2d(x=x, output_size=[3, 3]) + out2 = paddle.nn.functional.adaptive_avg_pool2d(x=x, + output_size=output_size) + return out1, out2 + + +class TestOutputSizeListTensor2(TestOutputSizeTensor): + + def path_prefix(self): + return 'pool2d_tensor2' + + def call_func(self, x): + # A Tensor + output_size = paddle.assign([3, 3]) + out1 = paddle.nn.functional.adaptive_avg_pool2d(x=x, output_size=[3, 3]) + out2 = paddle.nn.functional.adaptive_avg_pool2d(x=x, + output_size=output_size) + return out1, out2 + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 4c847a4233..3194b5720b 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -1469,11 +1469,20 @@ def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None): output_size = utils.convert_to_list(output_size, 2, 'output_size') else: output_size = list(output_size) - if output_size[0] == None: + if output_size[0] is None: output_size[0] = in_h - if output_size[1] == None: + if output_size[1] is None: output_size[1] = in_w + if _non_static_mode(): + output_size = [ + item.numpy().item(0) if isinstance(item, Variable) else item + for item in output_size + ] + # output_size support Variable in static mode + elif utils._contain_var(output_size): + output_size = utils._convert_to_tensor_list(output_size) + if in_dygraph_mode(): return _C_ops.pool2d_gpudnn_unused(x, output_size, [1, 1], [0, 0], False, True, data_format, 'avg', @@ -1585,11 +1594,11 @@ def adaptive_avg_pool3d(x, output_size, data_format='NCDHW', name=None): output_size = utils.convert_to_list(output_size, 3, 'output_size') else: output_size = list(output_size) - if output_size[0] == None: + if output_size[0] is None: output_size[0] = in_l - if output_size[1] == None: + if output_size[1] is None: output_size[1] = in_h - if output_size[2] == None: + if output_size[2] is None: output_size[2] = in_w if in_dynamic_mode(): -- GitLab