diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 54e51200dc745338ee0dbf7c9e86aae1eb2e8bb8..ae4dda166c733e7b6da44fb215b5913672f5d42c 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -103,8 +103,6 @@ from .tensor.logic import logical_not #DEFINE_ALIAS from .tensor.logic import logical_or #DEFINE_ALIAS from .tensor.logic import logical_xor #DEFINE_ALIAS from .tensor.logic import not_equal #DEFINE_ALIAS -# from .tensor.logic import reduce_all #DEFINE_ALIAS -# from .tensor.logic import reduce_any #DEFINE_ALIAS from .tensor.logic import allclose #DEFINE_ALIAS from .tensor.logic import equal_all #DEFINE_ALIAS # from .tensor.logic import isnan #DEFINE_ALIAS @@ -162,6 +160,8 @@ from .tensor.math import reciprocal #DEFINE_ALIAS # from .tensor.math import reduce_min #DEFINE_ALIAS # from .tensor.math import reduce_prod #DEFINE_ALIAS # from .tensor.math import reduce_sum #DEFINE_ALIAS +from .tensor.math import all #DEFINE_ALIAS +from .tensor.math import any #DEFINE_ALIAS from .tensor.math import round #DEFINE_ALIAS from .tensor.math import rsqrt #DEFINE_ALIAS from .tensor.math import scale #DEFINE_ALIAS diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index c2bb96ead2bf985efdf6d572bd09ecf3c091353e..ac762944b3a6885b2f32ce5e1be408c5d40f0e43 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -315,6 +315,8 @@ def fc(input, .. code-block:: python import paddle.fluid as fluid + import paddle + paddle.enable_static() # when input is single tensor data = fluid.data(name="data", shape=[-1, 32], dtype="float32") fc = fluid.layers.fc(input=data, size=1000, act="tanh") @@ -468,6 +470,9 @@ def embedding(input, import paddle.fluid as fluid import numpy as np + import paddle + paddle.enable_static() + data = fluid.data(name='x', shape=[None, 1], dtype='int64') # example 1 @@ -731,6 +736,8 @@ def linear_chain_crf(input, label, param_attr=None, length=None): import paddle.fluid as fluid import numpy as np + import paddle + paddle.enable_static() #define net structure, using LodTensor train_program = fluid.Program() @@ -855,6 +862,8 @@ def crf_decoding(input, param_attr, label=None, length=None): .. code-block:: python import paddle.fluid as fluid + import paddle + paddle.enable_static() # LoDTensor-based example num_labels = 10 @@ -1458,6 +1467,9 @@ def conv2d(input, .. code-block:: python import paddle.fluid as fluid + import paddle + paddle.enable_static() + data = fluid.data(name='data', shape=[None, 3, 32, 32], dtype='float32') conv2d = fluid.layers.conv2d(input=data, num_filters=2, filter_size=3, act="relu") """ @@ -1728,6 +1740,8 @@ def conv3d(input, .. code-block:: python import paddle.fluid as fluid + import paddle + paddle.enable_static() data = fluid.data(name='data', shape=[None, 3, 12, 32, 32], dtype='float32') conv3d = fluid.layers.conv3d(input=data, num_filters=2, filter_size=3, act="relu") """ @@ -2377,6 +2391,7 @@ def adaptive_pool2d(input, # output[:, :, i, j] = avg(input[:, :, hstart: hend, wstart: wend]) # import paddle + paddle.enable_static() data = paddle.rand(shape=[1,3,32,32]) pool_out = paddle.fluid.layers.adaptive_pool2d( input=data, @@ -2531,6 +2546,7 @@ def adaptive_pool3d(input, # import paddle + paddle.enable_static() data = paddle.rand(shape=[1,3,32,32,32]) pool_out = paddle.fluid.layers.adaptive_pool3d( input=data, @@ -2726,6 +2742,8 @@ def batch_norm(input, .. code-block:: python import paddle.fluid as fluid + import paddle + paddle.enable_static() x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32') hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w') hidden2 = fluid.layers.batch_norm(input=hidden1) @@ -2735,6 +2753,8 @@ def batch_norm(input, # batch_norm with momentum as Variable import paddle.fluid as fluid import paddle.fluid.layers.learning_rate_scheduler as lr_scheduler + import paddle + paddle.enable_static() def get_decay_momentum(momentum_init, decay_steps, decay_rate): global_step = lr_scheduler._decay_step_counter() @@ -3134,6 +3154,8 @@ def instance_norm(input, .. code-block:: python import paddle.fluid as fluid + import paddle + paddle.enable_static() x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32') hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w') hidden2 = fluid.layers.instance_norm(input=hidden1) @@ -3269,6 +3291,7 @@ def data_norm(input, .. code-block:: python import paddle + paddle.enable_static() x = paddle.randn(shape=[32,100]) hidden2 = paddle.static.nn.data_norm(input=x) @@ -3451,6 +3474,8 @@ def layer_norm(input, import paddle.fluid as fluid import numpy as np + import paddle + paddle.enable_static() x = fluid.data(name='x', shape=[-1, 32, 32], dtype='float32') hidden1 = fluid.layers.layer_norm(input=x, begin_norm_axis=1) place = fluid.CPUPlace() @@ -3566,6 +3591,9 @@ def group_norm(input, .. code-block:: python import paddle.fluid as fluid + import paddle + paddle.enable_static() + data = fluid.data(name='data', shape=[None, 8, 32, 32], dtype='float32') x = fluid.layers.group_norm(input=data, groups=4) """ @@ -3887,6 +3915,8 @@ def conv2d_transpose(input, .. code-block:: python import paddle.fluid as fluid + import paddle + paddle.enable_static() data = fluid.data(name='data', shape=[None, 3, 32, 32], dtype='float32') conv2d_transpose = fluid.layers.conv2d_transpose(input=data, num_filters=2, filter_size=3) """ @@ -4177,6 +4207,8 @@ def conv3d_transpose(input, .. code-block:: python import paddle.fluid as fluid + import paddle + paddle.enable_static() data = fluid.data(name='data', shape=[None, 3, 12, 32, 32], dtype='float32') conv3d_transpose = fluid.layers.conv3d_transpose(input=data, num_filters=2, filter_size=3) """ @@ -4659,7 +4691,7 @@ def reduce_all(input, dim=None, keep_dim=False, name=None): This OP computes the ``logical and`` of tensor elements over the given dimension, and output the result. Args: - input (Variable): The input variable which is a Tensor or LoDTensor, the input data type should be `bool`. + input (Tensor): the input tensor, it's data type should be `bool`. dim (list|int|optional): The dimension along which the logical and is computed. If :attr:`None`, compute the logical and over all elements of :attr:`input` and return a Tensor variable with a single element, @@ -4672,11 +4704,12 @@ def reduce_all(input, dim=None, keep_dim=False, name=None): will be named automatically. The default value is None. Returns: - Variable, the output data type is bool. : The reduced tensor variable with ``logical and`` in given dims. + Tensor, the output data type is bool. : The reduced tensor variable with ``logical and`` in given dims. Examples: .. code-block:: python + import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers import numpy as np @@ -4684,15 +4717,15 @@ def reduce_all(input, dim=None, keep_dim=False, name=None): # x is a bool Tensor variable with following elements: # [[True, False] # [True, True]] - x = layers.assign(np.array([[1, 0], [1, 1]], dtype='int32')) - x = layers.cast(x, 'bool') + x = fluid.layers.assign(np.array([[1, 0], [1, 1]], dtype='int32')) + x = fluid.layers.cast(x, 'bool') - out = layers.reduce_all(x) # False - out = layers.reduce_all(x, dim=0) # [True, False] - out = layers.reduce_all(x, dim=-1) # [False, True] + out = fluid.layers.reduce_all(x) # False + out = fluid.layers.reduce_all(x, dim=0) # [True, False] + out = fluid.layers.reduce_all(x, dim=-1) # [False, True] # keep_dim=False, x.shape=(2,2), out.shape=(2,) - out = layers.reduce_all(x, dim=1, keep_dim=True) # [[False], [True]] + out = fluid.layers.reduce_all(x, dim=1, keep_dim=True) # [[False], [True]] # keep_dim=True, x.shape=(2,2), out.shape=(2,1) """ @@ -4719,7 +4752,7 @@ def reduce_any(input, dim=None, keep_dim=False, name=None): This OP computes the ``logical or`` of tensor elements over the given dimension, and output the result. Args: - input (Variable): The input variable which is a Tensor or LoDTensor, the input data type should be `bool`. + input (Tensor): the input tensor, it's data type should be `bool`. dim (list|int|optional): The dimension along which the logical and is computed. If :attr:`None`, compute the logical and over all elements of :attr:`input` and return a Tensor variable with a single element, @@ -4728,14 +4761,15 @@ def reduce_any(input, dim=None, keep_dim=False, name=None): keep_dim (bool): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` is true. The default value is False. - name(str|None): A name for this layer(optional). If set None, the layer + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Variable, the output data type is bool. : The reduced tensor variable with ``logical or`` in given dims. + Tensor, the output data type is bool. : The reduced tensor variable with ``logical or`` in given dims. Examples: .. code-block:: python + import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers import numpy as np @@ -4743,15 +4777,15 @@ def reduce_any(input, dim=None, keep_dim=False, name=None): # x is a bool Tensor variable with following elements: # [[True, False] # [False, False]] - x = layers.assign(np.array([[1, 0], [0, 0]], dtype='int32')) - x = layers.cast(x, 'bool') + x = fluid.layers.assign(np.array([[1, 0], [0, 0]], dtype='int32')) + x = fluid.layers.cast(x, 'bool') - out = layers.reduce_any(x) # True - out = layers.reduce_any(x, dim=0) # [True, False] - out = layers.reduce_any(x, dim=-1) # [True, False] + out = fluid.layers.reduce_any(x) # True + out = fluid.layers.reduce_any(x, dim=0) # [True, False] + out = fluid.layers.reduce_any(x, dim=-1) # [True, False] # keep_dim=False, x.shape=(2,2), out.shape=(2,) - out = layers.reduce_any(x, dim=1, + out = fluid.layers.reduce_any(x, dim=1, keep_dim=True) # [[True], [False]] # keep_dim=True, x.shape=(2,2), out.shape=(2,1) @@ -5613,6 +5647,8 @@ def im2sequence(input, .. code-block:: python import paddle.fluid as fluid + import paddle + paddle.enable_static() data = fluid.data(name='data', shape=[None, 3, 32, 32], dtype='float32') output = fluid.layers.im2sequence( @@ -5669,6 +5705,8 @@ def row_conv(input, future_context_size, param_attr=None, act=None): Examples: >>> # for LodTensor inputs >>> import paddle.fluid as fluid + >>> import paddle + >>> paddle.enable_static() >>> x = fluid.data(name='x', shape=[9, 16], >>> dtype='float32', lod_level=1) >>> out = fluid.layers.row_conv(input=x, future_context_size=2) @@ -5982,6 +6020,8 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): .. code-block:: python import paddle.fluid as fluid + import paddle + paddle.enable_static() global_step = fluid.layers.autoincreased_step_counter( counter_name='@LR_DECAY_COUNTER@', begin=0, step=1) """ @@ -9730,6 +9770,8 @@ def prelu(x, mode, param_attr=None, name=None): .. code-block:: python import paddle.fluid as fluid + import paddle + paddle.enable_static() from paddle.fluid.param_attr import ParamAttr x = fluid.data(name="x", shape=[None,5,10,10], dtype="float32") mode = 'channel' @@ -14307,6 +14349,9 @@ def deformable_conv(input, #deformable conv v2: import paddle.fluid as fluid + import paddle + paddle.enable_static() + C_in, H_in, W_in = 3, 32, 32 filter_size, deformable_groups = 3, 1 data = fluid.data(name='data', shape=[None, C_in, H_in, W_in], dtype='float32') diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py index dcf5151578ad5e574a9a723ee3d416a1d47ecb9c..d525009fbd734b9f743715319f00cbf1ef0ae659 100644 --- a/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py @@ -63,10 +63,7 @@ class TestLayer(fluid.dygraph.Layer): bias_attr=False) self._sync_batch_norm2 = SyncBatchNorm( - num_filters, - weight_attr=False, - bias_attr=False, - track_running_stats=False) + num_filters, weight_attr=False, bias_attr=False) def forward(self, inputs): y = self._conv(inputs) diff --git a/python/paddle/fluid/tests/unittests/test_adaptive_max_pool2d.py b/python/paddle/fluid/tests/unittests/test_adaptive_max_pool2d.py index 944725fab643580f9f60336edf05e3f96ae1255e..18860db9dae51cc41e4c8c6f2563db5d444a905d 100644 --- a/python/paddle/fluid/tests/unittests/test_adaptive_max_pool2d.py +++ b/python/paddle/fluid/tests/unittests/test_adaptive_max_pool2d.py @@ -150,7 +150,7 @@ class TestAdaptiveMaxPool2DAPI(unittest.TestCase): x = paddle.to_tensor(self.x_np) out_1 = paddle.nn.functional.adaptive_max_pool2d( - x=x, return_indices=False, output_size=[3, 3]) + x=x, return_mask=False, output_size=[3, 3]) out_2 = paddle.nn.functional.adaptive_max_pool2d(x=x, output_size=5) diff --git a/python/paddle/fluid/tests/unittests/test_conv1d_transpose_layer.py b/python/paddle/fluid/tests/unittests/test_conv1d_transpose_layer.py index 9c43e2f3e6e9d834c25a7490653fa7ead214eaa3..40b7074ed3914e67d3cd58af3c8fd63e92736740 100644 --- a/python/paddle/fluid/tests/unittests/test_conv1d_transpose_layer.py +++ b/python/paddle/fluid/tests/unittests/test_conv1d_transpose_layer.py @@ -92,7 +92,7 @@ class Conv1DTransposeTestCase(unittest.TestCase): "weight", self.weight_shape, dtype=self.dtype) b_var = fluid.data( "bias", (self.out_channels, ), dtype=self.dtype) - y_var = F.conv_transpose1d( + y_var = F.conv1d_transpose( x_var, w_var, None if self.no_bias else b_var, diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_layer.py b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_layer.py index 28c3a466aa6c8d6252b4a1a04a61ba79c571785f..f51baf50ec898a7364be24806138f92ea2f32c05 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_transpose_layer.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_transpose_layer.py @@ -128,7 +128,7 @@ class Conv2DTransposeTestCase(unittest.TestCase): else: output_size = self.output_size - y_var = F.conv_transpose2d( + y_var = F.conv2d_transpose( x_var, w_var, None if self.no_bias else b_var, diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_transpose_layer.py b/python/paddle/fluid/tests/unittests/test_conv3d_transpose_layer.py index dac84a8486ef243231ac473223144d9e81ea3fac..a567ec727389366e020441e336c12c4395d8e056 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_transpose_layer.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_transpose_layer.py @@ -119,7 +119,7 @@ class Conv3DTransposeTestCase(unittest.TestCase): "weight", self.weight_shape, dtype=self.dtype) b_var = fluid.data( "bias", (self.num_filters, ), dtype=self.dtype) - y_var = F.conv_transpose3d( + y_var = F.conv3d_transpose( x_var, w_var, None if self.no_bias else b_var, diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv2d_transpose.py b/python/paddle/fluid/tests/unittests/test_functional_conv2d_transpose.py index 1fb07bf4345909deb5485a89232270336658ae8b..e3b821a07bffdf0ed74fac9fc0adb4dbc31c41c2 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv2d_transpose.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv2d_transpose.py @@ -111,7 +111,7 @@ class TestFunctionalConv2D(TestCase): "weight", self.weight.shape, dtype=self.dtype) if not self.no_bias: bias = fluid.data("bias", self.bias.shape, dtype=self.dtype) - y = F.conv_transpose2d( + y = F.conv2d_transpose( x, weight, None if self.no_bias else bias, @@ -134,7 +134,7 @@ class TestFunctionalConv2D(TestCase): x = dg.to_variable(self.input) weight = dg.to_variable(self.weight) bias = None if self.no_bias else dg.to_variable(self.bias) - y = F.conv_transpose2d( + y = F.conv2d_transpose( x, weight, bias, @@ -215,7 +215,7 @@ class TestFunctionalConv2DError(TestCase): "weight", self.weight_shape, dtype=self.dtype) if not self.no_bias: bias = fluid.data("bias", self.bias_shape, dtype=self.dtype) - y = F.conv_transpose2d( + y = F.conv2d_transpose( x, weight, None if self.no_bias else bias, diff --git a/python/paddle/fluid/tests/unittests/test_functional_conv3d_transpose.py b/python/paddle/fluid/tests/unittests/test_functional_conv3d_transpose.py index 7441f7cb915e8b1fdd2155fff79e145fb6a00c0f..910d28515b7787c4187c2a83039f19f356069bd8 100644 --- a/python/paddle/fluid/tests/unittests/test_functional_conv3d_transpose.py +++ b/python/paddle/fluid/tests/unittests/test_functional_conv3d_transpose.py @@ -113,7 +113,7 @@ class TestFunctionalConv3DTranspose(TestCase): "weight", self.weight.shape, dtype=self.dtype) if not self.no_bias: bias = fluid.data("bias", self.bias.shape, dtype=self.dtype) - y = F.conv_transpose3d( + y = F.conv3d_transpose( x, weight, None if self.no_bias else bias, @@ -138,7 +138,7 @@ class TestFunctionalConv3DTranspose(TestCase): x = dg.to_variable(self.input) weight = dg.to_variable(self.weight) bias = None if self.no_bias else dg.to_variable(self.bias) - y = F.conv_transpose3d( + y = F.conv3d_transpose( x, weight, bias, @@ -222,7 +222,7 @@ class TestFunctionalConv3DTransposeError(TestCase): "weight", self.weight_shape, dtype=self.dtype) if not self.no_bias: bias = fluid.data("bias", self.bias_shape, dtype=self.dtype) - y = F.conv_transpose3d( + y = F.conv3d_transpose( x, weight, None if self.no_bias else bias, diff --git a/python/paddle/fluid/tests/unittests/test_pool1d_api.py b/python/paddle/fluid/tests/unittests/test_pool1d_api.py index cc2490d1f1245389c370c0326d81def0ddd7198e..00f75337baafb78bc9a443152ec9da7cd721e2cd 100644 --- a/python/paddle/fluid/tests/unittests/test_pool1d_api.py +++ b/python/paddle/fluid/tests/unittests/test_pool1d_api.py @@ -148,11 +148,7 @@ class TestPool1D_API(unittest.TestCase): input_np = np.random.random([2, 3, 32]).astype("float32") input = fluid.dygraph.to_variable(input_np) result = F.avg_pool1d( - input, - kernel_size=2, - stride=2, - padding=[1], - count_include_pad=True) + input, kernel_size=2, stride=2, padding=[1], exclusive=True) result_np = avg_pool1D_forward_naive( input_np, ksize=[2], strides=[2], paddings=[1], exclusive=False) @@ -160,7 +156,8 @@ class TestPool1D_API(unittest.TestCase): self.assertTrue(np.allclose(result.numpy(), result_np)) avg_pool1d_dg = paddle.nn.AvgPool1D( - kernel_size=2, stride=None, padding=1, count_include_pad=True) + kernel_size=2, stride=None, padding=1, exclusive=True) + result = avg_pool1d_dg(input) self.assertTrue(np.allclose(result.numpy(), result_np)) @@ -200,7 +197,7 @@ class TestPool1D_API(unittest.TestCase): input_np = np.random.random([2, 3, 32]).astype("float32") input = fluid.dygraph.to_variable(input_np) result, index = F.max_pool1d( - input, kernel_size=2, stride=2, padding=0, return_indices=True) + input, kernel_size=2, stride=2, padding=0, return_mask=True) result_np = max_pool1D_forward_naive( input_np, ksize=[2], strides=[2], paddings=[0]) diff --git a/python/paddle/fluid/tests/unittests/test_pool2d_api.py b/python/paddle/fluid/tests/unittests/test_pool2d_api.py index 66505327c2df3d7249b677be6d64b02b1352360a..f4432bf33864707045513f2122f02bdd9a6394b7 100644 --- a/python/paddle/fluid/tests/unittests/test_pool2d_api.py +++ b/python/paddle/fluid/tests/unittests/test_pool2d_api.py @@ -134,7 +134,7 @@ class TestPool2D_API(unittest.TestCase): input_np = np.random.random([2, 3, 32, 32]).astype("float32") input = fluid.dygraph.to_variable(input_np) result = max_pool2d( - input, kernel_size=2, stride=2, padding=0, return_indices=False) + input, kernel_size=2, stride=2, padding=0, return_mask=False) result_np = pool2D_forward_naive( input_np, @@ -159,7 +159,7 @@ class TestPool2D_API(unittest.TestCase): kernel_size=2, stride=2, padding=0, - return_indices=False, + return_mask=False, data_format="NHWC") result_np = pool2D_forward_naive( @@ -222,7 +222,7 @@ class TestPool2D_API(unittest.TestCase): kernel_size=2, stride=None, padding="SAME", - return_indices=True) + return_mask=True) result_np = pool2D_forward_naive( input_np, @@ -269,7 +269,7 @@ class TestPool2D_API(unittest.TestCase): kernel_size=2, stride=2, padding=padding, - return_indices=False) + return_mask=False) result_np = pool2D_forward_naive( input_np, @@ -490,7 +490,7 @@ class TestPool2DError_API(unittest.TestCase): padding=0, ceil_mode=False, data_format='NHWC', - return_indices=True) + return_mask=True) self.assertRaises(ValueError, run9) diff --git a/python/paddle/fluid/tests/unittests/test_pool3d_api.py b/python/paddle/fluid/tests/unittests/test_pool3d_api.py index b2700303ee477d33d93c413fcb82f71fd452230f..91158fe674b1e8686adaf50c05fc30bf1578833e 100644 --- a/python/paddle/fluid/tests/unittests/test_pool3d_api.py +++ b/python/paddle/fluid/tests/unittests/test_pool3d_api.py @@ -83,7 +83,7 @@ class TestPool3D_API(unittest.TestCase): stride=2, padding=1, ceil_mode=False, - count_include_pad=True) + exclusive=True) result_np = avg_pool3D_forward_naive( input_np, @@ -100,7 +100,7 @@ class TestPool3D_API(unittest.TestCase): stride=None, padding=1, ceil_mode=False, - count_include_pad=True) + exclusive=True) result = avg_pool3d_dg(input) self.assertTrue(np.allclose(result.numpy(), result_np)) @@ -175,7 +175,7 @@ class TestPool3D_API(unittest.TestCase): stride=2, padding=0, data_format="NDHWC", - return_indices=False) + return_mask=False) result_np = pool3D_forward_naive( input_np, @@ -239,7 +239,7 @@ class TestPool3D_API(unittest.TestCase): kernel_size=2, stride=None, padding="SAME", - return_indices=True) + return_mask=True) result_np = pool3D_forward_naive( input_np, @@ -467,7 +467,7 @@ class TestPool3DError_API(unittest.TestCase): stride=2, padding=0, data_format='NDHWC', - return_indices=True) + return_mask=True) self.assertRaises(ValueError, run10) diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 80b201d0842183750361d5e08bab5f78f40a858b..e549a2eca2d7d046de0ea6d03fa7855f459a0c78 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -767,5 +767,117 @@ class API_TestSumOp(unittest.TestCase): self.assertTrue((out3 == np.sum(np_x, axis=(0, 1, 2))).all()) +class TestAllAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + paddle.enable_static() + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data(name="input", shape=[4, 4], dtype="bool") + result = paddle.all(x=input) + input_np = np.random.randint(0, 2, [4, 4]).astype("bool") + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result]) + self.assertTrue(np.allclose(fetches[0], np.all(input_np))) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + paddle.disable_static() + for place in self.places: + with fluid.dygraph.guard(place): + np_x = np.random.randint(0, 2, (12, 10)).astype(np.bool) + x = fluid.layers.assign(np_x) + x = fluid.layers.cast(x, 'bool') + + out1 = paddle.all(x) + np_out1 = out1.numpy() + expect_res1 = np.all(np_x) + self.assertTrue((np_out1 == expect_res1).all()) + + out2 = paddle.all(x, axis=0) + np_out2 = out2.numpy() + expect_res2 = np.all(np_x, axis=0) + self.assertTrue((np_out2 == expect_res2).all()) + + out3 = paddle.all(x, axis=-1) + np_out3 = out3.numpy() + expect_res3 = np.all(np_x, axis=-1) + self.assertTrue((np_out3 == expect_res3).all()) + + out4 = paddle.all(x, axis=1, keepdim=True) + np_out4 = out4.numpy() + expect_res4 = np.all(np_x, axis=1, keepdims=True) + self.assertTrue((np_out4 == expect_res4).all()) + + paddle.enable_static() + + +class TestAnyAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + paddle.enable_static() + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data(name="input", shape=[4, 4], dtype="bool") + result = paddle.any(x=input) + input_np = np.random.randint(0, 2, [4, 4]).astype("bool") + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result]) + self.assertTrue(np.allclose(fetches[0], np.any(input_np))) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + paddle.disable_static() + for place in self.places: + with fluid.dygraph.guard(place): + np_x = np.random.randint(0, 2, (12, 10)).astype(np.bool) + x = fluid.layers.assign(np_x) + x = fluid.layers.cast(x, 'bool') + + out1 = paddle.any(x) + np_out1 = out1.numpy() + expect_res1 = np.any(np_x) + self.assertTrue((np_out1 == expect_res1).all()) + + out2 = paddle.any(x, axis=0) + np_out2 = out2.numpy() + expect_res2 = np.any(np_x, axis=0) + self.assertTrue((np_out2 == expect_res2).all()) + + out3 = paddle.any(x, axis=-1) + np_out3 = out3.numpy() + expect_res3 = np.any(np_x, axis=-1) + self.assertTrue((np_out3 == expect_res3).all()) + + out4 = paddle.any(x, axis=1, keepdim=True) + np_out4 = out4.numpy() + expect_res4 = np.any(np_x, axis=1, keepdims=True) + self.assertTrue((np_out4 == expect_res4).all()) + + paddle.enable_static() + + if __name__ == '__main__': + import paddle + paddle.enable_static() unittest.main() diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 5f9307845ae9d630e6506bc264f643165cdffea1..07e8b1f4d6d0fd855e0df298141435f415f2b000 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -73,12 +73,12 @@ from .common import interpolate #DEFINE_ALIAS from .common import upsample #DEFINE_ALIAS from .common import bilinear #DEFINE_ALIAS from .conv import conv1d #DEFINE_ALIAS -from .conv import conv_transpose1d #DEFINE_ALIAS +from .conv import conv1d_transpose #DEFINE_ALIAS from .common import linear #DEFINE_ALIAS from .conv import conv2d #DEFINE_ALIAS -from .conv import conv_transpose2d #DEFINE_ALIAS +from .conv import conv2d_transpose #DEFINE_ALIAS from .conv import conv3d #DEFINE_ALIAS -from .conv import conv_transpose3d #DEFINE_ALIAS +from .conv import conv3d_transpose #DEFINE_ALIAS # from .extension import add_position_encoding #DEFINE_ALIAS # from .extension import autoincreased_step_counter #DEFINE_ALIAS # from .extension import continuous_value_model #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/conv.py b/python/paddle/nn/functional/conv.py index 03dd40fb140cfc17e8bbea85f159feeaef2933a5..6df1ce368c1b0b2936ac87838b022d104a8c6eea 100644 --- a/python/paddle/nn/functional/conv.py +++ b/python/paddle/nn/functional/conv.py @@ -15,11 +15,11 @@ from __future__ import print_function __all__ = [ 'conv1d', - 'conv_transpose1d', + 'conv1d_transpose', 'conv2d', - 'conv_transpose2d', + 'conv2d_transpose', 'conv3d', - 'conv_transpose3d', + 'conv3d_transpose', ] import numpy as np @@ -541,7 +541,7 @@ def conv2d(x, return out -def conv_transpose1d(x, +def conv1d_transpose(x, weight, bias=None, stride=1, @@ -682,7 +682,7 @@ def conv_transpose1d(x, [[4, 2]]]).astype(np.float32) x_var = paddle.to_tensor(x) w_var = paddle.to_tensor(w) - y_var = F.conv_transpose1d(x_var, w_var) + y_var = F.conv1d_transpose(x_var, w_var) y_np = y_var.numpy() print y_np @@ -802,7 +802,7 @@ def conv_transpose1d(x, return out -def conv_transpose2d(x, +def conv2d_transpose(x, weight, bias=None, stride=1, @@ -920,7 +920,7 @@ def conv_transpose2d(x, None by default. Returns: - A Tensor representing the conv_transpose2d, whose + A Tensor representing the conv2d_transpose, whose data type is the same with input and shape is (num_batches, channels, out_h, out_w) or (num_batches, out_h, out_w, channels). The tensor variable storing transposed convolution result. @@ -946,7 +946,7 @@ def conv_transpose2d(x, x_var = paddle.randn((2, 3, 8, 8), dtype='float32') w_var = paddle.randn((3, 6, 3, 3), dtype='float32') - y_var = F.conv_transpose2d(x_var, w_var) + y_var = F.conv2d_transpose(x_var, w_var) y_np = y_var.numpy() print(y_np.shape) @@ -1242,7 +1242,7 @@ def conv3d(x, return out -def conv_transpose3d(x, +def conv3d_transpose(x, weight, bias=None, stride=1, @@ -1364,7 +1364,7 @@ def conv_transpose3d(x, None by default. Returns: - A Tensor representing the conv_transpose3d, whose data + A Tensor representing the conv3d_transpose, whose data type is the same with input and shape is (num_batches, channels, out_d, out_h, out_w) or (num_batches, out_d, out_h, out_w, channels). If act is None, the tensor variable storing the transposed convolution result, and if act is not None, the tensor @@ -1391,7 +1391,7 @@ def conv_transpose3d(x, x_var = paddle.randn((2, 3, 8, 8, 8), dtype='float32') w_var = paddle.randn((3, 6, 3, 3, 3), dtype='float32') - y_var = F.conv_transpose3d(x_var, w_var) + y_var = F.conv3d_transpose(x_var, w_var) y_np = y_var.numpy() print(y_np.shape) diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 73652ff1266f5234546e2ae2de0460bc32113064..73e3cb31221f131cf5866177a96a3fcb46d8d189 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -157,7 +157,7 @@ def avg_pool1d(x, kernel_size, stride=None, padding=0, - count_include_pad=True, + exclusive=True, ceil_mode=False, name=None): """ @@ -179,7 +179,7 @@ def avg_pool1d(x, 4. A list[int] or tuple(int) whose length is 2. It has the form [pad_before, pad_after]. 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. - count_include_pad (bool): Whether to exclude padding points in average pooling + exclusive (bool): Whether to exclude padding points in average pooling mode, default is `True`. ceil_mode (bool): ${ceil_mode_comment}Whether to use the ceil function to calculate output height and width. If it is set to False, the floor function will be used. The default value is False. @@ -230,8 +230,8 @@ def avg_pool1d(x, x, 'pooling_type', 'avg', 'ksize', kernel_size, 'global_pooling', False, 'strides', stride, 'paddings', padding, 'padding_algorithm', padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode, - 'use_mkldnn', False, 'exclusive', not count_include_pad, - 'data_format', data_format) + 'use_mkldnn', False, 'exclusive', not exclusive, 'data_format', + data_format) return squeeze(output, [2]) op_type = 'pool2d' @@ -253,7 +253,7 @@ def avg_pool1d(x, "use_cudnn": True, "ceil_mode": ceil_mode, "use_mkldnn": False, - "exclusive": not count_include_pad, + "exclusive": not exclusive, "data_format": data_format, }) @@ -265,7 +265,7 @@ def avg_pool2d(x, stride=None, padding=0, ceil_mode=False, - count_include_pad=True, + exclusive=True, divisor_override=None, data_format="NCHW", name=None): @@ -294,7 +294,7 @@ def avg_pool2d(x, 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape - count_include_pad (bool): Whether to exclude padding points in average pooling + exclusive (bool): Whether to exclude padding points in average pooling mode, default is `true`. divisor_override (float): if specified, it will be used as divisor, otherwise kernel_size will be used. Default None. data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`. @@ -338,8 +338,8 @@ def avg_pool2d(x, x, 'pooling_type', 'avg', 'ksize', kernel_size, 'global_pooling', False, 'padding_algorithm', padding_algorithm, 'strides', stride, 'paddings', padding, 'use_cudnn', True, 'ceil_mode', ceil_mode, - 'use_mkldnn', False, 'exclusive', not count_include_pad, - 'data_format', data_format) + 'use_mkldnn', False, 'exclusive', not exclusive, 'data_format', + data_format) if divisor_override is None: return output else: @@ -365,7 +365,7 @@ def avg_pool2d(x, "use_cudnn": True, "ceil_mode": ceil_mode, "use_mkldnn": False, - "exclusive": not count_include_pad, + "exclusive": not exclusive, "data_format": data_format, }) @@ -381,7 +381,7 @@ def avg_pool3d(x, stride=None, padding=0, ceil_mode=False, - count_include_pad=True, + exclusive=True, divisor_override=None, data_format="NCDHW", name=None): @@ -408,7 +408,7 @@ def avg_pool3d(x, 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. ceil_mode (bool): ${ceil_mode_comment} - count_include_pad (bool): Whether to exclude padding points in average pooling + exclusive (bool): Whether to exclude padding points in average pooling mode, default is True. divisor_override (int|float) if specified, it will be used as divisor, otherwise kernel_size will be used. Default None. data_format (string): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`. @@ -452,8 +452,8 @@ def avg_pool3d(x, x, 'pooling_type', 'avg', 'ksize', kernel_size, 'strides', stride, 'paddings', padding, 'global_pooling', False, 'padding_algorithm', padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode, - 'use_mkldnn', False, 'exclusive', not count_include_pad, - 'data_format', data_format) + 'use_mkldnn', False, 'exclusive', not exclusive, 'data_format', + data_format) if divisor_override is None: return output else: @@ -481,7 +481,7 @@ def avg_pool3d(x, "use_cudnn": True, "ceil_mode": ceil_mode, "use_mkldnn": False, - "exclusive": not count_include_pad, + "exclusive": not exclusive, "data_format": data_format, }) @@ -497,7 +497,7 @@ def max_pool1d(x, kernel_size, stride=None, padding=0, - return_indices=False, + return_mask=False, ceil_mode=False, name=None): """ @@ -519,7 +519,7 @@ def max_pool1d(x, 4. A list[int] or tuple(int) whose length is 2. It has the form [pad_before, pad_after]. 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. - return_indices (bool): Whether return the max indices along with the outputs. default is `False`. + return_mask (bool): Whether return the max indices along with the outputs. default is `False`. ceil_mode (bool): Whether to use the ceil function to calculate output height and width. False is the default. If it is set to False, the floor function will be used. Default False. name(str, optional): For detailed information, please refer @@ -542,7 +542,7 @@ def max_pool1d(x, data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32]).astype(np.float32)) pool_out = F.max_pool1d(data, kernel_size=2, stride=2, padding=0) # pool_out shape: [1, 3, 16] - pool_out, indices = F.max_pool1d(data, kernel_size=2, stride=2, padding=0, return_indices=True) + pool_out, indices = F.max_pool1d(data, kernel_size=2, stride=2, padding=0, return_mask=True) # pool_out shape: [1, 3, 16], indices shape: [1, 3, 16] """ """NCL to NCHW""" @@ -563,16 +563,16 @@ def max_pool1d(x, padding = _expand_low_nd_padding(padding) if in_dygraph_mode(): - if return_indices: + if return_mask: pool_out = core.ops.max_pool2d_with_index( x, 'ksize', kernel_size, 'global_pooling', False, 'strides', stride, 'paddings', padding, 'padding_algorithm', padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True, 'data_format', data_format) - return (squeeze(pool_out[0], [2]), squeeze( - pool_out[1], - [2])) if return_indices else squeeze(pool_out[0], [2]) + return (squeeze(pool_out[0], [2]), + squeeze(pool_out[1], + [2])) if return_mask else squeeze(pool_out[0], [2]) else: pool_out = core.ops.pool2d( x, 'pooling_type', 'max', 'ksize', kernel_size, @@ -582,7 +582,7 @@ def max_pool1d(x, 'data_format', data_format) return squeeze(pool_out, [2]) - op_type = 'max_pool2d_with_index' if return_indices else "pool2d" + op_type = 'max_pool2d_with_index' if return_mask else "pool2d" helper = LayerHelper(op_type, **locals()) dtype = helper.input_dtype() pool_out = helper.create_variable_for_type_inference(dtype) @@ -608,14 +608,14 @@ def max_pool1d(x, }) return (squeeze(pool_out, [2]), - squeeze(mask, [2])) if return_indices else squeeze(pool_out, [2]) + squeeze(mask, [2])) if return_mask else squeeze(pool_out, [2]) def max_pool2d(x, kernel_size, stride=None, padding=0, - return_indices=False, + return_mask=False, ceil_mode=False, data_format="NCHW", name=None): @@ -643,7 +643,7 @@ def max_pool2d(x, 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape - return_indices (bool): Whether to return the max indices along with the outputs. Default False, only support `"NCHW"` data format + return_mask (bool): Whether to return the max indices along with the outputs. Default False, only support `"NCHW"` data format data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NHWC"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: `[batch_size, input_channels, input_height, input_width]`. @@ -668,12 +668,12 @@ def max_pool2d(x, kernel_size=2, stride=2, padding=0) # output.shape [1, 3, 16, 16] - # for return_indices=True + # for return_mask=True out, max_indices = F.max_pool2d(x, kernel_size=2, stride=2, padding=0, - return_indices=True) + return_mask=True) # out.shape [1, 3, 16, 16], max_indices.shape [1, 3, 16, 16], """ check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool2d') @@ -693,20 +693,20 @@ def max_pool2d(x, padding, padding_algorithm = _update_padding_nd( padding, num_dims=2, channel_last=channel_last, ceil_mode=ceil_mode) - if data_format == "NHWC" and return_indices: + if data_format == "NHWC" and return_mask: raise ValueError( - "When setting return_indices to true, data_format must be set to NCHW in API:max_pool2d" + "When setting return_mask to true, data_format must be set to NCHW in API:max_pool2d" ) if in_dygraph_mode(): - if return_indices: + if return_mask: output = core.ops.max_pool2d_with_index( x, 'ksize', kernel_size, 'global_pooling', False, 'strides', stride, 'paddings', padding, 'padding_algorithm', padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True, 'data_format', data_format) - return output if return_indices else output[0] + return output if return_mask else output[0] else: output = core.ops.pool2d( x, 'pooling_type', 'max', 'ksize', kernel_size, @@ -716,7 +716,7 @@ def max_pool2d(x, 'data_format', data_format) return output - op_type = 'max_pool2d_with_index' if return_indices else "pool2d" + op_type = 'max_pool2d_with_index' if return_mask else "pool2d" helper = LayerHelper(op_type, **locals()) dtype = helper.input_dtype() pool_out = helper.create_variable_for_type_inference(dtype) @@ -741,14 +741,14 @@ def max_pool2d(x, "data_format": data_format, }) - return (pool_out, mask) if return_indices else pool_out + return (pool_out, mask) if return_mask else pool_out def max_pool3d(x, kernel_size, stride=None, padding=0, - return_indices=False, + return_mask=False, ceil_mode=False, data_format="NCDHW", name=None): @@ -773,7 +773,7 @@ def max_pool3d(x, 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. ceil_mode (bool): ${ceil_mode_comment} - return_indices (bool): Whether to return the max indices along with the outputs. Default False. Only support "NDCHW" data_format. + return_mask (bool): Whether to return the max indices along with the outputs. Default False. Only support "NDCHW" data_format. data_format (string): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`. The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`. @@ -798,13 +798,13 @@ def max_pool3d(x, kernel_size=2, stride=2, padding=0) output.shape [1, 3, 16, 16, 16] - # for return_indices=True + # for return_mask=True x = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32, 32, 32]).astype(np.float32)) output, max_indices = paddle.nn.functional.max_pool3d(x, kernel_size = 2, stride = 2, padding=0, - return_indices=True) + return_mask=True) # output.shape [None, 3, 16, 16, 16], max_indices.shape [None, 3, 16, 16, 16], """ check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool3d') @@ -819,20 +819,20 @@ def max_pool3d(x, padding, padding_algorithm = _update_padding_nd( padding, 3, channel_last=channel_last, ceil_mode=ceil_mode) - if data_format == "NDHWC" and return_indices: + if data_format == "NDHWC" and return_mask: raise ValueError( - "When setting return_indices to true, data_format must be set to NCDHW in API:max_pool3d" + "When setting return_mask to true, data_format must be set to NCDHW in API:max_pool3d" ) if in_dygraph_mode(): - if return_indices: + if return_mask: output = core.ops.max_pool3d_with_index( x, 'pooling_type', 'max', 'ksize', kernel_size, 'strides', stride, 'paddings', padding, 'global_pooling', False, 'padding_algorithm', padding_algorithm, 'use_cudnn', True, 'ceil_mode', ceil_mode, 'use_mkldnn', False, 'exclusive', True, 'data_format', data_format) - return output if return_indices else output[0] + return output if return_mask else output[0] else: output = core.ops.pool3d( x, 'pooling_type', 'max', 'ksize', kernel_size, @@ -842,7 +842,7 @@ def max_pool3d(x, 'data_format', data_format) return output - op_type = "max_pool3d_with_index" if return_indices else "pool3d" + op_type = "max_pool3d_with_index" if return_mask else "pool3d" helper = LayerHelper(op_type, **locals()) dtype = helper.input_dtype() pool_out = helper.create_variable_for_type_inference(dtype) @@ -867,7 +867,7 @@ def max_pool3d(x, "data_format": data_format, }) - return (pool_out, mask) if return_indices else pool_out + return (pool_out, mask) if return_mask else pool_out def adaptive_avg_pool1d(x, output_size, name=None): @@ -1148,7 +1148,7 @@ def adaptive_avg_pool3d(x, output_size, data_format='NCDHW', name=None): return pool_out -def adaptive_max_pool1d(x, output_size, return_indices=False, name=None): +def adaptive_max_pool1d(x, output_size, return_mask=False, name=None): """ This API implements adaptive max pooling 1d operation. See more details in :ref:`api_nn_pooling_AdaptiveMaxPool1d` . @@ -1159,7 +1159,7 @@ def adaptive_max_pool1d(x, output_size, return_indices=False, name=None): where N is batch size, C is the number of channels, L is the length of the feature. The data type is float32 or float64. output_size (int): The pool kernel size. The value should be an integer. - return_indices (bool): If true, the index of max pooling point will be returned along + return_mask (bool): If true, the index of max pooling point will be returned along with outputs. It cannot be set in average pooling type. Default False. name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and @@ -1190,7 +1190,7 @@ def adaptive_max_pool1d(x, output_size, return_indices=False, name=None): data = paddle.to_tensor(np.random.uniform(-1, 1, [1, 3, 32]).astype(np.float32)) pool_out = F.adaptive_max_pool1d(data, output_size=16) # pool_out shape: [1, 3, 16]) - pool_out, indices = F.adaptive_max_pool1d(data, output_size=16, return_indices=True) + pool_out, indices = F.adaptive_max_pool1d(data, output_size=16, return_mask=True) # pool_out shape: [1, 3, 16] indices shape: [1, 3, 16] """ pool_type = 'max' @@ -1198,7 +1198,7 @@ def adaptive_max_pool1d(x, output_size, return_indices=False, name=None): 'adaptive_max_pool1d') _check_input(x, 3) check_type(output_size, 'pool_size', int, 'adaptive_max_pool1d') - check_type(return_indices, 'return_indices', bool, 'adaptive_max_pool1d') + check_type(return_mask, 'return_mask', bool, 'adaptive_max_pool1d') pool_size = [1] + utils.convert_to_list(output_size, 1, 'pool_size') @@ -1209,7 +1209,7 @@ def adaptive_max_pool1d(x, output_size, return_indices=False, name=None): pool_out = core.ops.max_pool2d_with_index( x, 'pooling_type', pool_type, 'ksize', pool_size, 'adaptive', True) return (squeeze(pool_out[0], [2]), squeeze( - pool_out[1], [2])) if return_indices else squeeze(pool_out[0], [2]) + pool_out[1], [2])) if return_mask else squeeze(pool_out[0], [2]) helper = LayerHelper(l_type, **locals()) dtype = helper.input_dtype() @@ -1229,10 +1229,10 @@ def adaptive_max_pool1d(x, output_size, return_indices=False, name=None): }) return (squeeze(pool_out, [2]), - squeeze(mask, [2])) if return_indices else squeeze(pool_out, [2]) + squeeze(mask, [2])) if return_mask else squeeze(pool_out, [2]) -def adaptive_max_pool2d(x, output_size, return_indices=False, name=None): +def adaptive_max_pool2d(x, output_size, return_mask=False, name=None): """ This operation applies a 2D adaptive max pooling on input tensor. See more details in :ref:`api_nn_pooling_AdaptiveMaxPool2d` . @@ -1240,7 +1240,7 @@ def adaptive_max_pool2d(x, output_size, return_indices=False, name=None): Args: x (Tensor): The input tensor of adaptive max pool2d operator, which is a 4-D tensor. The data type can be float16, float32, float64, int32 or int64. output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, it must contain two elements, (H, W). H and W can be either a int, or None which means the size will be the same as that of the input. - return_indices (bool): If true, the index of max pooling point will be returned along with outputs. Default False. + return_mask (bool): If true, the index of max pooling point will be returned along with outputs. Default False. name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. Returns: @@ -1280,7 +1280,7 @@ def adaptive_max_pool2d(x, output_size, return_indices=False, name=None): 'adaptive_max_pool2d') _check_input(x, 4) #check_type(output_size, 'pool_size', (int), 'adaptive_max_pool2d') - check_type(return_indices, 'return_indices', bool, 'adaptive_max_pool2d') + check_type(return_mask, 'return_mask', bool, 'adaptive_max_pool2d') in_h, in_w = x.shape[2:4] if isinstance(output_size, int): @@ -1295,7 +1295,7 @@ def adaptive_max_pool2d(x, output_size, return_indices=False, name=None): if in_dygraph_mode(): pool_out = core.ops.max_pool2d_with_index( x, 'pooling_type', 'max', 'ksize', output_size, 'adaptive', True) - return pool_out if return_indices else pool_out[0] + return pool_out if return_mask else pool_out[0] l_type = 'max_pool2d_with_index' @@ -1315,11 +1315,11 @@ def adaptive_max_pool2d(x, output_size, return_indices=False, name=None): "ksize": output_size, "adaptive": True, }) - #return (pool_out, mask) if return_indices else pool_out + #return (pool_out, mask) if return_mask else pool_out return pool_out -def adaptive_max_pool3d(x, output_size, return_indices=False, name=None): +def adaptive_max_pool3d(x, output_size, return_mask=False, name=None): """ This operation applies a 3D adaptive max pooling on input tensor. See more details in :ref:`api_nn_pooling_AdaptiveMaxPool3d` . @@ -1327,7 +1327,7 @@ def adaptive_max_pool3d(x, output_size, return_indices=False, name=None): Args: x (Tensor): The input tensor of adaptive max pool3d operator, which is a 5-D tensor. The data type can be float32, float64. output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, it must contain three elements, (D, H, W). D, H and W can be either a int, or None which means the size will be the same as that of the input. - return_indices (bool): If true, the index of max pooling point will be returned along with outputs. Default False. + return_mask (bool): If true, the index of max pooling point will be returned along with outputs. Default False. name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. Returns: @@ -1371,7 +1371,7 @@ def adaptive_max_pool3d(x, output_size, return_indices=False, name=None): 'adaptive_max_pool3d') _check_input(x, 5) #check_type(output_size, 'pool_size', (int), 'adaptive_max_pool3d') - check_type(return_indices, 'return_indices', bool, 'adaptive_max_pool3d') + check_type(return_mask, 'return_mask', bool, 'adaptive_max_pool3d') in_l, in_h, in_w = x.shape[2:5] if isinstance(output_size, int): @@ -1388,7 +1388,7 @@ def adaptive_max_pool3d(x, output_size, return_indices=False, name=None): if in_dygraph_mode(): pool_out = core.ops.max_pool3d_with_index( x, 'pooling_type', 'max', 'ksize', output_size, 'adaptive', True) - return pool_out if return_indices else pool_out[0] + return pool_out if return_mask else pool_out[0] l_type = 'max_pool3d_with_index' @@ -1409,4 +1409,4 @@ def adaptive_max_pool3d(x, output_size, return_indices=False, name=None): "adaptive": True, }) - return (pool_out, mask) if return_indices else pool_out + return (pool_out, mask) if return_mask else pool_out diff --git a/python/paddle/nn/layer/conv.py b/python/paddle/nn/layer/conv.py index 51c466d113f027693d3e27a2463d8b5d69845ddd..f97e549464738234b6f6bce529557750deb9fc2c 100644 --- a/python/paddle/nn/layer/conv.py +++ b/python/paddle/nn/layer/conv.py @@ -427,7 +427,7 @@ class Conv1DTranspose(_ConvNd): data_format=data_format) def forward(self, x, output_size=None): - out = F.conv_transpose1d( + out = F.conv1d_transpose( x, self.weight, bias=self.bias, @@ -748,7 +748,7 @@ class Conv2DTranspose(_ConvNd): else: output_padding = 0 - out = F.conv_transpose2d( + out = F.conv2d_transpose( x, self.weight, bias=self.bias, @@ -954,16 +954,16 @@ class Conv3DTranspose(_ConvNd): **Note**: - The conv_transpose3d can be seen as the backward of the conv3d. For conv3d, + The conv3d_transpose can be seen as the backward of the conv3d. For conv3d, when stride > 1, conv3d maps multiple input shape to the same output shape, - so for conv_transpose3d, when stride > 1, input shape maps multiple output shape. + so for conv3d_transpose, when stride > 1, input shape maps multiple output shape. If output_size is None, :math:`H_{out} = H^\prime_{out}, :math:`H_{out} = \ H^\prime_{out}, W_{out} = W^\prime_{out}`; else, the :math:`D_{out}` of the output size must between :math:`D^\prime_{out}` and :math:`D^\prime_{out} + strides[0]`, the :math:`H_{out}` of the output size must between :math:`H^\prime_{out}` and :math:`H^\prime_{out} + strides[1]`, and the :math:`W_{out}` of the output size must between :math:`W^\prime_{out}` and :math:`W^\prime_{out} + strides[2]`, - conv_transpose3d can compute the kernel size automatically. + conv3d_transpose can compute the kernel size automatically. Parameters: in_channels(int): The number of channels in the input image. @@ -1086,7 +1086,7 @@ class Conv3DTranspose(_ConvNd): else: output_padding = 0 - out = F.conv_transpose3d( + out = F.conv3d_transpose( x, self.weight, bias=self.bias, diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index a996844c8f5a86df5468bd29b55472002356afc0..5e2292d40d2bfbd83d6fe37f1b4ea03c82397c31 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -73,7 +73,6 @@ class _InstanceNormBase(layers.Layer): momentum=0.9, weight_attr=None, bias_attr=None, - track_running_stats=False, data_format="NCHW", name=None): super(_InstanceNormBase, self).__init__() @@ -135,9 +134,6 @@ class InstanceNorm1D(_InstanceNormBase): epsilon(float, optional): A value added to the denominator for numerical stability. Default is 1e-5. momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. - track_running_stats(bool, optional): Whether to use global mean and - variance. In train mode, when setting track_running_stats True, the global mean - and variance are also used during train period. Default: False. weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm will create ParamAttr as weight_attr, the name of scale can be set in ParamAttr. @@ -159,9 +155,6 @@ class InstanceNorm1D(_InstanceNormBase): Returns: None. - **Note**: - Momentum and track_running_stats is not effective. The next version will fix the problem . - Examples: @@ -214,9 +207,6 @@ class InstanceNorm2D(_InstanceNormBase): epsilon(float, optional): A value added to the denominator for numerical stability. Default is 1e-5. momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. - track_running_stats(bool, optional): Whether to use global mean and - variance. In train mode, when setting track_running_stats True, the global mean - and variance are also used during train period. Default: False. weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm will create ParamAttr as weight_attr, the name of scale can be set in ParamAttr. @@ -237,8 +227,6 @@ class InstanceNorm2D(_InstanceNormBase): Returns: None. - **Note**: - Momentum and track_running_stats is not effective. The next version will fix the problem . Examples: @@ -290,9 +278,6 @@ class InstanceNorm3D(_InstanceNormBase): epsilon(float, optional): A value added to the denominator for numerical stability. Default is 1e-5. momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. - track_running_stats(bool, optional): Whether to use global mean and - variance. In train mode, when setting track_running_stats True, the global mean - and variance are also used during train period. Default: False. weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm will create ParamAttr as weight_attr, the name of scale can be set in ParamAttr. @@ -313,8 +298,6 @@ class InstanceNorm3D(_InstanceNormBase): Returns: None. - **Note**: - Momentum and track_running_stats is not effective. The next version will fix the problem . Examples: @@ -570,7 +553,6 @@ class _BatchNormBase(layers.Layer): weight_attr=None, bias_attr=None, data_format='NCHW', - track_running_stats=True, name=None): super(_BatchNormBase, self).__init__() self._num_features = num_features @@ -636,7 +618,6 @@ class _BatchNormBase(layers.Layer): self._momentum = momentum self._epsilon = epsilon self._fuse_with_relu = False - self._track_running_stats = track_running_stats self._name = name def _check_input_dim(self, input): @@ -651,11 +632,7 @@ class _BatchNormBase(layers.Layer): self._check_input_dim(input) - if not self.training and not self._track_running_stats: - raise ValueError( - 'When inference, expected track_running_stats is True.') - - if self.training and not self._track_running_stats: + if self.training: warnings.warn( "When training, we now always track global mean and variance.") @@ -720,9 +697,6 @@ class BatchNorm1D(_BatchNormBase): will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. data_format(str, optional): Specify the input data format, may be "NC", "NCL" or "NLC". Defalut "NCL". - track_running_stats(bool, optional): Whether to use global mean and variance. In train period, - True will track global mean and variance used for inference. When inference, track_running_stats must be - True. Default: True. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: @@ -732,9 +706,6 @@ class BatchNorm1D(_BatchNormBase): Returns: None. - - **Note**: - Now track_running_stats is actucal always true. The next version will fix the problem . Examples: @@ -817,9 +788,6 @@ class BatchNorm2D(_BatchNormBase): will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW. - track_running_stats(bool, optional): Whether to use global mean and variance. In train period, - True will track global mean and variance used for inference. When inference, track_running_stats must be - True. Default: True. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: @@ -830,9 +798,6 @@ class BatchNorm2D(_BatchNormBase): Returns: None - **Note**: - Now track_running_stats is actucal always true. The next version will fix the problem . - Examples: .. code-block:: python @@ -912,9 +877,6 @@ class BatchNorm3D(_BatchNormBase): will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. data_format(str, optional): Specify the input data format, the data format can be "NCDHW" or "NDHWC. Default: NCDHW. - track_running_stats(bool, optional): Whether to use global mean and variance. In train period, - True will track global mean and variance used for inference. When inference, track_running_stats must be - True. Default: True. name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. Shape: @@ -925,9 +887,6 @@ class BatchNorm3D(_BatchNormBase): Returns: None - **Note**: - Now track_running_stats is actucal always true. The next version will fix the problem . - Examples: .. code-block:: python @@ -1024,8 +983,6 @@ class SyncBatchNorm(_BatchNormBase): will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not set, the bias is initialized zero. If it is set to False, this layer will not have trainable bias parameter. Default: None. - track_running_stats(bool, optional): Whether to compute global stats, which including running mean and - running variance. Default: True. Shapes: input: Tensor that the dimension from 2 to 5. @@ -1055,11 +1012,10 @@ class SyncBatchNorm(_BatchNormBase): weight_attr=None, bias_attr=None, data_format='NCHW', - track_running_stats=True, name=None): super(SyncBatchNorm, self).__init__(num_features, momentum, epsilon, weight_attr, - bias_attr, data_format, track_running_stats, name) + bias_attr, data_format, name) def forward(self, x): # create output @@ -1147,10 +1103,10 @@ class SyncBatchNorm(_BatchNormBase): """ layer_output = layer if isinstance(layer, _BatchNormBase): - layer_output = SyncBatchNorm( - layer._num_features, layer._momentum, layer._epsilon, - layer._weight_attr, layer._bias_attr, layer._data_format, - layer._track_running_stats, layer._name) + layer_output = SyncBatchNorm(layer._num_features, layer._momentum, + layer._epsilon, layer._weight_attr, + layer._bias_attr, layer._data_format, + layer._name) if layer._weight_attr != False and layer._bias_attr != False: with no_grad(): diff --git a/python/paddle/nn/layer/pooling.py b/python/paddle/nn/layer/pooling.py index 9e544cb02e70e755eb5c0255e36267ef8ab12e6e..0b0a4909f8550efad93db8c5dc7037a99640f771 100755 --- a/python/paddle/nn/layer/pooling.py +++ b/python/paddle/nn/layer/pooling.py @@ -35,7 +35,7 @@ __all__ = [ class AvgPool1D(layers.Layer): """ This operation applies a 1D average pooling over an input signal composed - of several input planes, based on the input, output_size, return_indices parameters. + of several input planes, based on the input, output_size, return_mask parameters. Input(X) and output(Out) are in NCL format, where N is batch size, C is the number of channels, L is the length of the feature. The output tensor shape will be [N, C, output_size]. @@ -61,7 +61,7 @@ class AvgPool1D(layers.Layer): 4. A list[int] or tuple(int) whose length is 2. It has the form [pad_before, pad_after]. 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. - count_include_pad (bool): Whether to exclude padding points in average pooling + exclusive (bool): Whether to exclude padding points in average pooling mode, default is `True`. ceil_mode (bool): ${ceil_mode_comment}Whether to use the ceil function to calculate output height and width. If it is set to False, the floor function will be used. The default value is False. @@ -103,7 +103,7 @@ class AvgPool1D(layers.Layer): kernel_size, stride=None, padding=0, - count_include_pad=True, + exclusive=True, ceil_mode=False, name=None): super(AvgPool1D, self).__init__() @@ -111,12 +111,12 @@ class AvgPool1D(layers.Layer): self.stride = stride self.padding = padding self.ceil_mode = ceil_mode - self.count_include_pad = count_include_pad + self.exclusive = exclusive self.name = name def forward(self, x): out = F.avg_pool1d(x, self.kernel_size, self.stride, self.padding, - self.count_include_pad, self.ceil_mode, self.name) + self.exclusive, self.ceil_mode, self.name) return out @@ -156,7 +156,7 @@ class AvgPool2D(layers.Layer): 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape - count_include_pad (bool): Whether to exclude padding points in average pooling + exclusive (bool): Whether to exclude padding points in average pooling mode, default is `true`. divisor_override (float): if specified, it will be used as divisor, otherwise kernel_size will be used. Default None. data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`. @@ -197,7 +197,7 @@ class AvgPool2D(layers.Layer): stride=None, padding=0, ceil_mode=False, - count_include_pad=True, + exclusive=True, divisor_override=None, data_format="NCHW", name=None): @@ -206,7 +206,7 @@ class AvgPool2D(layers.Layer): self.stride = stride self.padding = padding self.ceil_mode = ceil_mode - self.count_include_pad = count_include_pad + self.exclusive = exclusive self.divisor = divisor_override self.data_format = data_format self.name = name @@ -218,7 +218,7 @@ class AvgPool2D(layers.Layer): stride=self.stride, padding=self.padding, ceil_mode=self.ceil_mode, - count_include_pad=self.count_include_pad, + exclusive=self.exclusive, divisor_override=self.divisor, data_format=self.data_format, name=self.name) @@ -247,7 +247,7 @@ class AvgPool3D(layers.Layer): 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. ceil_mode (bool): ${ceil_mode_comment} - count_include_pad (bool): Whether to exclude padding points in average pooling + exclusive (bool): Whether to exclude padding points in average pooling mode, default is True. divisor_override (int|float) if specified, it will be used as divisor, otherwise kernel_size will be used. Default None. data_format (string): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`. @@ -289,7 +289,7 @@ class AvgPool3D(layers.Layer): stride, padding=0, ceil_mode=False, - count_include_pad=True, + exclusive=True, divisor_override=None, data_format="NCDHW", name=None): @@ -298,7 +298,7 @@ class AvgPool3D(layers.Layer): self.stride = stride self.padding = padding self.ceil_mode = ceil_mode - self.count_include_pad = count_include_pad + self.exclusive = exclusive self.divisor = divisor_override self.data_format = data_format self.name = name @@ -310,7 +310,7 @@ class AvgPool3D(layers.Layer): stride=self.stride, padding=self.padding, ceil_mode=self.ceil_mode, - count_include_pad=self.count_include_pad, + exclusive=self.exclusive, divisor_override=self.divisor, data_format=self.data_format, name=self.name) @@ -319,7 +319,7 @@ class AvgPool3D(layers.Layer): class MaxPool1D(layers.Layer): """ Applies a 1D max pooling over an input signal composed of several input planes based - on the input, output_size, return_indices parameters. + on the input, output_size, return_mask parameters. Input(X) and output(Out) are in NCL format, where N is batch size, C is the number of channels, L is the length of the feature. @@ -343,7 +343,7 @@ class MaxPool1D(layers.Layer): 4. A list[int] or tuple(int) whose length is 2. It has the form [pad_before, pad_after]. 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. - return_indices (bool): Whether return the max indices along with the outputs. default is `False`. + return_mask (bool): Whether return the max indices along with the outputs. default is `False`. ceil_mode (bool): Whether to use the ceil function to calculate output height and width. False is the default. If it is set to False, the floor function will be used. Default False. name(str, optional): For detailed information, please refer @@ -377,7 +377,7 @@ class MaxPool1D(layers.Layer): pool_out = MaxPool1D(data) # pool_out shape: [1, 3, 16] - MaxPool1D = nn.MaxPool1D(kernel_size=2, stride=2, padding=0, return_indices=True) + MaxPool1D = nn.MaxPool1D(kernel_size=2, stride=2, padding=0, return_mask=True) pool_out, indices = MaxPool1D(data) # pool_out shape: [1, 3, 16], indices shape: [1, 3, 16] @@ -387,7 +387,7 @@ class MaxPool1D(layers.Layer): kernel_size, stride=None, padding=0, - return_indices=False, + return_mask=False, ceil_mode=False, name=None): super(MaxPool1D, self).__init__() @@ -395,12 +395,12 @@ class MaxPool1D(layers.Layer): self.stride = stride self.padding = padding self.ceil_mode = ceil_mode - self.return_indices = return_indices + self.return_mask = return_mask self.name = name def forward(self, input): out = F.max_pool1d(input, self.kernel_size, self.stride, self.padding, - self.return_indices, self.ceil_mode, self.name) + self.return_mask, self.ceil_mode, self.name) return out @@ -440,7 +440,7 @@ class MaxPool2D(layers.Layer): 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. ceil_mode (bool): when True, will use `ceil` instead of `floor` to compute the output shape - return_indices (bool): Whether to return the max indices along with the outputs. + return_mask (bool): Whether to return the max indices along with the outputs. data_format (string): The data format of the input and output data. An optional string from: `"NCHW"`, `"NDHW"`. The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: `[batch_size, input_channels, input_height, input_width]`. @@ -473,8 +473,8 @@ class MaxPool2D(layers.Layer): output = MaxPool2D(input) # output.shape [1, 3, 16, 16] - # for return_indices=True - MaxPool2D = nn.MaxPool2D(kernel_size=2,stride=2, padding=0, return_indices=True) + # for return_mask=True + MaxPool2D = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, return_mask=True) output, max_indices = MaxPool2D(input) # output.shape [1, 3, 16, 16], max_indices.shape [1, 3, 16, 16], """ @@ -483,7 +483,7 @@ class MaxPool2D(layers.Layer): kernel_size, stride=None, padding=0, - return_indices=False, + return_mask=False, ceil_mode=False, data_format="NCHW", name=None): @@ -491,7 +491,7 @@ class MaxPool2D(layers.Layer): self.ksize = kernel_size self.stride = stride self.padding = padding - self.return_indices = return_indices + self.return_mask = return_mask self.ceil_mode = ceil_mode self.data_format = data_format self.name = name @@ -502,7 +502,7 @@ class MaxPool2D(layers.Layer): kernel_size=self.ksize, stride=self.stride, padding=self.padding, - return_indices=self.return_indices, + return_mask=self.return_mask, data_format=self.data_format, name=self.name) @@ -530,7 +530,7 @@ class MaxPool3D(layers.Layer): 5. A list or tuple of pairs of integers. It has the form [[pad_before, pad_after], [pad_before, pad_after], ...]. Note that, the batch dimension and channel dimension should be [0,0] or (0,0). The default value is 0. ceil_mode (bool): ${ceil_mode_comment} - return_indices (bool): Whether to return the max indices along with the outputs. + return_mask (bool): Whether to return the max indices along with the outputs. data_format (string): The data format of the input and output data. An optional string from: `"NCDHW"`, `"NDHWC"`. The default is `"NCDHW"`. When it is `"NCDHW"`, the data is stored in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`. @@ -564,8 +564,8 @@ class MaxPool3D(layers.Layer): output = MaxPool3D(input) # output.shape [1, 2, 3, 16, 16] - # for return_indices=True - MaxPool3D = nn.MaxPool3D(kernel_size=2,stride=2, padding=0, return_indices=True) + # for return_mask=True + MaxPool3D = nn.MaxPool3D(kernel_size=2, stride=2, padding=0, return_mask=True) output, max_indices = MaxPool3D(input) # output.shape [1, 2, 3, 16, 16], max_indices.shape [1, 2, 3, 16, 16], """ @@ -574,7 +574,7 @@ class MaxPool3D(layers.Layer): kernel_size, stride, padding, - return_indices=False, + return_mask=False, ceil_mode=False, data_format="NCDHW", name=None): @@ -582,7 +582,7 @@ class MaxPool3D(layers.Layer): self.ksize = kernel_size self.stride = stride self.padding = padding - self.return_indices = return_indices + self.return_mask = return_mask self.ceil_mode = ceil_mode self.data_format = data_format self.name = name @@ -593,7 +593,7 @@ class MaxPool3D(layers.Layer): kernel_size=self.ksize, stride=self.stride, padding=self.padding, - return_indices=self.return_indices, + return_mask=self.return_mask, data_format=self.data_format, name=self.name) @@ -602,7 +602,7 @@ class AdaptiveAvgPool1D(layers.Layer): """ This operation applies a 1D adaptive average pooling over an input signal composed - of several input planes, based on the input, output_size, return_indices parameters. + of several input planes, based on the input, output_size, return_mask parameters. Input(X) and output(Out) are in NCL format, where N is batch size, C is the number of channels, L is the length of the feature. The output tensor shape will be [N, C, output_size]. @@ -841,7 +841,7 @@ class AdaptiveMaxPool1D(layers.Layer): """ This operation applies a 1D adaptive max pooling over an input signal composed - of several input planes, based on the input, output_size, return_indices parameters. + of several input planes, based on the input, output_size, return_mask parameters. Input(X) and output(Out) are in NCL format, where N is batch size, C is the number of channels, L is the length of the feature. The output tensor shape will be [N, C, output_size]. @@ -859,7 +859,7 @@ class AdaptiveMaxPool1D(layers.Layer): Args: output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, it must contain one int. - return_indices (bool): If true, the index of max pooling point will be returned along + return_mask (bool): If true, the index of max pooling point will be returned along with outputs. It cannot be set in average pooling type. Default False. name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and @@ -898,22 +898,22 @@ class AdaptiveMaxPool1D(layers.Layer): pool_out = AdaptiveMaxPool1D(data) # pool_out shape: [1, 3, 16] - # for return_indices = true - AdaptiveMaxPool1D = nn.AdaptiveMaxPool1D(output_size=16, return_indices=True) + # for return_mask = true + AdaptiveMaxPool1D = nn.AdaptiveMaxPool1D(output_size=16, return_mask=True) pool_out, indices = AdaptiveMaxPool1D(data) # pool_out shape: [1, 3, 16], indices shape: [1, 3, 16] """ - def __init__(self, output_size, return_indices=False, name=None): + def __init__(self, output_size, return_mask=False, name=None): super(AdaptiveMaxPool1D, self).__init__() self.output_size = output_size - self.return_indices = return_indices + self.return_mask = return_mask self.name = name def forward(self, input): - return F.adaptive_max_pool1d(input, self.output_size, - self.return_indices, self.name) + return F.adaptive_max_pool1d(input, self.output_size, self.return_mask, + self.name) class AdaptiveMaxPool2D(layers.Layer): @@ -932,7 +932,7 @@ class AdaptiveMaxPool2D(layers.Layer): Output(i ,j) &= max(Input[hstart:hend, wstart:wend]) Parameters: output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, it must contain two element, (H, W). H and W can be either a int, or None which means the size will be the same as that of the input. - return_indices (bool): If true, the index of max pooling point will be returned along with outputs. It cannot be set in average pooling type. Default False. + return_mask (bool): If true, the index of max pooling point will be returned along with outputs. It cannot be set in average pooling type. Default False. name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. @@ -965,21 +965,21 @@ class AdaptiveMaxPool2D(layers.Layer): paddle.disable_static() input_data = np.random.rand(2, 3, 32, 32) x = paddle.to_tensor(input_data) - adaptive_max_pool = paddle.nn.AdaptiveMaxPool2D(output_size=3, return_indices=True) + adaptive_max_pool = paddle.nn.AdaptiveMaxPool2D(output_size=3, return_mask=True) pool_out, indices = adaptive_max_pool(x = x) """ - def __init__(self, output_size, return_indices=False, name=None): + def __init__(self, output_size, return_mask=False, name=None): super(AdaptiveMaxPool2D, self).__init__() self._output_size = output_size - self._return_indices = return_indices + self._return_mask = return_mask self._name = name def forward(self, x): return F.adaptive_max_pool2d( x, output_size=self._output_size, - return_indices=self._return_indices, + return_mask=self._return_mask, name=self._name) @@ -1002,7 +1002,7 @@ class AdaptiveMaxPool3D(layers.Layer): Parameters: output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, it must contain three elements, (D, H, W). D, H and W can be either a int, or None which means the size will be the same as that of the input. - return_indices (bool): If true, the index of max pooling point will be returned along with outputs. Default False. + return_mask (bool): If true, the index of max pooling point will be returned along with outputs. Default False. name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. @@ -1040,21 +1040,21 @@ class AdaptiveMaxPool3D(layers.Layer): pool = paddle.nn.AdaptiveMaxPool3D(output_size=4) out = pool(x) # out shape: [2, 3, 4, 4, 4] - pool = paddle.nn.AdaptiveMaxPool3D(output_size=3, return_indices=True) + pool = paddle.nn.AdaptiveMaxPool3D(output_size=3, return_mask=True) out, indices = pool(x) # out shape: [2, 3, 4, 4, 4], indices shape: [2, 3, 4, 4, 4] """ - def __init__(self, output_size, return_indices=False, name=None): + def __init__(self, output_size, return_mask=False, name=None): super(AdaptiveMaxPool3D, self).__init__() self._output_size = output_size - self._return_indices = return_indices + self._return_mask = return_mask self._name = name def forward(self, x): return F.adaptive_max_pool3d( x, output_size=self._output_size, - return_indices=self._return_indices, + return_mask=self._return_mask, name=self._name) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 958bfb304fb149c36454163ece08f365b1021981..eaade222388fa335f0fbc9d572b22bc517cdae2c 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -66,8 +66,6 @@ from .logic import logical_not #DEFINE_ALIAS from .logic import logical_or #DEFINE_ALIAS from .logic import logical_xor #DEFINE_ALIAS from .logic import not_equal #DEFINE_ALIAS -# from .logic import reduce_all #DEFINE_ALIAS -# from .logic import reduce_any #DEFINE_ALIAS from .logic import allclose #DEFINE_ALIAS from .logic import equal_all #DEFINE_ALIAS # from .logic import isnan #DEFINE_ALIAS @@ -164,6 +162,8 @@ from .math import isfinite #DEFINE_ALIAS from .math import isinf #DEFINE_ALIAS from .math import isnan #DEFINE_ALIAS from .math import prod #DEFINE_ALIAS +from .math import all #DEFINE_ALIAS +from .math import any #DEFINE_ALIAS from .random import multinomial #DEFINE_ALIAS from .random import standard_normal from .random import normal diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 27671a4f157475e04250ba56fd378103f7d0d829..da08270d742e54de554c109987e84c847263d38d 100644 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -29,6 +29,8 @@ from ..fluid.layers import logical_and #DEFINE_ALIAS from ..fluid.layers import logical_not #DEFINE_ALIAS from ..fluid.layers import logical_or #DEFINE_ALIAS from ..fluid.layers import logical_xor #DEFINE_ALIAS +from ..fluid.layers import reduce_all #DEFINE_ALIAS +from ..fluid.layers import reduce_any #DEFINE_ALIAS __all__ = [ 'equal', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 895d0c175905cc8071c379367b4cbadb43811b06..36793e0769672250e40510b89a9813e76e73ee9e 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -21,7 +21,7 @@ from paddle.common_ops_import import * from paddle.tensor import cast import paddle from ..fluid import layers -from ..fluid.framework import core, _varbase_creator, in_dygraph_mode, Variable +from ..fluid.framework import core, _varbase_creator, in_dygraph_mode, Variable, convert_np_dtype_to_dtype_ from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype from ..fluid.layers.layer_function_generator import _generate_doc_string_, generate_activation_fn, generate_layer_fn @@ -46,6 +46,8 @@ from ..fluid.layers import exp #DEFINE_ALIAS from ..fluid.layers import floor #DEFINE_ALIAS from ..fluid.layers import log #DEFINE_ALIAS from ..fluid.layers import reciprocal #DEFINE_ALIAS +from ..fluid.layers import reduce_all #DEFINE_ALIAS +from ..fluid.layers import reduce_any #DEFINE_ALIAS # from ..fluid.layers import reduce_max #DEFINE_ALIAS # from ..fluid.layers import reduce_min #DEFINE_ALIAS # from ..fluid.layers import reduce_prod #DEFINE_ALIAS @@ -1933,3 +1935,201 @@ def increment(x, value=1.0, name=None): outputs={'Out': [x]}, attrs={'step': float(value)}) return x + + +def all(x, axis=None, keepdim=False, name=None): + """ + Computes the the ``logical and`` of tensor elements over the given dimension. + + Args: + x (Tensor): An N-D Tensor, the input data type should be `bool`. + axis (int|list|tuple, optional): The dimensions along which the ``logical and`` is compute. If + :attr:`None`, and all elements of :attr:`x` and return a + Tensor variable with a single element, otherwise must be in the + range :math:`[-rank(x), rank(x))`. If :math:`axis[i] < 0`, + the dimension to reduce is :math:`rank + axis[i]`. + keepdim (bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result Tensor will have one fewer dimension + than the :attr:`x` unless :attr:`keepdim` is true, default + value is False. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Tensor: Results the ``logical and`` on the specified axis of input Tensor `x`, it's data type is bool. + + Raises: + ValueError: If the data type of `x` is not bool. + TypeError: The type of :attr:`axis` must be int, list or tuple. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import paddle.fluid.layers as layers + import numpy as np + + # set as static mode + paddle.disable_static() + + # x is a bool Tensor variable with following elements: + # [[True, False] + # [True, True]] + x = layers.assign(np.array([[1, 0], [1, 1]], dtype='int32')) + print(x) + x = layers.cast(x, 'bool') + + # out1 should be [False] + out1 = paddle.all(x) # [False] + print(out1) + + # out2 should be [True, False] + out2 = paddle.all(x, axis=0) # [True, False] + print(out2) + + # keep_dim=False, out3 should be [False, True], out.shape should be (2,) + out3 = paddle.all(x, axis=-1) # [False, True] + print(out3) + + # keep_dim=True, out4 should be [[False], [True]], out.shape should be (2,1) + out4 = paddle.all(x, axis=1, keep_dim=True) + out4 = layers.cast(out4, 'int32') # [[False], [True]] + print(out4) + + """ + if axis is not None and not isinstance(axis, (list, tuple)): + axis = [axis] + + if not axis: + reduce_all_flag = True + else: + if len(axis) == len(x.shape): + reduce_all_flag = True + else: + reduce_all_flag = False + + attrs = { + 'dim': axis if axis != None and axis != [] and axis != () else [0], + 'keep_dim': keepdim, + 'reduce_all': reduce_all_flag + } + dtype_flag = False + + + if in_dygraph_mode(): + axis = axis if axis != None and axis != [] else [0] + return core.ops.reduce_all(x, 'dim', axis, 'keep_dim', keepdim, + 'reduce_all', reduce_all_flag) + check_variable_and_dtype(x, 'x', ['bool'], 'all') + + + check_type(axis, 'axis', (int, list, tuple, type(None)), 'all') + + helper = LayerHelper('all', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='reduce_all', + inputs={'X': x}, + outputs={'Out': out}, + attrs=attrs) + return out + + +def any(x, axis=None, keepdim=False, name=None): + """ + Computes the the ``logical or`` of tensor elements over the given dimension. + + Args: + x (Tensor): An N-D Tensor, the input data type should be `bool`. + axis (int|list|tuple, optional): The dimensions along which the ``logical or`` is compute. If + :attr:`None`, and all elements of :attr:`x` and return a + Tensor variable with a single element, otherwise must be in the + range :math:`[-rank(x), rank(x))`. If :math:`axis[i] < 0`, + the dimension to reduce is :math:`rank + axis[i]`. + keepdim (bool, optional): Whether to reserve the reduced dimension in the + output Tensor. The result Tensor will have one fewer dimension + than the :attr:`x` unless :attr:`keepdim` is true, default + value is False. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + + Returns: + Tensor: Results the ``logical or`` on the specified axis of input Tensor `x`, it's data type is bool. + + Raises: + ValueError: If the data type of `x` is not bool. + TypeError: The type of :attr:`axis` must be int, list or tuple. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import paddle.fluid.layers as layers + import numpy as np + + # set as static mode + paddle.disable_static() + + # x is a bool Tensor variable with following elements: + # [[True, False] + # [False, False]] + x = layers.assign(np.array([[1, 0], [1, 1]], dtype='int32')) + print(x) + x = layers.cast(x, 'bool') + + # out1 should be [True] + out1 = paddle.any(x) # [True] + print(out1) + + # out2 should be [True, False] + out2 = paddle.any(x, axis=0) # [True, False] + print(out2) + + # keep_dim=False, out3 should be [True, False], out.shape should be (2,) + out3 = paddle.any(x, axis=-1) # [True, False] + print(out3) + + # keep_dim=True, result should be [[True], [False]], out.shape should be (2,1) + out4 = paddle.any(x, axis=1, keep_dim=True) + out4 = layers.cast(out4, 'int32') # [[True], [False]] + print(out4) + + """ + if axis is not None and not isinstance(axis, (list, tuple)): + axis = [axis] + + if not axis: + reduce_all_flag = True + else: + if len(axis) == len(x.shape): + reduce_all_flag = True + else: + reduce_all_flag = False + + attrs = { + 'dim': axis if axis != None and axis != [] and axis != () else [0], + 'keep_dim': keepdim, + 'reduce_all': reduce_all_flag + } + dtype_flag = False + + + if in_dygraph_mode(): + axis = axis if axis != None and axis != [] else [0] + return core.ops.reduce_any(x, 'dim', axis, 'keep_dim', keepdim, + 'reduce_all', reduce_all_flag) + check_variable_and_dtype(x, 'x', ['bool'], 'any') + + + check_type(axis, 'axis', (int, list, tuple, type(None)), 'any') + + helper = LayerHelper('any', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='reduce_any', + inputs={'X': x}, + outputs={'Out': out}, + attrs=attrs) + return out