From 57e124298d1441c6505408d5c1e7805906fe0083 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Sat, 22 Aug 2020 10:16:00 +0800 Subject: [PATCH] var, std: input->x, adjust attr order, remove out, add docs (#26446) --- python/paddle/fluid/layers/nn.py | 5 +- .../fluid/tests/unittests/test_std_layer.py | 151 ++++++++----- .../tests/unittests/test_variance_layer.py | 151 ++++++++----- python/paddle/tensor/creation.py | 24 +- python/paddle/tensor/stat.py | 210 ++++++++---------- 5 files changed, 297 insertions(+), 244 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1f1439aa7cd..955217a46ce 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12301,13 +12301,10 @@ def clip_by_norm(x, max_norm, name=None): return out +@deprecated(since="2.0.0", update_to="paddle.mean") @templatedoc() def mean(x, name=None): """ - :alias_main: paddle.mean - :alias: paddle.mean,paddle.tensor.mean,paddle.tensor.stat.mean - :old_api: paddle.fluid.layers.mean - ${comment} Args: diff --git a/python/paddle/fluid/tests/unittests/test_std_layer.py b/python/paddle/fluid/tests/unittests/test_std_layer.py index d1e00563042..e4551514814 100644 --- a/python/paddle/fluid/tests/unittests/test_std_layer.py +++ b/python/paddle/fluid/tests/unittests/test_std_layer.py @@ -15,65 +15,104 @@ import unittest import numpy as np import paddle -import paddle.fluid as fluid -class TestStdLayer(unittest.TestCase): +def ref_std(x, axis=None, unbiased=True, keepdim=False): + ddof = 1 if unbiased else 0 + if isinstance(axis, int): + axis = (axis, ) + if axis is not None: + axis = tuple(axis) + return np.std(x, axis=axis, ddof=ddof, keepdims=keepdim) + + +class TestStdAPI(unittest.TestCase): def setUp(self): - self._dtype = "float64" - self._input = np.random.random([2, 3, 4, 5]).astype(self._dtype) - - def static(self, axis=None, keepdim=False, unbiased=True): - prog = fluid.Program() - with fluid.program_guard(prog): - data = fluid.data( - name="data", dtype=self._dtype, shape=[None, 3, 4, 5]) - out = prog.current_block().create_var( - dtype=self._dtype, shape=[2, 3, 4, 5]) - paddle.std(input=data, - axis=axis, - keepdim=keepdim, - unbiased=unbiased, - out=out) - - exe = fluid.Executor(self._place) - return exe.run(feed={"data": self._input}, - program=prog, - fetch_list=[out])[0] - - def dynamic(self, axis=None, keepdim=False, unbiased=True): - with fluid.dygraph.guard(self._place): - data = fluid.dygraph.to_variable(self._input) - out = paddle.std(input=data, - axis=axis, - keepdim=keepdim, - unbiased=unbiased) - return out.numpy() - - def numpy(self, axis=None, keepdim=False, unbiased=True): - ddof = 1 if unbiased else 0 - axis = tuple(axis) if isinstance(axis, list) else axis - return np.std(self._input, axis=axis, keepdims=keepdim, ddof=ddof) - - def test_equal(self): - places = [] - if fluid.core.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for place in places: - self._place = place - self.assertTrue(np.allclose(self.numpy(), self.static())) - self.assertTrue( - np.allclose( - self.numpy(axis=[0, 2]), self.dynamic(axis=[0, 2]))) - self.assertTrue( - np.allclose( - self.numpy( - axis=[1, 3], keepdim=True), - self.dynamic( - axis=[1, 3], keepdim=True))) - self.assertTrue( - np.allclose( - self.numpy(unbiased=False), self.dynamic(unbiased=False))) + self.dtype = 'float64' + self.shape = [1, 3, 4, 10] + self.axis = [1, 3] + self.keepdim = False + self.unbiased = True + self.set_attrs() + self.x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + self.place=paddle.CUDAPlace(0) \ + if paddle.fluid.core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def set_attrs(self): + pass + + def static(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.shape, self.dtype) + out = paddle.std(x, self.axis, self.unbiased, self.keepdim) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + return res[0] + + def dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + out = paddle.std(x, self.axis, self.unbiased, self.keepdim) + paddle.enable_static() + return out.numpy() + + def test_api(self): + out_ref = ref_std(self.x, self.axis, self.unbiased, self.keepdim) + out_dygraph = self.dygraph() + out_static = self.static() + for out in [out_dygraph, out_static]: + self.assertTrue(np.allclose(out_ref, out)) + self.assertTrue(np.equal(out_ref.shape, out.shape).all()) + + +class TestStdAPI_dtype(TestStdAPI): + def set_attrs(self): + self.dtype = 'float32' + + +class TestStdAPI_axis_int(TestStdAPI): + def set_attrs(self): + self.axis = 2 + + +class TestStdAPI_axis_list(TestStdAPI): + def set_attrs(self): + self.axis = [1, 2] + + +class TestStdAPI_axis_tuple(TestStdAPI): + def set_attrs(self): + self.axis = (1, 3) + + +class TestStdAPI_keepdim(TestStdAPI): + def set_attrs(self): + self.keepdim = False + + +class TestStdAPI_unbiased(TestStdAPI): + def set_attrs(self): + self.unbiased = False + + +class TestStdAPI_alias(unittest.TestCase): + def test_alias(self): + paddle.disable_static() + x = paddle.to_tensor(np.array([10, 12], 'float32')) + out1 = paddle.std(x).numpy() + out2 = paddle.tensor.std(x).numpy() + out3 = paddle.tensor.stat.std(x).numpy() + self.assertTrue(np.allclose(out1, out2)) + self.assertTrue(np.allclose(out1, out3)) + paddle.enable_static() + + +class TestStdError(unittest.TestCase): + def test_error(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', [2, 3, 4], 'int32') + self.assertRaises(TypeError, paddle.std, x) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_variance_layer.py b/python/paddle/fluid/tests/unittests/test_variance_layer.py index 569f064db85..b5bb3cc978a 100644 --- a/python/paddle/fluid/tests/unittests/test_variance_layer.py +++ b/python/paddle/fluid/tests/unittests/test_variance_layer.py @@ -15,65 +15,104 @@ import unittest import numpy as np import paddle -import paddle.fluid as fluid -class TestVarianceLayer(unittest.TestCase): +def ref_var(x, axis=None, unbiased=True, keepdim=False): + ddof = 1 if unbiased else 0 + if isinstance(axis, int): + axis = (axis, ) + if axis is not None: + axis = tuple(axis) + return np.var(x, axis=axis, ddof=ddof, keepdims=keepdim) + + +class TestVarAPI(unittest.TestCase): def setUp(self): - self._dtype = "float64" - self._input = np.random.random([2, 3, 4, 5]).astype(self._dtype) - - def static(self, axis=None, keepdim=False, unbiased=True): - prog = fluid.Program() - with fluid.program_guard(prog): - data = fluid.data( - name="data", dtype=self._dtype, shape=[None, 3, 4, 5]) - out = prog.current_block().create_var( - dtype=self._dtype, shape=[2, 3, 4, 5]) - paddle.var(input=data, - axis=axis, - keepdim=keepdim, - unbiased=unbiased, - out=out) - - exe = fluid.Executor(self._place) - return exe.run(feed={"data": self._input}, - program=prog, - fetch_list=[out])[0] - - def dynamic(self, axis=None, keepdim=False, unbiased=True): - with fluid.dygraph.guard(self._place): - data = fluid.dygraph.to_variable(self._input) - out = paddle.var(input=data, - axis=axis, - keepdim=keepdim, - unbiased=unbiased) - return out.numpy() - - def numpy(self, axis=None, keepdim=False, unbiased=True): - ddof = 1 if unbiased else 0 - axis = tuple(axis) if isinstance(axis, list) else axis - return np.var(self._input, axis=axis, keepdims=keepdim, ddof=ddof) - - def test_equal(self): - places = [fluid.CPUPlace()] - if fluid.core.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for place in places: - self._place = place - self.assertTrue(np.allclose(self.numpy(), self.static())) - self.assertTrue( - np.allclose( - self.numpy(axis=[0, 2]), self.dynamic(axis=[0, 2]))) - self.assertTrue( - np.allclose( - self.numpy( - axis=[1, 3], keepdim=True), - self.dynamic( - axis=[1, 3], keepdim=True))) - self.assertTrue( - np.allclose( - self.numpy(unbiased=False), self.dynamic(unbiased=False))) + self.dtype = 'float64' + self.shape = [1, 3, 4, 10] + self.axis = [1, 3] + self.keepdim = False + self.unbiased = True + self.set_attrs() + self.x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + self.place=paddle.CUDAPlace(0) \ + if paddle.fluid.core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def set_attrs(self): + pass + + def static(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.shape, self.dtype) + out = paddle.var(x, self.axis, self.unbiased, self.keepdim) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + return res[0] + + def dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x) + out = paddle.var(x, self.axis, self.unbiased, self.keepdim) + paddle.enable_static() + return out.numpy() + + def test_api(self): + out_ref = ref_var(self.x, self.axis, self.unbiased, self.keepdim) + out_dygraph = self.dygraph() + out_static = self.static() + for out in [out_dygraph, out_static]: + self.assertTrue(np.allclose(out_ref, out)) + self.assertTrue(np.equal(out_ref.shape, out.shape).all()) + + +class TestVarAPI_dtype(TestVarAPI): + def set_attrs(self): + self.dtype = 'float32' + + +class TestVarAPI_axis_int(TestVarAPI): + def set_attrs(self): + self.axis = 2 + + +class TestVarAPI_axis_list(TestVarAPI): + def set_attrs(self): + self.axis = [1, 2] + + +class TestVarAPI_axis_tuple(TestVarAPI): + def set_attrs(self): + self.axis = (1, 3) + + +class TestVarAPI_keepdim(TestVarAPI): + def set_attrs(self): + self.keepdim = False + + +class TestVarAPI_unbiased(TestVarAPI): + def set_attrs(self): + self.unbiased = False + + +class TestVarAPI_alias(unittest.TestCase): + def test_alias(self): + paddle.disable_static() + x = paddle.to_tensor(np.array([10, 12], 'float32')) + out1 = paddle.var(x).numpy() + out2 = paddle.tensor.var(x).numpy() + out3 = paddle.tensor.stat.var(x).numpy() + self.assertTrue(np.allclose(out1, out2)) + self.assertTrue(np.allclose(out1, out3)) + paddle.enable_static() + + +class TestVarError(unittest.TestCase): + def test_error(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', [2, 3, 4], 'int32') + self.assertRaises(TypeError, paddle.var, x) if __name__ == '__main__': diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 55bf1344014..449b4c8aff6 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -361,14 +361,14 @@ def ones_like(x, dtype=None, name=None): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([1,2,3], dtype='float32')) - out1 = paddle.zeros_like(x) # [1., 1., 1.] - out2 = paddle.zeros_like(x, dtype='int32') # [1, 1, 1] + x = paddle.to_tensor(np.array([1,2,3], dtype='float32')) + out1 = paddle.zeros_like(x) # [1., 1., 1.] + out2 = paddle.zeros_like(x, dtype='int32') # [1, 1, 1] """ return full_like(x=x, fill_value=1, dtype=dtype, name=name) @@ -451,14 +451,14 @@ def zeros_like(x, dtype=None, name=None): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - x = paddle.to_tensor(np.array([1,2,3], dtype='float32')) - out1 = paddle.zeros_like(x) # [0., 0., 0.] - out2 = paddle.zeros_like(x, dtype='int32') # [0, 0, 0] + x = paddle.to_tensor(np.array([1,2,3], dtype='float32')) + out1 = paddle.zeros_like(x) # [0., 0., 0.] + out2 = paddle.zeros_like(x, dtype='int32') # [0, 0, 0] """ return full_like(x=x, fill_value=0, dtype=dtype, name=name) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 19f315cbb69..91676a6316b 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -40,9 +40,9 @@ def mean(x, axis=None, keepdim=False, name=None): should be in range [-D, D), where D is the dimensions of ``x`` . If ``axis`` or element(s) of ``axis`` is less than 0, it works the same way as :math:`axis + D` . If ``axis`` is None, mean is - calculated along all elements of ``x``. Default is None. + calculated over all elements of ``x``. Default is None. keepdim (bool, optional): Whether to reserve the reduced dimension(s) - in the output Tensor. If ``keep_dim`` is True, the dimensions of + in the output Tensor. If ``keepdim`` is True, the dimensions of the output Tensor is the same as ``x`` except in the reduced dimensions(it is of size 1 in this case). Otherwise, the shape of the output Tensor is squeezed in ``axis`` . Default is False. @@ -67,7 +67,7 @@ def mean(x, axis=None, keepdim=False, name=None): [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]], 'float32') - x = paddle.to_variable(x) + x = paddle.to_tensor(x) out1 = paddle.mean(x) # [12.5] out2 = paddle.mean(x, axis=-1) @@ -111,142 +111,120 @@ def mean(x, axis=None, keepdim=False, name=None): return out -def var(input, axis=None, keepdim=False, unbiased=True, out=None, name=None): +def var(x, axis=None, unbiased=True, keepdim=False, name=None): """ - :alias_main: paddle.var - :alias: paddle.var,paddle.tensor.var,paddle.tensor.stat.var - - Computes the variance of the input Variable's elements along the specified - axis. + Computes the variance of ``x`` along ``axis`` . Args: - input (Variable): The input Variable to be computed variance, with data - type float32 and float64 supported. - axis (list|int, optional): The axis along which the variance is computed. - If `None`, compute the variance over all elements of :attr:`input` - and return a Variable with a single element, otherwise it must be in - the range :math:`[-rank(input), rank(input))`. If :math:`axis[i] < 0`, - the axis to compute is :math:`rank(input) + axis[i]`. - keepdim (bool, optional): Whether to reserve the reduced dimensions in - the output Variable. The dimensions in :attr:`axis` will be squeezed - and the result Variable will have :attr:`len(axis)` fewer dimensions - than the :attr:`input` unless :attr:`keepdim` is true, default False. - unbiased (bool, optional): Whether to compute variance via the unbiased - estimator, in which the divisor used in the computation is - :math:`N - 1`, where :math:`N` represents the number of elements - along :attr:`axis`, otherwise the divisor is :math:`N`. Default True. - out (Variable, optional): Alternate output Variable to store the result - variance. Default None. - name (str, optional): The name for this layer. Normally there is no - need for user to set this property. For more information, please - refer to :ref:`api_guide_Name`. Default None. + x (Tensor): The input Tensor with data type float32, float64. + axis (int|list|tuple, optional): The axis along which to perform + variance calculations. ``axis`` should be int, list(int) or + tuple(int). If ``axis`` is a list/tuple of dimension(s), variance + is calculated along all element(s) of ``axis`` . ``axis`` or + element(s) of ``axis`` should be in range [-D, D), where D is the + dimensions of ``x`` . If ``axis`` or element(s) of ``axis`` is less + than 0, it works the same way as :math:`axis + D` . If ``axis`` is + None, variance is calculated over all elements of ``x``. Default + is None. + unbiased (bool, optional): Whether to use the unbiased estimation. If + ``unbiased`` is True, the divisor used in the computation is + :math:`N - 1`, where :math:`N` represents the number of elements + along ``axis`` , otherwise the divisor is :math:`N`. Default is True. + keepdim (bool, optional): Whether to reserve the reduced dimension(s) + in the output Tensor. If ``keepdim`` is True, the dimensions of + the output Tensor is the same as ``x`` except in the reduced + dimensions(it is of size 1 in this case). Otherwise, the shape of + the output Tensor is squeezed in ``axis`` . Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. Returns: - Variable: The result variance with the same dtype as :attr:`input`. - If :attr:`out = None`, returns a new Variable containing the - variance, otherwise returns a reference to the output Variable. + Tensor, results of variance along ``axis`` of ``x``, with the same data + type as ``x``. Examples: .. code-block:: python - import numpy as np import paddle - import paddle.fluid.dygraph as dg - - a = np.array([[1.0, 2.0], [3.0, 4.0]]).astype("float32") - with dg.guard(): - data = dg.to_variable(a) - variance = paddle.var(data, axis=[1]) - print(variance.numpy()) - # [0.5 0.5] + import numpy as np + + paddle.disable_static() + + x = np.array([[1.0, 2.0, 3.0], [1.0, 4.0, 5.0]]) + x = paddle.to_tensor(x) + out1 = paddle.var(x) + # [2.66666667] + out2 = paddle.var(x, axis=1) + # [1. 4.33333333] """ - dtype = convert_dtype(input.dtype) - if dtype not in ["float32", "float64"]: - raise ValueError("Layer tensor.var() only supports floating-point " - "dtypes, but received {}.".format(dtype)) - rank = len(input.shape) - axes = axis if axis != None and axis != [] else range(rank) - axes = [e if e >= 0 else e + rank for e in axes] - inp_shape = input.shape if in_dygraph_mode() else layers.shape(input) - mean = layers.reduce_mean(input, dim=axis, keep_dim=True, name=name) - tmp = layers.reduce_mean( - (input - mean)**2, dim=axis, keep_dim=keepdim, name=name) + if not in_dygraph_mode(): + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'var') + u = mean(x, axis, True, name) + out = paddle.sum((x - u)**2, axis, keepdim=keepdim, name=name) + + n = paddle.cast(paddle.numel(x), x.dtype) \ + / paddle.cast(paddle.numel(out), x.dtype) if unbiased: - n = 1 - for i in axes: - n *= inp_shape[i] - if not in_dygraph_mode(): - n = layers.cast(n, dtype) - zero_const = layers.fill_constant(shape=[1], dtype=dtype, value=0.0) - factor = where(n > 1.0, n / (n - 1.0), zero_const) - else: - factor = n / (n - 1.0) if n > 1.0 else 0.0 - tmp *= factor - if out: - layers.assign(input=tmp, output=out) - return out - else: - return tmp - - -def std(input, axis=None, keepdim=False, unbiased=True, out=None, name=None): - """ - :alias_main: paddle.std - :alias: paddle.std,paddle.tensor.std,paddle.tensor.stat.std + one_const = paddle.ones([1], x.dtype) + n = where(n > one_const, n - 1., one_const) + out /= n + return out + - Computes the standard-deviation of the input Variable's elements along the specified - axis. +def std(x, axis=None, unbiased=True, keepdim=False, name=None): + """ + Computes the standard-deviation of ``x`` along ``axis`` . Args: - input (Variable): The input Variable to be computed standard-deviation, with data - type float32 and float64 supported. - axis (list|int, optional): The axis along which the standard-deviation is computed. - If `None`, compute the standard-deviation over all elements of :attr:`input` - and return a Variable with a single element, otherwise it must be in - the range :math:`[-rank(input), rank(input))`. If :math:`axis[i] < 0`, - the axis to compute is :math:`rank(input) + axis[i]`. - keepdim (bool, optional): Whether to reserve the reduced dimensions in - the output Variable. The dimensions in :attr:`axis` will be squeezed - and the result Variable will have :attr:`len(axis)` fewer dimensions - than the :attr:`input` unless :attr:`keepdim` is true, default False. - unbiased (bool, optional): Whether to compute standard-deviation via the unbiased - estimator, in which the divisor used in the computation is - :math:`N - 1`, where :math:`N` represents the number of elements - along :attr:`axis`, otherwise the divisor is :math:`N`. Default True. - out (Variable, optional): Alternate output Variable to store the result - standard-deviation . Default None. - name (str, optional): The name for this layer. Normally there is no - need for user to set this property. For more information, please - refer to :ref:`api_guide_Name`. Default None. + x (Tensor): The input Tensor with data type float32, float64. + axis (int|list|tuple, optional): The axis along which to perform + standard-deviation calculations. ``axis`` should be int, list(int) + or tuple(int). If ``axis`` is a list/tuple of dimension(s), + standard-deviation is calculated along all element(s) of ``axis`` . + ``axis`` or element(s) of ``axis`` should be in range [-D, D), + where D is the dimensions of ``x`` . If ``axis`` or element(s) of + ``axis`` is less than 0, it works the same way as :math:`axis + D` . + If ``axis`` is None, standard-deviation is calculated over all + elements of ``x``. Default is None. + unbiased (bool, optional): Whether to use the unbiased estimation. If + ``unbiased`` is True, the standard-deviation is calculated via the + unbiased estimator. If ``unbiased`` is True, the divisor used in + the computation is :math:`N - 1`, where :math:`N` represents the + number of elements along ``axis`` , otherwise the divisor is + :math:`N`. Default is True. + keepdim (bool, optional): Whether to reserve the reduced dimension(s) + in the output Tensor. If ``keepdim`` is True, the dimensions of + the output Tensor is the same as ``x`` except in the reduced + dimensions(it is of size 1 in this case). Otherwise, the shape of + the output Tensor is squeezed in ``axis`` . Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. Returns: - Variable: The result standard-deviation with the same dtype as :attr:`input`. - If :attr:`out = None`, returns a new Variable containing the - standard-deviation , otherwise returns a reference to the output Variable. + Tensor, results of standard-deviation along ``axis`` of ``x``, with the + same data type as ``x``. + Examples: .. code-block:: python import paddle - import paddle.fluid as fluid - # x is a Tensor variable with following elements: - # [[0.2, 0.3, 0.5, 0.9] - # [0.1, 0.2, 0.6, 0.7]] - # Each example is followed by the corresponding output tensor. - x = fluid.data(name='x', shape=[2, 4], dtype='float32') - paddle.std(x) # [0.28252685] - paddle.std(x, axis=[0]) # [0.0707107, 0.07071075, 0.07071064, 0.1414217] - paddle.std(x, axis=[-1]) # [0.30956957, 0.29439208] + import numpy as np + + paddle.disable_static() + + x = np.array([[1.0, 2.0, 3.0], [1.0, 4.0, 5.0]]) + x = paddle.to_tensor(x) + out1 = paddle.std(x) + # [1.63299316] + out2 = paddle.std(x, axis=1) + # [1. 2.081666] """ - check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'std') - - tmp = var(input, axis=axis, keepdim=keepdim, unbiased=unbiased, name=name) - tmp = layers.sqrt(tmp) - if out is not None: - layers.assign(input=tmp, output=out) - return out - else: - return tmp + if not in_dygraph_mode(): + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'std') + + out = var(**locals()) + return paddle.sqrt(out) def numel(x, name=None): -- GitLab