未验证 提交 74836ec7 编写于 作者: B Bai Yifan 提交者: GitHub

[2.0API]Add adaptive_avg_pool_2/3d (#26369)

* add adaptive_avg_pool2d

* add adaptive_avg_pool3d
上级 8d194524
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
def adaptive_start_index(index, input_size, output_size):
return int(np.floor(index * input_size / output_size))
def adaptive_end_index(index, input_size, output_size):
return int(np.ceil((index + 1) * input_size / output_size))
def adaptive_pool2d_forward(x, output_size, data_format='NCHW',
pool_type="avg"):
N = x.shape[0]
C, H, W = [x.shape[1], x.shape[2], x.shape[3]] if data_format == 'NCHW' \
else [x.shape[3], x.shape[1], x.shape[2]]
if (isinstance(output_size, int) or output_size == None):
H_out = output_size
W_out = output_size
output_size = [H_out, W_out]
else:
H_out, W_out = output_size
if output_size[0] == None:
output_size[0] = H
H_out = H
if output_size[1] == None:
output_size[1] = W
W_out = W
out = np.zeros((N, C, H_out, W_out)) if data_format=='NCHW' \
else np.zeros((N, H_out, W_out, C))
for i in range(H_out):
in_h_start = adaptive_start_index(i, H, output_size[0])
in_h_end = adaptive_end_index(i, H, output_size[0])
for j in range(W_out):
in_w_start = adaptive_start_index(j, W, output_size[1])
in_w_end = adaptive_end_index(j, W, output_size[1])
if data_format == 'NCHW':
x_masked = x[:, :, in_h_start:in_h_end, in_w_start:in_w_end]
if pool_type == 'avg':
field_size = (
(in_h_end - in_h_start) * (in_w_end - in_w_start))
out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size
elif pool_type == 'max':
out[:, :, i, j] = np.max(x_masked, axis=(2, 3))
elif data_format == 'NHWC':
x_masked = x[:, in_h_start:in_h_end, in_w_start:in_w_end, :]
if pool_type == 'avg':
field_size = (
(in_h_end - in_h_start) * (in_w_end - in_w_start))
out[:, i, j, :] = np.sum(x_masked, axis=(1, 2)) / field_size
elif pool_type == 'max':
out[:, i, j, :] = np.max(x_masked, axis=(1, 2))
return out
class TestAdaptiveAvgPool2dAPI(unittest.TestCase):
def setUp(self):
self.x_np = np.random.random([2, 3, 7, 7]).astype("float32")
self.res_1_np = adaptive_pool2d_forward(
x=self.x_np, output_size=[3, 3], pool_type="avg")
self.res_2_np = adaptive_pool2d_forward(
x=self.x_np, output_size=5, pool_type="avg")
self.res_3_np = adaptive_pool2d_forward(
x=self.x_np, output_size=[2, 5], pool_type="avg")
self.res_4_np = adaptive_pool2d_forward(
x=self.x_np,
output_size=[3, 3],
pool_type="avg",
data_format="NHWC")
self.res_5_np = adaptive_pool2d_forward(
x=self.x_np, output_size=[None, 3], pool_type="avg")
def test_static_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x = paddle.data(name="x", shape=[2, 3, 7, 7], dtype="float32")
out_1 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[3, 3])
out_2 = paddle.nn.functional.adaptive_avg_pool2d(x=x, output_size=5)
out_3 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[2, 5])
out_4 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[3, 3], data_format="NHWC")
out_5 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[None, 3])
exe = paddle.static.Executor(place=place)
[res_1, res_2, res_3, res_4, res_5] = exe.run(
fluid.default_main_program(),
feed={"x": self.x_np},
fetch_list=[out_1, out_2, out_3, out_4, out_5])
assert np.allclose(res_1, self.res_1_np)
assert np.allclose(res_2, self.res_2_np)
assert np.allclose(res_3, self.res_3_np)
assert np.allclose(res_4, self.res_4_np)
assert np.allclose(res_5, self.res_5_np)
def test_dynamic_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
x = paddle.to_variable(self.x_np)
out_1 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[3, 3])
out_2 = paddle.nn.functional.adaptive_avg_pool2d(x=x, output_size=5)
out_3 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[2, 5])
out_4 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[3, 3], data_format="NHWC")
out_5 = paddle.nn.functional.adaptive_avg_pool2d(
x=x, output_size=[None, 3])
assert np.allclose(out_1.numpy(), self.res_1_np)
assert np.allclose(out_2.numpy(), self.res_2_np)
assert np.allclose(out_3.numpy(), self.res_3_np)
assert np.allclose(out_4.numpy(), self.res_4_np)
assert np.allclose(out_5.numpy(), self.res_5_np)
class TestAdaptiveAvgPool2dClassAPI(unittest.TestCase):
def setUp(self):
self.x_np = np.random.random([2, 3, 7, 7]).astype("float32")
self.res_1_np = adaptive_pool2d_forward(
x=self.x_np, output_size=[3, 3], pool_type="avg")
self.res_2_np = adaptive_pool2d_forward(
x=self.x_np, output_size=5, pool_type="avg")
self.res_3_np = adaptive_pool2d_forward(
x=self.x_np, output_size=[2, 5], pool_type="avg")
self.res_4_np = adaptive_pool2d_forward(
x=self.x_np,
output_size=[3, 3],
pool_type="avg",
data_format="NHWC")
self.res_5_np = adaptive_pool2d_forward(
x=self.x_np, output_size=[None, 3], pool_type="avg")
def test_static_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x = paddle.data(name="x", shape=[2, 3, 7, 7], dtype="float32")
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=[3, 3])
out_1 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=5)
out_2 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=[2, 5])
out_3 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(
output_size=[3, 3], data_format="NHWC")
out_4 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(
output_size=[None, 3])
out_5 = adaptive_avg_pool(x=x)
exe = paddle.static.Executor(place=place)
[res_1, res_2, res_3, res_4, res_5] = exe.run(
fluid.default_main_program(),
feed={"x": self.x_np},
fetch_list=[out_1, out_2, out_3, out_4, out_5])
assert np.allclose(res_1, self.res_1_np)
assert np.allclose(res_2, self.res_2_np)
assert np.allclose(res_3, self.res_3_np)
assert np.allclose(res_4, self.res_4_np)
assert np.allclose(res_5, self.res_5_np)
def test_dynamic_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
x = paddle.to_variable(self.x_np)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=[3, 3])
out_1 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=5)
out_2 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=[2, 5])
out_3 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(
output_size=[3, 3], data_format="NHWC")
out_4 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(
output_size=[None, 3])
out_5 = adaptive_avg_pool(x=x)
assert np.allclose(out_1.numpy(), self.res_1_np)
assert np.allclose(out_2.numpy(), self.res_2_np)
assert np.allclose(out_3.numpy(), self.res_3_np)
assert np.allclose(out_4.numpy(), self.res_4_np)
assert np.allclose(out_5.numpy(), self.res_5_np)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
def adaptive_start_index(index, input_size, output_size):
return int(np.floor(index * input_size / output_size))
def adaptive_end_index(index, input_size, output_size):
return int(np.ceil((index + 1) * input_size / output_size))
def adaptive_pool3d_forward(x,
output_size,
adaptive=True,
data_format='NCDHW',
pool_type='avg'):
N = x.shape[0]
C, D, H, W = [x.shape[1], x.shape[2], x.shape[3], x.shape[4]] \
if data_format == 'NCDHW' else [x.shape[4], x.shape[1], x.shape[2],x.shape[3]]
if (isinstance(output_size, int) or output_size == None):
H_out = output_size
W_out = output_size
D_out = output_size
output_size = [D_out, H_out, W_out]
else:
D_out, H_out, W_out = output_size
if output_size[0] == None:
output_size[0] = D
D_out = D
if output_size[1] == None:
output_size[1] = H
H_out = H
if output_size[2] == None:
output_size[2] = W
W_out = W
out = np.zeros((N, C, D_out, H_out, W_out)) if data_format=='NCDHW' \
else np.zeros((N, D_out, H_out, W_out, C))
for k in range(D_out):
d_start = adaptive_start_index(k, D, output_size[0])
d_end = adaptive_end_index(k, D, output_size[0])
for i in range(H_out):
h_start = adaptive_start_index(i, H, output_size[1])
h_end = adaptive_end_index(i, H, output_size[1])
for j in range(W_out):
w_start = adaptive_start_index(j, W, output_size[2])
w_end = adaptive_end_index(j, W, output_size[2])
if data_format == 'NCDHW':
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:
w_end]
if pool_type == 'avg':
field_size = (d_end - d_start) * (h_end - h_start) * (
w_end - w_start)
out[:, :, k, i, j] = np.sum(x_masked,
axis=(2, 3, 4)) / field_size
elif pool_type == 'max':
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
elif data_format == 'NDHWC':
x_masked = x[:, d_start:d_end, h_start:h_end, w_start:
w_end, :]
if pool_type == 'avg':
field_size = (d_end - d_start) * (h_end - h_start) * (
w_end - w_start)
out[:, k, i, j, :] = np.sum(x_masked,
axis=(1, 2, 3)) / field_size
elif pool_type == 'max':
out[:, k, i, j, :] = np.max(x_masked, axis=(1, 2, 3))
return out
class TestAdaptiveAvgPool3dAPI(unittest.TestCase):
def setUp(self):
self.x_np = np.random.random([2, 3, 5, 7, 7]).astype("float32")
self.res_1_np = adaptive_pool3d_forward(
x=self.x_np, output_size=[3, 3, 3], pool_type="avg")
self.res_2_np = adaptive_pool3d_forward(
x=self.x_np, output_size=5, pool_type="avg")
self.res_3_np = adaptive_pool3d_forward(
x=self.x_np, output_size=[2, 3, 5], pool_type="avg")
self.res_4_np = adaptive_pool3d_forward(
x=self.x_np,
output_size=[3, 3, 3],
pool_type="avg",
data_format="NDHWC")
self.res_5_np = adaptive_pool3d_forward(
x=self.x_np, output_size=[None, 3, None], pool_type="avg")
def test_static_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x = paddle.data(name="x", shape=[2, 3, 5, 7, 7], dtype="float32")
out_1 = paddle.nn.functional.adaptive_avg_pool3d(
x=x, output_size=[3, 3, 3])
out_2 = paddle.nn.functional.adaptive_avg_pool3d(x=x, output_size=5)
out_3 = paddle.nn.functional.adaptive_avg_pool3d(
x=x, output_size=[2, 3, 5])
out_4 = paddle.nn.functional.adaptive_avg_pool3d(
x=x, output_size=[3, 3, 3], data_format="NDHWC")
out_5 = paddle.nn.functional.adaptive_avg_pool3d(
x=x, output_size=[None, 3, None])
exe = paddle.static.Executor(place=place)
[res_1, res_2, res_3, res_4, res_5] = exe.run(
fluid.default_main_program(),
feed={"x": self.x_np},
fetch_list=[out_1, out_2, out_3, out_4, out_5])
assert np.allclose(res_1, self.res_1_np)
assert np.allclose(res_2, self.res_2_np)
assert np.allclose(res_3, self.res_3_np)
assert np.allclose(res_4, self.res_4_np)
assert np.allclose(res_5, self.res_5_np)
def test_dynamic_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
x = paddle.to_variable(self.x_np)
out_1 = paddle.nn.functional.adaptive_avg_pool3d(
x=x, output_size=[3, 3, 3])
out_2 = paddle.nn.functional.adaptive_avg_pool3d(x=x, output_size=5)
out_3 = paddle.nn.functional.adaptive_avg_pool3d(
x=x, output_size=[2, 3, 5])
out_4 = paddle.nn.functional.adaptive_avg_pool3d(
x=x, output_size=[3, 3, 3], data_format="NDHWC")
out_5 = paddle.nn.functional.adaptive_avg_pool3d(
x=x, output_size=[None, 3, None])
assert np.allclose(out_1.numpy(), self.res_1_np)
assert np.allclose(out_2.numpy(), self.res_2_np)
assert np.allclose(out_3.numpy(), self.res_3_np)
assert np.allclose(out_4.numpy(), self.res_4_np)
assert np.allclose(out_5.numpy(), self.res_5_np)
class TestAdaptiveAvgPool3dClassAPI(unittest.TestCase):
def setUp(self):
self.x_np = np.random.random([2, 3, 5, 7, 7]).astype("float32")
self.res_1_np = adaptive_pool3d_forward(
x=self.x_np, output_size=[3, 3, 3], pool_type="avg")
self.res_2_np = adaptive_pool3d_forward(
x=self.x_np, output_size=5, pool_type="avg")
self.res_3_np = adaptive_pool3d_forward(
x=self.x_np, output_size=[2, 3, 5], pool_type="avg")
self.res_4_np = adaptive_pool3d_forward(
x=self.x_np,
output_size=[3, 3, 3],
pool_type="avg",
data_format="NDHWC")
self.res_5_np = adaptive_pool3d_forward(
x=self.x_np, output_size=[None, 3, None], pool_type="avg")
def test_static_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.enable_static()
x = paddle.data(name="x", shape=[2, 3, 5, 7, 7], dtype="float32")
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(
output_size=[3, 3, 3])
out_1 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(output_size=5)
out_2 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(
output_size=[2, 3, 5])
out_3 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(
output_size=[3, 3, 3], data_format="NDHWC")
out_4 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(
output_size=[None, 3, None])
out_5 = adaptive_avg_pool(x=x)
exe = paddle.static.Executor(place=place)
[res_1, res_2, res_3, res_4, res_5] = exe.run(
fluid.default_main_program(),
feed={"x": self.x_np},
fetch_list=[out_1, out_2, out_3, out_4, out_5])
assert np.allclose(res_1, self.res_1_np)
assert np.allclose(res_2, self.res_2_np)
assert np.allclose(res_3, self.res_3_np)
assert np.allclose(res_4, self.res_4_np)
assert np.allclose(res_5, self.res_5_np)
def test_dynamic_graph(self):
for use_cuda in ([False, True]
if core.is_compiled_with_cuda() else [False]):
place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
paddle.disable_static(place=place)
x = paddle.to_variable(self.x_np)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(
output_size=[3, 3, 3])
out_1 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(output_size=5)
out_2 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(
output_size=[2, 3, 5])
out_3 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(
output_size=[3, 3, 3], data_format="NDHWC")
out_4 = adaptive_avg_pool(x=x)
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(
output_size=[None, 3, None])
out_5 = adaptive_avg_pool(x=x)
assert np.allclose(out_1.numpy(), self.res_1_np)
assert np.allclose(out_2.numpy(), self.res_2_np)
assert np.allclose(out_3.numpy(), self.res_3_np)
assert np.allclose(out_4.numpy(), self.res_4_np)
assert np.allclose(out_5.numpy(), self.res_5_np)
if __name__ == '__main__':
unittest.main()
......@@ -87,6 +87,8 @@ from .layer.common import Embedding #DEFINE_ALIAS
from .layer.common import Linear #DEFINE_ALIAS
from .layer.common import Flatten #DEFINE_ALIAS
from .layer.common import UpSample #DEFINE_ALIAS
from .layer.pooling import AdaptiveAvgPool2d #DEFINE_ALIAS
from .layer.pooling import AdaptiveAvgPool3d #DEFINE_ALIAS
from .layer.conv import Conv2D #DEFINE_ALIAS
from .layer.conv import Conv2DTranspose #DEFINE_ALIAS
from .layer.conv import Conv3D #DEFINE_ALIAS
......
......@@ -160,6 +160,8 @@ from .pooling import pool2d #DEFINE_ALIAS
from .pooling import pool3d #DEFINE_ALIAS
from .pooling import adaptive_pool2d #DEFINE_ALIAS
from .pooling import adaptive_pool3d #DEFINE_ALIAS
from .pooling import adaptive_avg_pool2d #DEFINE_ALIAS
from .pooling import adaptive_avg_pool3d #DEFINE_ALIAS
# from .rnn import gru_unit #DEFINE_ALIAS
# from .rnn import lstm #DEFINE_ALIAS
# from .rnn import lstm_unit #DEFINE_ALIAS
......
......@@ -13,9 +13,268 @@
# limitations under the License.
# TODO: define pooling functions
import paddle
from ...fluid import core
from ...fluid.layers import pool2d #DEFINE_ALIAS
from ...fluid.layers import pool3d #DEFINE_ALIAS
from ...fluid.layers import adaptive_pool2d #DEFINE_ALIAS
from ...fluid.layers import adaptive_pool3d #DEFINE_ALIAS
from ...fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ...fluid.layers import utils
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import in_dygraph_mode
__all__ = ['pool2d', 'pool3d', 'adaptive_pool2d', 'adaptive_pool3d']
__all__ = [
'pool2d', 'pool3d', 'adaptive_pool2d', 'adaptive_pool3d',
'adaptive_avg_pool2d', 'adaptive_avg_pool3d'
]
def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None):
"""
This operation applies 2D adaptive avg pooling on input tensor. The h and w dimensions
of the output tensor are determined by the parameter output_size.
See more detail in :ref:`api_nn_pooling_AdaptiveAvgPool2d` .
For avg adaptive pool2d:
.. math::
hstart &= floor(i * H_{in} / H_{out})
hend &= ceil((i + 1) * H_{in} / H_{out})
wstart &= floor(j * W_{in} / W_{out})
wend &= ceil((j + 1) * W_{in} / W_{out})
Output(i ,j) &= \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)}
Args:
x (Tensor): The input tensor of adaptive avg pool2d operator, which is a 4-D tensor.
The data type can be float16, float32, float64, int32 or int64.
output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain two element, (H, W). H and W can be either a int, or None which means
the size will be the same as that of the input.
data_format (str): The data format of the input and output data. An optional string
from: "NCHW", "NHWC". The default is "NCHW". When it is "NCHW", the data is stored in
the order of: [batch_size, input_channels, input_height, input_width].
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor: The output tensor of avg adaptive pool2d result. The data type is same as input tensor.
Raises:
ValueError: If `data_format` is not "NCHW" or "NHWC".
Examples:
.. code-block:: python
# adaptive avg pool2d
# suppose input data in shape of [N, C, H, W], `output_size` is [m, n],
# output shape is [N, C, m, n], adaptive pool divide H and W dimensions
# of input data into m * n grids averagely and performs poolings in each
# grid to get output.
# adaptive avg pool performs calculations as follow:
#
# for i in range(m):
# for j in range(n):
# hstart = floor(i * H / m)
# hend = ceil((i + 1) * H / m)
# wstart = floor(i * W / n)
# wend = ceil((i + 1) * W / n)
# output[:, :, i, j] = avg(input[:, :, hstart: hend, wstart: wend])
#
import paddle
import numpy as np
paddle.disable_static()
input_data = np.random.rand(2, 3, 32, 32)
x = paddle.to_tensor(input_data)
# x.shape is [2, 3, 32, 32]
pool_out = paddle.nn.functional.adaptive_avg_pool2d(
x = x,
output_size=[3, 3])
# pool_out.shape is [2, 3, 3, 3]
"""
if not in_dygraph_mode():
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'],
'adaptive_avg_pool2d')
check_type(data_format, 'data_format', str, 'adaptive_avg_pool2d')
if data_format not in ["NCHW", "NHWC"]:
raise ValueError(
"Attr(data_format) should be 'NCHW' or 'NHWC'. Received "
"Attr(data_format): %s." % str(data_format))
if data_format == "NCHW":
in_h, in_w = x.shape[2:4]
else:
in_h, in_w = x.shape[1:3]
if isinstance(output_size, int):
output_size = utils.convert_to_list(output_size, 2, 'output_size')
else:
if output_size[0] == None:
output_size[0] = in_h
if output_size[1] == None:
output_size[1] = in_w
if in_dygraph_mode():
output = core.ops.pool2d(x, 'pooling_type', 'avg', 'ksize', output_size,
'global_pooling', False, 'adaptive', True,
'data_format', data_format)
return output
l_type = 'pool2d'
helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
outputs = {"Out": pool_out}
helper.append_op(
type=l_type,
inputs={"X": x},
outputs=outputs,
attrs={
"pooling_type": "avg",
"ksize": output_size,
"adaptive": True,
"data_format": data_format,
})
return pool_out
def adaptive_avg_pool3d(x, output_size, data_format='NCDHW', name=None):
"""
This operation applies 3D adaptive avg pooling on input tensor. The h and w dimensions
of the output tensor are determined by the parameter output_size.
See more detail in :ref:`api_nn_pooling_AdaptiveAvgPool3d` .
For avg adaptive pool3d:
.. math::
dstart &= floor(i * D_{in} / D_{out})
dend &= ceil((i + 1) * D_{in} / D_{out})
hstart &= floor(j * H_{in} / H_{out})
hend &= ceil((j + 1) * H_{in} / H_{out})
wstart &= floor(k * W_{in} / W_{out})
wend &= ceil((k + 1) * W_{in} / W_{out})
Output(i ,j, k) &= \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{(dend - dstart) * (hend - hstart) * (wend - wstart)}
Args:
x (Tensor): The input tensor of adaptive avg pool3d operator, which is a 5-D tensor.
The data type can be float16, float32, float64, int32 or int64.
output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain three elements, (D, H, W). D, H and W can be either a int, or None which means
the size will be the same as that of the input.
data_format (str): The data format of the input and output data. An optional string
from: "NCDHW", "NDHWC". The default is "NCDHW". When it is "NCDHW", the data is stored in
the order of: [batch_size, input_channels, input_depth, input_height, input_width].
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor: The output tensor of avg adaptive pool3d result. The data type is same as input tensor.
Raises:
ValueError: If `data_format` is not "NCDHW" or "NDHWC".
Examples:
.. code-block:: python
# adaptive avg pool3d
# suppose input data in shape of [N, C, D, H, W], `output_size` is [l, m, n],
# output shape is [N, C, l, m, n], adaptive pool divide D, H and W dimensions
# of input data into l * m * n grids averagely and performs poolings in each
# grid to get output.
# adaptive avg pool performs calculations as follow:
#
# for i in range(l):
# for j in range(m):
# for k in range(n):
# dstart = floor(i * D / l)
# dend = ceil((i + 1) * D / l)
# hstart = floor(j * H / m)
# hend = ceil((j + 1) * H / m)
# wstart = floor(k * W / n)
# wend = ceil((k + 1) * W / n)
# output[:, :, i, j, k] =
# avg(input[:, :, dstart:dend, hstart: hend, wstart: wend])
import paddle
import numpy as np
paddle.disable_static()
input_data = np.random.rand(2, 3, 8, 32, 32)
x = paddle.to_tensor(input_data)
# x.shape is [2, 3, 8, 32, 32]
pool_out = paddle.nn.functional.adaptive_avg_pool3d(
x = x,
output_size=[3, 3, 3])
# pool_out.shape is [2, 3, 3, 3, 3]
"""
if not in_dygraph_mode():
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'],
'adaptive_avg_pool3d')
check_type(data_format, 'data_format', str, 'adaptive_avg_pool3d')
if data_format not in ["NCDHW", "NDHWC"]:
raise ValueError(
"Attr(data_format) should be 'NCDHW' or 'NDHWC'. Received "
"Attr(data_format): %s." % str(data_format))
if data_format == "NCDHW":
in_l, in_h, in_w = x.shape[2:5]
else:
in_l, in_h, in_w = x.shape[1:4]
if isinstance(output_size, int):
output_size = utils.convert_to_list(output_size, 3, 'output_size')
else:
if output_size[0] == None:
output_size[0] = in_l
if output_size[1] == None:
output_size[1] = in_h
if output_size[2] == None:
output_size[2] = in_w
if in_dygraph_mode():
output = core.ops.pool3d(x, 'pooling_type', 'avg', 'ksize', output_size,
'global_pooling', False, 'adaptive', True,
'data_format', data_format)
return output
l_type = 'pool3d'
helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype)
outputs = {"Out": pool_out}
helper.append_op(
type=l_type,
inputs={"X": x},
outputs=outputs,
attrs={
"pooling_type": "avg",
"ksize": output_size,
"adaptive": True,
"data_format": data_format,
})
return pool_out
......@@ -52,6 +52,8 @@ from .common import Embedding #DEFINE_ALIAS
from .common import Linear #DEFINE_ALIAS
from .common import Flatten #DEFINE_ALIAS
from .common import UpSample #DEFINE_ALIAS
from .pooling import AdaptiveAvgPool2d #DEFINE_ALIAS
from .pooling import AdaptiveAvgPool3d #DEFINE_ALIAS
from .conv import Conv2D #DEFINE_ALIAS
from .conv import Conv2DTranspose #DEFINE_ALIAS
from .conv import Conv3D #DEFINE_ALIAS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from ...fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ...fluid.layers import utils
from ...fluid.dygraph import layers
from ...fluid.layer_helper import LayerHelper
from .. import functional as F
__all__ = [
'AdaptiveAvgPool2d',
'AdaptiveAvgPool3d',
]
class AdaptiveAvgPool2d(layers.Layer):
"""
This operation applies 2D adaptive avg pooling on input tensor. The h and w dimensions
of the output tensor are determined by the parameter output_size.
For avg adaptive pool2d:
.. math::
hstart &= floor(i * H_{in} / H_{out})
hend &= ceil((i + 1) * H_{in} / H_{out})
wstart &= floor(j * W_{in} / W_{out})
wend &= ceil((j + 1) * W_{in} / W_{out})
Output(i ,j) &= \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)}
Parameters:
output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain two element, (H, W). H and W can be either a int, or None which means
the size will be the same as that of the input.
data_format (str): The data format of the input and output data. An optional string
from: "NCHW", "NHWC". The default is "NCHW". When it is "NCHW", the data is stored in
the order of: [batch_size, input_channels, input_height, input_width].
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Shape:
x (Tensor): The input tensor of adaptive avg pool2d operator, which is a 4-D tensor. The data type can be float16, float32, float64, int32 or int64.
output (Tensor): The output tensor of adaptive avg pool2d operator, which is a 4-D tensor. The data type is same as input x.
Returns:
A callable object of AdaptiveAvgPool2d.
Examples:
.. code-block:: python
# adaptive avg pool2d
# suppose input data in shape of [N, C, H, W], `output_size` is [m, n],
# output shape is [N, C, m, n], adaptive pool divide H and W dimensions
# of input data into m * n grids averagely and performs poolings in each
# grid to get output.
# adaptive avg pool performs calculations as follow:
#
# for i in range(m):
# for j in range(n):
# hstart = floor(i * H / m)
# hend = ceil((i + 1) * H / m)
# wstart = floor(i * W / n)
# wend = ceil((i + 1) * W / n)
# output[:, :, i, j] = avg(input[:, :, hstart: hend, wstart: wend])
#
import paddle
import numpy as np
paddle.disable_static()
input_data = np.random.rand(2, 3, 32, 32)
x = paddle.to_tensor(input_data)
# x.shape is [2, 3, 32, 32]
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool2d(output_size=3)
pool_out = adaptive_avg_pool(x = x)
# pool_out.shape is [2, 3, 3, 3]
"""
def __init__(self, output_size, data_format="NCHW", name=None):
super(AdaptiveAvgPool2d, self).__init__()
self._output_size = output_size
self._data_format = data_format
self._name = name
def forward(self, x):
return F.adaptive_avg_pool2d(
x,
output_size=self._output_size,
data_format=self._data_format,
name=self._name)
class AdaptiveAvgPool3d(layers.Layer):
"""
This operation applies 3D adaptive avg pooling on input tensor. The h and w dimensions
of the output tensor are determined by the parameter output_size.
For avg adaptive pool3d:
.. math::
dstart &= floor(i * D_{in} / D_{out})
dend &= ceil((i + 1) * D_{in} / D_{out})
hstart &= floor(j * H_{in} / H_{out})
hend &= ceil((j + 1) * H_{in} / H_{out})
wstart &= floor(k * W_{in} / W_{out})
wend &= ceil((k + 1) * W_{in} / W_{out})
Output(i ,j, k) &= \\frac{sum(Input[dstart:dend, hstart:hend, wstart:wend])}{(dend - dstart) * (hend - hstart) * (wend - wstart)}
Parameters:
output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain three elements, (D, H, W). D, H and W can be either a int, or None which means
the size will be the same as that of the input.
data_format (str): The data format of the input and output data. An optional string
from: "NCDHW", "NDHWC". The default is "NCDHW". When it is "NCDHW", the data is stored in
the order of: [batch_size, input_channels, input_depth, input_height, input_width].
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Shape:
x (Tensor): The input tensor of adaptive avg pool3d operator, which is a 5-D tensor. The data type can be float16, float32, float64, int32 or int64.
output (Tensor): The output tensor of adaptive avg pool3d operator, which is a 5-D tensor. The data type is same as input x.
Returns:
A callable object of AdaptiveAvgPool3d.
Examples:
.. code-block:: python
# adaptive avg pool3d
# suppose input data in shape of [N, C, D, H, W], `output_size` is [l, m, n],
# output shape is [N, C, l, m, n], adaptive pool divide D, H and W dimensions
# of input data into l * m * n grids averagely and performs poolings in each
# grid to get output.
# adaptive avg pool performs calculations as follow:
#
# for i in range(l):
# for j in range(m):
# for k in range(n):
# dstart = floor(i * D / l)
# dend = ceil((i + 1) * D / l)
# hstart = floor(j * H / m)
# hend = ceil((j + 1) * H / m)
# wstart = floor(k * W / n)
# wend = ceil((k + 1) * W / n)
# output[:, :, i, j, k] =
# avg(input[:, :, dstart:dend, hstart: hend, wstart: wend])
import paddle
import numpy as np
paddle.disable_static()
input_data = np.random.rand(2, 3, 8, 32, 32)
x = paddle.to_tensor(input_data)
# x.shape is [2, 3, 8, 32, 32]
adaptive_avg_pool = paddle.nn.AdaptiveAvgPool3d(output_size=3)
pool_out = adaptive_avg_pool(x = x)
# pool_out = [2, 3, 3, 3, 3]
"""
def __init__(self, output_size, data_format="NCDHW", name=None):
super(AdaptiveAvgPool3d, self).__init__()
self._output_size = output_size
self._data_format = data_format
self._name = name
def forward(self, x):
return F.adaptive_avg_pool3d(
x,
output_size=self._output_size,
data_format=self._data_format,
name=self._name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册