diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 48f61027bbdf60aad47d25fe88a88f3f65675764..bb020b58f44e70e493b9e4fd8b57f7f72a9f6cde 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -194,7 +194,8 @@ from .tensor.math import clip #DEFINE_ALIAS from .tensor.math import trace #DEFINE_ALIAS from .tensor.math import kron #DEFINE_ALIAS from .tensor.math import prod #DEFINE_ALIAS -# from .tensor.random import gaussin #DEFINE_ALIAS +from .tensor.random import standard_normal +from .tensor.random import normal from .tensor.random import uniform #DEFINE_ALIAS from .tensor.random import shuffle #DEFINE_ALIAS from .tensor.random import randn #DEFINE_ALIAS diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index d8b31bc6616477de721550de92dec32ed02a6384..20ccd91666190454e4711fd07bfe259c518f01d7 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10466,6 +10466,7 @@ def uniform_random_batch_size_like(input, return out +@deprecated(since="2.0.0", update_to="paddle.normal") @templatedoc() def gaussian_random(shape, mean=0.0, diff --git a/python/paddle/fluid/tests/unittests/test_normal.py b/python/paddle/fluid/tests/unittests/test_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..a9d9af4d50be77bd1d2ecc11dd872ef612209f1e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_normal.py @@ -0,0 +1,197 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import copy + +np.random.seed(10) + + +class TestNormalAPI(unittest.TestCase): + def setUp(self): + self.mean = 1.0 + self.std = 0.0 + self.shape = None + self.repeat_num = 1000 + self.set_attrs() + self.dtype = self.get_dtype() + self.place=paddle.CUDAPlace(0) \ + if paddle.fluid.core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def set_attrs(self): + self.shape = [8, 12] + + def get_shape(self): + if isinstance(self.mean, np.ndarray): + shape = self.mean.shape + elif isinstance(self.std, np.ndarray): + shape = self.std.shape + else: + shape = self.shape + return list(shape) + + def get_dtype(self): + if isinstance(self.mean, np.ndarray): + return self.mean.dtype + elif isinstance(self.std, np.ndarray): + return self.std.dtype + else: + return 'float32' + + def static_api(self): + shape = self.get_shape() + ret_all_shape = copy.deepcopy(shape) + ret_all_shape.insert(0, self.repeat_num) + ret_all = np.zeros(ret_all_shape, self.dtype) + if isinstance(self.mean, np.ndarray) \ + and isinstance(self.std, np.ndarray): + with paddle.static.program_guard(paddle.static.Program()): + mean = paddle.data('Mean', self.mean.shape, self.mean.dtype) + std = paddle.data('Std', self.std.shape, self.std.dtype) + out = paddle.normal(mean, std, self.shape) + + exe = paddle.static.Executor(self.place) + for i in range(self.repeat_num): + ret = exe.run(feed={ + 'Mean': self.mean, + 'Std': self.std.reshape(shape) + }, + fetch_list=[out]) + ret_all[i] = ret[0] + return ret_all + elif isinstance(self.mean, np.ndarray): + with paddle.static.program_guard(paddle.static.Program()): + mean = paddle.data('Mean', self.mean.shape, self.mean.dtype) + out = paddle.normal(mean, self.std, self.shape) + + exe = paddle.static.Executor(self.place) + for i in range(self.repeat_num): + ret = exe.run(feed={'Mean': self.mean}, fetch_list=[out]) + ret_all[i] = ret[0] + return ret_all + elif isinstance(self.std, np.ndarray): + with paddle.static.program_guard(paddle.static.Program()): + std = paddle.data('Std', self.std.shape, self.std.dtype) + out = paddle.normal(self.mean, std, self.shape) + + exe = paddle.static.Executor(self.place) + for i in range(self.repeat_num): + ret = exe.run(feed={'Std': self.std}, fetch_list=[out]) + ret_all[i] = ret[0] + return ret_all + else: + with paddle.static.program_guard(paddle.static.Program()): + out = paddle.normal(self.mean, self.std, self.shape) + + exe = paddle.static.Executor(self.place) + for i in range(self.repeat_num): + ret = exe.run(fetch_list=[out]) + ret_all[i] = ret[0] + return ret_all + + def dygraph_api(self): + paddle.disable_static(self.place) + shape = self.get_shape() + ret_all_shape = copy.deepcopy(shape) + ret_all_shape.insert(0, self.repeat_num) + ret_all = np.zeros(ret_all_shape, self.dtype) + + mean = paddle.to_tensor(self.mean) \ + if isinstance(self.mean, np.ndarray) else self.mean + std = paddle.to_tensor(self.std) \ + if isinstance(self.std, np.ndarray) else self.std + for i in range(self.repeat_num): + out = paddle.normal(mean, std, self.shape) + ret_all[i] = out.numpy() + paddle.enable_static() + return ret_all + + def test_api(self): + ret_static = self.static_api() + ret_dygraph = self.dygraph_api() + for ret in [ret_static, ret_dygraph]: + shape_ref = self.get_shape() + self.assertEqual(shape_ref, list(ret[0].shape)) + + ret = ret.flatten().reshape([self.repeat_num, -1]) + mean = np.mean(ret, axis=0) + std = np.std(ret, axis=0) + mean_ref=self.mean.reshape([1, -1]) \ + if isinstance(self.mean, np.ndarray) else self.mean + std_ref=self.std.reshape([1, -1]) \ + if isinstance(self.std, np.ndarray) else self.std + self.assertTrue(np.allclose(mean_ref, mean, 0.1, 0.1)) + self.assertTrue(np.allclose(std_ref, std, 0.1, 0.1)) + + +class TestNormalAPI_mean_is_tensor(TestNormalAPI): + def set_attrs(self): + self.mean = np.random.uniform(-2, -1, [2, 3, 4, 5]).astype('float64') + + +class TestNormalAPI_std_is_tensor(TestNormalAPI): + def set_attrs(self): + self.std = np.random.uniform(0.7, 1, [2, 3, 17]).astype('float64') + + +class TestNormalAPI_mean_std_are_tensor(TestNormalAPI): + def set_attrs(self): + self.mean = np.random.uniform(1, 2, [1, 100]).astype('float64') + self.std = np.random.uniform(0.5, 1, [1, 100]).astype('float64') + + +class TestNormalAPI_mean_std_are_tensor_with_different_dtype(TestNormalAPI): + def set_attrs(self): + self.mean = np.random.uniform(1, 2, [100]).astype('float64') + self.std = np.random.uniform(1, 2, [100]).astype('float32') + + +class TestNormalAlias(unittest.TestCase): + def test_alias(self): + paddle.disable_static() + shape = [1, 2, 3] + out1 = paddle.normal(shape=shape) + out2 = paddle.tensor.normal(shape=shape) + out3 = paddle.tensor.random.normal(shape=shape) + paddle.enable_static() + + +class TestNormalErrors(unittest.TestCase): + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + mean = [1, 2, 3] + self.assertRaises(TypeError, paddle.normal, mean) + + std = [1, 2, 3] + self.assertRaises(TypeError, paddle.normal, std=std) + + mean = paddle.data('Mean', [100], 'int32') + self.assertRaises(TypeError, paddle.normal, mean) + + std = paddle.data('Std', [100], 'int32') + self.assertRaises(TypeError, paddle.normal, mean=1.0, std=std) + + self.assertRaises(TypeError, paddle.normal, shape=1) + + self.assertRaises(TypeError, paddle.normal, shape=[1.0]) + + shape = paddle.data('Shape', [100], 'float32') + self.assertRaises(TypeError, paddle.normal, shape=shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index e0830505c050dbebe83baeb78aed7d8faef11fa9..35634ba701391ed99464085b70a93acfd5370709 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -162,7 +162,8 @@ from .math import clip #DEFINE_ALIAS from .math import trace #DEFINE_ALIAS from .math import kron #DEFINE_ALIAS from .math import prod #DEFINE_ALIAS -# from .random import gaussin #DEFINE_ALIAS +from .random import standard_normal +from .random import normal from .random import uniform #DEFINE_ALIAS from .random import shuffle #DEFINE_ALIAS from .random import randn #DEFINE_ALIAS diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 6b8986004d2756b5ea5f83534260680615c68644..005e7beefe6877530b0a3c89d3bc8bfeabebc59d 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -21,20 +21,23 @@ from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator, V from ..fluid.layers.layer_function_generator import templatedoc from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype -from ..fluid.layers import utils, gaussian_random +from ..fluid.layers import utils from ..fluid.layers.tensor import fill_constant +import paddle +import warnings from ..fluid.io import shuffle #DEFINE_ALIAS __all__ = [ 'bernoulli', - # 'gaussin', + 'standard_normal', + 'normal', 'uniform', 'shuffle', 'randn', 'rand', 'randint', - 'randperm' + 'randperm', ] @@ -91,6 +94,237 @@ def bernoulli(x, name=None): return out +def gaussian_random(shape, mean=0.0, std=1.0, dtype='float32', name=None): + """ + This OP returns a Tensor filled with random values sampled from a Gaussian + distribution, with ``shape`` and ``dtype``. + + Args: + shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape`` + is a list or tuple, the elements of it should be integers or Tensors + (with the shape [1], and the data type int32 or int64). If ``shape`` + is a Tensor, it should be a 1-D Tensor(with the data type int32 or + int64). + mean(float|int, optional): Mean of the output tensor, default is 0.0. + std(float|int, optional): Standard deviation of the output tensor, default + is 1.0. + seed(int, optional): ${seed_comment} + dtype(str|np.dtype|core.VarDesc.VarType, optional): The data type of + the output Tensor. Supported data types: float32, float64. + Default is float32. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A Tensor filled with random values sampled from a Gaussian + distribution, with ``shape`` and ``dtype``. + """ + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + seed = 0 + op_type_for_check = 'gaussian_random/standard_normal/randn/normal' + + if in_dygraph_mode(): + shape = utils._convert_shape_to_list(shape) + return core.ops.gaussian_random('shape', shape, 'mean', + float(mean), 'std', + float(std), 'seed', seed, 'dtype', + dtype) + + check_type(shape, 'shape', (list, tuple, Variable), op_type_for_check) + check_dtype(dtype, 'dtype', ['float32', 'float64'], op_type_for_check) + + inputs = {} + attrs = { + 'mean': mean, + 'std': std, + 'seed': seed, + 'dtype': dtype, + 'use_mkldnn': False + } + utils._get_shape_tensor_inputs( + inputs=inputs, attrs=attrs, shape=shape, op_type=op_type_for_check) + + helper = LayerHelper('gaussian_random', **locals()) + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='gaussian_random', + inputs=inputs, + outputs={'Out': out}, + attrs=attrs) + out.stop_gradient = True + return out + + +def standard_normal(shape, dtype=None, name=None): + """ + This OP returns a Tensor filled with random values sampled from a standard + normal distribution with mean 0 and standard deviation 1, with ``shape`` + and ``dtype``. + + Args: + shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape`` + is a list or tuple, the elements of it should be integers or Tensors + (with the shape [1], and the data type int32 or int64). If ``shape`` + is a Tensor, it should be a 1-D Tensor(with the data type int32 or + int64). + dtype(str|np.dtype|core.VarDesc.VarType, optional): The data type of the + output tensor. Supported data types: float32, float64. If ``dytpe`` + is None, the data type is float32. Default is None. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A Tensor filled with random values sampled from a standard + normal distribution with mean 0 and standard deviation 1, with + ``shape`` and ``dtype``. + + Raises: + TypeError: If ``shape`` is not list, tuple, Tensor. + TypeError: If ``dtype`` is not float32, float64. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + # example 1: attr shape is a list which doesn't contain Tensor. + result_1 = paddle.standard_normal(shape=[2, 3]) + # [[-2.923464 , 0.11934398, -0.51249987], # random + # [ 0.39632758, 0.08177969, 0.2692008 ]] # random + + # example 2: attr shape is a list which contains Tensor. + dim_1 = paddle.fill_constant([1], "int64", 2) + dim_2 = paddle.fill_constant([1], "int32", 3) + result_2 = paddle.standard_normal(shape=[dim_1, dim_2, 2]) + # [[[-2.8852394 , -0.25898588], # random + # [-0.47420555, 0.17683524], # random + # [-0.7989969 , 0.00754541]], # random + # [[ 0.85201347, 0.32320443], # random + # [ 1.1399018 , 0.48336947], # random + # [ 0.8086993 , 0.6868893 ]]] # random + + # example 3: attr shape is a Tensor, the data type must be int64 or int32. + var_shape = paddle.to_tensor(np.array([2, 3])) + result_3 = paddle.standard_normal(var_shape) + # [[-2.878077 , 0.17099959, 0.05111201] # random + # [-0.3761474, -1.044801 , 1.1870178 ]] # random + + """ + if dtype is None: + dtype = 'float32' + + return gaussian_random( + shape=shape, mean=0.0, std=1.0, dtype=dtype, name=name) + + +randn = standard_normal + + +def normal(mean=0.0, std=1.0, shape=None, name=None): + """ + This OP returns a Tensor filled with random values sampled from a normal + distribution with ``mean`` and ``std`` (standard deviation) . + + If ``mean`` is a Tensor, the output Tensor has the same shape and data type as ``mean``. + If ``mean`` is not a Tensor and ``std`` is a Tensor, the output Tensor has the same shape and data type as ``std``. + If ``mean`` and ``std`` are not a Tensor, the output Tensor has the same shape as ``shape``, with data type float32. + + If ``mean`` and ``std`` are Tensor, the num of elements of ``mean`` and ``std`` should be the same. + + Args: + mean (float|Tensor, optional): The mean of the output Tensor's normal distribution. + If ``mean`` is float, all elements of the output Tensor shared the same mean. + If ``mean`` is a Tensor(data type supports float32, float64), it has per-element means. + Default is 0.0 + std (float|Tensor, optional): The standard deviation of the output Tensor's normal distribution. + If ``std`` is float, all elements of the output Tensor shared the same standard deviation. + If ``std`` is a Tensor(data type supports float32, float64), it has per-element standard deviations. + Defaule is 1.0 + shape (list|tuple|Tensor, optional): The shape of the output Tensor. If ``shape`` + is a list or tuple, the elements of it should be integers or Tensors + (with the shape [1], and the data type int32 or int64). If ``shape`` + is a Tensor, it should be a 1-D Tensor(with the data type int32 or + int64). If ``mean`` or ``std`` is a Tensor, the shape of the output + Tensor is the same as ``mean`` or ``std`` , attr ``shape`` is ignored. + Default is None + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor filled with random values sampled from a normal distribution with ``mean`` and ``std`` . + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + out1 = paddle.normal(shape=[2, 3]) + # [[ 0.17501129 0.32364586 1.561118 ] # random + # [-1.7232178 1.1545963 -0.76156676]] # random + + mean_tensor = paddle.to_tensor(np.array([1.0, 2.0, 3.0])) + out2 = paddle.normal(mean=mean_tensor) + # [ 0.18644847 -1.19434458 3.93694787] # random + + std_tensor = paddle.to_tensor(np.array([1.0, 2.0, 3.0])) + out3 = paddle.normal(mean=mean_tensor, std=std_tensor) + # [1.00780561 3.78457445 5.81058198] # random + + """ + if not in_dygraph_mode(): + check_type(mean, 'mean', (int, float, Variable), 'normal') + check_type(std, 'std', (int, float, Variable), 'normal') + if isinstance(mean, Variable): + check_dtype( + mean.dtype, 'mean', ['float32', 'float64'], 'normal', + "If mean is Tensor, it's data type only support float32, float64." + ) + if isinstance(std, Variable): + check_dtype( + std.dtype, 'std', ['float32', 'float64'], 'normal', + "If std is Tensor, it's data type only support float32, float64." + ) + if shape is not None: + if isinstance(shape, (list, tuple)): + for item in shape: + check_type(item, 'shape', (int), 'normal', + 'Elements of shape should be int.') + elif isinstance(shape, Variable): + check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'normal') + else: + assert TypeError( + 'If mean and std are all not Tensor, shape should be list, tuple, Tensor.' + ) + + if isinstance(mean, Variable): + if isinstance(std, Variable): + if std.dtype != mean.dtype: + std = paddle.cast(std, mean.dtype) + mean_shape = paddle.shape(mean) + std = paddle.reshape(std, mean_shape) + else: + std = float(std) + out = standard_normal(paddle.shape(mean), mean.dtype, name) + elif isinstance(std, Variable): + mean = float(mean) + out = standard_normal(paddle.shape(std), std.dtype, name) + else: + return gaussian_random(shape=shape, mean=mean, std=std, name=name) + + out = out * std + mean + if not in_dygraph_mode(): + out.stop_grediant = True + return out + + def uniform(shape, dtype='float32', min=-1.0, max=1.0, seed=0, name=None): """ This OP returns a Tensor filled with random values sampled from a uniform @@ -98,10 +332,8 @@ def uniform(shape, dtype='float32', min=-1.0, max=1.0, seed=0, name=None): Examples: :: - Input: shape = [1, 2] - Output: result=[[0.8505902, 0.8397286]] @@ -161,7 +393,6 @@ def uniform(shape, dtype='float32', min=-1.0, max=1.0, seed=0, name=None): # attr shape is a Tensor, the data type must be int64 or int32. shape = np.array([2, 3]) shape_tensor = paddle.to_tensor(shape) - result_3 = paddle.tensor.random.uniform(shape_tensor) # if shape_tensor's value is [2, 3] # result_3 is: @@ -237,40 +468,40 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() + paddle.disable_static() - # example 1: - # attr shape is a list which doesn't contain Tensor. - result_1 = paddle.randint(low=-5, high=5, shape=[3]) - # [0, -3, 2] - - # example 2: - # attr shape is a list which contains Tensor. - dim_1 = paddle.fill_constant([1], "int64", 2) - dim_2 = paddle.fill_constant([1], "int32", 3) - result_2 = paddle.randint(low=-5, high=5, shape=[dim_1, dim_2], dtype="int32") - # [[0, -1, -3], - # [4, -2, 0]] - - # example 3: - # attr shape is a Tensor - var_shape = paddle.to_variable(np.array([3])) - result_3 = paddle.randint(low=-5, high=5, shape=var_shape) - # [-2, 2, 3] - - # example 4: - # data type is int32 - result_4 = paddle.randint(low=-5, high=5, shape=[3], dtype='int32') - # [-5, 4, -4] - - # example 5: - # Input only one parameter - # low=0, high=10, shape=[1], dtype='int64' - result_5 = paddle.randint(10) - # [7] + # example 1: + # attr shape is a list which doesn't contain Tensor. + result_1 = paddle.randint(low=-5, high=5, shape=[3]) + # [0, -3, 2] # random + + # example 2: + # attr shape is a list which contains Tensor. + dim_1 = paddle.fill_constant([1], "int64", 2) + dim_2 = paddle.fill_constant([1], "int32", 3) + result_2 = paddle.randint(low=-5, high=5, shape=[dim_1, dim_2], dtype="int32") + # [[0, -1, -3], # random + # [4, -2, 0]] # random + + # example 3: + # attr shape is a Tensor + var_shape = paddle.to_variable(np.array([3])) + result_3 = paddle.randint(low=-5, high=5, shape=var_shape) + # [-2, 2, 3] # random + + # example 4: + # data type is int32 + result_4 = paddle.randint(low=-5, high=5, shape=[3], dtype='int32') + # [-5, 4, -4] # random + + # example 5: + # Input only one parameter + # low=0, high=10, shape=[1], dtype='int64' + result_5 = paddle.randint(10) + # [7] # random """ if high is None: @@ -309,77 +540,6 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): return out -def randn(shape, dtype=None, name=None): - """ - :alias_main: paddle.randn - :alias: paddle.tensor.randn, paddle.tensor.random.randn - - This OP returns a Tensor filled with random values sampled from a normal - distribution with mean 0 and standard deviation 1 (also called the standard - normal distribution), with ``shape`` and ``dtype``. - - Args: - shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape`` - is a list or tuple, the elements of it should be integers or Tensors - (with the shape [1], and the data type int32 or int64). If ``shape`` - is a Tensor, it should be a 1-D Tensor(with the data type int32 or - int64). - dtype(str|np.dtype|core.VarDesc.VarType, optional): The data type of the - output tensor. Supported data types: float32, float64. If ``dytpe`` - is None, the data type is float32. Default is None. - name(str, optional): The default value is None. Normally there is no - need for user to set this property. For more information, please - refer to :ref:`api_guide_Name`. - - Returns: - Tensor: A Tensor filled with random values sampled from a normal - distribution with mean 0 and standard deviation 1 (also called the - standard normal distribution), with ``shape`` and ``dtype``. - - Raises: - TypeError: If ``shape`` is not list, tuple, Tensor. - TypeError: If ``dtype`` is not float32, float64. - - Examples: - .. code-block:: python - - import paddle - import numpy as np - - paddle.disable_static() - - # example 1: attr shape is a list which doesn't contain Tensor. - result_1 = paddle.randn(shape=[2, 3]) - # [[-2.923464 , 0.11934398, -0.51249987], - # [ 0.39632758, 0.08177969, 0.2692008 ]] - - # example 2: attr shape is a list which contains Tensor. - dim_1 = paddle.fill_constant([1], "int64", 2) - dim_2 = paddle.fill_constant([1], "int32", 3) - result_2 = paddle.randn(shape=[dim_1, dim_2, 2]) - # [[[-2.8852394 , -0.25898588], - # [-0.47420555, 0.17683524], - # [-0.7989969 , 0.00754541]], - # [[ 0.85201347, 0.32320443], - # [ 1.1399018 , 0.48336947], - # [ 0.8086993 , 0.6868893 ]]] - - # example 3: attr shape is a Tensor, the data type must be int64 or int32. - var_shape = paddle.to_variable(np.array([2, 3])) - result_3 = paddle.randn(var_shape) - # [[-2.878077 , 0.17099959, 0.05111201] - # [-0.3761474, -1.044801 , 1.1870178 ]] - - """ - if dtype is None: - dtype = 'float32' - - out = gaussian_random( - shape=shape, mean=0.0, std=1.0, seed=0, dtype=dtype, name=name) - out.stop_gradient = True - return out - - @templatedoc() def randperm(n, dtype="int64", name=None): """ @@ -409,15 +569,15 @@ def randperm(n, dtype="int64", name=None): Examples: .. code-block:: python - import paddle + import paddle - paddle.disable_static() + paddle.disable_static() - result_1 = paddle.randperm(5) - # [4, 1, 2, 3, 0] + result_1 = paddle.randperm(5) + # [4, 1, 2, 3, 0] # random - result_2 = paddle.randperm(7, 'int32') - # [1, 6, 2, 0, 4, 3, 5] + result_2 = paddle.randperm(7, 'int32') + # [1, 6, 2, 0, 4, 3, 5] # random """ if not isinstance(dtype, core.VarDesc.VarType): @@ -481,31 +641,31 @@ def rand(shape, dtype=None, name=None): Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np - paddle.disable_static() - # example 1: attr shape is a list which doesn't contain Tensor. - result_1 = paddle.rand(shape=[2, 3]) - # [[0.451152 , 0.55825245, 0.403311 ], - # [0.22550228, 0.22106001, 0.7877319 ]] - - # example 2: attr shape is a list which contains Tensor. - dim_1 = paddle.fill_constant([1], "int64", 2) - dim_2 = paddle.fill_constant([1], "int32", 3) - result_2 = paddle.rand(shape=[dim_1, dim_2, 2]) - # [[[0.8879919 , 0.25788337], - # [0.28826773, 0.9712097 ], - # [0.26438272, 0.01796806]], - # [[0.33633623, 0.28654453], - # [0.79109055, 0.7305809 ], - # [0.870881 , 0.2984597 ]]] - - # example 3: attr shape is a Tensor, the data type must be int64 or int32. - var_shape = paddle.to_variable(np.array([2, 3])) - result_3 = paddle.rand(var_shape) - # [[0.22920267, 0.841956 , 0.05981819], - # [0.4836288 , 0.24573246, 0.7516129 ]] + paddle.disable_static() + # example 1: attr shape is a list which doesn't contain Tensor. + result_1 = paddle.rand(shape=[2, 3]) + # [[0.451152 , 0.55825245, 0.403311 ], # random + # [0.22550228, 0.22106001, 0.7877319 ]] # random + + # example 2: attr shape is a list which contains Tensor. + dim_1 = paddle.fill_constant([1], "int64", 2) + dim_2 = paddle.fill_constant([1], "int32", 3) + result_2 = paddle.rand(shape=[dim_1, dim_2, 2]) + # [[[0.8879919 , 0.25788337], # random + # [0.28826773, 0.9712097 ], # random + # [0.26438272, 0.01796806]], # random + # [[0.33633623, 0.28654453], # random + # [0.79109055, 0.7305809 ], # random + # [0.870881 , 0.2984597 ]]] # random + + # example 3: attr shape is a Tensor, the data type must be int64 or int32. + var_shape = paddle.to_variable(np.array([2, 3])) + result_3 = paddle.rand(var_shape) + # [[0.22920267, 0.841956 , 0.05981819], # random + # [0.4836288 , 0.24573246, 0.7516129 ]] # random """ if dtype is None: