diff --git a/python/paddle/fluid/tests/unittests/test_pool1d_api.py b/python/paddle/fluid/tests/unittests/test_pool1d_api.py new file mode 100644 index 0000000000000000000000000000000000000000..b1a25ad3529e8b0a4126bc458838ecd876e5af30 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pool1d_api.py @@ -0,0 +1,373 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +import paddle +import paddle.nn.functional as F +import paddle.fluid as fluid + + +def adaptive_start_index(index, input_size, output_size): + return int(np.floor(index * input_size / output_size)) + + +def adaptive_end_index(index, input_size, output_size): + return int(np.ceil((index + 1) * input_size / output_size)) + + +def max_pool1D_forward_naive(x, + ksize, + strides, + paddings, + global_pool=0, + ceil_mode=False, + exclusive=False, + adaptive=False, + data_type=np.float64): + N, C, L = x.shape + if global_pool == 1: + ksize = [L] + if adaptive: + L_out = ksize[0] + else: + L_out = (L - ksize[0] + 2 * paddings[0] + strides[0] - 1 + ) // strides[0] + 1 if ceil_mode else ( + L - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + + out = np.zeros((N, C, L_out)) + for i in range(L_out): + if adaptive: + r_start = adaptive_start_index(i, L, ksize[0]) + r_end = adaptive_end_index(i, L, ksize[0]) + else: + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], L)) + x_masked = x[:, :, r_start:r_end] + + out[:, :, i] = np.max(x_masked, axis=(2)) + return out + + +def avg_pool1D_forward_naive(x, + ksize, + strides, + paddings, + global_pool=0, + ceil_mode=False, + exclusive=False, + adaptive=False, + data_type=np.float64): + N, C, L = x.shape + if global_pool == 1: + ksize = [L] + if adaptive: + L_out = ksize[0] + else: + L_out = (L - ksize[0] + 2 * paddings[0] + strides[0] - 1 + ) // strides[0] + 1 if ceil_mode else ( + L - ksize[0] + 2 * paddings[0]) // strides[0] + 1 + + out = np.zeros((N, C, L_out)) + for i in range(L_out): + if adaptive: + r_start = adaptive_start_index(i, L, ksize[0]) + r_end = adaptive_end_index(i, L, ksize[0]) + else: + r_start = np.max((i * strides[0] - paddings[0], 0)) + r_end = np.min((i * strides[0] + ksize[0] - paddings[0], L)) + x_masked = x[:, :, r_start:r_end] + + field_size = (r_end - r_start) \ + if (exclusive or adaptive) else (ksize[0]) + if data_type == np.int8 or data_type == np.uint8: + out[:, :, i] = (np.rint( + np.sum(x_masked, axis=(2, 3)) / field_size)).astype(data_type) + else: + out[:, :, i] = (np.sum(x_masked, axis=(2)) / + field_size).astype(data_type) + return out + + +class TestPool1d_API(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_avg_static_results(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data(name="input", shape=[2, 3, 32], dtype="float32") + result = F.avg_pool1d(input, kernel_size=2, stride=2, padding=0) + + input_np = np.random.random([2, 3, 32]).astype("float32") + result_np = avg_pool1D_forward_naive( + input_np, ksize=[2], strides=[2], paddings=[0], ceil_mode=False) + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result]) + self.assertTrue(np.allclose(fetches[0], result_np)) + + def check_avg_dygraph_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = F.avg_pool1d(input, kernel_size=2, stride=2, padding=[0]) + + result_np = avg_pool1D_forward_naive( + input_np, ksize=[2], strides=[2], paddings=[0]) + + self.assertTrue(np.allclose(result.numpy(), result_np)) + + avg_pool1d_dg = paddle.nn.layer.AvgPool1d( + kernel_size=2, stride=None, padding=0) + result = avg_pool1d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_max_static_results(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data(name="input", shape=[2, 3, 32], dtype="float32") + result = F.max_pool1d(input, kernel_size=2, stride=2, padding=[0]) + + input_np = np.random.random([2, 3, 32]).astype("float32") + result_np = max_pool1D_forward_naive( + input_np, ksize=[2], strides=[2], paddings=[0]) + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result]) + self.assertTrue(np.allclose(fetches[0], result_np)) + + def check_max_dygraph_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = F.max_pool1d(input, kernel_size=2, stride=2, padding=0) + + result_np = max_pool1D_forward_naive( + input_np, ksize=[2], strides=[2], paddings=[0]) + + self.assertTrue(np.allclose(result.numpy(), result_np)) + + max_pool1d_dg = paddle.nn.layer.MaxPool1d( + kernel_size=2, stride=None, padding=0) + result = max_pool1d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_adaptive_max_dygraph_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = F.adaptive_max_pool1d(input, output_size=16) + + result_np = max_pool1D_forward_naive( + input_np, ksize=[16], strides=[0], paddings=[0], adaptive=True) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + ada_max_pool1d_dg = paddle.nn.layer.AdaptiveMaxPool1d( + output_size=16) + result = ada_max_pool1d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_adaptive_avg_dygraph_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = F.adaptive_avg_pool1d(input, output_size=16) + result_np = avg_pool1D_forward_naive( + input_np, ksize=[16], strides=[0], paddings=[0], adaptive=True) + + self.assertTrue(np.allclose(result.numpy(), result_np)) + + ada_max_pool1d_dg = paddle.nn.layer.AdaptiveAvgPool1d( + output_size=16) + result = ada_max_pool1d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_adaptive_max_static_results(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data(name="input", shape=[2, 3, 32], dtype="float32") + result = F.adaptive_max_pool1d(input, output_size=16) + + input_np = np.random.random([2, 3, 32]).astype("float32") + result_np = max_pool1D_forward_naive( + input_np, ksize=[16], strides=[2], paddings=[0], adaptive=True) + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result]) + self.assertTrue(np.allclose(fetches[0], result_np)) + + def check_adaptive_avg_static_results(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data(name="input", shape=[2, 3, 32], dtype="float32") + result = F.adaptive_avg_pool1d(input, output_size=16) + + input_np = np.random.random([2, 3, 32]).astype("float32") + result_np = avg_pool1D_forward_naive( + input_np, ksize=[16], strides=[2], paddings=[0], adaptive=True) + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result]) + self.assertTrue(np.allclose(fetches[0], result_np)) + + def check_max_dygraph_padding_same(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = F.max_pool1d( + input, kernel_size=2, stride=2, padding="SAME") + + result_np = max_pool1D_forward_naive( + input_np, ksize=[2], strides=[2], paddings=[0]) + + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_avg_dygraph_padding_same(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = F.avg_pool1d( + input, kernel_size=2, stride=2, padding="SAME") + + result_np = avg_pool1D_forward_naive( + input_np, ksize=[2], strides=[2], paddings=[0]) + + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def test_pool1d(self): + for place in self.places: + + self.check_max_dygraph_results(place) + self.check_avg_dygraph_results(place) + self.check_max_static_results(place) + self.check_avg_static_results(place) + self.check_adaptive_max_dygraph_results(place) + self.check_adaptive_avg_dygraph_results(place) + self.check_adaptive_max_static_results(place) + self.check_adaptive_avg_static_results(place) + self.check_max_dygraph_padding_same(place) + self.check_avg_dygraph_padding_same(place) + + +class TestPool2dError_API(unittest.TestCase): + def test_error_api(self): + def run1(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = [[2]] + res_pd = F.max_pool1d( + input_pd, kernel_size=2, stride=2, padding=padding) + + self.assertRaises(ValueError, run1) + + def run2(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = [[2]] + res_pd = F.max_pool1d( + input_pd, kernel_size=2, stride=2, padding=padding) + + self.assertRaises(ValueError, run2) + + def run3(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "padding" + res_pd = F.max_pool1d( + input_pd, kernel_size=2, stride=2, padding=padding) + + self.assertRaises(ValueError, run3) + + def run4(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "VALID" + res_pd = F.max_pool1d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + ceil_mode=True) + + self.assertRaises(ValueError, run4) + + def run5(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "VALID" + res_pd = F.max_pool1d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + ceil_mode=True) + + self.assertRaises(ValueError, run5) + + def run6(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "VALID" + res_pd = F.avg_pool1d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + ceil_mode=True) + + self.assertRaises(ValueError, run6) + + def run7(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "paddle" + res_pd = F.avg_pool1d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + ceil_mode=True) + + self.assertRaises(ValueError, run7) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_api.py b/python/paddle/fluid/tests/unittests/test_pool2d_api.py new file mode 100644 index 0000000000000000000000000000000000000000..73df0885d8fed4ddc4c03c91d2c331e72772e398 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pool2d_api.py @@ -0,0 +1,375 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test_pool2d_op import adaptive_start_index, adaptive_end_index, pool2D_forward_naive +import unittest +from op_test import OpTest +import numpy as np +import paddle.fluid.core as core +from paddle.nn.functional import * +import paddle.fluid as fluid +import paddle + + +class TestPool2d_API(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_avg_static_results(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data( + name="input", shape=[2, 3, 32, 32], dtype="float32") + result = avg_pool2d(input, kernel_size=2, stride=2, padding=0) + + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + result_np = pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[0, 0], + pool_type='avg') + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result]) + self.assertTrue(np.allclose(fetches[0], result_np)) + + def check_avg_dygraph_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = avg_pool2d(input, kernel_size=2, stride=2, padding=0) + + result_np = pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[0, 0], + pool_type='avg') + self.assertTrue(np.allclose(result.numpy(), result_np)) + + avg_pool2d_dg = paddle.nn.layer.AvgPool2d( + kernel_size=2, stride=2, padding=0) + result = avg_pool2d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_max_static_results(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data( + name="input", shape=[2, 3, 32, 32], dtype="float32") + result = max_pool2d(input, kernel_size=2, stride=2, padding=0) + + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + result_np = pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[0, 0], + pool_type='max') + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result]) + self.assertTrue(np.allclose(fetches[0], result_np)) + + def check_max_dygraph_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = max_pool2d( + input, kernel_size=2, stride=2, padding=0, return_indices=False) + + result_np = pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[0, 0], + pool_type='max') + self.assertTrue(np.allclose(result.numpy(), result_np)) + + max_pool2d_dg = paddle.nn.layer.MaxPool2d( + kernel_size=2, stride=2, padding=0) + result = max_pool2d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_max_dygraph_stride_is_none(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result, indices = max_pool2d( + input, + kernel_size=2, + stride=None, + padding="SAME", + return_indices=True) + + result_np = pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[0, 0], + pool_type='max', + padding_algorithm="SAME") + self.assertTrue(np.allclose(result.numpy(), result_np)) + + max_pool2d_dg = paddle.nn.layer.MaxPool2d( + kernel_size=2, stride=2, padding=0) + result = max_pool2d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_avg_dygraph_stride_is_none(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = avg_pool2d( + input, kernel_size=2, stride=None, padding="SAME") + + result_np = pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[0, 0], + pool_type='avg', + padding_algorithm="SAME") + self.assertTrue(np.allclose(result.numpy(), result_np)) + + avg_pool2d_dg = paddle.nn.layer.AvgPool2d( + kernel_size=2, stride=2, padding=0) + result = avg_pool2d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_max_dygraph_padding(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + padding = [[0, 0], [0, 0], [0, 0], [0, 0]] + result = max_pool2d( + input, + kernel_size=2, + stride=2, + padding=padding, + return_indices=False) + + result_np = pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[0, 0], + pool_type='max') + self.assertTrue(np.allclose(result.numpy(), result_np)) + + max_pool2d_dg = paddle.nn.layer.MaxPool2d( + kernel_size=2, stride=2, padding=0) + result = max_pool2d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_avg_divisor(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + padding = [[0, 0], [0, 0], [0, 0], [0, 0]] + result = avg_pool2d( + input, + kernel_size=2, + stride=2, + padding=padding, + divisor_override=4) + + result_np = pool2D_forward_naive( + input_np, + ksize=[2, 2], + strides=[2, 2], + paddings=[0, 0], + pool_type='avg') + self.assertTrue(np.allclose(result.numpy(), result_np)) + + avg_pool2d_dg = paddle.nn.layer.AvgPool2d( + kernel_size=2, stride=2, padding=0) + result = avg_pool2d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def test_pool2d(self): + for place in self.places: + + self.check_max_dygraph_results(place) + self.check_avg_dygraph_results(place) + self.check_max_static_results(place) + self.check_avg_static_results(place) + self.check_max_dygraph_stride_is_none(place) + self.check_avg_dygraph_stride_is_none(place) + self.check_max_dygraph_padding(place) + self.check_avg_divisor(place) + + +class TestPool2dError_API(unittest.TestCase): + def test_error_api(self): + def run1(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = [[0, 1], [0, 0], [0, 0], [0, 0]] + res_pd = max_pool2d( + input_pd, kernel_size=2, stride=2, padding=padding) + + self.assertRaises(ValueError, run1) + + def run2(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = [[0, 1], [0, 0], [0, 0], [0, 0]] + res_pd = max_pool2d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + data_format='NHWC') + + self.assertRaises(ValueError, run2) + + def run3(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "padding" + res_pd = max_pool2d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + data_format='NHWC') + + self.assertRaises(ValueError, run3) + + def run3_avg(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "padding" + res_pd = avg_pool2d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + data_format='NHWC') + + self.assertRaises(ValueError, run3_avg) + + def run4(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "VALID" + res_pd = max_pool2d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + ceil_mode=True, + data_format='NHWC') + + self.assertRaises(ValueError, run4) + + def run4_avg(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "VALID" + res_pd = avg_pool2d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + ceil_mode=True, + data_format='NHWC') + + self.assertRaises(ValueError, run4_avg) + + def run5(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "padding" + res_pd = avg_pool2d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + data_format='NHWC') + + self.assertRaises(ValueError, run5) + + def run6(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "VALID" + res_pd = avg_pool2d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + ceil_mode=True, + data_format='NHWC') + + self.assertRaises(ValueError, run6) + + def run7(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "VALID" + res_pd = avg_pool2d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + ceil_mode=False, + data_format='NNNN') + + self.assertRaises(ValueError, run7) + + def run8(): + with fluid.dygraph.guard(): + input_np = np.random.uniform(-1, 1, + [2, 3, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = "VALID" + res_pd = max_pool2d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + ceil_mode=False, + data_format='NNNN') + + self.assertRaises(ValueError, run8) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_api.py b/python/paddle/fluid/tests/unittests/test_pool3d_api.py new file mode 100644 index 0000000000000000000000000000000000000000..cc078e9aae7aafe55e937b80270dd012fd64ff70 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pool3d_api.py @@ -0,0 +1,341 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from __future__ import division + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +from op_test import OpTest +import paddle.fluid as fluid +from paddle.nn.functional import avg_pool3d, max_pool3d +from test_pool3d_op import adaptive_start_index, adaptive_end_index, pool3D_forward_naive + + +class TestPool3d_API(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_avg_static_results(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data( + name="input", shape=[2, 3, 32, 32, 32], dtype="float32") + result = avg_pool3d(input, kernel_size=2, stride=2, padding=0) + + input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") + result_np = pool3D_forward_naive( + input_np, + ksize=[2, 2, 2], + strides=[2, 2, 2], + paddings=[0, 0, 0], + pool_type='avg') + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result]) + self.assertTrue(np.allclose(fetches[0], result_np)) + + def check_avg_dygraph_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = avg_pool3d(input, kernel_size=2, stride=2, padding="SAME") + + result_np = pool3D_forward_naive( + input_np, + ksize=[2, 2, 2], + strides=[2, 2, 2], + paddings=[0, 0, 0], + pool_type='avg', + padding_algorithm="SAME") + + self.assertTrue(np.allclose(result.numpy(), result_np)) + + avg_pool3d_dg = paddle.nn.layer.AvgPool3d( + kernel_size=2, stride=None, padding="SAME") + result = avg_pool3d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_max_static_results(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data( + name="input", shape=[2, 3, 32, 32, 32], dtype="float32") + result = max_pool3d(input, kernel_size=2, stride=2, padding=0) + + input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") + result_np = pool3D_forward_naive( + input_np, + ksize=[2, 2, 2], + strides=[2, 2, 2], + paddings=[0, 0, 0], + pool_type='max') + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result]) + self.assertTrue(np.allclose(fetches[0], result_np)) + + def check_max_dygraph_results(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result = max_pool3d(input, kernel_size=2, stride=2, padding=0) + + result_np = pool3D_forward_naive( + input_np, + ksize=[2, 2, 2], + strides=[2, 2, 2], + paddings=[0, 0, 0], + pool_type='max') + + self.assertTrue(np.allclose(result.numpy(), result_np)) + max_pool3d_dg = paddle.nn.layer.MaxPool3d( + kernel_size=2, stride=None, padding=0) + result = max_pool3d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_max_dygraph_stride_is_none(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + result, indices = max_pool3d( + input, + kernel_size=2, + stride=None, + padding="SAME", + return_indices=True) + + result_np = pool3D_forward_naive( + input_np, + ksize=[2, 2, 2], + strides=[2, 2, 2], + paddings=[0, 0, 0], + pool_type='max', + padding_algorithm="SAME") + + self.assertTrue(np.allclose(result.numpy(), result_np)) + max_pool3d_dg = paddle.nn.layer.MaxPool3d( + kernel_size=2, stride=2, padding=0) + result = max_pool3d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_max_dygraph_padding(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + padding = [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]] + result = max_pool3d(input, kernel_size=2, stride=2, padding=padding) + + result_np = pool3D_forward_naive( + input_np, + ksize=[2, 2, 2], + strides=[2, 2, 2], + paddings=[0, 0, 0], + pool_type='max') + + self.assertTrue(np.allclose(result.numpy(), result_np)) + max_pool3d_dg = paddle.nn.layer.MaxPool3d( + kernel_size=2, stride=2, padding=0) + result = max_pool3d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + padding = [0, 0, 0, 0, 0, 0] + result = max_pool3d(input, kernel_size=2, stride=2, padding=padding) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def check_avg_divisor(self, place): + with fluid.dygraph.guard(place): + input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") + input = fluid.dygraph.to_variable(input_np) + padding = 0 + result = avg_pool3d( + input, + kernel_size=2, + stride=2, + padding=padding, + divisor_override=8) + + result_np = pool3D_forward_naive( + input_np, + ksize=[2, 2, 2], + strides=[2, 2, 2], + paddings=[0, 0, 0], + pool_type='avg') + + self.assertTrue(np.allclose(result.numpy(), result_np)) + avg_pool3d_dg = paddle.nn.layer.AvgPool3d( + kernel_size=2, stride=2, padding=0) + result = avg_pool3d_dg(input) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + padding = [0, 0, 0, 0, 0, 0] + result = avg_pool3d( + input, + kernel_size=2, + stride=2, + padding=padding, + divisor_override=8) + self.assertTrue(np.allclose(result.numpy(), result_np)) + + def test_pool3d(self): + for place in self.places: + + self.check_max_dygraph_results(place) + self.check_avg_dygraph_results(place) + self.check_max_static_results(place) + self.check_avg_static_results(place) + self.check_max_dygraph_stride_is_none(place) + self.check_max_dygraph_padding(place) + self.check_avg_divisor(place) + + +class TestPool3dError_API(unittest.TestCase): + def test_error_api(self): + def run1(): + with fluid.dygraph.guard(): + input_np = np.random.uniform( + -1, 1, [2, 3, 32, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = [[0, 1], [0, 0], [0, 0], [0, 0], [0, 0]] + res_pd = avg_pool3d( + input_pd, kernel_size=2, stride=2, padding=padding) + + self.assertRaises(ValueError, run1) + + def run2(): + with fluid.dygraph.guard(): + input_np = np.random.uniform( + -1, 1, [2, 3, 32, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = [[0, 1], [0, 0], [0, 0], [0, 0], [0, 0]] + res_pd = avg_pool3d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + data_format='NCDHW') + + self.assertRaises(ValueError, run2) + + def run3(): + with fluid.dygraph.guard(): + input_np = np.random.uniform( + -1, 1, [2, 3, 32, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + padding = [[0, 1], [0, 0], [0, 0], [0, 0], [0, 0]] + res_pd = avg_pool3d( + input_pd, + kernel_size=2, + stride=2, + padding=padding, + data_format='NDHWC') + + self.assertRaises(ValueError, run3) + + def run4(): + with fluid.dygraph.guard(): + input_np = np.random.uniform( + -1, 1, [2, 3, 32, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + res_pd = avg_pool3d( + input_pd, + kernel_size=2, + stride=2, + padding=0, + data_format='NNNN') + + self.assertRaises(ValueError, run4) + + def run5(): + with fluid.dygraph.guard(): + input_np = np.random.uniform( + -1, 1, [2, 3, 32, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + res_pd = max_pool3d( + input_pd, + kernel_size=2, + stride=2, + padding=0, + data_format='NNNN') + + self.assertRaises(ValueError, run5) + + def run6(): + with fluid.dygraph.guard(): + input_np = np.random.uniform( + -1, 1, [2, 3, 32, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + res_pd = avg_pool3d( + input_pd, + kernel_size=2, + stride=2, + padding="padding", + data_format='NNNN') + + self.assertRaises(ValueError, run6) + + def run7(): + with fluid.dygraph.guard(): + input_np = np.random.uniform( + -1, 1, [2, 3, 32, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + res_pd = max_pool3d( + input_pd, + kernel_size=2, + stride=2, + padding="padding", + data_format='NNNN') + + self.assertRaises(ValueError, run7) + + def run8(): + with fluid.dygraph.guard(): + input_np = np.random.uniform( + -1, 1, [2, 3, 32, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + res_pd = avg_pool3d( + input_pd, + kernel_size=2, + stride=2, + padding="VALID", + ceil_mode=True, + data_format='NNNN') + + self.assertRaises(ValueError, run8) + + def run9(): + with fluid.dygraph.guard(): + input_np = np.random.uniform( + -1, 1, [2, 3, 32, 32, 32]).astype(np.float32) + input_pd = fluid.dygraph.to_variable(input_np) + res_pd = max_pool3d( + input_pd, + kernel_size=2, + stride=2, + padding="VALID", + ceil_mode=True, + data_format='NNNN') + + self.assertRaises(ValueError, run9) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index a952cd587be839dda450610d87361b4729376313..53f5954279311b2765d1acb14ed66193c2c2f52d 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -25,6 +25,8 @@ from . import extension __all__ += extension.__all__ from . import common __all__ += common.__all__ +from . import pooling +__all__ += pooling.__all__ from . import loss __all__ += loss.__all__ from .activation import brelu #DEFINE_ALIAS @@ -166,10 +168,18 @@ from .norm import l2_normalize #DEFINE_ALIAS from .norm import lrn #DEFINE_ALIAS from .norm import normalize #DEFINE_ALIAS # from .norm import spectral_norm #DEFINE_ALIAS +from .pooling import max_pool1d #DEFINE_ALIAS +from .pooling import avg_pool1d #DEFINE_ALIAS +from .pooling import adaptive_max_pool1d #DEFINE_ALIAS +from .pooling import adaptive_avg_pool1d #DEFINE_ALIAS from .pooling import pool2d #DEFINE_ALIAS from .pooling import pool3d #DEFINE_ALIAS from .pooling import adaptive_pool2d #DEFINE_ALIAS from .pooling import adaptive_pool3d #DEFINE_ALIAS +from .pooling import avg_pool2d #DEFINE_ALIAS +from .pooling import max_pool2d #DEFINE_ALIAS +from .pooling import avg_pool3d #DEFINE_ALIAS +from .pooling import max_pool3d #DEFINE_ALIAS from .pooling import adaptive_avg_pool2d #DEFINE_ALIAS from .pooling import adaptive_avg_pool3d #DEFINE_ALIAS # from .rnn import gru_unit #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index c396d00320af3e41a4d26db7c98db9e2b14d6602..96d361e7ecf318df02417fdf6ed011a3f5ff328a 100644 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -13,23 +13,1208 @@ # limitations under the License. # TODO: define pooling functions -import paddle -from ...fluid import core from ...fluid.layers import pool2d #DEFINE_ALIAS from ...fluid.layers import pool3d #DEFINE_ALIAS from ...fluid.layers import adaptive_pool2d #DEFINE_ALIAS from ...fluid.layers import adaptive_pool3d #DEFINE_ALIAS -from ...fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype -from ...fluid.layers import utils -from ...fluid.layer_helper import LayerHelper -from ...fluid.framework import in_dygraph_mode +from ...fluid import core +from ...fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_ +from ...fluid.layers import utils, LayerHelper +from ...fluid.data_feeder import check_type, check_variable_and_dtype, check_type, check_dtype, convert_dtype +from ...fluid.layers import unsqueeze, squeeze __all__ = [ - 'pool2d', 'pool3d', 'adaptive_pool2d', 'adaptive_pool3d', - 'adaptive_avg_pool2d', 'adaptive_avg_pool3d' + 'pool2d', + 'pool3d', + 'avg_pool1d', + 'max_pool1d', + 'adaptive_avg_pool1d', + 'adaptive_max_pool1d', + 'adaptive_avg_pool2d', + 'adaptive_avg_pool3d', + 'adaptive_pool2d', + 'adaptive_pool3d', + 'max_pool2d', + 'avg_pool2d', + 'max_pool3d', + 'avg_pool3d', ] +def check_input(x, dimension): + if len(x.shape) != dimension: + raise ValueError("Excepted Input X is 3-D tensor, but received {}-D {}". + format(len(x.shape), type(x))) + + +def check_instance(x, x_name, types=(int, float)): + + if not isinstance(x, types): + raise ValueError("Excepted {} type for {} but received type: {}. ". + format(types, x_name, type(x))) + + +def update_padding1d(padding, pool_type='avg'): + def is_list_or_tuple(ele): + if isinstance(ele, list) or isinstance(ele, tuple): + return True + return False + + if is_list_or_tuple(padding): + if padding.__len__() == 1 and not is_list_or_tuple(padding[0]): + return [0, padding[0]] + else: + raise ValueError( + "{}_pool1d() argument 'padding' should contain one int (got {})". + format(pool_type, padding.__len__())) + else: + padding = [0, padding] + + return padding + + +def update_padding2d(padding, data_format): + def is_list_or_tuple(ele): + if isinstance(ele, list) or isinstance(ele, tuple): + return True + return False + + if is_list_or_tuple(padding) and len(padding) == 4: + if is_list_or_tuple(padding[0]) and (data_format == "NCHW"): + if not (padding[0] == [0, 0] and padding[1] == [0, 0]): + raise ValueError( + "Non-zero pool_padding(%s) in the batch or channel dimensions " + "is not supported." % str(padding)) + padding = padding[2:4] + padding = [ele for a_list in padding for ele in a_list] + elif is_list_or_tuple(padding[0]) and (data_format == "NHWC"): + if not (padding[0] == [0, 0] and padding[3] == [0, 0]): + raise ValueError( + "Non-zero pool_padding(%s) in the batch or channel dimensions " + "is not supported." % str(padding)) + padding = padding[1:3] + padding = [ele for a_list in padding for ele in a_list] + padding = utils.convert_to_list(padding, 4, 'padding') + + if utils._is_symmetric_padding(padding, 2): + padding = [padding[0], padding[2]] + else: + padding = utils.convert_to_list(padding, 2, 'padding') + + return padding + + +def update_padding3d(padding, data_format): + def is_list_or_tuple(ele): + if isinstance(ele, (list, tuple)): + return True + return False + + if is_list_or_tuple(padding) and len(padding) == 5: + if is_list_or_tuple(padding[0]) and (data_format == "NCDHW"): + if not (padding[0] == [0, 0] and padding[1] == [0, 0]): + raise ValueError( + "Non-zero pool_padding(%s) in the batch or channel dimensions " + "is not supported." % str(padding)) + padding = padding[2:5] + padding = [ele for a_list in padding for ele in a_list] + elif is_list_or_tuple(padding[0]) and (data_format == "NDHWC"): + if not (padding[0] == [0, 0] and padding[4] == [0, 0]): + raise ValueError( + "Non-zero pool_padding(%s) in the batch or channel dimensions " + "is not supported." % str(padding)) + padding = padding[1:4] + padding = [ele for a_list in padding for ele in a_list] + padding = utils.convert_to_list(padding, 6, 'padding') + if utils._is_symmetric_padding(padding, 3): + padding = [padding[0], padding[2], padding[4]] + + elif is_list_or_tuple(padding) and len(padding) == 6: + padding = utils.convert_to_list(padding, 6, 'padding') + if utils._is_symmetric_padding(padding, 3): + padding = [padding[0], padding[2], padding[4]] + else: + padding = utils.convert_to_list(padding, 3, 'padding') + + return padding + + +def avg_pool1d(x, + kernel_size, + stride=None, + padding=0, + count_include_pad=True, + ceil_mode=False, + name=None): + """ + + This operation applies a 1D average pooling over an input signal composed + of several input planes, based on the input, output_size, return_indices parameters. + Input(X) and output(Out) are in NCL format, where N is batch + size, C is the number of channels, L is the length of the feature. + The output tensor shape will be [N, C, output_size]. + + The output value of the layer with input size (N, C, L), + output (N, C, L_{out}) and kernel_size k can be precisely described as + For average pool1d: + + .. math:: + + Output(N_i, C_i, l) &= mean(Input[N_i, C_i, stride \times l:stride \times l+k]) + + + Args: + x (Tensor): The input tensor of pooling operator which is a 3-D tensor with + shape [N, C, L]. where `N` is batch size, `C` is the number of channels, + `L` is the length of the feature. The data type if float32 or float64. + kernel_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain one integers. + stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list, + it must contain one integers. + padding (string|int|list|tuple): The pool padding. If `pool_padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If pool padding size is a tuple or list, + it could be the following forms: `[pad_left, pad_right]`. If padding is non-zero, + then the input is implicitly zero-padded on both sides for padding number of points. + count_include_pad (bool): Whether to exclude padding points in average pooling + mode, default is `true`. + ceil_mode (bool): ${ceil_mode_comment}Whether to use the ceil function to calculate output height and width. + If it is set to False, the floor function will be used. Default False + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: The output tensor of pooling result. The data type is same as input tensor. + + Raises: + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is "VALID", but `ceil_mode` is True. + ValueError: If `padding` is a list or tuple but its length greater than 1. + ShapeError: If the input is not a 3-D. + ShapeError: If the output's shape calculated is not greater than 0. + + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + paddle.disable_static() + + data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32]).astype(np.float32)) + pool_out = F.avg_pool1d(data, kernel_size=2, stride=2, padding=0) + # pool_out shape: [1, 3, 16] + + """ + """NCL to NCHW""" + data_format = "NCHW" + check_variable_and_dtype(x, 'input', ['float32', 'float64'], 'avg_pool1d') + check_input(x, 3) + x = unsqueeze(x, [2]) + kernel_size = utils.convert_to_list(kernel_size, 1, 'pool_size') + kernel_size = [1] + kernel_size + if stride is None: + stride = kernel_size + else: + stride = utils.convert_to_list(stride, 1, 'pool_stride') + stride = [1] + stride + + padding_algorithm = "EXPLICIT" + if isinstance(padding, str): + padding = padding.upper() + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Unknown Attr(padding): '%s'. It can only be 'SAME' or 'VALID'." + % str(padding)) + if padding == "VALID": + padding_algorithm = "VALID" + padding = [0] + if ceil_mode != False: + raise ValueError( + "When Attr(padding) is \"VALID\", Attr(ceil_mode) must be False. " + "Received ceil_mode: True.") + elif padding == "SAME": + padding_algorithm = "SAME" + padding = [0] + + padding = update_padding1d(padding, "avg") + + if in_dygraph_mode(): + output = core.ops.pool2d( + x, 'pooling_type', 'avg', 'ksize', kernel_size, 'global_pooling', + False, 'strides', stride, 'paddings', padding, 'padding_algorithm', + padding_algorithm, 'use_cudnn', not count_include_pad, 'ceil_mode', + ceil_mode, 'use_mkldnn', False, 'exclusive', True, 'data_format', + data_format) + return squeeze(output, [2]) + + op_type = 'pool2d' + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type=op_type, + inputs={"X": x}, + outputs={"Out": pool_out}, + attrs={ + "pooling_type": 'avg', + "ksize": kernel_size, + "global_pooling": False, + "strides": stride, + "paddings": padding, + "padding_algorithm": padding_algorithm, + "use_cudnn": True, + "ceil_mode": ceil_mode, + "use_mkldnn": False, + "exclusive": not count_include_pad, + "data_format": data_format, + }) + + return squeeze(pool_out, [2]) + + +def max_pool1d(x, + kernel_size, + stride=None, + padding=0, + return_indices=False, + ceil_mode=False, + name=None): + """ + + Applies a 1D max pooling over an input signal composed of several input planes based + on the input, output_size, return_indices parameters. + Input(X) and output(Out) are in NCL format, where N is batch + size, C is the number of channels, L is the length of the feature. + + The output value of the layer with input size (N, C, L), + output (N, C, L_{out}) and kernel_size k can be precisely described as + For average pool1d: + + .. math:: + + Output(N_i, C_i, l) &= max(Input[N_i, C_i, stride \times l:stride \times l+k])} + + Args: + x (Tensor): The input tensor of pooling operator which is a 3-D tensor with + shape [N, C, L], where `N` is batch size, `C` is the number of channels, + `L` is the length of the feature. The data type if float32 or float64. + kernel_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain one integers. + stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list, + it must contain one integers. + padding (string|int|list|tuple): The pool padding. If `pool_padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If pool padding size is a tuple or list, + it could be the following forms: `[pad_left, pad_right]`. + return_indices (bool): Whether return the max indices along with the outputs. default is `False`. + ceil_mode (bool): Whether to use the ceil function to calculate output height and width. False is the default. + If it is set to False, the floor function will be used. Default False. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: The output tensor of pooling result. The data type is same as input tensor. + + Raises: + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is "VALID", but `ceil_mode` is True. + ValueError: If `padding` is a list or tuple but its length greater than 1. + ShapeError: If the input is not a 3-D. + ShapeError: If the output's shape calculated is not greater than 0. + + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn.functional as F + paddle.disable_static() + + data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32]).astype(np.float32)) + pool_out = F.max_pool1d(data, kernel_size=2, stride=2, padding=0) + # pool_out shape: [1, 3, 16] + + pool_out, indices = F.max_pool1d(data, kernel_size=2, stride=2, padding=0, return_indices=True) + # pool_out shape: [1, 3, 16], indices shape: [1, 3, 16] + + """ + """NCL to NCHW""" + data_format = "NCHW" + check_variable_and_dtype(x, 'input', ['float32', 'float64'], 'max_pool1d') + check_input(x, 3) + x = unsqueeze(x, [2]) + kernel_size = [1] + utils.convert_to_list(kernel_size, 1, 'pool_size') + if stride is None: + stride = kernel_size + else: + stride = [1] + utils.convert_to_list(stride, 1, 'pool_stride') + + padding_algorithm = "EXPLICIT" + if isinstance(padding, str): + padding = padding.upper() + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Unknown Attr(padding): '%s'. It can only be 'SAME' or 'VALID'." + % str(padding)) + if padding == "VALID": + padding_algorithm = "VALID" + padding = [0] + if ceil_mode != False: + raise ValueError( + "When Attr(padding) is \"VALID\", Attr(ceil_mode) must be False. " + "Received ceil_mode: True.") + elif padding == "SAME": + padding_algorithm = "SAME" + padding = [0] + + padding = update_padding1d(padding, 'max') + + if in_dygraph_mode(): + pool_out = core.ops.max_pool2d_with_index( + x, 'ksize', kernel_size, 'global_pooling', False, 'strides', stride, + 'paddings', padding, 'padding_algorithm', padding_algorithm, + 'use_cudnn', True, 'ceil_mode', ceil_mode, 'use_mkldnn', False, + 'exclusive', True, 'data_format', data_format) + return (squeeze(pool_out[0], [2]), squeeze( + pool_out[1], [2])) if return_indices else squeeze(pool_out[0], [2]) + + op_type = 'max_pool2d_with_index' + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + mask = helper.create_variable_for_type_inference(dtype) + outputs = {"Out": pool_out, "Mask": mask} + + helper.append_op( + type=op_type, + inputs={"X": x}, + outputs=outputs, + attrs={ + "pooling_type": 'max', + "ksize": kernel_size, + "global_pooling": False, + "strides": stride, + "paddings": padding, + "padding_algorithm": padding_algorithm, + "use_cudnn": True, + "ceil_mode": ceil_mode, + "use_mkldnn": False, + "exclusive": True, + "data_format": data_format, + }) + + return (squeeze(pool_out, [2]), + squeeze(mask, [2])) if return_indices else squeeze(pool_out, [2]) + + +def adaptive_avg_pool1d(x, output_size, name=None): + """ + + This operation applies a 1D adaptive average pooling over an input signal composed + of several input planes, based on the input, output_size, return_indices parameters. + Input(X) and output(Out) are in NCL format, where N is batch + size, C is the number of channels, L is the length of the feature. + The output tensor shape will be [N, C, output_size]. + + For average adaptive pool1d: + + .. math:: + + lstart &= floor(i * L_{in} / L_{out}) + + lend &= ceil((i + 1) * L_{in} / L_{out}) + + Output(i) &= \\frac{sum(Input[lstart:lend])}{(lstart - lend)} + + Args: + x (Tensor): The input tensor of pooling operator, which is a 3-D tensor + with shape [N, C, L]. The format of input tensor is NCL, + where N is batch size, C is the number of channels, L is the + length of the feature. The data type is float32 or float64. + output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain one int. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: The output tensor of adaptive average pooling result. The data type is same + as input tensor. + + Raises: + ValueError: 'output_size' should be a integer or list or tuple with length as 1. + + Examples: + .. code-block:: python + + # average adaptive pool1d + # suppose input data in shape of [N, C, L], `output_size` is m or [m], + # output shape is [N, C, m], adaptive pool divide L dimension + # of input data into m grids averagely and performs poolings in each + # grid to get output. + # adaptive max pool performs calculations as follow: + # + # for i in range(m): + # lstart = floor(i * L / m) + # lend = ceil((i + 1) * L / m) + # output[:, :, i] = sum(input[:, :, lstart: lend])/(lstart - lend) + # + import paddle + import paddle.nn.functional as F + paddle.disable_static() + + data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32]).astype(np.float32)) + pool_out = F.adaptive_average_pool1d(data, output_size=16) + # pool_out shape: [1, 3, 16]) + """ + pool_type = 'avg' + check_variable_and_dtype(x, 'input', ['float32', 'float64'], + 'adaptive_pool2d') + check_input(x, 3) + check_type(output_size, 'pool_size', (int), 'adaptive_pool1d') + + pool_size = [1] + utils.convert_to_list(output_size, 1, 'pool_size') + + l_type = "pool2d" + x = unsqueeze(x, [2]) + if in_dygraph_mode(): + pool_out = core.ops.pool2d(x, 'pooling_type', pool_type, 'ksize', + pool_size, 'adaptive', True) + return squeeze(pool_out, [2]) + + helper = LayerHelper(l_type, **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + + outputs = {"Out": pool_out} + helper.append_op( + type=l_type, + inputs={"X": x}, + outputs=outputs, + attrs={ + "pooling_type": pool_type, + "ksize": pool_size, + "adaptive": True, + }) + + return squeeze(pool_out, [2]) + + +def adaptive_max_pool1d(x, output_size, return_indices=False, name=None): + """ + This operation applies a 1D adaptive max pooling over an input signal composed + of several input planes, based on the input, output_size, return_indices parameters. + Input(X) and output(Out) are in NCL format, where N is batch + size, C is the number of channels, L is the length of the feature. + The output tensor shape will be [N, C, output_size]. + + For max adaptive pool1d: + + .. math:: + + lstart &= floor(i * L_{in} / L_{out}) + + lend &= ceil((i + 1) * L_{in} / L_{out}) + + Output(i) &= max(Input[lstart:lend])} + + Args: + x (Tensor): The input tensor of pooling operator, which is a 3-D tensor + with shape [N, C, L]. The format of input tensor is NCL, + where N is batch size, C is the number of channels, L is the + length of the feature. The data type is float32 or float64. + output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain one int. + return_indices (bool): If true, the index of max pooling point will be returned along + with outputs. It cannot be set in average pooling type. Default False. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: The output tensor of adaptive pooling result. The data type is same + as input tensor. + + Raises: + ValueError: 'output_size' should be a integer or list or tuple with length as 1. + + Examples: + .. code-block:: python + + # max adaptive pool1d + # suppose input data in shape of [N, C, L], `output_size` is m or [m], + # output shape is [N, C, m], adaptive pool divide L dimension + # of input data into m grids averagely and performs poolings in each + # grid to get output. + # adaptive max pool performs calculations as follow: + # + # for i in range(m): + # lstart = floor(i * L / m) + # lend = ceil((i + 1) * L / m) + # output[:, :, i] = max(input[:, :, lstart: lend]) + # + import paddle + import paddle.nn.functional as F + paddle.disable_static() + + data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32]).astype(np.float32)) + pool_out = F.adaptive_max_pool1d(data, output_size=16) + # pool_out shape: [1, 3, 16]) + + pool_out, indices = F.adaptive_max_pool1d(data, output_size=16, return_indices=True) + # pool_out shape: [1, 3, 16] indices shape: [1, 3, 16] + + """ + pool_type = 'max' + check_variable_and_dtype(x, 'input', ['float32', 'float64'], + 'adaptive_max_pool1d') + check_input(x, 3) + check_type(output_size, 'pool_size', (int), 'adaptive_max_pool1d') + check_type(return_indices, 'return_indices', bool, 'adaptive_max_pool1d') + + pool_size = [1] + utils.convert_to_list(output_size, 1, 'pool_size') + + l_type = 'max_pool2d_with_index' + + x = unsqueeze(x, [2]) + if in_dygraph_mode(): + pool_out = core.ops.max_pool2d_with_index( + x, 'pooling_type', pool_type, 'ksize', pool_size, 'adaptive', True) + return (squeeze(pool_out[0], [2]), squeeze( + pool_out[1], [2])) if return_indices else squeeze(pool_out[0], [2]) + + helper = LayerHelper(l_type, **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + + mask = helper.create_variable_for_type_inference(dtype) + outputs = {"Out": pool_out, "Mask": mask} + + helper.append_op( + type=l_type, + inputs={"X": x}, + outputs=outputs, + attrs={ + "pooling_type": pool_type, + "ksize": pool_size, + "adaptive": True, + }) + + return (squeeze(pool_out, [2]), + squeeze(mask, [2])) if return_indices else squeeze(pool_out, [2]) + + +def max_pool2d(x, + kernel_size, + stride=None, + padding=0, + return_indices=False, + ceil_mode=False, + data_format="NCHW", + name=None): + """ + This operation applies 2D max pooling over input feature based on the input, + and kernel_size, stride, padding parameters. Input(X) and Output(Out) are + in NCHW format, where N is batch size, C is the number of channels, + H is the height of the feature, and W is the width of the feature. + + Example: + Input: + X shape: $(N, C, H_{in}, W_{in})$ + Attr: + kernel_size: ksize + stride: stride + + Output: + Out shape: $(N, C, H_{out}, W_{out})$ + $$ + out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, ksize[0] -1} \max_{n=0, \ldots, ksize[1]-1} \\ + & \text{input}(N_i, C_j, \text{stride[0]} \times h + m, + \text{stride[1]} \times w + n) + $$ + + Args: + x (Tensor): The input tensor of pooling operator which is a 4-D tensor with + shape [N, C, H, W]. The format of input tensor is `"NCHW"` or + `"NHWC"`, where `N` is batch size, `C` is the number of channels, + `H` is the height of the feature, and `W` is the width of the + feature. The data type if float32 or float64. + kernel_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain two integers, (pool_size_Height, pool_size_Width). + Otherwise, the pool kernel size will be a square of an int. + stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list, + it must contain two integers, (pool_stride_Height, pool_stride_Width). + Otherwise, the pool stride size will be a square of an int. + padding (string|int|list|tuple): The pool padding. If `pool_padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If pool padding size is a tuple or list, + it could be in three forms: `[pad_height, pad_width]` or + `[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, and when `data_format` is `"NCHW"`, + `pool_padding` can be in the form `[[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NHWC"`, `pool_padding` can be in the form + `[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + Otherwise, the pool padding size will be a square of an int. + ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape + return_indices (bool): Whether to return the max indices along with the outputs. + data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: The output tensor of pooling result. The data type is same as input tensor. + Raises: + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is "VALID", but `ceil_mode` is True. + ShapeError: If the output's shape calculated is not greater than 0. + Examples: + .. code-block:: python + import paddle + import paddle.nn.functional as F + import numpy as np + paddle.disable_static() + + # max pool2d + input = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32, 32]).astype(np.float32)) + output = F.max_pool2d(input, + kernel_size=2, + stride=2, padding=0) + # output.shape [1, 3, 16, 16] + + # for return_indices=True + output, max_indices = F.max_pool2d(input, + kernel_size=2, + stride=2, + padding=0, + return_indices=True) + # output.shape [1, 3, 16, 16], max_indices.shape [1, 3, 16, 16], + """ + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool2d') + kernel_size = utils.convert_to_list(kernel_size, 2, 'pool_size') + if stride is None: + stride = kernel_size + else: + stride = utils.convert_to_list(stride, 2, 'pool_stride') + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCHW' or 'NHWC'. Received " + "Attr(data_format): %s." % str(data_format)) + padding_algorithm = "EXPLICIT" + if isinstance(padding, str): + padding = padding.upper() + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Unknown Attr(padding): '%s'. It can only be 'SAME' or 'VALID'." + % str(padding)) + if padding == "VALID": + padding_algorithm = "VALID" + padding = [0, 0] + if ceil_mode != False: + raise ValueError( + "When Attr(padding) is \"VALID\", Attr(ceil_mode) must be False. " + "Received ceil_mode: True.") + elif padding == "SAME": + padding_algorithm = "SAME" + padding = [0, 0] + + padding = update_padding2d(padding, data_format) + + if in_dygraph_mode(): + output = core.ops.max_pool2d_with_index( + x, 'ksize', kernel_size, 'global_pooling', False, 'strides', stride, + 'paddings', padding, 'padding_algorithm', padding_algorithm, + 'use_cudnn', True, 'ceil_mode', ceil_mode, 'use_mkldnn', False, + 'exclusive', True, 'data_format', data_format) + return output if return_indices else output[0] + + op_type = 'max_pool2d_with_index' + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + mask = helper.create_variable_for_type_inference(dtype) + outputs = {"Out": pool_out, "Mask": mask} + + helper.append_op( + type=op_type, + inputs={"X": x}, + outputs=outputs, + attrs={ + "pooling_type": 'max', + "ksize": kernel_size, + "global_pooling": False, + "strides": stride, + "paddings": padding, + "padding_algorithm": padding_algorithm, + "use_cudnn": True, + "ceil_mode": ceil_mode, + "use_mkldnn": False, + "exclusive": True, + "data_format": data_format, + }) + + return (pool_out, mask) if return_indices else pool_out + + +def avg_pool2d(x, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + data_format="NCHW", + name=None): + """ + This operation applies 2D average pooling over input features based on the input, + and kernel_size, stride, padding parameters. Input(X) and Output(Out) are + in NCHW format, where N is batch size, C is the number of channels, + H is the height of the feature, and W is the width of the feature. + + Example: + Input: + X shape: $(N, C, H_{in}, W_{in})$ + Attr: + kernel_size: ksize + + Output: + Out shape: $(N, C, H_{out}, W_{out})$ + $$ + out(N_i, C_j, h, w) = \frac{1}{ksize[0] * ksize[1]} \sum_{m=0}^{ksize[0]-1} \sum_{n=0}^{ksize[1]-1} + input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) + $$ + + Args: + x (Tensor): The input tensor of pooling operator which is a 4-D tensor with + shape [N, C, H, W]. The format of input tensor is `"NCHW"` or + `"NHWC"`, where `N` is batch size, `C` is the number of channels, + `H` is the height of the feature, and `W` is the width of the + feature. The data type if float32 or float64. + kernel_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain two integers, (pool_size_Height, pool_size_Width). + Otherwise, the pool kernel size will be a square of an int. + stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list, + it must contain two integers, (pool_stride_Height, pool_stride_Width). + Otherwise, the pool stride size will be a square of an int. + padding (string|int|list|tuple): The pool padding. If `pool_padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If pool padding size is a tuple or list, + it could be in three forms: `[pad_height, pad_width]` or + `[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, and when `data_format` is `"NCHW"`, + `pool_padding` can be in the form `[[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NHWC"`, `pool_padding` can be in the form + `[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + Otherwise, the pool padding size will be a square of an int. + ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape + count_include_pad (bool): Whether to exclude padding points in average pooling + mode, default is `true`. + divisor_override (float): if specified, it will be used as divisor, otherwise kernel_size will be used. Default None. + data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: The output tensor of pooling result. The data type is same as input tensor. + Raises: + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is "VALID", but `ceil_mode` is True. + ShapeError: If the output's shape calculated is not greater than 0. + Examples: + .. code-block:: python + import paddle + import paddle.nn.functional as F + import numpy as np + paddle.disable_static() + + # avg pool2d + input = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32, 32]).astype(np.float32)) + output = F.avg_pool2d(input, + kernel_size=2, + stride=2, padding=0) + # output.shape [1, 3, 16, 16] + + """ + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'avg_pool2d') + kernel_size = utils.convert_to_list(kernel_size, 2, 'pool_size') + if stride is None: + stride = kernel_size + else: + stride = utils.convert_to_list(stride, 2, 'pool_stride') + + padding_algorithm = "EXPLICIT" + if isinstance(padding, str): + padding = padding.upper() + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Unknown Attr(pool_padding): '%s'. It can only be 'SAME' or 'VALID'." + % str(padding)) + if padding == "VALID": + padding_algorithm = "VALID" + padding = [0, 0] + if ceil_mode != False: + raise ValueError( + "When Attr(pool_padding) is \"VALID\", Attr(ceil_mode) must be False. " + "Received ceil_mode: True.") + elif padding == "SAME": + padding_algorithm = "SAME" + padding = [0, 0] + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCHW' or 'NHWC'. Received " + "Attr(data_format): %s." % str(data_format)) + pool_padding = update_padding2d(padding, data_format) + + if in_dygraph_mode(): + output = core.ops.pool2d( + x, 'pooling_type', 'avg', 'ksize', kernel_size, 'global_pooling', + False, 'padding_algorithm', padding_algorithm, 'strides', stride, + 'paddings', pool_padding, 'use_cudnn', True, 'ceil_mode', ceil_mode, + 'use_mkldnn', False, 'exclusive', not count_include_pad, + 'data_format', data_format) + if divisor_override is None: + return output + else: + check_instance(divisor_override, "divisor_override") + return output * (kernel_size[0] * kernel_size[1]) / divisor_override + + op_type = 'pool2d' + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type=op_type, + inputs={"X": x}, + outputs={"Out": pool_out}, + attrs={ + "pooling_type": "avg", + "ksize": kernel_size, + "global_pooling": False, + "strides": stride, + "paddings": pool_padding, + "padding_algorithm": padding_algorithm, + "use_cudnn": True, + "ceil_mode": ceil_mode, + "use_mkldnn": False, + "exclusive": not count_include_pad, + "data_format": data_format, + }) + + if divisor_override is None: + return pool_out + else: + check_instance(divisor_override, "divisor_override") + return pool_out * (kernel_size[0] * kernel_size[1]) / divisor_override + + +def max_pool3d(x, + kernel_size, + stride=None, + padding=0, + return_indices=False, + ceil_mode=False, + data_format="NCDHW", + name=None): + """ + This operation applies 3D max pooling over input features based on the input, + and kernel_size, stride, padding parameters. Input(X) and Output(Out) are + in NCDHW format, where N is batch size, C is the number of channels, + H is the height of the feature, D is the depth of the feature, and W is the width of the feature. + + Example: + Input: + X shape: $(N, C, D_{in}, H_{in}, W_{in})$ + Attr: + kernel_size: ksize + + Output: + Out shape: $(N, C, D_{out}, H_{out}, W_{out})$ + $$ + \text{out}(N_i, C_j, d, h, w) ={} & \max_{k=0, \ldots, ksize[0]-1} \max_{m=0, \ldots, ksize[1]-1} \max_{n=0, \ldots, ksize[2]-1} \\ + & \text{input}(N_i, C_j, \text{stride[0]} \times d + k, + \text{stride[1]} \times h + m, \text{stride[2]} \times w + n) + $$ + + Args: + x (Tensor): The input tensor of pooling operator, which is a 5-D tensor with + shape [N, C, D, H, W]. The format of + input tensor is `"NCDHW"` or `"NDHWC"`, where `N` is batch size, `C` is + the number of channels, `D` is the depth of the feature, + `H` is the height of the feature, and `W` is the width + of the feature. + kernel_size (int|list|tuple): The pool kernel size. If pool kernel size + is a tuple or list, it must contain three integers, + (pool_size_Depth, pool_size_Height, pool_size_Width). + Otherwise, the pool kernel size will be the cube of an int. + stride (string|int|list|tuple)): The pool padding. If `pool_padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If pool stride size is a tuple or list, + it must contain three integers, `[stride_Depth, stride_Height, stride_Width]`. + Otherwise, the pool stride size will be a cube of an int. + padding (int|list|tuple): The pool padding size. If pool padding size is a tuple or list, + it could be in three forms: `[pad_depth, pad_height, pad_width]` or + `[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, + and when `data_format` is `"NCDHW"`, `pool_padding` can be in the form + `[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NDHWC"`, `pool_padding` can be in the form + `[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + ceil_mode (bool): ${ceil_mode_comment} + return_indices (bool): Whether to return the max indices along with the outputs. + data_format (string): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`. + The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_depth, input_height, input_width]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tensor: The output tensor of pooling result. The data type is same as input tensor. + Raises: + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is "VALID", but `ceil_mode` is True. + ShapeError: If the output's shape calculated is not greater than 0. + Examples: + .. code-block:: python + import paddle + import paddle.nn.functional as F + import numpy as np + paddle.disable_static() + + # max pool3d + input = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32, 32, 32]).astype(np.float32)) + output = F.max_pool2d(input, + kernel_size=2, + stride=2, padding=0) + output.shape [1, 3, 16, 16, 16] + + # for return_indices=True + input = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32, 32, 32]).astype(np.float32)) + output, max_indices = paddle.nn.functional.max_pool3d(input, + kernel_size = 2, + stride = 2, + padding=0, + return_indices=True) + # output.shape [None, 3, 16, 16, 16], max_indices.shape [None, 3, 16, 16, 16], + + """ + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool3d') + kernel_size = utils.convert_to_list(kernel_size, 3, 'pool_size') + if stride is None: + stride = kernel_size + else: + stride = utils.convert_to_list(stride, 3, 'pool_stride') + + padding_algorithm = "EXPLICIT" + if isinstance(padding, str): + padding = padding.upper() + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Unknown Attr(pool_padding): '%s'. It can only be 'SAME' or 'VALID'." + % str(padding)) + if padding == "VALID": + padding_algorithm = "VALID" + padding = [0, 0, 0] + if ceil_mode != False: + raise ValueError( + "When Attr(pool_padding) is \"VALID\", ceil_mode must be False. " + "Received ceil_mode: True.") + elif padding == "SAME": + padding_algorithm = "SAME" + padding = [0, 0, 0] + + if data_format not in ["NCDHW", "NDHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCDHW' or 'NDHWC'. Received " + "Attr(data_format): %s" % str(data_format)) + padding = update_padding3d(padding, data_format) + + if in_dygraph_mode(): + output = core.ops.max_pool3d_with_index( + x, 'pooling_type', 'max', 'ksize', kernel_size, 'strides', stride, + 'paddings', padding, 'global_pooling', False, 'padding_algorithm', + padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode, + 'use_mkldnn', False, 'exclusive', True, 'data_format', data_format) + return output if return_indices else output[0] + + op_type = "max_pool3d_with_index" + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + mask = helper.create_variable_for_type_inference(dtype) + outputs = {"Out": pool_out, "Mask": mask} + + helper.append_op( + type=op_type, + inputs={"X": x}, + outputs=outputs, + attrs={ + "pooling_type": 'max', + "ksize": kernel_size, + "global_pooling": False, + "strides": stride, + "paddings": padding, + "padding_algorithm": padding_algorithm, + "use_cudnn": True, + "ceil_mode": ceil_mode, + "use_mkldnn": False, + "exclusive": False, + "data_format": data_format, + }) + + return (pool_out, mask) if return_indices else pool_out + + +def avg_pool3d(x, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=False, + divisor_override=None, + data_format="NCDHW", + name=None): + """ + This operation applies 3D max pooling over input features based on the input, + and kernel_size, stride, padding parameters. Input(X) and Output(Out) are + in NCDHW format, where N is batch size, C is the number of channels, + H is the height of the feature, D is the depth of the feature, and W is the width of the feature. + + Args: + input (Tensor): The input tensor of pooling operator, which is a 5-D tensor with + shape [N, C, D, H, W], where `N` is batch size, `C` is + the number of channels, `D` is the depth of the feature, + `H` is the height of the feature, and `W` is the width + of the feature. + kernel_size (int|list|tuple): The pool kernel size. If pool kernel size + is a tuple or list, it must contain three integers, + (pool_size_Depth, pool_size_Height, pool_size_Width). + Otherwise, the pool kernel size will be the cube of an int. + stride (string|int|list|tuple)): The pool padding. If `pool_padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If pool stride size is a tuple or list, + it must contain three integers, `[stride_Depth, stride_Height, stride_Width]`. + Otherwise, the pool stride size will be a cube of an int. + padding (int|list|tuple): The pool padding size. If pool padding size is a tuple or list, + it could be in three forms: `[pad_depth, pad_height, pad_width]` or + `[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, + and when `data_format` is `"NCDHW"`, `pool_padding` can be in the form + `[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NDHWC"`, `pool_padding` can be in the form + `[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + ceil_mode (bool): ${ceil_mode_comment} + count_include_pad (bool): Whether to exclude padding points in average pooling + mode, default is True. + divisor_override (int|float) if specified, it will be used as divisor, otherwise kernel_size will be used. Default None. + data_format (string): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`. + The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_depth, input_height, input_width]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + + Returns: + Tensor: The output tensor of pooling result. The data type is same as input tensor. + Raises: + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is "VALID", but `ceil_mode` is True. + ShapeError: If the output's shape calculated is not greater than 0. + Examples: + .. code-block:: python + import paddle.fluid as fluid + import paddle + input = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32, 32, 32]).astype(np.float32)) + # avg pool3d + pool3d = paddle.nn.functional.avg_pool3d( + input, + kernel_size = 2, + stride = 2, + padding=0) + # pool3d.shape: [1, 3, 16, 16, 16] + """ + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool3d') + kernel_size = utils.convert_to_list(kernel_size, 3, 'pool_size') + if stride is None: + stride = kernel_size + else: + stride = utils.convert_to_list(stride, 3, 'pool_stride') + + padding_algorithm = "EXPLICIT" + if isinstance(padding, str): + padding = padding.upper() + if padding not in ["SAME", "VALID"]: + raise ValueError( + "Unknown Attr(pool_padding): '%s'. It can only be 'SAME' or 'VALID'." + % str(padding)) + if padding == "VALID": + padding_algorithm = "VALID" + padding = [0, 0, 0] + if ceil_mode != False: + raise ValueError( + "When Attr(pool_padding) is \"VALID\", ceil_mode must be False. " + "Received ceil_mode: True.") + elif padding == "SAME": + padding_algorithm = "SAME" + padding = [0, 0, 0] + + if data_format not in ["NCDHW", "NDHWC"]: + raise ValueError( + "Attr(data_format) should be 'NCDHW' or 'NDHWC'. Received " + "Attr(data_format): %s" % str(data_format)) + padding = update_padding3d(padding, data_format) + + if in_dygraph_mode(): + output = core.ops.pool3d( + x, 'pooling_type', 'avg', 'ksize', kernel_size, 'strides', stride, + 'paddings', padding, 'global_pooling', False, 'padding_algorithm', + padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode, + 'use_mkldnn', False, 'exclusive', not count_include_pad, + 'data_format', data_format) + if divisor_override is None: + return output + else: + check_instance(divisor_override, "divisor_override") + return output * (kernel_size[0] * kernel_size[1] * + kernel_size[2]) / divisor_override + + op_type = "pool3d" + helper = LayerHelper(op_type, **locals()) + dtype = helper.input_dtype() + pool_out = helper.create_variable_for_type_inference(dtype) + outputs = {"Out": pool_out} + + helper.append_op( + type=op_type, + inputs={"X": x}, + outputs=outputs, + attrs={ + "pooling_type": 'avg', + "ksize": kernel_size, + "global_pooling": False, + "strides": stride, + "paddings": padding, + "padding_algorithm": padding_algorithm, + "use_cudnn": True, + "ceil_mode": ceil_mode, + "use_mkldnn": False, + "exclusive": not count_include_pad, + "data_format": data_format, + }) + + if divisor_override is None: + return pool_out + else: + check_instance(divisor_override, "divisor_override") + return pool_out * (kernel_size[0] * kernel_size[1] * + kernel_size[2]) / divisor_override + + def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None): """ diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index de52744e651d4b5c90bf1e89d25836433fa52806..b7098aee423feeb3670206c1459a05401c86e5c0 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -60,6 +60,14 @@ from .common import Dropout3D #DEFINE_ALIAS from .common import AlphaDropout #DEFINE_ALIAS from .pooling import AdaptiveAvgPool2d #DEFINE_ALIAS from .pooling import AdaptiveAvgPool3d #DEFINE_ALIAS +from .pooling import AvgPool1d #DEFINE_ALIAS +from .pooling import MaxPool1d #DEFINE_ALIAS +from .pooling import AdaptiveAvgPool1d #DEFINE_ALIAS +from .pooling import AdaptiveMaxPool1d #DEFINE_ALIAS +from .pooling import AvgPool2d #DEFINE_ALIAS +from .pooling import MaxPool2d #DEFINE_ALIAS +from .pooling import AvgPool3d #DEFINE_ALIAS +from .pooling import MaxPool3d #DEFINE_ALIAS from .conv import Conv1d #DEFINE_ALIAS from .conv import Conv2d #DEFINE_ALIAS from .conv import Conv3d #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index e0d751eef42a9b5e9675fd933b470429785a2f08..1a96a3738af96b5abe8f44d1bb5178855251f7b9 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO: define the common classes to build a neural network +# TODO: define the common classes to build a neural network from ...fluid.dygraph import BilinearTensorProduct #DEFINE_ALIAS from ...fluid.dygraph import Pool2D #DEFINE_ALIAS from ...fluid.dygraph import Embedding #DEFINE_ALIAS @@ -583,8 +583,8 @@ class ReflectionPad1d(layers.Layer): Default is "NCL" name (str, optional) : The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: + + Returns: None Examples: @@ -642,8 +642,8 @@ class ReplicationPad1d(layers.Layer): Default is "NCL" name (str, optional) : The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: + + Returns: None Examples: @@ -657,7 +657,7 @@ class ReplicationPad1d(layers.Layer): Code Examples: .. code-block:: python - + import paddle import paddle.nn as nn import numpy as np @@ -702,8 +702,8 @@ class ConstantPad1d(layers.Layer): Default is "NCL" name (str, optional) : The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: + + Returns: None Examples: @@ -718,7 +718,7 @@ class ConstantPad1d(layers.Layer): Code Examples: .. code-block:: python - + import paddle import paddle.nn as nn import numpy as np @@ -765,8 +765,8 @@ class ConstantPad2d(layers.Layer): Default is "NCHW" name (str, optional) : The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: + + Returns: None Examples: @@ -781,7 +781,7 @@ class ConstantPad2d(layers.Layer): Code Examples: .. code-block:: python - + import paddle import paddle.nn as nn import numpy as np @@ -830,8 +830,8 @@ class ZeroPad2d(layers.Layer): Default is "NCHW" name (str, optional) : The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: + + Returns: None Examples: @@ -845,7 +845,7 @@ class ZeroPad2d(layers.Layer): Code Examples: .. code-block:: python - + import paddle import paddle.nn as nn import numpy as np @@ -892,8 +892,8 @@ class ReplicationPad2d(layers.Layer): Default is "NCHW" name (str, optional) : The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: + + Returns: None Examples: @@ -907,7 +907,7 @@ class ReplicationPad2d(layers.Layer): Code Examples: .. code-block:: python - + import paddle import paddle.nn as nn import numpy as np @@ -954,8 +954,8 @@ class ReflectionPad2d(layers.Layer): Default is "NCHW" name (str, optional) : The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: + + Returns: None Examples: @@ -969,7 +969,7 @@ class ReflectionPad2d(layers.Layer): Code Examples: .. code-block:: python - + import paddle import paddle.nn as nn import numpy as np @@ -1019,8 +1019,8 @@ class ConstantPad3d(layers.Layer): Default is "NCDHW" name (str, optional) : The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: + + Returns: None Examples: @@ -1035,7 +1035,7 @@ class ConstantPad3d(layers.Layer): Code Examples: .. code-block:: python - + import paddle import paddle.nn as nn import numpy as np @@ -1084,8 +1084,8 @@ class ReplicationPad3d(layers.Layer): Default is "NCDHW" name (str, optional) : The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. - - Returns: + + Returns: None Examples: @@ -1099,7 +1099,7 @@ class ReplicationPad3d(layers.Layer): Code Examples: .. code-block:: python - + import paddle import paddle.nn as nn import numpy as np @@ -1141,7 +1141,7 @@ class CosineSimilarity(layers.Layer): Parameters: axis (int): Dimension of vectors to compute cosine similarity. Default is 1. eps(float): Small value to avoid division by zero. Default is 1e-8. - Returns: + Returns: None Examples: @@ -1162,7 +1162,7 @@ class CosineSimilarity(layers.Layer): Code Examples: .. code-block:: python - + import paddle import paddle.nn as nn import numpy as np diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index 65ea6b0b05d14dc6a9119f7dc6f0afa10aed97b2..d4e50bd993c2a07c01fd6f4ebcadb20ed8e38cdb 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -23,6 +23,14 @@ from .. import functional as F __all__ = [ 'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', + 'AvgPool1d', + 'maxPool1d', + 'AdaptiveMaxPool1d', + 'AdaptiveAvgPool1d', + 'AvgPool2d', + 'MaxPool2d', + 'AvgPool3d', + 'MaxPool3d', ] @@ -194,3 +202,676 @@ class AdaptiveAvgPool3d(layers.Layer): output_size=self._output_size, data_format=self._data_format, name=self._name) + + +class AvgPool1d(layers.Layer): + """ + This operation applies a 1D average pooling over an input signal composed + of several input planes, based on the input, output_size, return_indices parameters. + Input(X) and output(Out) are in NCL format, where N is batch + size, C is the number of channels, L is the length of the feature. + The output tensor shape will be [N, C, output_size]. + + The output value of the layer with input size (N, C, L), + output (N, C, L_{out}) and kernel_size k can be precisely described as + For average pool1d: + + .. math:: + + Output(N_i, C_i, l) &= mean(Input[N_i, C_i, stride \times l:stride \times l+k]) + + + Args: + kernel_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain one integers. + stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list, + it must contain one integers. + padding (string|int|list|tuple): The pool padding. If `pool_padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If pool padding size is a tuple or list, + it could be the following forms: `[pad_left, pad_right]`. If padding is non-zero, + then the input is implicitly zero-padded on both sides for padding number of points. + count_include_pad (bool): Whether to exclude padding points in average pooling + mode, default is `true`. + ceil_mode (bool): ${ceil_mode_comment}Whether to use the ceil function to calculate output height and width. + If it is set to False, the floor function will be used. Default False + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + None. + + Raises: + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is "VALID", but `ceil_mode` is True. + ValueError: If `padding` is a list or tuple but its length greater than 1. + ShapeError: If the input is not a 3-D. + ShapeError: If the output's shape calculated is not greater than 0. + + + Examples: + + .. code-block:: python + import paddle + import paddle.nn as nn + paddle.disable_static() + + data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32]).astype(np.float32)) + AvgPool1d = nn.AvgPool1d(kernel_size=2, stride=2, padding=0) + pool_out = AvgPool1d(data) + # pool_out shape: [1, 3, 16] + + """ + + def __init__(self, + kernel_size, + stride=None, + padding=0, + count_include_pad=True, + ceil_mode=False, + name=None): + super(AvgPool1d, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + self.name = name + + def forward(self, x): + out = F.avg_pool1d(x, self.kernel_size, self.stride, self.padding, + self.count_include_pad, self.ceil_mode, self.name) + return out + + +class MaxPool1d(layers.Layer): + """ + Applies a 1D max pooling over an input signal composed of several input planes based + on the input, output_size, return_indices parameters. + Input(X) and output(Out) are in NCL format, where N is batch + size, C is the number of channels, L is the length of the feature. + + The output value of the layer with input size (N, C, L), + output (N, C, L_{out}) and kernel_size k can be precisely described as + For average pool1d: + + .. math:: + + Output(N_i, C_i, l) &= max(Input[N_i, C_i, stride \times l:stride \times l+k])} + + Args: + kernel_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain one integers. + stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list, + it must contain one integers. + padding (string|int|list|tuple): The pool padding. If `pool_padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If pool padding size is a tuple or list, + it could be the following forms: `[pad_left, pad_right]`. + return_indices (bool): Whether return the max indices along with the outputs. default is `False`. + ceil_mode (bool): Whether to use the ceil function to calculate output height and width. False is the default. + If it is set to False, the floor function will be used. Default False + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + None. + + Raises: + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is "VALID", but `ceil_mode` is True. + ValueError: If `padding` is a list or tuple but its length greater than 1. + ShapeError: If the input is not a 3-D. + ShapeError: If the output's shape calculated is not greater than 0. + + + Examples: + + .. code-block:: python + + import paddle + import paddle.nn as nn + paddle.disable_static() + + data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32]).astype(np.float32)) + MaxPool1d = nn.MaxPool1d(kernel_size=2, stride=2, padding=0) + pool_out = MaxPool1d(data) + # pool_out shape: [1, 3, 16] + + MaxPool1d = nn.MaxPool1d(kernel_size=2, stride=2, padding=0, return_indices=True) + pool_out, indices = MaxPool1d(data) + # pool_out shape: [1, 3, 16], indices shape: [1, 3, 16] + + """ + + def __init__(self, + kernel_size, + stride=None, + padding=0, + return_indices=False, + ceil_mode=False, + name=None): + super(MaxPool1d, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.ceil_mode = ceil_mode + self.return_indices = return_indices + self.name = name + + def forward(self, input): + out = F.max_pool1d(input, self.kernel_size, self.stride, self.padding, + self.return_indices, self.ceil_mode, self.name) + return out + + +class AdaptiveAvgPool1d(layers.Layer): + """ + + This operation applies a 1D adaptive average pooling over an input signal composed + of several input planes, based on the input, output_size, return_indices parameters. + Input(X) and output(Out) are in NCL format, where N is batch + size, C is the number of channels, L is the length of the feature. + The output tensor shape will be [N, C, output_size]. + + For average adaptive pool1d: + + .. math:: + + lstart &= floor(i * L_{in} / L_{out}) + + lend &= ceil((i + 1) * L_{in} / L_{out}) + + Output(i) &= \\frac{sum(Input[lstart:lend])}{(lstart - lend)} + + Args: + output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain one int. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + None. + + Raises: + ValueError: 'pool_size' should be a integer or list or tuple with length as 1. + + Examples: + .. code-block:: python + + # average adaptive pool1d + # suppose input data in shape of [N, C, L], `output_size` is m or [m], + # output shape is [N, C, m], adaptive pool divide L dimension + # of input data into m grids averagely and performs poolings in each + # grid to get output. + # adaptive max pool performs calculations as follow: + # + # for i in range(m): + # lstart = floor(i * L / m) + # lend = ceil((i + 1) * L / m) + # output[:, :, i] = sum(input[:, :, lstart: lend])/(lstart - lend) + # + import paddle + import paddle.nn as nn + paddle.disable_static() + + data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32]).astype(np.float32)) + AdaptiveAvgPool1d = nn.AdaptiveAvgPool1d(output_size=16) + pool_out = AdaptiveAvgPool1d(data) + # pool_out shape: [1, 3, 16] + """ + + def __init__(self, output_size, name=None): + super(AdaptiveAvgPool1d, self).__init__() + self.output_size = output_size + self.name = name + + def forward(self, input): + return F.adaptive_avg_pool1d(input, self.output_size, self.name) + + +class AdaptiveMaxPool1d(layers.Layer): + """ + + This operation applies a 1D adaptive max pooling over an input signal composed + of several input planes, based on the input, output_size, return_indices parameters. + Input(X) and output(Out) are in NCL format, where N is batch + size, C is the number of channels, L is the length of the feature. + The output tensor shape will be [N, C, output_size]. + + For max adaptive pool1d: + + .. math:: + + lstart &= floor(i * L_{in} / L_{out}) + + lend &= ceil((i + 1) * L_{in} / L_{out}) + + Output(i) &= max(Input[lstart:lend])} + + Args: + output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain one int. + return_indices (bool): If true, the index of max pooling point will be returned along + with outputs. It cannot be set in average pooling type. Default False. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + Returns: + None. + + Raises: + ValueError: 'pool_size' should be a integer or list or tuple with length as 1. + + Examples: + .. code-block:: python + + # max adaptive pool1d + # suppose input data in shape of [N, C, L], `output_size` is m or [m], + # output shape is [N, C, m], adaptive pool divide L dimension + # of input data into m grids averagely and performs poolings in each + # grid to get output. + # adaptive max pool performs calculations as follow: + # + # for i in range(m): + # lstart = floor(i * L / m) + # lend = ceil((i + 1) * L / m) + # output[:, :, i] = max(input[:, :, lstart: lend]) + # + import paddle + import paddle.nn as nn + paddle.disable_static() + + data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32]).astype(np.float32)) + AdaptiveMaxPool1d = nn.AdaptiveMaxPool1d(output_size=16) + pool_out = AdaptiveMaxPool1d(data) + # pool_out shape: [1, 3, 16] + + # for return_indices = true + AdaptiveMaxPool1d = nn.AdaptiveMaxPool1d(output_size=16, return_indices=True) + pool_out, indices = AdaptiveMaxPool1d(data) + # pool_out shape: [1, 3, 16], indices shape: [1, 3, 16] + + """ + + def __init__(self, output_size, return_indices=False, name=None): + super(AdaptiveMaxPool1d, self).__init__() + self.output_size = output_size + self.return_indices = return_indices + self.name = name + + def forward(self, input): + return F.adaptive_max_pool1d(input, self.output_size, + self.return_indices, self.name) + + +class AvgPool2d(layers.Layer): + """ + This operation applies 2D average pooling over input features based on the input, + and kernel_size, stride, padding parameters. Input(X) and Output(Out) are + in NCHW format, where N is batch size, C is the number of channels, + H is the height of the feature, and W is the width of the feature. + + Example: + Input: + X shape: $(N, C, H_{in}, W_{in})$ + Attr: + kernel_size: ksize + + Output: + Out shape: $(N, C, H_{out}, W_{out})$ + $$ + out(N_i, C_j, h, w) = \frac{1}{ksize[0] * ksize[1]} \sum_{m=0}^{ksize[0]-1} \sum_{n=0}^{ksize[1]-1} + input(N_i, C_j, stride[0] \times h + m, stride[1] \times w + n) + $$ + + Args: + kernel_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain two integers, (pool_size_Height, pool_size_Width). + Otherwise, the pool kernel size will be a square of an int. + stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list, + it must contain two integers, (pool_stride_Height, pool_stride_Width). + Otherwise, the pool stride size will be a square of an int. Default: kernel_size. + padding (string|int|list|tuple): The pool padding. If `pool_padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If pool padding size is a tuple or list, + it could be in three forms: `[pad_height, pad_width]` or + `[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, and when `data_format` is `"NCHW"`, + `pool_padding` can be in the form `[[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NHWC"`, `pool_padding` can be in the form + `[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + Otherwise, the pool padding size will be a square of an int. + ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape + count_include_pad (bool): Whether to exclude padding points in average pooling + mode, default is `true`. + divisor_override (int|float) if specified, it will be used as divisor, otherwise kernel_size will be used. Default None. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. + + Returns: None. + Raises: + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is "VALID", but `ceil_mode` is True. + ShapeError: If the output's shape calculated is not greater than 0. + Examples: + .. code-block:: python + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + # max pool2d + input = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32, 32]).astype(np.float32)) + AvgPool2d = nn.AvgPool2d(kernel_size=2, + stride=2, padding=0) + output = AvgPoo2d(input) + # output.shape [1, 3, 16, 16] + + """ + + def __init__(self, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + data_format="NCHW", + name=None): + super(AvgPool2d, self).__init__() + self.ksize = kernel_size + self.stride = stride + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + self.divisor = divisor_override + self.data_format = data_format + self.name = name + + def forward(self, x): + return F.avg_pool2d( + x, + kernel_size=self.ksize, + stride=self.stride, + padding=self.padding, + ceil_mode=self.ceil_mode, + count_include_pad=self.count_include_pad, + divisor_override=self.divisor, + data_format=self.data_format, + name=self.name) + + +class MaxPool2d(layers.Layer): + """ + This operation applies 2D max pooling over input feature based on the input, + and kernel_size, stride, padding parameters. Input(X) and Output(Out) are + in NCHW format, where N is batch size, C is the number of channels, + H is the height of the feature, and W is the width of the feature. + + Example: + Input: + X shape: $(N, C, H_{in}, W_{in})$ + Attr: + kernel_size: ksize + + Output: + Out shape: $(N, C, H_{out}, W_{out})$ + $$ + out(N_i, C_j, h, w) ={} & \max_{m=0, \ldots, ksize[0] -1} \max_{n=0, \ldots, ksize[1]-1} \\ + & \text{input}(N_i, C_j, \text{stride[0]} \times h + m, + \text{stride[1]} \times w + n) + $$ + + Args: + kernel_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, + it must contain two integers, (pool_size_Height, pool_size_Width). + Otherwise, the pool kernel size will be a square of an int. + stride (int|list|tuple): The pool stride size. If pool stride size is a tuple or list, + it must contain two integers, (pool_stride_Height, pool_stride_Width). + Otherwise, the pool stride size will be a square of an int. Default: kernel_size. + padding (string|int|list|tuple): The pool padding. If `pool_padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If pool padding size is a tuple or list, + it could be in three forms: `[pad_height, pad_width]` or + `[pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, and when `data_format` is `"NCHW"`, + `pool_padding` can be in the form `[[0,0], [0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NHWC"`, `pool_padding` can be in the form + `[[0,0], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + Otherwise, the pool padding size will be a square of an int. + ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape + return_indices (bool): Whether to return the max indices along with the outputs. + data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: None + Raises: + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is "VALID", but `ceil_mode` is True. + ShapeError: If the output's shape calculated is not greater than 0. + Examples: + .. code-block:: python + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + # max pool2d + input = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32, 32]).astype(np.float32)) + MaxPool2d = nn.MaxPool2d(kernel_size=2, + stride=2, padding=0) + output = MaxPool2d(input) + # output.shape [1, 3, 16, 16] + + # for return_indices=True + MaxPool2d = nn.MaxPool2d(kernel_size=2,stride=2, padding=0, return_indices=True) + output, max_indices = MaxPool2d(input) + # output.shape [1, 3, 16, 16], max_indices.shape [1, 3, 16, 16], + """ + + def __init__(self, + kernel_size, + stride=None, + padding=0, + return_indices=False, + ceil_mode=False, + data_format="NCHW", + name=None): + super(MaxPool2d, self).__init__() + self.ksize = kernel_size + self.stride = stride + self.padding = padding + self.return_indices = return_indices + self.ceil_mode = ceil_mode + self.data_format = data_format + self.name = name + + def forward(self, x): + return F.max_pool2d( + x, + kernel_size=self.ksize, + stride=self.stride, + padding=self.padding, + return_indices=self.return_indices, + data_format=self.data_format, + name=self.name) + + +class MaxPool3d(layers.Layer): + """ + This operation applies 3D max pooling over input features based on the input, + and kernel_size, stride, padding parameters. Input(X) and Output(Out) are + in NCDHW format, where N is batch size, C is the number of channels, + H is the height of the feature, D is the depth of the feature, and W is the width of the feature. + + Args: + kernel_size (int|list|tuple): The pool kernel size. If pool kernel size + is a tuple or list, it must contain three integers, + (pool_size_Depth, pool_size_Height, pool_size_Width). + Otherwise, the pool kernel size will be the cube of an int. + stride (string|int|list|tuple)): The pool padding. If `pool_padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If pool stride size is a tuple or list, + it must contain three integers, `[stride_Depth, stride_Height, stride_Width]`. + Otherwise, the pool stride size will be a cube of an int. Default kernel_size. + padding (int|list|tuple): The pool padding size. If pool padding size is a tuple or list, + it could be in three forms: `[pad_depth, pad_height, pad_width]` or + `[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, + and when `data_format` is `"NCDHW"`, `pool_padding` can be in the form + `[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NDHWC"`, `pool_padding` can be in the form + `[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + ceil_mode (bool): when True, will use ceil instead of floor to compute the output shape. + count_include_pad (bool): Whether to exclude padding points in average pooling + mode, default is True. + data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + + Returns:None. + Raises: + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is "VALID", but `ceil_mode` is True. + ShapeError: If the output's shape calculated is not greater than 0. + Examples: + .. code-block:: python + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + # max pool3d + input = paddle.to_tensor(np.random.uniform(-1, 1, [1, 2, 3, 32, 32]).astype(np.float32)) + MaxPool3d = nn.MaxPool3d(kernel_size=2, + stride=2, padding=0) + output = MaxPool3d(input) + # output.shape [1, 2, 3, 16, 16] + + # for return_indices=True + MaxPool3d = nn.MaxPool3d(kernel_size=2,stride=2, padding=0, return_indices=True) + output, max_indices = MaxPool3d(input) + # output.shape [1, 2, 3, 16, 16], max_indices.shape [1, 2, 3, 16, 16], + """ + + def __init__(self, + kernel_size, + stride, + padding, + return_indices=False, + ceil_mode=False, + data_format="NCDHW", + name=None): + super(MaxPool3d, self).__init__() + self.ksize = kernel_size + self.stride = stride + self.padding = padding + self.return_indices = return_indices + self.ceil_mode = ceil_mode + self.data_format = data_format + self.name = name + + def forward(self, x): + return F.max_pool3d( + x, + kernel_size=self.ksize, + stride=self.stride, + padding=self.padding, + return_indices=self.return_indices, + data_format=self.data_format, + name=self.name) + + +class AvgPool3d(layers.Layer): + """ + This operation applies 3D max pooling over input features based on the input, + and kernel_size, stride, padding parameters. Input(X) and Output(Out) are + in NCDHW format, where N is batch size, C is the number of channels, + H is the height of the feature, D is the depth of the feature, and W is the width of the feature. + + Args: + kernel_size (int|list|tuple): The pool kernel size. If pool kernel size + is a tuple or list, it must contain three integers, + (pool_size_Depth, pool_size_Height, pool_size_Width). + Otherwise, the pool kernel size will be the cube of an int. + stride (string|int|list|tuple)): The pool padding. If `pool_padding` is a string, either 'VALID' or + 'SAME' which is the padding algorithm. If pool stride size is a tuple or list, + it must contain three integers, `[stride_Depth, stride_Height, stride_Width]`. + Otherwise, the pool stride size will be a cube of an int. + padding (int|list|tuple): The pool padding size. If pool padding size is a tuple or list, + it could be in three forms: `[pad_depth, pad_height, pad_width]` or + `[pad_depth_front, pad_depth_back, pad_height_top, pad_height_bottom, pad_width_left, pad_width_right]`, + and when `data_format` is `"NCDHW"`, `pool_padding` can be in the form + `[[0,0], [0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right]]`. + when `data_format` is `"NDHWC"`, `pool_padding` can be in the form + `[[0,0], [pad_depth_front, pad_depth_back], [pad_height_top, pad_height_bottom], [pad_width_left, pad_width_right], [0,0]]`. + ceil_mode (bool): ${ceil_mode_comment} + count_include_pad (bool): Whether to exclude padding points in average pooling + mode, default is True. + divisor_override (int|float) if specified, it will be used as divisor, otherwise kernel_size will be used. Default None. + data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: None. + Raises: + ValueError: If `padding` is a string, but not "SAME" or "VALID". + ValueError: If `padding` is "VALID", but `ceil_mode` is True. + ShapeError: If the output's shape calculated is not greater than 0. + Examples: + .. code-block:: python + import paddle + import paddle.nn as nn + import numpy as np + paddle.disable_static() + + # avg pool3d + input = paddle.to_tensor(np.random.uniform(-1, 1, [1, 2, 3, 32, 32]).astype(np.float32)) + AvgPool3d = nn.AvgPool3d(kernel_size=2, + stride=2, padding=0) + output = AvgPool3d(input) + # output.shape [1, 2, 3, 16, 16] + + """ + + def __init__(self, + kernel_size, + stride, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, + data_format="NCDHW", + name=None): + super(AvgPool3d, self).__init__() + self.ksize = kernel_size + self.stride = stride + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + self.divisor = divisor_override + self.data_format = data_format + self.name = name + + def forward(self, x): + return F.avg_pool3d( + x, + kernel_size=self.ksize, + stride=self.stride, + padding=self.padding, + ceil_mode=self.ceil_mode, + count_include_pad=self.count_include_pad, + divisor_override=self.divisor, + data_format=self.data_format, + name=self.name)