diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d23f20e1b3d4b917a7618aa36a5efe4ac734a22d..dc6c37925efaf3544e012f341505cc56d25eacb0 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4397,12 +4397,9 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): return out +@deprecated(since="2.0.0", update_to="paddle.mean") def reduce_mean(input, dim=None, keep_dim=False, name=None): """ - :alias_main: paddle.reduce_mean - :alias: paddle.reduce_mean,paddle.tensor.reduce_mean,paddle.tensor.stat.reduce_mean - :old_api: paddle.fluid.layers.reduce_mean - Computes the mean of the input tensor's elements along the given dimension. Args: @@ -4451,31 +4448,7 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_mean(y, dim=[0, 1]) # [4.0, 5.0] """ - if dim is not None and not isinstance(dim, list): - dim = [dim] - - if in_dygraph_mode(): - reduce_all = True if dim == None or dim == [] or len(dim) == len( - input.shape) else False - dim = dim if dim != None and dim != [] else [0] - return core.ops.reduce_mean(input, 'dim', dim, 'keep_dim', keep_dim, - 'reduce_all', reduce_all) - attrs = { - 'dim': dim if dim != None and dim != [] else [0], - 'keep_dim': keep_dim, - 'reduce_all': True - if dim == None or dim == [] or len(dim) == len(input.shape) else False - } - check_variable_and_dtype( - input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_mean') - helper = LayerHelper('reduce_mean', **locals()) - out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) - helper.append_op( - type='reduce_mean', - inputs={'X': input}, - outputs={'Out': out}, - attrs=attrs) - return out + return paddle.mean(x=input, axis=dim, keepdim=keep_dim, name=name) def reduce_max(input, dim=None, keep_dim=False, name=None): @@ -12331,6 +12304,7 @@ def mean(x, name=None): name='data', shape=[2, 3], dtype='float32') mean = fluid.layers.mean(input) """ + if in_dygraph_mode(): return core.ops.mean(x) diff --git a/python/paddle/fluid/tests/unittests/test_mean_op.py b/python/paddle/fluid/tests/unittests/test_mean_op.py index f3abd1acce6fafb8d187bfbe82765f982acae010..3799640b98800f660e72e3c8b4580949d5deb12a 100644 --- a/python/paddle/fluid/tests/unittests/test_mean_op.py +++ b/python/paddle/fluid/tests/unittests/test_mean_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import Program, program_guard @@ -73,5 +74,61 @@ class TestFP16MeanOp(TestMeanOp): place, ['X'], 'Out', max_relative_error=0.8) +class TestMeanAPI(unittest.TestCase): + """ + test paddle.tensor.stat.mean + """ + + def setUp(self): + self.x_shape = [2, 3, 4, 5] + self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_api_static(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.x_shape) + out1 = paddle.mean(x) + out2 = paddle.tensor.mean(x) + out3 = paddle.tensor.stat.mean(x) + axis = np.arange(len(self.x_shape)).tolist() + out4 = paddle.mean(x, axis) + out5 = paddle.mean(x, tuple(axis)) + + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x}, + fetch_list=[out1, out2, out3, out4, out5]) + out_ref = np.mean(self.x) + for out in res: + self.assertEqual(np.allclose(out, out_ref), True) + + def test_api_imperative(self): + def test_case(x, axis=None, keepdim=False): + x_tensor = paddle.to_variable(x) + out = paddle.mean(x_tensor, axis, keepdim) + if isinstance(axis, list): + axis = tuple(axis) + if len(axis) == 0: + axis = None + out_ref = np.mean(x, axis, keepdims=keepdim) + self.assertEqual(np.allclose(out.numpy(), out_ref), True) + + paddle.disable_static(self.place) + test_case(self.x) + test_case(self.x, []) + test_case(self.x, -1) + test_case(self.x, keepdim=True) + test_case(self.x, 2, keepdim=True) + test_case(self.x, [0, 2]) + test_case(self.x, (0, 2)) + test_case(self.x, [0, 1, 2, 3]) + paddle.enable_static() + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', [10, 12], 'int8') + self.assertRaises(TypeError, paddle.mean, x) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 82cfe0a2423f3d5eef2a1f214917349b332367fb..0531da2b06ec37fd60389cd2abb85822ebc9d0f9 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -628,5 +628,24 @@ class API_TestSumOp(unittest.TestCase): self.assertEqual((np_z == z_expected).all(), True) +class API_TestReduceMeanOp(unittest.TestCase): + def test_static(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data("x", shape=[10, 10], dtype="float32") + out = fluid.layers.reduce_mean(input=x, dim=1) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + x_np = np.random.rand(10, 10).astype(np.float32) + res = exe.run(feed={"x": x_np}, fetch_list=[out]) + self.assertEqual(np.allclose(res[0], np.mean(x_np, axis=1)), True) + + def test_dygraph(self): + with fluid.dygraph.guard(): + x_np = np.random.rand(10, 10).astype(np.float32) + x = fluid.dygraph.to_variable(x_np) + out = fluid.layers.reduce_mean(input=x, dim=1) + self.assertEqual(np.allclose(out.numpy(), np.mean(x_np, axis=1)), True) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 9b3bb081d9776a7ef88245f76d564c7b107ca669..7d22a0be5b0a9a2088f22535c6e2e56f7dc1f959 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -13,17 +13,99 @@ # limitations under the License. # TODO: define statistical functions of a tensor -from ..fluid.layers import mean #DEFINE_ALIAS from ..fluid.layers import reduce_mean #DEFINE_ALIAS __all__ = ['mean', 'reduce_mean', 'std', 'var'] import numpy as np from ..fluid.layer_helper import LayerHelper -from ..fluid.framework import in_dygraph_mode +from ..fluid.framework import core, in_dygraph_mode from ..fluid import layers from .search import where from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype +import paddle + + +def mean(x, axis=None, keepdim=False, name=None): + """ + Computes the mean of the input tensor's elements along ``axis``. + + Args: + x (Tensor): The input Tensor with data type float32, float64, int32, + int64. + axis (int|list|tuple, optional): The axis along which to perform mean + calculations. ``axis`` should be int, list(int) or tuple(int). If + ``axis`` is a list/tuple of dimension(s), mean is calculated along + all element(s) of ``axis`` . ``axis`` or element(s) of ``axis`` + should be in range [-D, D), where D is the dimensions of ``x`` . If + ``axis`` or element(s) of ``axis`` is less than 0, it works the + same way as :math:`axis + D` . If ``axis`` is None, mean is + calculated along all elements of ``x``. Default is None. + keepdim (bool, optional): Whether to reserve the reduced dimension(s) + in the output Tensor. If ``keep_dim`` is True, the dimensions of + the output Tensor is the same as ``x`` except in the reduced + dimensions(it is of size 1 in this case). Otherwise, the shape of + the output Tensor is squeezed in ``axis`` . Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, results of average along ``axis`` of ``x``, with the same data + type as ``x``. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = np.array([[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]], + [[13, 14, 15, 16], + [17, 18, 19, 20], + [21, 22, 23, 24]]], 'float32') + x = paddle.to_variable(x) + out1 = paddle.mean(x) + # [12.5] + out2 = paddle.mean(x, axis=-1) + # [[ 2.5 6.5 10.5] + # [14.5 18.5 22.5]] + out3 = paddle.mean(x, axis=-1, keepdim=True) + # [[[ 2.5] + # [ 6.5] + # [10.5]] + # [[14.5] + # [18.5] + # [22.5]]] + out4 = paddle.mean(x, axis=[0, 2]) + # [ 8.5 12.5 16.5] + """ + + if isinstance(axis, int): + axis = [axis] + reduce_all = True if axis is None \ + or len(axis)==0 \ + or len(axis) == len(x.shape) else False + if axis is None or len(axis) == 0: + axis = [0] + + if in_dygraph_mode(): + return core.ops.reduce_mean(x, 'dim', axis, 'keep_dim', keepdim, + 'reduce_all', reduce_all) + + check_variable_and_dtype(x, 'x/input', + ['float32', 'float64', 'int32', 'int64'], + 'mean/reduce_mean') + + helper = LayerHelper('mean', **locals()) + attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all} + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='reduce_mean', inputs={'X': x}, outputs={'Out': out}, attrs=attrs) + return out def var(input, axis=None, keepdim=False, unbiased=True, out=None, name=None):