未验证 提交 e6675f4f 编写于 作者: Z zhupengyang 提交者: GitHub

normal: support mean and std tensor; randn = standard_normal (#26367)

上级 3a9417f4
...@@ -194,7 +194,8 @@ from .tensor.math import clip #DEFINE_ALIAS ...@@ -194,7 +194,8 @@ from .tensor.math import clip #DEFINE_ALIAS
from .tensor.math import trace #DEFINE_ALIAS from .tensor.math import trace #DEFINE_ALIAS
from .tensor.math import kron #DEFINE_ALIAS from .tensor.math import kron #DEFINE_ALIAS
from .tensor.math import prod #DEFINE_ALIAS from .tensor.math import prod #DEFINE_ALIAS
# from .tensor.random import gaussin #DEFINE_ALIAS from .tensor.random import standard_normal
from .tensor.random import normal
from .tensor.random import uniform #DEFINE_ALIAS from .tensor.random import uniform #DEFINE_ALIAS
from .tensor.random import shuffle #DEFINE_ALIAS from .tensor.random import shuffle #DEFINE_ALIAS
from .tensor.random import randn #DEFINE_ALIAS from .tensor.random import randn #DEFINE_ALIAS
......
...@@ -10466,6 +10466,7 @@ def uniform_random_batch_size_like(input, ...@@ -10466,6 +10466,7 @@ def uniform_random_batch_size_like(input,
return out return out
@deprecated(since="2.0.0", update_to="paddle.normal")
@templatedoc() @templatedoc()
def gaussian_random(shape, def gaussian_random(shape,
mean=0.0, mean=0.0,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import copy
np.random.seed(10)
class TestNormalAPI(unittest.TestCase):
def setUp(self):
self.mean = 1.0
self.std = 0.0
self.shape = None
self.repeat_num = 1000
self.set_attrs()
self.dtype = self.get_dtype()
self.place=paddle.CUDAPlace(0) \
if paddle.fluid.core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def set_attrs(self):
self.shape = [8, 12]
def get_shape(self):
if isinstance(self.mean, np.ndarray):
shape = self.mean.shape
elif isinstance(self.std, np.ndarray):
shape = self.std.shape
else:
shape = self.shape
return list(shape)
def get_dtype(self):
if isinstance(self.mean, np.ndarray):
return self.mean.dtype
elif isinstance(self.std, np.ndarray):
return self.std.dtype
else:
return 'float32'
def static_api(self):
shape = self.get_shape()
ret_all_shape = copy.deepcopy(shape)
ret_all_shape.insert(0, self.repeat_num)
ret_all = np.zeros(ret_all_shape, self.dtype)
if isinstance(self.mean, np.ndarray) \
and isinstance(self.std, np.ndarray):
with paddle.static.program_guard(paddle.static.Program()):
mean = paddle.data('Mean', self.mean.shape, self.mean.dtype)
std = paddle.data('Std', self.std.shape, self.std.dtype)
out = paddle.normal(mean, std, self.shape)
exe = paddle.static.Executor(self.place)
for i in range(self.repeat_num):
ret = exe.run(feed={
'Mean': self.mean,
'Std': self.std.reshape(shape)
},
fetch_list=[out])
ret_all[i] = ret[0]
return ret_all
elif isinstance(self.mean, np.ndarray):
with paddle.static.program_guard(paddle.static.Program()):
mean = paddle.data('Mean', self.mean.shape, self.mean.dtype)
out = paddle.normal(mean, self.std, self.shape)
exe = paddle.static.Executor(self.place)
for i in range(self.repeat_num):
ret = exe.run(feed={'Mean': self.mean}, fetch_list=[out])
ret_all[i] = ret[0]
return ret_all
elif isinstance(self.std, np.ndarray):
with paddle.static.program_guard(paddle.static.Program()):
std = paddle.data('Std', self.std.shape, self.std.dtype)
out = paddle.normal(self.mean, std, self.shape)
exe = paddle.static.Executor(self.place)
for i in range(self.repeat_num):
ret = exe.run(feed={'Std': self.std}, fetch_list=[out])
ret_all[i] = ret[0]
return ret_all
else:
with paddle.static.program_guard(paddle.static.Program()):
out = paddle.normal(self.mean, self.std, self.shape)
exe = paddle.static.Executor(self.place)
for i in range(self.repeat_num):
ret = exe.run(fetch_list=[out])
ret_all[i] = ret[0]
return ret_all
def dygraph_api(self):
paddle.disable_static(self.place)
shape = self.get_shape()
ret_all_shape = copy.deepcopy(shape)
ret_all_shape.insert(0, self.repeat_num)
ret_all = np.zeros(ret_all_shape, self.dtype)
mean = paddle.to_tensor(self.mean) \
if isinstance(self.mean, np.ndarray) else self.mean
std = paddle.to_tensor(self.std) \
if isinstance(self.std, np.ndarray) else self.std
for i in range(self.repeat_num):
out = paddle.normal(mean, std, self.shape)
ret_all[i] = out.numpy()
paddle.enable_static()
return ret_all
def test_api(self):
ret_static = self.static_api()
ret_dygraph = self.dygraph_api()
for ret in [ret_static, ret_dygraph]:
shape_ref = self.get_shape()
self.assertEqual(shape_ref, list(ret[0].shape))
ret = ret.flatten().reshape([self.repeat_num, -1])
mean = np.mean(ret, axis=0)
std = np.std(ret, axis=0)
mean_ref=self.mean.reshape([1, -1]) \
if isinstance(self.mean, np.ndarray) else self.mean
std_ref=self.std.reshape([1, -1]) \
if isinstance(self.std, np.ndarray) else self.std
self.assertTrue(np.allclose(mean_ref, mean, 0.1, 0.1))
self.assertTrue(np.allclose(std_ref, std, 0.1, 0.1))
class TestNormalAPI_mean_is_tensor(TestNormalAPI):
def set_attrs(self):
self.mean = np.random.uniform(-2, -1, [2, 3, 4, 5]).astype('float64')
class TestNormalAPI_std_is_tensor(TestNormalAPI):
def set_attrs(self):
self.std = np.random.uniform(0.7, 1, [2, 3, 17]).astype('float64')
class TestNormalAPI_mean_std_are_tensor(TestNormalAPI):
def set_attrs(self):
self.mean = np.random.uniform(1, 2, [1, 100]).astype('float64')
self.std = np.random.uniform(0.5, 1, [1, 100]).astype('float64')
class TestNormalAPI_mean_std_are_tensor_with_different_dtype(TestNormalAPI):
def set_attrs(self):
self.mean = np.random.uniform(1, 2, [100]).astype('float64')
self.std = np.random.uniform(1, 2, [100]).astype('float32')
class TestNormalAlias(unittest.TestCase):
def test_alias(self):
paddle.disable_static()
shape = [1, 2, 3]
out1 = paddle.normal(shape=shape)
out2 = paddle.tensor.normal(shape=shape)
out3 = paddle.tensor.random.normal(shape=shape)
paddle.enable_static()
class TestNormalErrors(unittest.TestCase):
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
mean = [1, 2, 3]
self.assertRaises(TypeError, paddle.normal, mean)
std = [1, 2, 3]
self.assertRaises(TypeError, paddle.normal, std=std)
mean = paddle.data('Mean', [100], 'int32')
self.assertRaises(TypeError, paddle.normal, mean)
std = paddle.data('Std', [100], 'int32')
self.assertRaises(TypeError, paddle.normal, mean=1.0, std=std)
self.assertRaises(TypeError, paddle.normal, shape=1)
self.assertRaises(TypeError, paddle.normal, shape=[1.0])
shape = paddle.data('Shape', [100], 'float32')
self.assertRaises(TypeError, paddle.normal, shape=shape)
if __name__ == "__main__":
unittest.main()
...@@ -162,7 +162,8 @@ from .math import clip #DEFINE_ALIAS ...@@ -162,7 +162,8 @@ from .math import clip #DEFINE_ALIAS
from .math import trace #DEFINE_ALIAS from .math import trace #DEFINE_ALIAS
from .math import kron #DEFINE_ALIAS from .math import kron #DEFINE_ALIAS
from .math import prod #DEFINE_ALIAS from .math import prod #DEFINE_ALIAS
# from .random import gaussin #DEFINE_ALIAS from .random import standard_normal
from .random import normal
from .random import uniform #DEFINE_ALIAS from .random import uniform #DEFINE_ALIAS
from .random import shuffle #DEFINE_ALIAS from .random import shuffle #DEFINE_ALIAS
from .random import randn #DEFINE_ALIAS from .random import randn #DEFINE_ALIAS
......
...@@ -21,20 +21,23 @@ from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator, V ...@@ -21,20 +21,23 @@ from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator, V
from ..fluid.layers.layer_function_generator import templatedoc from ..fluid.layers.layer_function_generator import templatedoc
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..fluid.layers import utils, gaussian_random from ..fluid.layers import utils
from ..fluid.layers.tensor import fill_constant from ..fluid.layers.tensor import fill_constant
import paddle
import warnings
from ..fluid.io import shuffle #DEFINE_ALIAS from ..fluid.io import shuffle #DEFINE_ALIAS
__all__ = [ __all__ = [
'bernoulli', 'bernoulli',
# 'gaussin', 'standard_normal',
'normal',
'uniform', 'uniform',
'shuffle', 'shuffle',
'randn', 'randn',
'rand', 'rand',
'randint', 'randint',
'randperm' 'randperm',
] ]
...@@ -91,6 +94,237 @@ def bernoulli(x, name=None): ...@@ -91,6 +94,237 @@ def bernoulli(x, name=None):
return out return out
def gaussian_random(shape, mean=0.0, std=1.0, dtype='float32', name=None):
"""
This OP returns a Tensor filled with random values sampled from a Gaussian
distribution, with ``shape`` and ``dtype``.
Args:
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
mean(float|int, optional): Mean of the output tensor, default is 0.0.
std(float|int, optional): Standard deviation of the output tensor, default
is 1.0.
seed(int, optional): ${seed_comment}
dtype(str|np.dtype|core.VarDesc.VarType, optional): The data type of
the output Tensor. Supported data types: float32, float64.
Default is float32.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: A Tensor filled with random values sampled from a Gaussian
distribution, with ``shape`` and ``dtype``.
"""
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
seed = 0
op_type_for_check = 'gaussian_random/standard_normal/randn/normal'
if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape)
return core.ops.gaussian_random('shape', shape, 'mean',
float(mean), 'std',
float(std), 'seed', seed, 'dtype',
dtype)
check_type(shape, 'shape', (list, tuple, Variable), op_type_for_check)
check_dtype(dtype, 'dtype', ['float32', 'float64'], op_type_for_check)
inputs = {}
attrs = {
'mean': mean,
'std': std,
'seed': seed,
'dtype': dtype,
'use_mkldnn': False
}
utils._get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type=op_type_for_check)
helper = LayerHelper('gaussian_random', **locals())
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='gaussian_random',
inputs=inputs,
outputs={'Out': out},
attrs=attrs)
out.stop_gradient = True
return out
def standard_normal(shape, dtype=None, name=None):
"""
This OP returns a Tensor filled with random values sampled from a standard
normal distribution with mean 0 and standard deviation 1, with ``shape``
and ``dtype``.
Args:
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
dtype(str|np.dtype|core.VarDesc.VarType, optional): The data type of the
output tensor. Supported data types: float32, float64. If ``dytpe``
is None, the data type is float32. Default is None.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: A Tensor filled with random values sampled from a standard
normal distribution with mean 0 and standard deviation 1, with
``shape`` and ``dtype``.
Raises:
TypeError: If ``shape`` is not list, tuple, Tensor.
TypeError: If ``dtype`` is not float32, float64.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
# example 1: attr shape is a list which doesn't contain Tensor.
result_1 = paddle.standard_normal(shape=[2, 3])
# [[-2.923464 , 0.11934398, -0.51249987], # random
# [ 0.39632758, 0.08177969, 0.2692008 ]] # random
# example 2: attr shape is a list which contains Tensor.
dim_1 = paddle.fill_constant([1], "int64", 2)
dim_2 = paddle.fill_constant([1], "int32", 3)
result_2 = paddle.standard_normal(shape=[dim_1, dim_2, 2])
# [[[-2.8852394 , -0.25898588], # random
# [-0.47420555, 0.17683524], # random
# [-0.7989969 , 0.00754541]], # random
# [[ 0.85201347, 0.32320443], # random
# [ 1.1399018 , 0.48336947], # random
# [ 0.8086993 , 0.6868893 ]]] # random
# example 3: attr shape is a Tensor, the data type must be int64 or int32.
var_shape = paddle.to_tensor(np.array([2, 3]))
result_3 = paddle.standard_normal(var_shape)
# [[-2.878077 , 0.17099959, 0.05111201] # random
# [-0.3761474, -1.044801 , 1.1870178 ]] # random
"""
if dtype is None:
dtype = 'float32'
return gaussian_random(
shape=shape, mean=0.0, std=1.0, dtype=dtype, name=name)
randn = standard_normal
def normal(mean=0.0, std=1.0, shape=None, name=None):
"""
This OP returns a Tensor filled with random values sampled from a normal
distribution with ``mean`` and ``std`` (standard deviation) .
If ``mean`` is a Tensor, the output Tensor has the same shape and data type as ``mean``.
If ``mean`` is not a Tensor and ``std`` is a Tensor, the output Tensor has the same shape and data type as ``std``.
If ``mean`` and ``std`` are not a Tensor, the output Tensor has the same shape as ``shape``, with data type float32.
If ``mean`` and ``std`` are Tensor, the num of elements of ``mean`` and ``std`` should be the same.
Args:
mean (float|Tensor, optional): The mean of the output Tensor's normal distribution.
If ``mean`` is float, all elements of the output Tensor shared the same mean.
If ``mean`` is a Tensor(data type supports float32, float64), it has per-element means.
Default is 0.0
std (float|Tensor, optional): The standard deviation of the output Tensor's normal distribution.
If ``std`` is float, all elements of the output Tensor shared the same standard deviation.
If ``std`` is a Tensor(data type supports float32, float64), it has per-element standard deviations.
Defaule is 1.0
shape (list|tuple|Tensor, optional): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64). If ``mean`` or ``std`` is a Tensor, the shape of the output
Tensor is the same as ``mean`` or ``std`` , attr ``shape`` is ignored.
Default is None
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor filled with random values sampled from a normal distribution with ``mean`` and ``std`` .
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
out1 = paddle.normal(shape=[2, 3])
# [[ 0.17501129 0.32364586 1.561118 ] # random
# [-1.7232178 1.1545963 -0.76156676]] # random
mean_tensor = paddle.to_tensor(np.array([1.0, 2.0, 3.0]))
out2 = paddle.normal(mean=mean_tensor)
# [ 0.18644847 -1.19434458 3.93694787] # random
std_tensor = paddle.to_tensor(np.array([1.0, 2.0, 3.0]))
out3 = paddle.normal(mean=mean_tensor, std=std_tensor)
# [1.00780561 3.78457445 5.81058198] # random
"""
if not in_dygraph_mode():
check_type(mean, 'mean', (int, float, Variable), 'normal')
check_type(std, 'std', (int, float, Variable), 'normal')
if isinstance(mean, Variable):
check_dtype(
mean.dtype, 'mean', ['float32', 'float64'], 'normal',
"If mean is Tensor, it's data type only support float32, float64."
)
if isinstance(std, Variable):
check_dtype(
std.dtype, 'std', ['float32', 'float64'], 'normal',
"If std is Tensor, it's data type only support float32, float64."
)
if shape is not None:
if isinstance(shape, (list, tuple)):
for item in shape:
check_type(item, 'shape', (int), 'normal',
'Elements of shape should be int.')
elif isinstance(shape, Variable):
check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'normal')
else:
assert TypeError(
'If mean and std are all not Tensor, shape should be list, tuple, Tensor.'
)
if isinstance(mean, Variable):
if isinstance(std, Variable):
if std.dtype != mean.dtype:
std = paddle.cast(std, mean.dtype)
mean_shape = paddle.shape(mean)
std = paddle.reshape(std, mean_shape)
else:
std = float(std)
out = standard_normal(paddle.shape(mean), mean.dtype, name)
elif isinstance(std, Variable):
mean = float(mean)
out = standard_normal(paddle.shape(std), std.dtype, name)
else:
return gaussian_random(shape=shape, mean=mean, std=std, name=name)
out = out * std + mean
if not in_dygraph_mode():
out.stop_grediant = True
return out
def uniform(shape, dtype='float32', min=-1.0, max=1.0, seed=0, name=None): def uniform(shape, dtype='float32', min=-1.0, max=1.0, seed=0, name=None):
""" """
This OP returns a Tensor filled with random values sampled from a uniform This OP returns a Tensor filled with random values sampled from a uniform
...@@ -98,10 +332,8 @@ def uniform(shape, dtype='float32', min=-1.0, max=1.0, seed=0, name=None): ...@@ -98,10 +332,8 @@ def uniform(shape, dtype='float32', min=-1.0, max=1.0, seed=0, name=None):
Examples: Examples:
:: ::
Input: Input:
shape = [1, 2] shape = [1, 2]
Output: Output:
result=[[0.8505902, 0.8397286]] result=[[0.8505902, 0.8397286]]
...@@ -161,7 +393,6 @@ def uniform(shape, dtype='float32', min=-1.0, max=1.0, seed=0, name=None): ...@@ -161,7 +393,6 @@ def uniform(shape, dtype='float32', min=-1.0, max=1.0, seed=0, name=None):
# attr shape is a Tensor, the data type must be int64 or int32. # attr shape is a Tensor, the data type must be int64 or int32.
shape = np.array([2, 3]) shape = np.array([2, 3])
shape_tensor = paddle.to_tensor(shape) shape_tensor = paddle.to_tensor(shape)
result_3 = paddle.tensor.random.uniform(shape_tensor) result_3 = paddle.tensor.random.uniform(shape_tensor)
# if shape_tensor's value is [2, 3] # if shape_tensor's value is [2, 3]
# result_3 is: # result_3 is:
...@@ -237,40 +468,40 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): ...@@ -237,40 +468,40 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np import numpy as np
paddle.disable_static() paddle.disable_static()
# example 1: # example 1:
# attr shape is a list which doesn't contain Tensor. # attr shape is a list which doesn't contain Tensor.
result_1 = paddle.randint(low=-5, high=5, shape=[3]) result_1 = paddle.randint(low=-5, high=5, shape=[3])
# [0, -3, 2] # [0, -3, 2] # random
# example 2: # example 2:
# attr shape is a list which contains Tensor. # attr shape is a list which contains Tensor.
dim_1 = paddle.fill_constant([1], "int64", 2) dim_1 = paddle.fill_constant([1], "int64", 2)
dim_2 = paddle.fill_constant([1], "int32", 3) dim_2 = paddle.fill_constant([1], "int32", 3)
result_2 = paddle.randint(low=-5, high=5, shape=[dim_1, dim_2], dtype="int32") result_2 = paddle.randint(low=-5, high=5, shape=[dim_1, dim_2], dtype="int32")
# [[0, -1, -3], # [[0, -1, -3], # random
# [4, -2, 0]] # [4, -2, 0]] # random
# example 3: # example 3:
# attr shape is a Tensor # attr shape is a Tensor
var_shape = paddle.to_variable(np.array([3])) var_shape = paddle.to_variable(np.array([3]))
result_3 = paddle.randint(low=-5, high=5, shape=var_shape) result_3 = paddle.randint(low=-5, high=5, shape=var_shape)
# [-2, 2, 3] # [-2, 2, 3] # random
# example 4: # example 4:
# data type is int32 # data type is int32
result_4 = paddle.randint(low=-5, high=5, shape=[3], dtype='int32') result_4 = paddle.randint(low=-5, high=5, shape=[3], dtype='int32')
# [-5, 4, -4] # [-5, 4, -4] # random
# example 5: # example 5:
# Input only one parameter # Input only one parameter
# low=0, high=10, shape=[1], dtype='int64' # low=0, high=10, shape=[1], dtype='int64'
result_5 = paddle.randint(10) result_5 = paddle.randint(10)
# [7] # [7] # random
""" """
if high is None: if high is None:
...@@ -309,77 +540,6 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): ...@@ -309,77 +540,6 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
return out return out
def randn(shape, dtype=None, name=None):
"""
:alias_main: paddle.randn
:alias: paddle.tensor.randn, paddle.tensor.random.randn
This OP returns a Tensor filled with random values sampled from a normal
distribution with mean 0 and standard deviation 1 (also called the standard
normal distribution), with ``shape`` and ``dtype``.
Args:
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
dtype(str|np.dtype|core.VarDesc.VarType, optional): The data type of the
output tensor. Supported data types: float32, float64. If ``dytpe``
is None, the data type is float32. Default is None.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: A Tensor filled with random values sampled from a normal
distribution with mean 0 and standard deviation 1 (also called the
standard normal distribution), with ``shape`` and ``dtype``.
Raises:
TypeError: If ``shape`` is not list, tuple, Tensor.
TypeError: If ``dtype`` is not float32, float64.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
# example 1: attr shape is a list which doesn't contain Tensor.
result_1 = paddle.randn(shape=[2, 3])
# [[-2.923464 , 0.11934398, -0.51249987],
# [ 0.39632758, 0.08177969, 0.2692008 ]]
# example 2: attr shape is a list which contains Tensor.
dim_1 = paddle.fill_constant([1], "int64", 2)
dim_2 = paddle.fill_constant([1], "int32", 3)
result_2 = paddle.randn(shape=[dim_1, dim_2, 2])
# [[[-2.8852394 , -0.25898588],
# [-0.47420555, 0.17683524],
# [-0.7989969 , 0.00754541]],
# [[ 0.85201347, 0.32320443],
# [ 1.1399018 , 0.48336947],
# [ 0.8086993 , 0.6868893 ]]]
# example 3: attr shape is a Tensor, the data type must be int64 or int32.
var_shape = paddle.to_variable(np.array([2, 3]))
result_3 = paddle.randn(var_shape)
# [[-2.878077 , 0.17099959, 0.05111201]
# [-0.3761474, -1.044801 , 1.1870178 ]]
"""
if dtype is None:
dtype = 'float32'
out = gaussian_random(
shape=shape, mean=0.0, std=1.0, seed=0, dtype=dtype, name=name)
out.stop_gradient = True
return out
@templatedoc() @templatedoc()
def randperm(n, dtype="int64", name=None): def randperm(n, dtype="int64", name=None):
""" """
...@@ -409,15 +569,15 @@ def randperm(n, dtype="int64", name=None): ...@@ -409,15 +569,15 @@ def randperm(n, dtype="int64", name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
paddle.disable_static() paddle.disable_static()
result_1 = paddle.randperm(5) result_1 = paddle.randperm(5)
# [4, 1, 2, 3, 0] # [4, 1, 2, 3, 0] # random
result_2 = paddle.randperm(7, 'int32') result_2 = paddle.randperm(7, 'int32')
# [1, 6, 2, 0, 4, 3, 5] # [1, 6, 2, 0, 4, 3, 5] # random
""" """
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
...@@ -481,31 +641,31 @@ def rand(shape, dtype=None, name=None): ...@@ -481,31 +641,31 @@ def rand(shape, dtype=None, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import numpy as np import numpy as np
paddle.disable_static() paddle.disable_static()
# example 1: attr shape is a list which doesn't contain Tensor. # example 1: attr shape is a list which doesn't contain Tensor.
result_1 = paddle.rand(shape=[2, 3]) result_1 = paddle.rand(shape=[2, 3])
# [[0.451152 , 0.55825245, 0.403311 ], # [[0.451152 , 0.55825245, 0.403311 ], # random
# [0.22550228, 0.22106001, 0.7877319 ]] # [0.22550228, 0.22106001, 0.7877319 ]] # random
# example 2: attr shape is a list which contains Tensor. # example 2: attr shape is a list which contains Tensor.
dim_1 = paddle.fill_constant([1], "int64", 2) dim_1 = paddle.fill_constant([1], "int64", 2)
dim_2 = paddle.fill_constant([1], "int32", 3) dim_2 = paddle.fill_constant([1], "int32", 3)
result_2 = paddle.rand(shape=[dim_1, dim_2, 2]) result_2 = paddle.rand(shape=[dim_1, dim_2, 2])
# [[[0.8879919 , 0.25788337], # [[[0.8879919 , 0.25788337], # random
# [0.28826773, 0.9712097 ], # [0.28826773, 0.9712097 ], # random
# [0.26438272, 0.01796806]], # [0.26438272, 0.01796806]], # random
# [[0.33633623, 0.28654453], # [[0.33633623, 0.28654453], # random
# [0.79109055, 0.7305809 ], # [0.79109055, 0.7305809 ], # random
# [0.870881 , 0.2984597 ]]] # [0.870881 , 0.2984597 ]]] # random
# example 3: attr shape is a Tensor, the data type must be int64 or int32. # example 3: attr shape is a Tensor, the data type must be int64 or int32.
var_shape = paddle.to_variable(np.array([2, 3])) var_shape = paddle.to_variable(np.array([2, 3]))
result_3 = paddle.rand(var_shape) result_3 = paddle.rand(var_shape)
# [[0.22920267, 0.841956 , 0.05981819], # [[0.22920267, 0.841956 , 0.05981819], # random
# [0.4836288 , 0.24573246, 0.7516129 ]] # [0.4836288 , 0.24573246, 0.7516129 ]] # random
""" """
if dtype is None: if dtype is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册