diff --git a/doc/fluid/api_cn/dygraph_cn/Pool2D_cn.rst b/doc/fluid/api_cn/dygraph_cn/Pool2D_cn.rst index fb47ba8ab256f2895f2d85b803e76d833b70159e..e66ec6b3237edbe73446be147aef39efe3cb66a8 100644 --- a/doc/fluid/api_cn/dygraph_cn/Pool2D_cn.rst +++ b/doc/fluid/api_cn/dygraph_cn/Pool2D_cn.rst @@ -3,7 +3,7 @@ Pool2D ------------------------------- -.. py:class:: paddle.fluid.dygraph.Pool2D(pool_size=-1, pool_type='max', pool_stride=1, pool_padding=0, global_pooling=False, use_cudnn=True, ceil_mode=False, exclusive=True) +.. py:class:: paddle.fluid.dygraph.Pool2D(pool_size=-1, pool_type='max', pool_stride=1, pool_padding=0, global_pooling=False, use_cudnn=True, ceil_mode=False, exclusive=True, data_format="NCHW") :alias_main: paddle.nn.Pool2D :alias: paddle.nn.Pool2D,paddle.nn.layer.Pool2D,paddle.nn.layer.common.Pool2D @@ -13,7 +13,7 @@ Pool2D 该接口用于构建 ``Pool2D`` 类的一个可调用对象,具体用法参照 ``代码示例`` 。其将在神经网络中构建一个二维池化层,并使用上述输入参数的池化配置,为二维空间池化操作,根据 ``input`` , 池化类型 ``pool_type`` , 池化核大小 ``pool_size`` , 步长 ``pool_stride`` ,填充 ``pool_padding`` 这些参数得到输出。 -输入X和输出Out是NCHW格式,N为批大小,C是通道数,H是特征高度,W是特征宽度。参数( ``ksize``, ``strides``, ``paddings`` )含有两个整型元素。分别表示高度和宽度上的参数。输入X的大小和输出Out的大小可能不一致。 +输入X和输出Out默认是NCHW格式,N为批大小,C是通道数,H是特征高度,W是特征宽度。参数( ``ksize``, ``strides``, ``paddings`` )含有两个整型元素。分别表示高度和宽度上的参数。输入X的大小和输出Out的大小可能不一致。 例如: @@ -66,13 +66,15 @@ Pool2D - **use_cudnn** (bool, 可选)- 是否用cudnn核,只有已安装cudnn库时才有效。默认True。 - **ceil_mode** (bool, 可选)- 是否用ceil函数计算输出高度和宽度。如果设为False,则使用floor函数。默认为False。 - **exclusive** (bool, 可选) - 是否在平均池化模式忽略填充值。默认为True。 + - **data_format** (str,可选) - 指定输入的数据格式,输出的数据格式将与输入保持一致,可以是"NCHW"和"NHWC"。N是批尺寸,C是通道数,H是特征高度,W是特征宽度。默认值:"NCHW"。 返回:无 抛出异常: - - ``ValueError`` - 如果 ``pool_type`` 既不是“max”也不是“avg” - - ``ValueError`` - 如果 ``global_pooling`` 为False并且‘pool_size’为-1 - - ``ValueError`` - 如果 ``use_cudnn`` 不是bool值 + - ``ValueError`` - 如果 ``pool_type`` 既不是“max”也不是“avg”。 + - ``ValueError`` - 如果 ``global_pooling`` 为False并且 ``pool_size`` 为-1。 + - ``ValueError`` - 如果 ``use_cudnn`` 不是bool值。 + - ``ValueError`` - 如果 ``data_format`` 既不是"NCHW"也不是"NHWC"。 **代码示例** @@ -80,9 +82,10 @@ Pool2D import paddle.fluid as fluid from paddle.fluid.dygraph.base import to_variable + import numpy as np with fluid.dygraph.guard(): - data = numpy.random.random((3, 32, 32, 5)).astype('float32') + data = np.random.random((3, 32, 32, 5)).astype('float32') pool2d = fluid.dygraph.Pool2D(pool_size=2, pool_type='max', pool_stride=1,