未验证 提交 940e673b 编写于 作者: D Double_V 提交者: GitHub

add pool unittest (#26949) (#27097)

上级 e7ea2166
...@@ -143,6 +143,27 @@ class TestPool1d_API(unittest.TestCase): ...@@ -143,6 +143,27 @@ class TestPool1d_API(unittest.TestCase):
result = avg_pool1d_dg(input) result = avg_pool1d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np)) self.assertTrue(np.allclose(result.numpy(), result_np))
def check_avg_dygraph_padding_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result = F.avg_pool1d(
input,
kernel_size=2,
stride=2,
padding=[1],
count_include_pad=True)
result_np = avg_pool1D_forward_naive(
input_np, ksize=[2], strides=[2], paddings=[1], exclusive=False)
self.assertTrue(np.allclose(result.numpy(), result_np))
avg_pool1d_dg = paddle.nn.AvgPool1d(
kernel_size=2, stride=None, padding=1, count_include_pad=True)
result = avg_pool1d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_static_results(self, place): def check_max_static_results(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
input = fluid.data(name="input", shape=[2, 3, 32], dtype="float32") input = fluid.data(name="input", shape=[2, 3, 32], dtype="float32")
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from test_pool2d_op import adaptive_start_index, adaptive_end_index, pool2D_forward_naive from test_pool2d_op import adaptive_start_index, adaptive_end_index, pool2D_forward_naive, avg_pool2D_forward_naive, max_pool2D_forward_naive
import unittest import unittest
from op_test import OpTest from op_test import OpTest
import numpy as np import numpy as np
...@@ -68,6 +68,47 @@ class TestPool2d_API(unittest.TestCase): ...@@ -68,6 +68,47 @@ class TestPool2d_API(unittest.TestCase):
result = avg_pool2d_dg(input) result = avg_pool2d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np)) self.assertTrue(np.allclose(result.numpy(), result_np))
def check_avg_dygraph_padding_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result = avg_pool2d(
input, kernel_size=2, stride=2, padding=1, ceil_mode=False)
result_np = avg_pool2D_forward_naive(
input_np,
ksize=[2, 2],
strides=[2, 2],
paddings=[1, 1],
ceil_mode=False,
exclusive=False)
self.assertTrue(np.allclose(result.numpy(), result_np))
avg_pool2d_dg = paddle.nn.layer.AvgPool2d(
kernel_size=2, stride=2, padding=1, ceil_mode=False)
result = avg_pool2d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_avg_dygraph_ceilmode_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result = avg_pool2d(
input, kernel_size=2, stride=2, padding=0, ceil_mode=True)
result_np = avg_pool2D_forward_naive(
input_np,
ksize=[2, 2],
strides=[2, 2],
paddings=[0, 0],
ceil_mode=True)
self.assertTrue(np.allclose(result.numpy(), result_np))
avg_pool2d_dg = paddle.nn.layer.AvgPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
result = avg_pool2d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_static_results(self, place): def check_max_static_results(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
input = fluid.data( input = fluid.data(
...@@ -108,6 +149,70 @@ class TestPool2d_API(unittest.TestCase): ...@@ -108,6 +149,70 @@ class TestPool2d_API(unittest.TestCase):
result = max_pool2d_dg(input) result = max_pool2d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np)) self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_dygraph_nhwc_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32]).astype("float32")
input = fluid.dygraph.to_variable(
np.transpose(input_np, [0, 2, 3, 1]))
result = max_pool2d(
input,
kernel_size=2,
stride=2,
padding=0,
return_indices=False,
data_format="NHWC")
result_np = pool2D_forward_naive(
input_np,
ksize=[2, 2],
strides=[2, 2],
paddings=[0, 0],
pool_type='max')
self.assertTrue(
np.allclose(
np.transpose(result.numpy(), [0, 3, 1, 2]), result_np))
def check_max_dygraph_padding_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result = max_pool2d(
input, kernel_size=2, stride=2, padding=1, ceil_mode=False)
result_np = max_pool2D_forward_naive(
input_np,
ksize=[2, 2],
strides=[2, 2],
paddings=[1, 1],
ceil_mode=False,
exclusive=False)
self.assertTrue(np.allclose(result.numpy(), result_np))
max_pool2d_dg = paddle.nn.layer.MaxPool2d(
kernel_size=2, stride=2, padding=1, ceil_mode=False)
result = max_pool2d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_dygraph_ceilmode_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result = max_pool2d(
input, kernel_size=2, stride=2, padding=0, ceil_mode=True)
result_np = max_pool2D_forward_naive(
input_np,
ksize=[2, 2],
strides=[2, 2],
paddings=[0, 0],
ceil_mode=True)
self.assertTrue(np.allclose(result.numpy(), result_np))
max_pool2d_dg = paddle.nn.layer.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
result = max_pool2d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_dygraph_stride_is_none(self, place): def check_max_dygraph_stride_is_none(self, place):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32]).astype("float32") input_np = np.random.random([2, 3, 32, 32]).astype("float32")
...@@ -215,6 +320,9 @@ class TestPool2d_API(unittest.TestCase): ...@@ -215,6 +320,9 @@ class TestPool2d_API(unittest.TestCase):
self.check_avg_dygraph_stride_is_none(place) self.check_avg_dygraph_stride_is_none(place)
self.check_max_dygraph_padding(place) self.check_max_dygraph_padding(place)
self.check_avg_divisor(place) self.check_avg_divisor(place)
self.check_max_dygraph_padding_results(place)
self.check_max_dygraph_ceilmode_results(place)
self.check_max_dygraph_nhwc_results(place)
class TestPool2dError_API(unittest.TestCase): class TestPool2dError_API(unittest.TestCase):
...@@ -370,6 +478,22 @@ class TestPool2dError_API(unittest.TestCase): ...@@ -370,6 +478,22 @@ class TestPool2dError_API(unittest.TestCase):
self.assertRaises(ValueError, run8) self.assertRaises(ValueError, run8)
def run9():
with fluid.dygraph.guard():
input_np = np.random.uniform(-1, 1,
[2, 3, 32, 32]).astype(np.float32)
input_pd = fluid.dygraph.to_variable(input_np)
res_pd = max_pool2d(
input_pd,
kernel_size=2,
stride=2,
padding=0,
ceil_mode=False,
data_format='NHWC',
return_indices=True)
self.assertRaises(ValueError, run9)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -22,7 +22,7 @@ import paddle.fluid.core as core ...@@ -22,7 +22,7 @@ import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.nn.functional import avg_pool3d, max_pool3d from paddle.nn.functional import avg_pool3d, max_pool3d
from test_pool3d_op import adaptive_start_index, adaptive_end_index, pool3D_forward_naive from test_pool3d_op import adaptive_start_index, adaptive_end_index, pool3D_forward_naive, avg_pool3D_forward_naive, max_pool3D_forward_naive
class TestPool3d_API(unittest.TestCase): class TestPool3d_API(unittest.TestCase):
...@@ -73,6 +73,58 @@ class TestPool3d_API(unittest.TestCase): ...@@ -73,6 +73,58 @@ class TestPool3d_API(unittest.TestCase):
result = avg_pool3d_dg(input) result = avg_pool3d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np)) self.assertTrue(np.allclose(result.numpy(), result_np))
def check_avg_dygraph_padding_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result = avg_pool3d(
input,
kernel_size=2,
stride=2,
padding=1,
ceil_mode=False,
count_include_pad=True)
result_np = avg_pool3D_forward_naive(
input_np,
ksize=[2, 2, 2],
strides=[2, 2, 2],
paddings=[1, 1, 1],
ceil_mode=False,
exclusive=False)
self.assertTrue(np.allclose(result.numpy(), result_np))
avg_pool3d_dg = paddle.nn.layer.AvgPool3d(
kernel_size=2,
stride=None,
padding=1,
ceil_mode=False,
count_include_pad=True)
result = avg_pool3d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_avg_dygraph_ceilmode_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result = avg_pool3d(
input, kernel_size=2, stride=2, padding=0, ceil_mode=True)
result_np = avg_pool3D_forward_naive(
input_np,
ksize=[2, 2, 2],
strides=[2, 2, 2],
paddings=[0, 0, 0],
ceil_mode=True)
self.assertTrue(np.allclose(result.numpy(), result_np))
avg_pool3d_dg = paddle.nn.layer.AvgPool3d(
kernel_size=2, stride=None, padding=0, ceil_mode=True)
result = avg_pool3d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_static_results(self, place): def check_max_static_results(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
input = fluid.data( input = fluid.data(
...@@ -112,6 +164,74 @@ class TestPool3d_API(unittest.TestCase): ...@@ -112,6 +164,74 @@ class TestPool3d_API(unittest.TestCase):
result = max_pool3d_dg(input) result = max_pool3d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np)) self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_dygraph_ndhwc_results(self, place):
print("run ndchw max pool3d")
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32")
input = fluid.dygraph.to_variable(
np.transpose(input_np, [0, 2, 3, 4, 1]))
result = max_pool3d(
input,
kernel_size=2,
stride=2,
padding=0,
data_format="NDHWC",
return_indices=False)
result_np = pool3D_forward_naive(
input_np,
ksize=[2, 2, 2],
strides=[2, 2, 2],
paddings=[0, 0, 0],
pool_type='max')
self.assertTrue(
np.allclose(
np.transpose(result.numpy(), [0, 4, 1, 2, 3]), result_np))
def check_max_dygraph_ceilmode_results(self, place):
print("run ceil mode max pool3d")
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result = max_pool3d(
input, kernel_size=2, stride=2, padding=0, ceil_mode=True)
result_np = max_pool3D_forward_naive(
input_np,
ksize=[2, 2, 2],
strides=[2, 2, 2],
paddings=[0, 0, 0],
ceil_mode=True)
self.assertTrue(np.allclose(result.numpy(), result_np))
max_pool3d_dg = paddle.nn.layer.MaxPool3d(
kernel_size=2, stride=None, padding=0, ceil_mode=True)
result = max_pool3d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_dygraph_padding_results(self, place):
with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32")
input = fluid.dygraph.to_variable(input_np)
result = max_pool3d(
input, kernel_size=2, stride=2, padding=1, ceil_mode=False)
result_np = max_pool3D_forward_naive(
input_np,
ksize=[2, 2, 2],
strides=[2, 2, 2],
paddings=[1, 1, 1],
ceil_mode=False)
self.assertTrue(np.allclose(result.numpy(), result_np))
max_pool3d_dg = paddle.nn.layer.MaxPool3d(
kernel_size=2, stride=None, padding=1, ceil_mode=False)
result = max_pool3d_dg(input)
self.assertTrue(np.allclose(result.numpy(), result_np))
def check_max_dygraph_stride_is_none(self, place): def check_max_dygraph_stride_is_none(self, place):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32") input_np = np.random.random([2, 3, 32, 32, 32]).astype("float32")
...@@ -205,6 +325,8 @@ class TestPool3d_API(unittest.TestCase): ...@@ -205,6 +325,8 @@ class TestPool3d_API(unittest.TestCase):
self.check_max_dygraph_stride_is_none(place) self.check_max_dygraph_stride_is_none(place)
self.check_max_dygraph_padding(place) self.check_max_dygraph_padding(place)
self.check_avg_divisor(place) self.check_avg_divisor(place)
self.check_max_dygraph_ndhwc_results(place)
self.check_max_dygraph_ceilmode_results(place)
class TestPool3dError_API(unittest.TestCase): class TestPool3dError_API(unittest.TestCase):
...@@ -336,6 +458,21 @@ class TestPool3dError_API(unittest.TestCase): ...@@ -336,6 +458,21 @@ class TestPool3dError_API(unittest.TestCase):
self.assertRaises(ValueError, run9) self.assertRaises(ValueError, run9)
def run10():
with fluid.dygraph.guard():
input_np = np.random.uniform(
-1, 1, [2, 3, 32, 32, 32]).astype(np.float32)
input_pd = fluid.dygraph.to_variable(input_np)
res_pd = max_pool3d(
input_pd,
kernel_size=2,
stride=2,
padding=0,
data_format='NDHWC',
return_indices=True)
self.assertRaises(ValueError, run10)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -168,7 +168,7 @@ def avg_pool1d(x, ...@@ -168,7 +168,7 @@ def avg_pool1d(x,
count_include_pad=True, count_include_pad=True,
ceil_mode=False, ceil_mode=False,
name=None): name=None):
""" """
This API implements average pooling 1d operation, This API implements average pooling 1d operation,
See more details in :ref:`api_nn_pooling_AvgPool1d` . See more details in :ref:`api_nn_pooling_AvgPool1d` .
...@@ -280,7 +280,7 @@ def avg_pool2d(x, ...@@ -280,7 +280,7 @@ def avg_pool2d(x,
""" """
This API implements average pooling 2d operation. This API implements average pooling 2d operation.
See more details in :ref:`api_nn_pooling_AvgPool2d` . See more details in :ref:`api_nn_pooling_AvgPool2d` .
Args: Args:
x (Tensor): The input tensor of pooling operator which is a 4-D tensor with x (Tensor): The input tensor of pooling operator which is a 4-D tensor with
shape [N, C, H, W]. The format of input tensor is `"NCHW"` or shape [N, C, H, W]. The format of input tensor is `"NCHW"` or
...@@ -640,7 +640,7 @@ def max_pool2d(x, ...@@ -640,7 +640,7 @@ def max_pool2d(x,
5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0).
The default value is 0. The default value is 0.
ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape
return_indices (bool): Whether to return the max indices along with the outputs. return_indices (bool): Whether to return the max indices along with the outputs. Default False, only support `"NCHW"` data format
data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`. data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`.
The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_height, input_width]`. `[batch_size, input_channels, input_height, input_width]`.
...@@ -690,15 +690,30 @@ def max_pool2d(x, ...@@ -690,15 +690,30 @@ def max_pool2d(x,
padding, padding_algorithm = _update_padding_nd( padding, padding_algorithm = _update_padding_nd(
padding, num_dims=2, channel_last=channel_last, ceil_mode=ceil_mode) padding, num_dims=2, channel_last=channel_last, ceil_mode=ceil_mode)
if data_format == "NHWC" and return_indices:
raise ValueError(
"When setting return_indices to true, data_format must be set to NCHW in API:max_pool2d"
)
if in_dygraph_mode(): if in_dygraph_mode():
output = core.ops.max_pool2d_with_index( if data_format == "NCHW":
x, 'ksize', kernel_size, 'global_pooling', False, 'strides', stride, output = core.ops.max_pool2d_with_index(
'paddings', padding, 'padding_algorithm', padding_algorithm, x, 'ksize', kernel_size, 'global_pooling', False, 'strides',
'use_cudnn', True, 'ceil_mode', ceil_mode, 'use_mkldnn', False, stride, 'paddings', padding, 'padding_algorithm',
'exclusive', True, 'data_format', data_format) padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode,
return output if return_indices else output[0] 'use_mkldnn', False, 'exclusive', True, 'data_format',
data_format)
return output if return_indices else output[0]
elif data_format == "NHWC" and not return_indices:
output = core.ops.pool2d(
x, 'pooling_type', 'max', 'ksize', kernel_size,
'global_pooling', False, 'padding_algorithm', padding_algorithm,
'strides', stride, 'paddings', padding, 'use_cudnn', True,
'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True,
'data_format', data_format)
return output
op_type = 'max_pool2d_with_index' op_type = 'max_pool2d_with_index' if data_format == "NCHW" else "max_pool2d"
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
...@@ -739,7 +754,7 @@ def max_pool3d(x, ...@@ -739,7 +754,7 @@ def max_pool3d(x,
See more details in :ref:`api_nn_pooling_MaxPool3d` . See more details in :ref:`api_nn_pooling_MaxPool3d` .
Args: Args:
x (Tensor): The input tensor of pooling operator, which is a 5-D tensor with x (Tensor): The input tensor of pooling operator, which is a 5-D tensor with
shape [N, C, D, H, W]. The format of input tensor is `"NCDHW"` or `"NDHWC"`, where N represents batch size, C represents the number of channels, D, H and W represent the depth, height and width of the feature respectively. shape [N, C, D, H, W]. The format of input tensor is `"NCDHW"` or `"NDHWC"`, where N represents batch size, C represents the number of channels, D, H and W represent the depth, height and width of the feature respectively.
kernel_size (int|list|tuple): The pool kernel size. If the kernel size kernel_size (int|list|tuple): The pool kernel size. If the kernel size
is a tuple or list, it must contain three integers, is a tuple or list, it must contain three integers,
(kernel_size_Depth, kernel_size_Height, kernel_size_Width). (kernel_size_Depth, kernel_size_Height, kernel_size_Width).
...@@ -755,7 +770,7 @@ def max_pool3d(x, ...@@ -755,7 +770,7 @@ def max_pool3d(x,
5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0).
The default value is 0. The default value is 0.
ceil_mode (bool): ${ceil_mode_comment} ceil_mode (bool): ${ceil_mode_comment}
return_indices (bool): Whether to return the max indices along with the outputs. return_indices (bool): Whether to return the max indices along with the outputs. Default False. Only support "NDCHW" data_format.
data_format (string): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`. data_format (string): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`.
The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of: The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of:
`[batch_size, input_channels, input_depth, input_height, input_width]`. `[batch_size, input_channels, input_depth, input_height, input_width]`.
...@@ -801,15 +816,30 @@ def max_pool3d(x, ...@@ -801,15 +816,30 @@ def max_pool3d(x,
padding, padding_algorithm = _update_padding_nd( padding, padding_algorithm = _update_padding_nd(
padding, 3, channel_last=channel_last, ceil_mode=ceil_mode) padding, 3, channel_last=channel_last, ceil_mode=ceil_mode)
if data_format == "NDHWC" and return_indices:
raise ValueError(
"When setting return_indices to true, data_format must be set to NCDHW in API:max_pool3d"
)
if in_dygraph_mode(): if in_dygraph_mode():
output = core.ops.max_pool3d_with_index( if data_format == "NCDHW":
x, 'pooling_type', 'max', 'ksize', kernel_size, 'strides', stride, output = core.ops.max_pool3d_with_index(
'paddings', padding, 'global_pooling', False, 'padding_algorithm', x, 'pooling_type', 'max', 'ksize', kernel_size, 'strides',
padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode, stride, 'paddings', padding, 'global_pooling', False,
'use_mkldnn', False, 'exclusive', True, 'data_format', data_format) 'padding_algorithm', padding_algorithm, 'use_cudnn', True,
return output if return_indices else output[0] 'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True,
'data_format', data_format)
return output if return_indices else output[0]
elif data_format == "NDHWC" and not return_indices:
output = core.ops.pool3d(
x, 'pooling_type', 'max', 'ksize', kernel_size,
'global_pooling', False, 'padding_algorithm', padding_algorithm,
'strides', stride, 'paddings', padding, 'use_cudnn', True,
'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True,
'data_format', data_format)
return output
op_type = "max_pool3d_with_index" op_type = "max_pool3d_with_index" if data_format == "NCDHW" else "max_pool3d"
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
pool_out = helper.create_variable_for_type_inference(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
...@@ -841,7 +871,7 @@ def adaptive_avg_pool1d(x, output_size, name=None): ...@@ -841,7 +871,7 @@ def adaptive_avg_pool1d(x, output_size, name=None):
""" """
This API implements adaptive average pooling 1d operation. This API implements adaptive average pooling 1d operation.
See more details in :ref:`api_nn_pooling_AdaptiveAvgPool1d` . See more details in :ref:`api_nn_pooling_AdaptiveAvgPool1d` .
Args: Args:
x (Tensor): The input tensor of pooling operator, which is a 3-D tensor x (Tensor): The input tensor of pooling operator, which is a 3-D tensor
with shape [N, C, L]. The format of input tensor is NCL, with shape [N, C, L]. The format of input tensor is NCL,
......
...@@ -850,7 +850,7 @@ class AdaptiveMaxPool1d(layers.Layer): ...@@ -850,7 +850,7 @@ class AdaptiveMaxPool1d(layers.Layer):
lend &= ceil((i + 1) * L_{in} / L_{out}) lend &= ceil((i + 1) * L_{in} / L_{out})
Output(i) &= max(Input[lstart:lend])} Output(i) &= max(Input[lstart:lend])
Args: Args:
output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
...@@ -932,7 +932,7 @@ class AdaptiveMaxPool2d(layers.Layer): ...@@ -932,7 +932,7 @@ class AdaptiveMaxPool2d(layers.Layer):
Shape: Shape:
x (Tensor): The input tensor of adaptive max pool2d operator, which is a 4-D tensor. The data type can be float32, float64. x (Tensor): The input tensor of adaptive max pool2d operator, which is a 4-D tensor. The data type can be float32, float64.
output (Tensor): The output tensor of adaptive max pool2d operator, which is a 4-D tensor. The data type is same as input x. output (Tensor): The output tensor of adaptive max pool2d operator, which is a 4-D tensor. The data type is same as input x.
Returns: Returns:
A callable object of AdaptiveMaxPool2d. A callable object of AdaptiveMaxPool2d.
Examples: Examples:
...@@ -1032,7 +1032,7 @@ class AdaptiveMaxPool3d(layers.Layer): ...@@ -1032,7 +1032,7 @@ class AdaptiveMaxPool3d(layers.Layer):
pool, indices = paddle.nn.AdaptiveMaxPool3d(output_size=3, return_indices=True) pool, indices = paddle.nn.AdaptiveMaxPool3d(output_size=3, return_indices=True)
out = pool(x) out = pool(x)
# out shape: [2, 3, 4, 4, 4], indices shape: [2, 3, 4, 4, 4] # out shape: [2, 3, 4, 4, 4], indices shape: [2, 3, 4, 4, 4]
""" """
def __init__(self, output_size, return_indices=False, name=None): def __init__(self, output_size, return_indices=False, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册