From 74836ec7b7406bd26e2f52daa31f23478d265307 Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Fri, 21 Aug 2020 10:56:39 +0800 Subject: [PATCH] [2.0API]Add adaptive_avg_pool_2/3d (#26369) * add adaptive_avg_pool2d * add adaptive_avg_pool3d --- .../unittests/test_adaptive_avg_pool2d.py | 274 ++++++++++++++++ .../unittests/test_adaptive_avg_pool3d.py | 293 ++++++++++++++++++ python/paddle/nn/__init__.py | 2 + python/paddle/nn/functional/__init__.py | 2 + python/paddle/nn/functional/pooling.py | 261 +++++++++++++++- python/paddle/nn/layer/__init__.py | 2 + python/paddle/nn/layer/pooling.py | 196 ++++++++++++ 7 files changed, 1029 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_adaptive_avg_pool2d.py create mode 100755 python/paddle/fluid/tests/unittests/test_adaptive_avg_pool3d.py create mode 100755 python/paddle/nn/layer/pooling.py diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_avg_pool2d.py b/python/paddle/fluid/tests/unittests/test_adaptive_avg_pool2d.py new file mode 100644 index 00000000000..55c30e3d2ad --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_adaptive_avg_pool2d.py @@ -0,0 +1,274 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from __future__ import division + +import unittest +import numpy as np + +import paddle.fluid.core as core +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + + +def adaptive_start_index(index, input_size, output_size): + return int(np.floor(index * input_size / output_size)) + + +def adaptive_end_index(index, input_size, output_size): + return int(np.ceil((index + 1) * input_size / output_size)) + + +def adaptive_pool2d_forward(x, output_size, data_format='NCHW', + pool_type="avg"): + + N = x.shape[0] + C, H, W = [x.shape[1], x.shape[2], x.shape[3]] if data_format == 'NCHW' \ + else [x.shape[3], x.shape[1], x.shape[2]] + + if (isinstance(output_size, int) or output_size == None): + H_out = output_size + W_out = output_size + output_size = [H_out, W_out] + else: + H_out, W_out = output_size + + if output_size[0] == None: + output_size[0] = H + H_out = H + if output_size[1] == None: + output_size[1] = W + W_out = W + + out = np.zeros((N, C, H_out, W_out)) if data_format=='NCHW' \ + else np.zeros((N, H_out, W_out, C)) + + for i in range(H_out): + in_h_start = adaptive_start_index(i, H, output_size[0]) + in_h_end = adaptive_end_index(i, H, output_size[0]) + + for j in range(W_out): + in_w_start = adaptive_start_index(j, W, output_size[1]) + in_w_end = adaptive_end_index(j, W, output_size[1]) + + if data_format == 'NCHW': + x_masked = x[:, :, in_h_start:in_h_end, in_w_start:in_w_end] + if pool_type == 'avg': + field_size = ( + (in_h_end - in_h_start) * (in_w_end - in_w_start)) + out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size + elif pool_type == 'max': + out[:, :, i, j] = np.max(x_masked, axis=(2, 3)) + elif data_format == 'NHWC': + x_masked = x[:, in_h_start:in_h_end, in_w_start:in_w_end, :] + if pool_type == 'avg': + field_size = ( + (in_h_end - in_h_start) * (in_w_end - in_w_start)) + out[:, i, j, :] = np.sum(x_masked, axis=(1, 2)) / field_size + elif pool_type == 'max': + out[:, i, j, :] = np.max(x_masked, axis=(1, 2)) + return out + + +class TestAdaptiveAvgPool2dAPI(unittest.TestCase): + def setUp(self): + self.x_np = np.random.random([2, 3, 7, 7]).astype("float32") + self.res_1_np = adaptive_pool2d_forward( + x=self.x_np, output_size=[3, 3], pool_type="avg") + + self.res_2_np = adaptive_pool2d_forward( + x=self.x_np, output_size=5, pool_type="avg") + + self.res_3_np = adaptive_pool2d_forward( + x=self.x_np, output_size=[2, 5], pool_type="avg") + + self.res_4_np = adaptive_pool2d_forward( + x=self.x_np, + output_size=[3, 3], + pool_type="avg", + data_format="NHWC") + + self.res_5_np = adaptive_pool2d_forward( + x=self.x_np, output_size=[None, 3], pool_type="avg") + + def test_static_graph(self): + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.enable_static() + x = paddle.data(name="x", shape=[2, 3, 7, 7], dtype="float32") + + out_1 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=[3, 3]) + + out_2 = paddle.nn.functional.adaptive_avg_pool2d(x=x, output_size=5) + + out_3 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=[2, 5]) + + out_4 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=[3, 3], data_format="NHWC") + + out_5 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=[None, 3]) + + exe = paddle.static.Executor(place=place) + [res_1, res_2, res_3, res_4, res_5] = exe.run( + fluid.default_main_program(), + feed={"x": self.x_np}, + fetch_list=[out_1, out_2, out_3, out_4, out_5]) + + assert np.allclose(res_1, self.res_1_np) + + assert np.allclose(res_2, self.res_2_np) + + assert np.allclose(res_3, self.res_3_np) + + assert np.allclose(res_4, self.res_4_np) + + assert np.allclose(res_5, self.res_5_np) + + def test_dynamic_graph(self): + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.disable_static(place=place) + x = paddle.to_variable(self.x_np) + + out_1 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=[3, 3]) + + out_2 = paddle.nn.functional.adaptive_avg_pool2d(x=x, output_size=5) + + out_3 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=[2, 5]) + + out_4 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=[3, 3], data_format="NHWC") + + out_5 = paddle.nn.functional.adaptive_avg_pool2d( + x=x, output_size=[None, 3]) + + assert np.allclose(out_1.numpy(), self.res_1_np) + + assert np.allclose(out_2.numpy(), self.res_2_np) + + assert np.allclose(out_3.numpy(), self.res_3_np) + + assert np.allclose(out_4.numpy(), self.res_4_np) + + assert np.allclose(out_5.numpy(), self.res_5_np) + + +class TestAdaptiveAvgPool2dClassAPI(unittest.TestCase): + def setUp(self): + self.x_np = np.random.random([2, 3, 7, 7]).astype("float32") + self.res_1_np = adaptive_pool2d_forward( + x=self.x_np, output_size=[3, 3], pool_type="avg") + + self.res_2_np = adaptive_pool2d_forward( + x=self.x_np, output_size=5, pool_type="avg") + + self.res_3_np = adaptive_pool2d_forward( + x=self.x_np, output_size=[2, 5], pool_type="avg") + + self.res_4_np = adaptive_pool2d_forward( + x=self.x_np, + output_size=[3, 3], + pool_type="avg", + data_format="NHWC") + + self.res_5_np = adaptive_pool2d_forward( + x=self.x_np, output_size=[None, 3], pool_type="avg") + + def test_static_graph(self): + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.enable_static() + x = paddle.data(name="x", shape=[2, 3, 7, 7], dtype="float32") + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=[3, 3]) + out_1 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=5) + out_2 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=[2, 5]) + out_3 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d( + output_size=[3, 3], data_format="NHWC") + out_4 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d( + output_size=[None, 3]) + out_5 = adaptive_avg_pool(x=x) + + exe = paddle.static.Executor(place=place) + [res_1, res_2, res_3, res_4, res_5] = exe.run( + fluid.default_main_program(), + feed={"x": self.x_np}, + fetch_list=[out_1, out_2, out_3, out_4, out_5]) + + assert np.allclose(res_1, self.res_1_np) + + assert np.allclose(res_2, self.res_2_np) + + assert np.allclose(res_3, self.res_3_np) + + assert np.allclose(res_4, self.res_4_np) + + assert np.allclose(res_5, self.res_5_np) + + def test_dynamic_graph(self): + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.disable_static(place=place) + x = paddle.to_variable(self.x_np) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=[3, 3]) + out_1 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=5) + out_2 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=[2, 5]) + out_3 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d( + output_size=[3, 3], data_format="NHWC") + out_4 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d( + output_size=[None, 3]) + out_5 = adaptive_avg_pool(x=x) + + assert np.allclose(out_1.numpy(), self.res_1_np) + + assert np.allclose(out_2.numpy(), self.res_2_np) + + assert np.allclose(out_3.numpy(), self.res_3_np) + + assert np.allclose(out_4.numpy(), self.res_4_np) + + assert np.allclose(out_5.numpy(), self.res_5_np) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_avg_pool3d.py b/python/paddle/fluid/tests/unittests/test_adaptive_avg_pool3d.py new file mode 100755 index 00000000000..c04ee660667 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_adaptive_avg_pool3d.py @@ -0,0 +1,293 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from __future__ import division + +import unittest +import numpy as np + +import paddle.fluid.core as core +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + + +def adaptive_start_index(index, input_size, output_size): + return int(np.floor(index * input_size / output_size)) + + +def adaptive_end_index(index, input_size, output_size): + return int(np.ceil((index + 1) * input_size / output_size)) + + +def adaptive_pool3d_forward(x, + output_size, + adaptive=True, + data_format='NCDHW', + pool_type='avg'): + + N = x.shape[0] + C, D, H, W = [x.shape[1], x.shape[2], x.shape[3], x.shape[4]] \ + if data_format == 'NCDHW' else [x.shape[4], x.shape[1], x.shape[2],x.shape[3]] + + if (isinstance(output_size, int) or output_size == None): + H_out = output_size + W_out = output_size + D_out = output_size + output_size = [D_out, H_out, W_out] + else: + D_out, H_out, W_out = output_size + + if output_size[0] == None: + output_size[0] = D + D_out = D + if output_size[1] == None: + output_size[1] = H + H_out = H + if output_size[2] == None: + output_size[2] = W + W_out = W + + out = np.zeros((N, C, D_out, H_out, W_out)) if data_format=='NCDHW' \ + else np.zeros((N, D_out, H_out, W_out, C)) + for k in range(D_out): + d_start = adaptive_start_index(k, D, output_size[0]) + d_end = adaptive_end_index(k, D, output_size[0]) + + for i in range(H_out): + h_start = adaptive_start_index(i, H, output_size[1]) + h_end = adaptive_end_index(i, H, output_size[1]) + + for j in range(W_out): + w_start = adaptive_start_index(j, W, output_size[2]) + w_end = adaptive_end_index(j, W, output_size[2]) + + if data_format == 'NCDHW': + x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start: + w_end] + if pool_type == 'avg': + field_size = (d_end - d_start) * (h_end - h_start) * ( + w_end - w_start) + out[:, :, k, i, j] = np.sum(x_masked, + axis=(2, 3, 4)) / field_size + elif pool_type == 'max': + out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4)) + + elif data_format == 'NDHWC': + x_masked = x[:, d_start:d_end, h_start:h_end, w_start: + w_end, :] + if pool_type == 'avg': + field_size = (d_end - d_start) * (h_end - h_start) * ( + w_end - w_start) + out[:, k, i, j, :] = np.sum(x_masked, + axis=(1, 2, 3)) / field_size + elif pool_type == 'max': + out[:, k, i, j, :] = np.max(x_masked, axis=(1, 2, 3)) + return out + + +class TestAdaptiveAvgPool3dAPI(unittest.TestCase): + def setUp(self): + self.x_np = np.random.random([2, 3, 5, 7, 7]).astype("float32") + self.res_1_np = adaptive_pool3d_forward( + x=self.x_np, output_size=[3, 3, 3], pool_type="avg") + + self.res_2_np = adaptive_pool3d_forward( + x=self.x_np, output_size=5, pool_type="avg") + + self.res_3_np = adaptive_pool3d_forward( + x=self.x_np, output_size=[2, 3, 5], pool_type="avg") + + self.res_4_np = adaptive_pool3d_forward( + x=self.x_np, + output_size=[3, 3, 3], + pool_type="avg", + data_format="NDHWC") + + self.res_5_np = adaptive_pool3d_forward( + x=self.x_np, output_size=[None, 3, None], pool_type="avg") + + def test_static_graph(self): + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.enable_static() + x = paddle.data(name="x", shape=[2, 3, 5, 7, 7], dtype="float32") + + out_1 = paddle.nn.functional.adaptive_avg_pool3d( + x=x, output_size=[3, 3, 3]) + + out_2 = paddle.nn.functional.adaptive_avg_pool3d(x=x, output_size=5) + + out_3 = paddle.nn.functional.adaptive_avg_pool3d( + x=x, output_size=[2, 3, 5]) + + out_4 = paddle.nn.functional.adaptive_avg_pool3d( + x=x, output_size=[3, 3, 3], data_format="NDHWC") + + out_5 = paddle.nn.functional.adaptive_avg_pool3d( + x=x, output_size=[None, 3, None]) + + exe = paddle.static.Executor(place=place) + [res_1, res_2, res_3, res_4, res_5] = exe.run( + fluid.default_main_program(), + feed={"x": self.x_np}, + fetch_list=[out_1, out_2, out_3, out_4, out_5]) + + assert np.allclose(res_1, self.res_1_np) + + assert np.allclose(res_2, self.res_2_np) + + assert np.allclose(res_3, self.res_3_np) + + assert np.allclose(res_4, self.res_4_np) + + assert np.allclose(res_5, self.res_5_np) + + def test_dynamic_graph(self): + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.disable_static(place=place) + x = paddle.to_variable(self.x_np) + + out_1 = paddle.nn.functional.adaptive_avg_pool3d( + x=x, output_size=[3, 3, 3]) + + out_2 = paddle.nn.functional.adaptive_avg_pool3d(x=x, output_size=5) + + out_3 = paddle.nn.functional.adaptive_avg_pool3d( + x=x, output_size=[2, 3, 5]) + + out_4 = paddle.nn.functional.adaptive_avg_pool3d( + x=x, output_size=[3, 3, 3], data_format="NDHWC") + + out_5 = paddle.nn.functional.adaptive_avg_pool3d( + x=x, output_size=[None, 3, None]) + + assert np.allclose(out_1.numpy(), self.res_1_np) + + assert np.allclose(out_2.numpy(), self.res_2_np) + + assert np.allclose(out_3.numpy(), self.res_3_np) + + assert np.allclose(out_4.numpy(), self.res_4_np) + + assert np.allclose(out_5.numpy(), self.res_5_np) + + +class TestAdaptiveAvgPool3dClassAPI(unittest.TestCase): + def setUp(self): + self.x_np = np.random.random([2, 3, 5, 7, 7]).astype("float32") + self.res_1_np = adaptive_pool3d_forward( + x=self.x_np, output_size=[3, 3, 3], pool_type="avg") + + self.res_2_np = adaptive_pool3d_forward( + x=self.x_np, output_size=5, pool_type="avg") + + self.res_3_np = adaptive_pool3d_forward( + x=self.x_np, output_size=[2, 3, 5], pool_type="avg") + + self.res_4_np = adaptive_pool3d_forward( + x=self.x_np, + output_size=[3, 3, 3], + pool_type="avg", + data_format="NDHWC") + + self.res_5_np = adaptive_pool3d_forward( + x=self.x_np, output_size=[None, 3, None], pool_type="avg") + + def test_static_graph(self): + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.enable_static() + x = paddle.data(name="x", shape=[2, 3, 5, 7, 7], dtype="float32") + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d( + output_size=[3, 3, 3]) + out_1 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(output_size=5) + out_2 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d( + output_size=[2, 3, 5]) + out_3 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d( + output_size=[3, 3, 3], data_format="NDHWC") + out_4 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d( + output_size=[None, 3, None]) + out_5 = adaptive_avg_pool(x=x) + + exe = paddle.static.Executor(place=place) + [res_1, res_2, res_3, res_4, res_5] = exe.run( + fluid.default_main_program(), + feed={"x": self.x_np}, + fetch_list=[out_1, out_2, out_3, out_4, out_5]) + + assert np.allclose(res_1, self.res_1_np) + + assert np.allclose(res_2, self.res_2_np) + + assert np.allclose(res_3, self.res_3_np) + + assert np.allclose(res_4, self.res_4_np) + + assert np.allclose(res_5, self.res_5_np) + + def test_dynamic_graph(self): + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + paddle.disable_static(place=place) + x = paddle.to_variable(self.x_np) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d( + output_size=[3, 3, 3]) + out_1 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(output_size=5) + out_2 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d( + output_size=[2, 3, 5]) + out_3 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d( + output_size=[3, 3, 3], data_format="NDHWC") + out_4 = adaptive_avg_pool(x=x) + + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d( + output_size=[None, 3, None]) + out_5 = adaptive_avg_pool(x=x) + + assert np.allclose(out_1.numpy(), self.res_1_np) + + assert np.allclose(out_2.numpy(), self.res_2_np) + + assert np.allclose(out_3.numpy(), self.res_3_np) + + assert np.allclose(out_4.numpy(), self.res_4_np) + + assert np.allclose(out_5.numpy(), self.res_5_np) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 84c466c977e..669b708ff17 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -87,6 +87,8 @@ from .layer.common import Embedding #DEFINE_ALIAS from .layer.common import Linear #DEFINE_ALIAS from .layer.common import Flatten #DEFINE_ALIAS from .layer.common import UpSample #DEFINE_ALIAS +from .layer.pooling import AdaptiveAvgPool2d #DEFINE_ALIAS +from .layer.pooling import AdaptiveAvgPool3d #DEFINE_ALIAS from .layer.conv import Conv2D #DEFINE_ALIAS from .layer.conv import Conv2DTranspose #DEFINE_ALIAS from .layer.conv import Conv3D #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 2ee23f2fea0..bf78762f7bd 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -160,6 +160,8 @@ from .pooling import pool2d #DEFINE_ALIAS from .pooling import pool3d #DEFINE_ALIAS from .pooling import adaptive_pool2d #DEFINE_ALIAS from .pooling import adaptive_pool3d #DEFINE_ALIAS +from .pooling import adaptive_avg_pool2d #DEFINE_ALIAS +from .pooling import adaptive_avg_pool3d #DEFINE_ALIAS # from .rnn import gru_unit #DEFINE_ALIAS # from .rnn import lstm #DEFINE_ALIAS # from .rnn import lstm_unit #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 618145fb1fa..c396d00320a 100644 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -13,9 +13,268 @@ # limitations under the License. # TODO: define pooling functions +import paddle +from ...fluid import core from ...fluid.layers import pool2d #DEFINE_ALIAS from ...fluid.layers import pool3d #DEFINE_ALIAS from ...fluid.layers import adaptive_pool2d #DEFINE_ALIAS from ...fluid.layers import adaptive_pool3d #DEFINE_ALIAS +from ...fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype +from ...fluid.layers import utils +from ...fluid.layer_helper import LayerHelper +from ...fluid.framework import in_dygraph_mode -__all__ = ['pool2d', 'pool3d', 'adaptive_pool2d', 'adaptive_pool3d'] +__all__ = [ + 'pool2d', 'pool3d', 'adaptive_pool2d', 'adaptive_pool3d', + 'adaptive_avg_pool2d', 'adaptive_avg_pool3d' +] + + +def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None): + """ + + This operation applies 2D adaptive avg pooling on input tensor. The h and w dimensions + of the output tensor are determined by the parameter output_size. + See more detail in :ref:`api_nn_pooling_AdaptiveAvgPool2d` . + + For avg adaptive pool2d: + + .. math:: + + hstart &= floor(i * H_{in} / H_{out}) + + hend &= ceil((i + 1) * H_{in} / H_{out}) + + wstart &= floor(j * W_{in} / W_{out}) + + wend &= ceil((j + 1) * W_{in} / W_{out}) + + Output(i ,j) &= \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)} + + Args: + x (Tensor): The input tensor of adaptive avg pool2d operator, which is a 4-D tensor. + The data type can be float16, float32, float64, int32 or int64. + output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain two element, (H, W). H and W can be either a int, or None which means + the size will be the same as that of the input. + data_format (str): The data format of the input and output data. An optional string + from: "NCHW", "NHWC". The default is "NCHW". When it is "NCHW", the data is stored in + the order of: [batch_size, input_channels, input_height, input_width]. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: The output tensor of avg adaptive pool2d result. The data type is same as input tensor. + + Raises: + ValueError: If `data_format` is not "NCHW" or "NHWC". + + Examples: + .. code-block:: python + + # adaptive avg pool2d + # suppose input data in shape of [N, C, H, W], `output_size` is [m, n], + # output shape is [N, C, m, n], adaptive pool divide H and W dimensions + # of input data into m * n grids averagely and performs poolings in each + # grid to get output. + # adaptive avg pool performs calculations as follow: + # + # for i in range(m): + # for j in range(n): + # hstart = floor(i * H / m) + # hend = ceil((i + 1) * H / m) + # wstart = floor(i * W / n) + # wend = ceil((i + 1) * W / n) + # output[:, :, i, j] = avg(input[:, :, hstart: hend, wstart: wend]) + # + import paddle + import numpy as np + paddle.disable_static() + input_data = np.random.rand(2, 3, 32, 32) + x = paddle.to_tensor(input_data) + # x.shape is [2, 3, 32, 32] + pool_out = paddle.nn.functional.adaptive_avg_pool2d( + x = x, + output_size=[3, 3]) + # pool_out.shape is [2, 3, 3, 3] + """ + if not in_dygraph_mode(): + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], + 'adaptive_avg_pool2d') + check_type(data_format, 'data_format', str, 'adaptive_avg_pool2d') + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCHW' or 'NHWC'. Received " + "Attr(data_format): %s." % str(data_format)) + + if data_format == "NCHW": + in_h, in_w = x.shape[2:4] + else: + in_h, in_w = x.shape[1:3] + + if isinstance(output_size, int): + output_size = utils.convert_to_list(output_size, 2, 'output_size') + else: + if output_size[0] == None: + output_size[0] = in_h + if output_size[1] == None: + output_size[1] = in_w + + if in_dygraph_mode(): + output = core.ops.pool2d(x, 'pooling_type', 'avg', 'ksize', output_size, + 'global_pooling', False, 'adaptive', True, + 'data_format', data_format) + return output + + l_type = 'pool2d' + + helper = LayerHelper(l_type, **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + + outputs = {"Out": pool_out} + + helper.append_op( + type=l_type, + inputs={"X": x}, + outputs=outputs, + attrs={ + "pooling_type": "avg", + "ksize": output_size, + "adaptive": True, + "data_format": data_format, + }) + + return pool_out + + +def adaptive_avg_pool3d(x, output_size, data_format='NCDHW', name=None): + """ + + This operation applies 3D adaptive avg pooling on input tensor. The h and w dimensions + of the output tensor are determined by the parameter output_size. + See more detail in :ref:`api_nn_pooling_AdaptiveAvgPool3d` . + + For avg adaptive pool3d: + + .. math:: + + dstart &= floor(i * D_{in} / D_{out}) + + dend &= ceil((i + 1) * D_{in} / D_{out}) + + hstart &= floor(j * H_{in} / H_{out}) + + hend &= ceil((j + 1) * H_{in} / H_{out}) + + wstart &= floor(k * W_{in} / W_{out}) + + wend &= ceil((k + 1) * W_{in} / W_{out}) + + Output(i ,j, k) &= \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{(dend - dstart) * (hend - hstart) * (wend - wstart)} + + Args: + x (Tensor): The input tensor of adaptive avg pool3d operator, which is a 5-D tensor. + The data type can be float16, float32, float64, int32 or int64. + output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain three elements, (D, H, W). D, H and W can be either a int, or None which means + the size will be the same as that of the input. + data_format (str): The data format of the input and output data. An optional string + from: "NCDHW", "NDHWC". The default is "NCDHW". When it is "NCDHW", the data is stored in + the order of: [batch_size, input_channels, input_depth, input_height, input_width]. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: The output tensor of avg adaptive pool3d result. The data type is same as input tensor. + + Raises: + ValueError: If `data_format` is not "NCDHW" or "NDHWC". + + Examples: + .. code-block:: python + + # adaptive avg pool3d + # suppose input data in shape of [N, C, D, H, W], `output_size` is [l, m, n], + # output shape is [N, C, l, m, n], adaptive pool divide D, H and W dimensions + # of input data into l * m * n grids averagely and performs poolings in each + # grid to get output. + # adaptive avg pool performs calculations as follow: + # + # for i in range(l): + # for j in range(m): + # for k in range(n): + # dstart = floor(i * D / l) + # dend = ceil((i + 1) * D / l) + # hstart = floor(j * H / m) + # hend = ceil((j + 1) * H / m) + # wstart = floor(k * W / n) + # wend = ceil((k + 1) * W / n) + # output[:, :, i, j, k] = + # avg(input[:, :, dstart:dend, hstart: hend, wstart: wend]) + import paddle + import numpy as np + paddle.disable_static() + input_data = np.random.rand(2, 3, 8, 32, 32) + x = paddle.to_tensor(input_data) + # x.shape is [2, 3, 8, 32, 32] + pool_out = paddle.nn.functional.adaptive_avg_pool3d( + x = x, + output_size=[3, 3, 3]) + # pool_out.shape is [2, 3, 3, 3, 3] + """ + if not in_dygraph_mode(): + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], + 'adaptive_avg_pool3d') + check_type(data_format, 'data_format', str, 'adaptive_avg_pool3d') + + if data_format not in ["NCDHW", "NDHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCDHW' or 'NDHWC'. Received " + "Attr(data_format): %s." % str(data_format)) + + if data_format == "NCDHW": + in_l, in_h, in_w = x.shape[2:5] + else: + in_l, in_h, in_w = x.shape[1:4] + + if isinstance(output_size, int): + output_size = utils.convert_to_list(output_size, 3, 'output_size') + else: + if output_size[0] == None: + output_size[0] = in_l + if output_size[1] == None: + output_size[1] = in_h + if output_size[2] == None: + output_size[2] = in_w + + if in_dygraph_mode(): + output = core.ops.pool3d(x, 'pooling_type', 'avg', 'ksize', output_size, + 'global_pooling', False, 'adaptive', True, + 'data_format', data_format) + return output + + l_type = 'pool3d' + + helper = LayerHelper(l_type, **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + outputs = {"Out": pool_out} + + helper.append_op( + type=l_type, + inputs={"X": x}, + outputs=outputs, + attrs={ + "pooling_type": "avg", + "ksize": output_size, + "adaptive": True, + "data_format": data_format, + }) + + return pool_out diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 84d13c2211b..a89e7802830 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -52,6 +52,8 @@ from .common import Embedding #DEFINE_ALIAS from .common import Linear #DEFINE_ALIAS from .common import Flatten #DEFINE_ALIAS from .common import UpSample #DEFINE_ALIAS +from .pooling import AdaptiveAvgPool2d #DEFINE_ALIAS +from .pooling import AdaptiveAvgPool3d #DEFINE_ALIAS from .conv import Conv2D #DEFINE_ALIAS from .conv import Conv2DTranspose #DEFINE_ALIAS from .conv import Conv3D #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py new file mode 100755 index 00000000000..65ea6b0b05d --- /dev/null +++ b/python/paddle/nn/layer/pooling.py @@ -0,0 +1,196 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +from ...fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype +from ...fluid.layers import utils +from ...fluid.dygraph import layers +from ...fluid.layer_helper import LayerHelper +from .. import functional as F + +__all__ = [ + 'AdaptiveAvgPool2d', + 'AdaptiveAvgPool3d', +] + + +class AdaptiveAvgPool2d(layers.Layer): + """ + + This operation applies 2D adaptive avg pooling on input tensor. The h and w dimensions + of the output tensor are determined by the parameter output_size. + + For avg adaptive pool2d: + + .. math:: + + hstart &= floor(i * H_{in} / H_{out}) + + hend &= ceil((i + 1) * H_{in} / H_{out}) + + wstart &= floor(j * W_{in} / W_{out}) + + wend &= ceil((j + 1) * W_{in} / W_{out}) + + Output(i ,j) &= \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)} + + + Parameters: + output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain two element, (H, W). H and W can be either a int, or None which means + the size will be the same as that of the input. + data_format (str): The data format of the input and output data. An optional string + from: "NCHW", "NHWC". The default is "NCHW". When it is "NCHW", the data is stored in + the order of: [batch_size, input_channels, input_height, input_width]. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Shape: + x (Tensor): The input tensor of adaptive avg pool2d operator, which is a 4-D tensor. The data type can be float16, float32, float64, int32 or int64. + output (Tensor): The output tensor of adaptive avg pool2d operator, which is a 4-D tensor. The data type is same as input x. + + Returns: + A callable object of AdaptiveAvgPool2d. + + Examples: + .. code-block:: python + + # adaptive avg pool2d + # suppose input data in shape of [N, C, H, W], `output_size` is [m, n], + # output shape is [N, C, m, n], adaptive pool divide H and W dimensions + # of input data into m * n grids averagely and performs poolings in each + # grid to get output. + # adaptive avg pool performs calculations as follow: + # + # for i in range(m): + # for j in range(n): + # hstart = floor(i * H / m) + # hend = ceil((i + 1) * H / m) + # wstart = floor(i * W / n) + # wend = ceil((i + 1) * W / n) + # output[:, :, i, j] = avg(input[:, :, hstart: hend, wstart: wend]) + # + import paddle + import numpy as np + paddle.disable_static() + input_data = np.random.rand(2, 3, 32, 32) + x = paddle.to_tensor(input_data) + # x.shape is [2, 3, 32, 32] + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=3) + pool_out = adaptive_avg_pool(x = x) + # pool_out.shape is [2, 3, 3, 3] + """ + + def __init__(self, output_size, data_format="NCHW", name=None): + super(AdaptiveAvgPool2d, self).__init__() + self._output_size = output_size + self._data_format = data_format + self._name = name + + def forward(self, x): + return F.adaptive_avg_pool2d( + x, + output_size=self._output_size, + data_format=self._data_format, + name=self._name) + + +class AdaptiveAvgPool3d(layers.Layer): + """ + + This operation applies 3D adaptive avg pooling on input tensor. The h and w dimensions + of the output tensor are determined by the parameter output_size. + + For avg adaptive pool3d: + + .. math:: + + dstart &= floor(i * D_{in} / D_{out}) + + dend &= ceil((i + 1) * D_{in} / D_{out}) + + hstart &= floor(j * H_{in} / H_{out}) + + hend &= ceil((j + 1) * H_{in} / H_{out}) + + wstart &= floor(k * W_{in} / W_{out}) + + wend &= ceil((k + 1) * W_{in} / W_{out}) + + Output(i ,j, k) &= \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{(dend - dstart) * (hend - hstart) * (wend - wstart)} + + + Parameters: + output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain three elements, (D, H, W). D, H and W can be either a int, or None which means + the size will be the same as that of the input. + data_format (str): The data format of the input and output data. An optional string + from: "NCDHW", "NDHWC". The default is "NCDHW". When it is "NCDHW", the data is stored in + the order of: [batch_size, input_channels, input_depth, input_height, input_width]. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + Shape: + x (Tensor): The input tensor of adaptive avg pool3d operator, which is a 5-D tensor. The data type can be float16, float32, float64, int32 or int64. + output (Tensor): The output tensor of adaptive avg pool3d operator, which is a 5-D tensor. The data type is same as input x. + + Returns: + A callable object of AdaptiveAvgPool3d. + + Examples: + .. code-block:: python + + # adaptive avg pool3d + # suppose input data in shape of [N, C, D, H, W], `output_size` is [l, m, n], + # output shape is [N, C, l, m, n], adaptive pool divide D, H and W dimensions + # of input data into l * m * n grids averagely and performs poolings in each + # grid to get output. + # adaptive avg pool performs calculations as follow: + # + # for i in range(l): + # for j in range(m): + # for k in range(n): + # dstart = floor(i * D / l) + # dend = ceil((i + 1) * D / l) + # hstart = floor(j * H / m) + # hend = ceil((j + 1) * H / m) + # wstart = floor(k * W / n) + # wend = ceil((k + 1) * W / n) + # output[:, :, i, j, k] = + # avg(input[:, :, dstart:dend, hstart: hend, wstart: wend]) + import paddle + import numpy as np + paddle.disable_static() + input_data = np.random.rand(2, 3, 8, 32, 32) + x = paddle.to_tensor(input_data) + # x.shape is [2, 3, 8, 32, 32] + adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(output_size=3) + pool_out = adaptive_avg_pool(x = x) + # pool_out = [2, 3, 3, 3, 3] + """ + + def __init__(self, output_size, data_format="NCDHW", name=None): + super(AdaptiveAvgPool3d, self).__init__() + self._output_size = output_size + self._data_format = data_format + self._name = name + + def forward(self, x): + return F.adaptive_avg_pool3d( + x, + output_size=self._output_size, + data_format=self._data_format, + name=self._name) -- GitLab