未验证 提交 57e12429 编写于 作者: Z zhupengyang 提交者: GitHub

var, std: input->x, adjust attr order, remove out, add docs (#26446)

上级 e966d0b6
...@@ -12301,13 +12301,10 @@ def clip_by_norm(x, max_norm, name=None): ...@@ -12301,13 +12301,10 @@ def clip_by_norm(x, max_norm, name=None):
return out return out
@deprecated(since="2.0.0", update_to="paddle.mean")
@templatedoc() @templatedoc()
def mean(x, name=None): def mean(x, name=None):
""" """
:alias_main: paddle.mean
:alias: paddle.mean,paddle.tensor.mean,paddle.tensor.stat.mean
:old_api: paddle.fluid.layers.mean
${comment} ${comment}
Args: Args:
......
...@@ -15,65 +15,104 @@ ...@@ -15,65 +15,104 @@
import unittest import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid
class TestStdLayer(unittest.TestCase): def ref_std(x, axis=None, unbiased=True, keepdim=False):
ddof = 1 if unbiased else 0
if isinstance(axis, int):
axis = (axis, )
if axis is not None:
axis = tuple(axis)
return np.std(x, axis=axis, ddof=ddof, keepdims=keepdim)
class TestStdAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self._dtype = "float64" self.dtype = 'float64'
self._input = np.random.random([2, 3, 4, 5]).astype(self._dtype) self.shape = [1, 3, 4, 10]
self.axis = [1, 3]
def static(self, axis=None, keepdim=False, unbiased=True): self.keepdim = False
prog = fluid.Program() self.unbiased = True
with fluid.program_guard(prog): self.set_attrs()
data = fluid.data( self.x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
name="data", dtype=self._dtype, shape=[None, 3, 4, 5]) self.place=paddle.CUDAPlace(0) \
out = prog.current_block().create_var( if paddle.fluid.core.is_compiled_with_cuda() \
dtype=self._dtype, shape=[2, 3, 4, 5]) else paddle.CPUPlace()
paddle.std(input=data,
axis=axis, def set_attrs(self):
keepdim=keepdim, pass
unbiased=unbiased,
out=out) def static(self):
with paddle.static.program_guard(paddle.static.Program()):
exe = fluid.Executor(self._place) x = paddle.data('X', self.shape, self.dtype)
return exe.run(feed={"data": self._input}, out = paddle.std(x, self.axis, self.unbiased, self.keepdim)
program=prog, exe = paddle.static.Executor(self.place)
fetch_list=[out])[0] res = exe.run(feed={'X': self.x}, fetch_list=[out])
return res[0]
def dynamic(self, axis=None, keepdim=False, unbiased=True):
with fluid.dygraph.guard(self._place): def dygraph(self):
data = fluid.dygraph.to_variable(self._input) paddle.disable_static()
out = paddle.std(input=data, x = paddle.to_tensor(self.x)
axis=axis, out = paddle.std(x, self.axis, self.unbiased, self.keepdim)
keepdim=keepdim, paddle.enable_static()
unbiased=unbiased) return out.numpy()
return out.numpy()
def test_api(self):
def numpy(self, axis=None, keepdim=False, unbiased=True): out_ref = ref_std(self.x, self.axis, self.unbiased, self.keepdim)
ddof = 1 if unbiased else 0 out_dygraph = self.dygraph()
axis = tuple(axis) if isinstance(axis, list) else axis out_static = self.static()
return np.std(self._input, axis=axis, keepdims=keepdim, ddof=ddof) for out in [out_dygraph, out_static]:
self.assertTrue(np.allclose(out_ref, out))
def test_equal(self): self.assertTrue(np.equal(out_ref.shape, out.shape).all())
places = []
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0)) class TestStdAPI_dtype(TestStdAPI):
for place in places: def set_attrs(self):
self._place = place self.dtype = 'float32'
self.assertTrue(np.allclose(self.numpy(), self.static()))
self.assertTrue(
np.allclose( class TestStdAPI_axis_int(TestStdAPI):
self.numpy(axis=[0, 2]), self.dynamic(axis=[0, 2]))) def set_attrs(self):
self.assertTrue( self.axis = 2
np.allclose(
self.numpy(
axis=[1, 3], keepdim=True), class TestStdAPI_axis_list(TestStdAPI):
self.dynamic( def set_attrs(self):
axis=[1, 3], keepdim=True))) self.axis = [1, 2]
self.assertTrue(
np.allclose(
self.numpy(unbiased=False), self.dynamic(unbiased=False))) class TestStdAPI_axis_tuple(TestStdAPI):
def set_attrs(self):
self.axis = (1, 3)
class TestStdAPI_keepdim(TestStdAPI):
def set_attrs(self):
self.keepdim = False
class TestStdAPI_unbiased(TestStdAPI):
def set_attrs(self):
self.unbiased = False
class TestStdAPI_alias(unittest.TestCase):
def test_alias(self):
paddle.disable_static()
x = paddle.to_tensor(np.array([10, 12], 'float32'))
out1 = paddle.std(x).numpy()
out2 = paddle.tensor.std(x).numpy()
out3 = paddle.tensor.stat.std(x).numpy()
self.assertTrue(np.allclose(out1, out2))
self.assertTrue(np.allclose(out1, out3))
paddle.enable_static()
class TestStdError(unittest.TestCase):
def test_error(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [2, 3, 4], 'int32')
self.assertRaises(TypeError, paddle.std, x)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -15,65 +15,104 @@ ...@@ -15,65 +15,104 @@
import unittest import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid
class TestVarianceLayer(unittest.TestCase): def ref_var(x, axis=None, unbiased=True, keepdim=False):
ddof = 1 if unbiased else 0
if isinstance(axis, int):
axis = (axis, )
if axis is not None:
axis = tuple(axis)
return np.var(x, axis=axis, ddof=ddof, keepdims=keepdim)
class TestVarAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self._dtype = "float64" self.dtype = 'float64'
self._input = np.random.random([2, 3, 4, 5]).astype(self._dtype) self.shape = [1, 3, 4, 10]
self.axis = [1, 3]
def static(self, axis=None, keepdim=False, unbiased=True): self.keepdim = False
prog = fluid.Program() self.unbiased = True
with fluid.program_guard(prog): self.set_attrs()
data = fluid.data( self.x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
name="data", dtype=self._dtype, shape=[None, 3, 4, 5]) self.place=paddle.CUDAPlace(0) \
out = prog.current_block().create_var( if paddle.fluid.core.is_compiled_with_cuda() \
dtype=self._dtype, shape=[2, 3, 4, 5]) else paddle.CPUPlace()
paddle.var(input=data,
axis=axis, def set_attrs(self):
keepdim=keepdim, pass
unbiased=unbiased,
out=out) def static(self):
with paddle.static.program_guard(paddle.static.Program()):
exe = fluid.Executor(self._place) x = paddle.data('X', self.shape, self.dtype)
return exe.run(feed={"data": self._input}, out = paddle.var(x, self.axis, self.unbiased, self.keepdim)
program=prog, exe = paddle.static.Executor(self.place)
fetch_list=[out])[0] res = exe.run(feed={'X': self.x}, fetch_list=[out])
return res[0]
def dynamic(self, axis=None, keepdim=False, unbiased=True):
with fluid.dygraph.guard(self._place): def dygraph(self):
data = fluid.dygraph.to_variable(self._input) paddle.disable_static()
out = paddle.var(input=data, x = paddle.to_tensor(self.x)
axis=axis, out = paddle.var(x, self.axis, self.unbiased, self.keepdim)
keepdim=keepdim, paddle.enable_static()
unbiased=unbiased) return out.numpy()
return out.numpy()
def test_api(self):
def numpy(self, axis=None, keepdim=False, unbiased=True): out_ref = ref_var(self.x, self.axis, self.unbiased, self.keepdim)
ddof = 1 if unbiased else 0 out_dygraph = self.dygraph()
axis = tuple(axis) if isinstance(axis, list) else axis out_static = self.static()
return np.var(self._input, axis=axis, keepdims=keepdim, ddof=ddof) for out in [out_dygraph, out_static]:
self.assertTrue(np.allclose(out_ref, out))
def test_equal(self): self.assertTrue(np.equal(out_ref.shape, out.shape).all())
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0)) class TestVarAPI_dtype(TestVarAPI):
for place in places: def set_attrs(self):
self._place = place self.dtype = 'float32'
self.assertTrue(np.allclose(self.numpy(), self.static()))
self.assertTrue(
np.allclose( class TestVarAPI_axis_int(TestVarAPI):
self.numpy(axis=[0, 2]), self.dynamic(axis=[0, 2]))) def set_attrs(self):
self.assertTrue( self.axis = 2
np.allclose(
self.numpy(
axis=[1, 3], keepdim=True), class TestVarAPI_axis_list(TestVarAPI):
self.dynamic( def set_attrs(self):
axis=[1, 3], keepdim=True))) self.axis = [1, 2]
self.assertTrue(
np.allclose(
self.numpy(unbiased=False), self.dynamic(unbiased=False))) class TestVarAPI_axis_tuple(TestVarAPI):
def set_attrs(self):
self.axis = (1, 3)
class TestVarAPI_keepdim(TestVarAPI):
def set_attrs(self):
self.keepdim = False
class TestVarAPI_unbiased(TestVarAPI):
def set_attrs(self):
self.unbiased = False
class TestVarAPI_alias(unittest.TestCase):
def test_alias(self):
paddle.disable_static()
x = paddle.to_tensor(np.array([10, 12], 'float32'))
out1 = paddle.var(x).numpy()
out2 = paddle.tensor.var(x).numpy()
out3 = paddle.tensor.stat.var(x).numpy()
self.assertTrue(np.allclose(out1, out2))
self.assertTrue(np.allclose(out1, out3))
paddle.enable_static()
class TestVarError(unittest.TestCase):
def test_error(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [2, 3, 4], 'int32')
self.assertRaises(TypeError, paddle.var, x)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -361,14 +361,14 @@ def ones_like(x, dtype=None, name=None): ...@@ -361,14 +361,14 @@ def ones_like(x, dtype=None, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np import numpy as np
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(np.array([1,2,3], dtype='float32')) x = paddle.to_tensor(np.array([1,2,3], dtype='float32'))
out1 = paddle.zeros_like(x) # [1., 1., 1.] out1 = paddle.zeros_like(x) # [1., 1., 1.]
out2 = paddle.zeros_like(x, dtype='int32') # [1, 1, 1] out2 = paddle.zeros_like(x, dtype='int32') # [1, 1, 1]
""" """
return full_like(x=x, fill_value=1, dtype=dtype, name=name) return full_like(x=x, fill_value=1, dtype=dtype, name=name)
...@@ -451,14 +451,14 @@ def zeros_like(x, dtype=None, name=None): ...@@ -451,14 +451,14 @@ def zeros_like(x, dtype=None, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np import numpy as np
paddle.disable_static() paddle.disable_static()
x = paddle.to_tensor(np.array([1,2,3], dtype='float32')) x = paddle.to_tensor(np.array([1,2,3], dtype='float32'))
out1 = paddle.zeros_like(x) # [0., 0., 0.] out1 = paddle.zeros_like(x) # [0., 0., 0.]
out2 = paddle.zeros_like(x, dtype='int32') # [0, 0, 0] out2 = paddle.zeros_like(x, dtype='int32') # [0, 0, 0]
""" """
return full_like(x=x, fill_value=0, dtype=dtype, name=name) return full_like(x=x, fill_value=0, dtype=dtype, name=name)
......
...@@ -40,9 +40,9 @@ def mean(x, axis=None, keepdim=False, name=None): ...@@ -40,9 +40,9 @@ def mean(x, axis=None, keepdim=False, name=None):
should be in range [-D, D), where D is the dimensions of ``x`` . If should be in range [-D, D), where D is the dimensions of ``x`` . If
``axis`` or element(s) of ``axis`` is less than 0, it works the ``axis`` or element(s) of ``axis`` is less than 0, it works the
same way as :math:`axis + D` . If ``axis`` is None, mean is same way as :math:`axis + D` . If ``axis`` is None, mean is
calculated along all elements of ``x``. Default is None. calculated over all elements of ``x``. Default is None.
keepdim (bool, optional): Whether to reserve the reduced dimension(s) keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keep_dim`` is True, the dimensions of in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False. the output Tensor is squeezed in ``axis`` . Default is False.
...@@ -67,7 +67,7 @@ def mean(x, axis=None, keepdim=False, name=None): ...@@ -67,7 +67,7 @@ def mean(x, axis=None, keepdim=False, name=None):
[[13, 14, 15, 16], [[13, 14, 15, 16],
[17, 18, 19, 20], [17, 18, 19, 20],
[21, 22, 23, 24]]], 'float32') [21, 22, 23, 24]]], 'float32')
x = paddle.to_variable(x) x = paddle.to_tensor(x)
out1 = paddle.mean(x) out1 = paddle.mean(x)
# [12.5] # [12.5]
out2 = paddle.mean(x, axis=-1) out2 = paddle.mean(x, axis=-1)
...@@ -111,142 +111,120 @@ def mean(x, axis=None, keepdim=False, name=None): ...@@ -111,142 +111,120 @@ def mean(x, axis=None, keepdim=False, name=None):
return out return out
def var(input, axis=None, keepdim=False, unbiased=True, out=None, name=None): def var(x, axis=None, unbiased=True, keepdim=False, name=None):
""" """
:alias_main: paddle.var Computes the variance of ``x`` along ``axis`` .
:alias: paddle.var,paddle.tensor.var,paddle.tensor.stat.var
Computes the variance of the input Variable's elements along the specified
axis.
Args: Args:
input (Variable): The input Variable to be computed variance, with data x (Tensor): The input Tensor with data type float32, float64.
type float32 and float64 supported. axis (int|list|tuple, optional): The axis along which to perform
axis (list|int, optional): The axis along which the variance is computed. variance calculations. ``axis`` should be int, list(int) or
If `None`, compute the variance over all elements of :attr:`input` tuple(int). If ``axis`` is a list/tuple of dimension(s), variance
and return a Variable with a single element, otherwise it must be in is calculated along all element(s) of ``axis`` . ``axis`` or
the range :math:`[-rank(input), rank(input))`. If :math:`axis[i] < 0`, element(s) of ``axis`` should be in range [-D, D), where D is the
the axis to compute is :math:`rank(input) + axis[i]`. dimensions of ``x`` . If ``axis`` or element(s) of ``axis`` is less
keepdim (bool, optional): Whether to reserve the reduced dimensions in than 0, it works the same way as :math:`axis + D` . If ``axis`` is
the output Variable. The dimensions in :attr:`axis` will be squeezed None, variance is calculated over all elements of ``x``. Default
and the result Variable will have :attr:`len(axis)` fewer dimensions is None.
than the :attr:`input` unless :attr:`keepdim` is true, default False. unbiased (bool, optional): Whether to use the unbiased estimation. If
unbiased (bool, optional): Whether to compute variance via the unbiased ``unbiased`` is True, the divisor used in the computation is
estimator, in which the divisor used in the computation is :math:`N - 1`, where :math:`N` represents the number of elements
:math:`N - 1`, where :math:`N` represents the number of elements along ``axis`` , otherwise the divisor is :math:`N`. Default is True.
along :attr:`axis`, otherwise the divisor is :math:`N`. Default True. keepdim (bool, optional): Whether to reserve the reduced dimension(s)
out (Variable, optional): Alternate output Variable to store the result in the output Tensor. If ``keepdim`` is True, the dimensions of
variance. Default None. the output Tensor is the same as ``x`` except in the reduced
name (str, optional): The name for this layer. Normally there is no dimensions(it is of size 1 in this case). Otherwise, the shape of
need for user to set this property. For more information, please the output Tensor is squeezed in ``axis`` . Default is False.
refer to :ref:`api_guide_Name`. Default None. name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable: The result variance with the same dtype as :attr:`input`. Tensor, results of variance along ``axis`` of ``x``, with the same data
If :attr:`out = None`, returns a new Variable containing the type as ``x``.
variance, otherwise returns a reference to the output Variable.
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np
import paddle import paddle
import paddle.fluid.dygraph as dg import numpy as np
a = np.array([[1.0, 2.0], [3.0, 4.0]]).astype("float32") paddle.disable_static()
with dg.guard():
data = dg.to_variable(a) x = np.array([[1.0, 2.0, 3.0], [1.0, 4.0, 5.0]])
variance = paddle.var(data, axis=[1]) x = paddle.to_tensor(x)
print(variance.numpy()) out1 = paddle.var(x)
# [0.5 0.5] # [2.66666667]
out2 = paddle.var(x, axis=1)
# [1. 4.33333333]
""" """
dtype = convert_dtype(input.dtype) if not in_dygraph_mode():
if dtype not in ["float32", "float64"]: check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'var')
raise ValueError("Layer tensor.var() only supports floating-point "
"dtypes, but received {}.".format(dtype))
rank = len(input.shape)
axes = axis if axis != None and axis != [] else range(rank)
axes = [e if e >= 0 else e + rank for e in axes]
inp_shape = input.shape if in_dygraph_mode() else layers.shape(input)
mean = layers.reduce_mean(input, dim=axis, keep_dim=True, name=name)
tmp = layers.reduce_mean(
(input - mean)**2, dim=axis, keep_dim=keepdim, name=name)
u = mean(x, axis, True, name)
out = paddle.sum((x - u)**2, axis, keepdim=keepdim, name=name)
n = paddle.cast(paddle.numel(x), x.dtype) \
/ paddle.cast(paddle.numel(out), x.dtype)
if unbiased: if unbiased:
n = 1 one_const = paddle.ones([1], x.dtype)
for i in axes: n = where(n > one_const, n - 1., one_const)
n *= inp_shape[i] out /= n
if not in_dygraph_mode(): return out
n = layers.cast(n, dtype)
zero_const = layers.fill_constant(shape=[1], dtype=dtype, value=0.0)
factor = where(n > 1.0, n / (n - 1.0), zero_const)
else:
factor = n / (n - 1.0) if n > 1.0 else 0.0
tmp *= factor
if out:
layers.assign(input=tmp, output=out)
return out
else:
return tmp
def std(input, axis=None, keepdim=False, unbiased=True, out=None, name=None):
"""
:alias_main: paddle.std
:alias: paddle.std,paddle.tensor.std,paddle.tensor.stat.std
Computes the standard-deviation of the input Variable's elements along the specified def std(x, axis=None, unbiased=True, keepdim=False, name=None):
axis. """
Computes the standard-deviation of ``x`` along ``axis`` .
Args: Args:
input (Variable): The input Variable to be computed standard-deviation, with data x (Tensor): The input Tensor with data type float32, float64.
type float32 and float64 supported. axis (int|list|tuple, optional): The axis along which to perform
axis (list|int, optional): The axis along which the standard-deviation is computed. standard-deviation calculations. ``axis`` should be int, list(int)
If `None`, compute the standard-deviation over all elements of :attr:`input` or tuple(int). If ``axis`` is a list/tuple of dimension(s),
and return a Variable with a single element, otherwise it must be in standard-deviation is calculated along all element(s) of ``axis`` .
the range :math:`[-rank(input), rank(input))`. If :math:`axis[i] < 0`, ``axis`` or element(s) of ``axis`` should be in range [-D, D),
the axis to compute is :math:`rank(input) + axis[i]`. where D is the dimensions of ``x`` . If ``axis`` or element(s) of
keepdim (bool, optional): Whether to reserve the reduced dimensions in ``axis`` is less than 0, it works the same way as :math:`axis + D` .
the output Variable. The dimensions in :attr:`axis` will be squeezed If ``axis`` is None, standard-deviation is calculated over all
and the result Variable will have :attr:`len(axis)` fewer dimensions elements of ``x``. Default is None.
than the :attr:`input` unless :attr:`keepdim` is true, default False. unbiased (bool, optional): Whether to use the unbiased estimation. If
unbiased (bool, optional): Whether to compute standard-deviation via the unbiased ``unbiased`` is True, the standard-deviation is calculated via the
estimator, in which the divisor used in the computation is unbiased estimator. If ``unbiased`` is True, the divisor used in
:math:`N - 1`, where :math:`N` represents the number of elements the computation is :math:`N - 1`, where :math:`N` represents the
along :attr:`axis`, otherwise the divisor is :math:`N`. Default True. number of elements along ``axis`` , otherwise the divisor is
out (Variable, optional): Alternate output Variable to store the result :math:`N`. Default is True.
standard-deviation . Default None. keepdim (bool, optional): Whether to reserve the reduced dimension(s)
name (str, optional): The name for this layer. Normally there is no in the output Tensor. If ``keepdim`` is True, the dimensions of
need for user to set this property. For more information, please the output Tensor is the same as ``x`` except in the reduced
refer to :ref:`api_guide_Name`. Default None. dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable: The result standard-deviation with the same dtype as :attr:`input`. Tensor, results of standard-deviation along ``axis`` of ``x``, with the
If :attr:`out = None`, returns a new Variable containing the same data type as ``x``.
standard-deviation , otherwise returns a reference to the output Variable.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.fluid as fluid import numpy as np
# x is a Tensor variable with following elements:
# [[0.2, 0.3, 0.5, 0.9] paddle.disable_static()
# [0.1, 0.2, 0.6, 0.7]]
# Each example is followed by the corresponding output tensor. x = np.array([[1.0, 2.0, 3.0], [1.0, 4.0, 5.0]])
x = fluid.data(name='x', shape=[2, 4], dtype='float32') x = paddle.to_tensor(x)
paddle.std(x) # [0.28252685] out1 = paddle.std(x)
paddle.std(x, axis=[0]) # [0.0707107, 0.07071075, 0.07071064, 0.1414217] # [1.63299316]
paddle.std(x, axis=[-1]) # [0.30956957, 0.29439208] out2 = paddle.std(x, axis=1)
# [1. 2.081666]
""" """
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'std') if not in_dygraph_mode():
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'std')
tmp = var(input, axis=axis, keepdim=keepdim, unbiased=unbiased, name=name)
tmp = layers.sqrt(tmp) out = var(**locals())
if out is not None: return paddle.sqrt(out)
layers.assign(input=tmp, output=out)
return out
else:
return tmp
def numel(x, name=None): def numel(x, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册