未验证 提交 6a09b8f1 编写于 作者: Z zhupengyang 提交者: GitHub

erase Raises and refine doce of random functions (#26901)

上级 559d9f2b
...@@ -132,6 +132,28 @@ def check_dtype(input_dtype, ...@@ -132,6 +132,28 @@ def check_dtype(input_dtype,
extra_message)) extra_message))
def check_shape(shape,
op_name,
expected_shape_type=(list, tuple, Variable),
expected_element_type=(int, Variable),
expected_tensor_dtype=('int32', 'int64')):
# See NOTE [ Why skip dynamic graph check ]
if in_dygraph_mode():
return
check_type(shape, 'shape', expected_shape_type, op_name)
if expected_element_type is not None and not isinstance(shape, Variable):
for item in shape:
check_type(item, 'element of shape', expected_element_type, op_name)
if expected_tensor_dtype is not None and isinstance(item, Variable):
check_dtype(
item.dtype, 'element of shape', expected_tensor_dtype,
op_name,
'If element of shape is Tensor, its data type should be {}'.
format(', '.join(expected_tensor_dtype)))
if expected_tensor_dtype is not None and isinstance(shape, Variable):
check_dtype(shape.dtype, 'shape', expected_tensor_dtype, op_name)
class DataToLoDTensorConverter(object): class DataToLoDTensorConverter(object):
def __init__(self, place, lod_level, shape, dtype): def __init__(self, place, lod_level, shape, dtype):
self.place = place self.place = place
......
...@@ -10610,7 +10610,7 @@ def gaussian_random(shape, ...@@ -10610,7 +10610,7 @@ def gaussian_random(shape,
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode(): if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape) shape = utils.convert_shape_to_list(shape)
return core.ops.gaussian_random('shape', shape, 'mean', return core.ops.gaussian_random('shape', shape, 'mean',
float(mean), 'std', float(mean), 'std',
float(std), 'seed', seed, 'dtype', float(std), 'seed', seed, 'dtype',
...@@ -10627,7 +10627,7 @@ def gaussian_random(shape, ...@@ -10627,7 +10627,7 @@ def gaussian_random(shape,
'dtype': dtype, 'dtype': dtype,
'use_mkldnn': False 'use_mkldnn': False
} }
utils._get_shape_tensor_inputs( utils.get_shape_tensor_inputs(
inputs=inputs, inputs=inputs,
attrs=attrs, attrs=attrs,
shape=shape, shape=shape,
...@@ -15116,7 +15116,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0, ...@@ -15116,7 +15116,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0,
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode(): if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape) shape = utils.convert_shape_to_list(shape)
return core.ops.uniform_random('shape', shape, 'min', return core.ops.uniform_random('shape', shape, 'min',
float(min), 'max', float(min), 'max',
float(max), 'seed', seed, 'dtype', dtype) float(max), 'seed', seed, 'dtype', dtype)
...@@ -15126,7 +15126,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0, ...@@ -15126,7 +15126,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0,
inputs = dict() inputs = dict()
attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype} attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype}
utils._get_shape_tensor_inputs( utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random/rand') inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random/rand')
helper = LayerHelper("uniform_random", **locals()) helper = LayerHelper("uniform_random", **locals())
......
...@@ -694,7 +694,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -694,7 +694,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
attrs['str_value'] = str(float(value)) attrs['str_value'] = str(float(value))
if in_dygraph_mode(): if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape) shape = utils.convert_shape_to_list(shape)
if out is None: if out is None:
out = _varbase_creator(dtype=dtype) out = _varbase_creator(dtype=dtype)
...@@ -731,7 +731,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -731,7 +731,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
'fill_constant') 'fill_constant')
helper = LayerHelper("fill_constant", **locals()) helper = LayerHelper("fill_constant", **locals())
utils._get_shape_tensor_inputs( utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='fill_constant') inputs=inputs, attrs=attrs, shape=shape, op_type='fill_constant')
if out is None: if out is None:
......
...@@ -282,7 +282,7 @@ def _contain_var(list_or_tuple): ...@@ -282,7 +282,7 @@ def _contain_var(list_or_tuple):
return False return False
def _get_shape_tensor_inputs(inputs, attrs, shape, op_type): def get_shape_tensor_inputs(inputs, attrs, shape, op_type):
from .tensor import fill_constant, cast from .tensor import fill_constant, cast
def _get_attr_shape(list_shape): def _get_attr_shape(list_shape):
...@@ -347,7 +347,7 @@ def _convert_to_tensor_list(old_list, dtype="int32"): ...@@ -347,7 +347,7 @@ def _convert_to_tensor_list(old_list, dtype="int32"):
return new_list_tensor return new_list_tensor
def _convert_shape_to_list(shape): def convert_shape_to_list(shape):
""" """
Convert shape(list, tuple, variable) to list in imperative mode Convert shape(list, tuple, variable) to list in imperative mode
""" """
......
...@@ -241,18 +241,18 @@ class TestGaussianRandomAPI(unittest.TestCase): ...@@ -241,18 +241,18 @@ class TestGaussianRandomAPI(unittest.TestCase):
def test_default_fp_16(): def test_default_fp_16():
paddle.framework.set_default_dtype('float16') paddle.framework.set_default_dtype('float16')
paddle.tensor.random.gaussian_random([2, 3]) paddle.tensor.random.gaussian([2, 3])
self.assertRaises(TypeError, test_default_fp_16) self.assertRaises(TypeError, test_default_fp_16)
def test_default_fp_32(): def test_default_fp_32():
paddle.framework.set_default_dtype('float32') paddle.framework.set_default_dtype('float32')
out = paddle.tensor.random.gaussian_random([2, 3]) out = paddle.tensor.random.gaussian([2, 3])
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP32) self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP32)
def test_default_fp_64(): def test_default_fp_64():
paddle.framework.set_default_dtype('float64') paddle.framework.set_default_dtype('float64')
out = paddle.tensor.random.gaussian_random([2, 3]) out = paddle.tensor.random.gaussian([2, 3])
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64) self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64)
test_default_fp_64() test_default_fp_64()
......
...@@ -58,6 +58,11 @@ class TestRandintOpError(unittest.TestCase): ...@@ -58,6 +58,11 @@ class TestRandintOpError(unittest.TestCase):
self.assertRaises(TypeError, paddle.randint, 5, dtype='float32') self.assertRaises(TypeError, paddle.randint, 5, dtype='float32')
self.assertRaises(ValueError, paddle.randint, 5, 5) self.assertRaises(ValueError, paddle.randint, 5, 5)
self.assertRaises(ValueError, paddle.randint, -5) self.assertRaises(ValueError, paddle.randint, -5)
self.assertRaises(TypeError, paddle.randint, 5, shape=['2'])
shape_tensor = paddle.static.data('X', [1])
self.assertRaises(TypeError, paddle.randint, 5, shape=shape_tensor)
self.assertRaises(
TypeError, paddle.randint, 5, shape=[shape_tensor])
class TestRandintOp_attr_tensorlist(OpTest): class TestRandintOp_attr_tensorlist(OpTest):
......
...@@ -14,17 +14,12 @@ ...@@ -14,17 +14,12 @@
# TODO: define random functions # TODO: define random functions
import numpy as np
from ..fluid import core from ..fluid import core
from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator, Variable, convert_np_dtype_to_dtype_ from ..fluid.framework import in_dygraph_mode, Variable, convert_np_dtype_to_dtype_
from ..fluid.layers.layer_function_generator import templatedoc
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, check_shape
from ..fluid.layers import utils from ..fluid.layers import utils
from ..fluid.layers.tensor import fill_constant
import paddle import paddle
import warnings
from ..fluid.io import shuffle #DEFINE_ALIAS from ..fluid.io import shuffle #DEFINE_ALIAS
...@@ -94,26 +89,26 @@ def bernoulli(x, name=None): ...@@ -94,26 +89,26 @@ def bernoulli(x, name=None):
return out return out
def gaussian_random(shape, mean=0.0, std=1.0, dtype=None, name=None): def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None):
""" """
This OP returns a Tensor filled with random values sampled from a Gaussian This OP returns a Tensor filled with random values sampled from a Gaussian
distribution, with ``shape`` and ``dtype``. distribution, with ``shape`` and ``dtype``.
Args: Args:
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape`` shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape`` (with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64). int64).
mean(float|int, optional): Mean of the output tensor, default is 0.0. mean (float|int, optional): Mean of the output tensor, default is 0.0.
std(float|int, optional): Standard deviation of the output tensor, default std (float|int, optional): Standard deviation of the output tensor, default
is 1.0. is 1.0.
seed(int, optional): ${seed_comment} seed (int, optional): Random seed of generator.
dtype(str|np.dtype, optional): The data type of the output Tensor. dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64. Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype`` Default is None, use global default dtype (see ``get_default_dtype``
for details). for details).
name(str, optional): The default value is None. Normally there is no name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
...@@ -121,26 +116,26 @@ def gaussian_random(shape, mean=0.0, std=1.0, dtype=None, name=None): ...@@ -121,26 +116,26 @@ def gaussian_random(shape, mean=0.0, std=1.0, dtype=None, name=None):
Tensor: A Tensor filled with random values sampled from a Gaussian Tensor: A Tensor filled with random values sampled from a Gaussian
distribution, with ``shape`` and ``dtype``. distribution, with ``shape`` and ``dtype``.
""" """
op_type_for_check = 'gaussian/standard_normal/randn/normal'
seed = 0
if dtype is None: if dtype is None:
dtype = paddle.framework.get_default_dtype() dtype = paddle.framework.get_default_dtype()
if dtype not in ['float32', 'float64']: if dtype not in ['float32', 'float64']:
raise TypeError( raise TypeError(
"gaussian_random only supports [float32, float64], but the default dtype is %s" "{} only supports [float32, float64], but the default dtype is {}"
% dtype) .format(op_type_for_check, dtype))
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
seed = 0
op_type_for_check = 'gaussian_random/standard_normal/randn/normal'
if in_dygraph_mode(): if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape) shape = utils.convert_shape_to_list(shape)
return core.ops.gaussian_random('shape', shape, 'mean', return core.ops.gaussian_random('shape', shape, 'mean',
float(mean), 'std', float(mean), 'std',
float(std), 'seed', seed, 'dtype', float(std), 'seed', seed, 'dtype',
dtype) dtype)
check_type(shape, 'shape', (list, tuple, Variable), op_type_for_check) check_shape(shape, op_type_for_check)
check_dtype(dtype, 'dtype', ['float32', 'float64'], op_type_for_check) check_dtype(dtype, 'dtype', ['float32', 'float64'], op_type_for_check)
inputs = {} inputs = {}
...@@ -151,10 +146,10 @@ def gaussian_random(shape, mean=0.0, std=1.0, dtype=None, name=None): ...@@ -151,10 +146,10 @@ def gaussian_random(shape, mean=0.0, std=1.0, dtype=None, name=None):
'dtype': dtype, 'dtype': dtype,
'use_mkldnn': False 'use_mkldnn': False
} }
utils._get_shape_tensor_inputs( utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type=op_type_for_check) inputs=inputs, attrs=attrs, shape=shape, op_type=op_type_for_check)
helper = LayerHelper('gaussian_random', **locals()) helper = LayerHelper('gaussian', **locals())
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='gaussian_random', type='gaussian_random',
...@@ -172,12 +167,12 @@ def standard_normal(shape, dtype=None, name=None): ...@@ -172,12 +167,12 @@ def standard_normal(shape, dtype=None, name=None):
and ``dtype``. and ``dtype``.
Args: Args:
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape`` shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape`` (with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64). int64).
dtype(str|np.dtype, optional): The data type of the output Tensor. dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64. Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype`` Default is None, use global default dtype (see ``get_default_dtype``
for details). for details).
...@@ -189,10 +184,6 @@ def standard_normal(shape, dtype=None, name=None): ...@@ -189,10 +184,6 @@ def standard_normal(shape, dtype=None, name=None):
normal distribution with mean 0 and standard deviation 1, with normal distribution with mean 0 and standard deviation 1, with
``shape`` and ``dtype``. ``shape`` and ``dtype``.
Raises:
TypeError: If ``shape`` is not list, tuple, Tensor.
TypeError: If ``dtype`` is not float32, float64.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -202,14 +193,14 @@ def standard_normal(shape, dtype=None, name=None): ...@@ -202,14 +193,14 @@ def standard_normal(shape, dtype=None, name=None):
paddle.disable_static() paddle.disable_static()
# example 1: attr shape is a list which doesn't contain Tensor. # example 1: attr shape is a list which doesn't contain Tensor.
result_1 = paddle.standard_normal(shape=[2, 3]) out1 = paddle.standard_normal(shape=[2, 3])
# [[-2.923464 , 0.11934398, -0.51249987], # random # [[-2.923464 , 0.11934398, -0.51249987], # random
# [ 0.39632758, 0.08177969, 0.2692008 ]] # random # [ 0.39632758, 0.08177969, 0.2692008 ]] # random
# example 2: attr shape is a list which contains Tensor. # example 2: attr shape is a list which contains Tensor.
dim_1 = paddle.fill_constant([1], "int64", 2) dim1 = paddle.full([1], 2, "int64")
dim_2 = paddle.fill_constant([1], "int32", 3) dim2 = paddle.full([1], 3, "int32")
result_2 = paddle.standard_normal(shape=[dim_1, dim_2, 2]) out2 = paddle.standard_normal(shape=[dim1, dim2, 2])
# [[[-2.8852394 , -0.25898588], # random # [[[-2.8852394 , -0.25898588], # random
# [-0.47420555, 0.17683524], # random # [-0.47420555, 0.17683524], # random
# [-0.7989969 , 0.00754541]], # random # [-0.7989969 , 0.00754541]], # random
...@@ -218,21 +209,13 @@ def standard_normal(shape, dtype=None, name=None): ...@@ -218,21 +209,13 @@ def standard_normal(shape, dtype=None, name=None):
# [ 0.8086993 , 0.6868893 ]]] # random # [ 0.8086993 , 0.6868893 ]]] # random
# example 3: attr shape is a Tensor, the data type must be int64 or int32. # example 3: attr shape is a Tensor, the data type must be int64 or int32.
var_shape = paddle.to_tensor(np.array([2, 3])) shape_tensor = paddle.to_tensor(np.array([2, 3]))
result_3 = paddle.standard_normal(var_shape) out3 = paddle.standard_normal(shape_tensor)
# [[-2.878077 , 0.17099959, 0.05111201] # random # [[-2.878077 , 0.17099959, 0.05111201] # random
# [-0.3761474, -1.044801 , 1.1870178 ]] # random # [-0.3761474, -1.044801 , 1.1870178 ]] # random
""" """
if dtype is None: return gaussian(shape=shape, mean=0.0, std=1.0, dtype=dtype, name=name)
dtype = paddle.framework.get_default_dtype()
if dtype not in ['float32', 'float64']:
raise TypeError(
"standard_normal only supports [float32, float64], but the default dtype is %s"
% dtype)
return gaussian_random(
shape=shape, mean=0.0, std=1.0, dtype=dtype, name=name)
randn = standard_normal randn = standard_normal
...@@ -306,16 +289,7 @@ def normal(mean=0.0, std=1.0, shape=None, name=None): ...@@ -306,16 +289,7 @@ def normal(mean=0.0, std=1.0, shape=None, name=None):
"If std is Tensor, it's data type only support float32, float64." "If std is Tensor, it's data type only support float32, float64."
) )
if shape is not None: if shape is not None:
if isinstance(shape, (list, tuple)): check_shape(shape, 'normal')
for item in shape:
check_type(item, 'shape', (int), 'normal',
'Elements of shape should be int.')
elif isinstance(shape, Variable):
check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'normal')
else:
assert TypeError(
'If mean and std are all not Tensor, shape should be list, tuple, Tensor.'
)
if isinstance(mean, Variable): if isinstance(mean, Variable):
if isinstance(std, Variable): if isinstance(std, Variable):
...@@ -330,7 +304,7 @@ def normal(mean=0.0, std=1.0, shape=None, name=None): ...@@ -330,7 +304,7 @@ def normal(mean=0.0, std=1.0, shape=None, name=None):
mean = float(mean) mean = float(mean)
out = standard_normal(paddle.shape(std), std.dtype, name) out = standard_normal(paddle.shape(std), std.dtype, name)
else: else:
return gaussian_random(shape=shape, mean=mean, std=std, name=name) return gaussian(shape=shape, mean=mean, std=std, name=name)
out = out * std + mean out = out * std + mean
if not in_dygraph_mode(): if not in_dygraph_mode():
...@@ -426,7 +400,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): ...@@ -426,7 +400,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode(): if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape) shape = utils.convert_shape_to_list(shape)
return core.ops.uniform_random('shape', shape, 'min', return core.ops.uniform_random('shape', shape, 'min',
float(min), 'max', float(min), 'max',
float(max), 'seed', seed, 'dtype', dtype) float(max), 'seed', seed, 'dtype', dtype)
...@@ -436,7 +410,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): ...@@ -436,7 +410,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
inputs = dict() inputs = dict()
attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype} attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype}
utils._get_shape_tensor_inputs( utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random/rand') inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random/rand')
helper = LayerHelper("uniform_random", **locals()) helper = LayerHelper("uniform_random", **locals())
...@@ -449,29 +423,26 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): ...@@ -449,29 +423,26 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
def randint(low=0, high=None, shape=[1], dtype=None, name=None): def randint(low=0, high=None, shape=[1], dtype=None, name=None):
""" """
:alias_main: paddle.randint
:alias: paddle.tensor.randint, paddle.tensor.random.randint
This OP returns a Tensor filled with random integers from a discrete uniform This OP returns a Tensor filled with random integers from a discrete uniform
distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``. distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``.
If ``high`` is None (the default), the range is [0, ``low``). If ``high`` is None (the default), the range is [0, ``low``).
Args: Args:
low(int): The lower bound on the range of random values to generate. low (int): The lower bound on the range of random values to generate.
The ``low`` is included in the range. If ``high`` is None, the The ``low`` is included in the range. If ``high`` is None, the
range is [0, ``low``). Default is 0. range is [0, ``low``). Default is 0.
high(int, optional): The upper bound on the range of random values to high (int, optional): The upper bound on the range of random values to
generate, the ``high`` is excluded in the range. Default is None generate, the ``high`` is excluded in the range. Default is None
(see above for behavior if high = None). Default is None. (see above for behavior if high = None). Default is None.
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape`` shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape`` (with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64). Default is [1]. int64). Default is [1].
dtype(str|np.dtype, optional): The data type of the dtype (str|np.dtype, optional): The data type of the
output tensor. Supported data types: int32, int64. If ``dytpe`` output tensor. Supported data types: int32, int64. If ``dytpe``
is None, the data type is int64. Default is None. is None, the data type is int64. Default is None.
name(str, optional): The default value is None. Normally there is no name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
...@@ -479,12 +450,6 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): ...@@ -479,12 +450,6 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
Tensor: A Tensor filled with random integers from a discrete uniform Tensor: A Tensor filled with random integers from a discrete uniform
distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``. distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``.
Raises:
TypeError: If ``shape`` is not list, tuple, Tensor.
TypeError: If ``dtype`` is not int32, int64.
ValueError: If ``high`` is not greater then ``low``; If ``high`` is
None, and ``low`` is not greater than 0.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -495,32 +460,32 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): ...@@ -495,32 +460,32 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
# example 1: # example 1:
# attr shape is a list which doesn't contain Tensor. # attr shape is a list which doesn't contain Tensor.
result_1 = paddle.randint(low=-5, high=5, shape=[3]) out1 = paddle.randint(low=-5, high=5, shape=[3])
# [0, -3, 2] # random # [0, -3, 2] # random
# example 2: # example 2:
# attr shape is a list which contains Tensor. # attr shape is a list which contains Tensor.
dim_1 = paddle.fill_constant([1], "int64", 2) dim1 = paddle.full([1], 2, "int64")
dim_2 = paddle.fill_constant([1], "int32", 3) dim2 = paddle.full([1], 3, "int32")
result_2 = paddle.randint(low=-5, high=5, shape=[dim_1, dim_2], dtype="int32") out2 = paddle.randint(low=-5, high=5, shape=[dim1, dim2], dtype="int32")
# [[0, -1, -3], # random # [[0, -1, -3], # random
# [4, -2, 0]] # random # [4, -2, 0]] # random
# example 3: # example 3:
# attr shape is a Tensor # attr shape is a Tensor
var_shape = paddle.to_variable(np.array([3])) shape_tensor = paddle.to_tensor(np.array([3]))
result_3 = paddle.randint(low=-5, high=5, shape=var_shape) out3 = paddle.randint(low=-5, high=5, shape=shape_tensor)
# [-2, 2, 3] # random # [-2, 2, 3] # random
# example 4: # example 4:
# data type is int32 # data type is int32
result_4 = paddle.randint(low=-5, high=5, shape=[3], dtype='int32') out4 = paddle.randint(low=-5, high=5, shape=[3], dtype='int32')
# [-5, 4, -4] # random # [-5, 4, -4] # random
# example 5: # example 5:
# Input only one parameter # Input only one parameter
# low=0, high=10, shape=[1], dtype='int64' # low=0, high=10, shape=[1], dtype='int64'
result_5 = paddle.randint(10) out5 = paddle.randint(10)
# [7] # random # [7] # random
""" """
...@@ -537,11 +502,11 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): ...@@ -537,11 +502,11 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode(): if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape) shape = utils.convert_shape_to_list(shape)
return core.ops.randint('shape', shape, 'low', low, 'high', high, return core.ops.randint('shape', shape, 'low', low, 'high', high,
'seed', 0, 'dtype', dtype) 'seed', 0, 'dtype', dtype)
check_type(shape, 'shape', (list, tuple, Variable), 'randint') check_shape(shape, 'randint')
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint') check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint')
if low >= high: if low >= high:
raise ValueError( raise ValueError(
...@@ -550,7 +515,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): ...@@ -550,7 +515,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
inputs = dict() inputs = dict()
attrs = {'low': low, 'high': high, 'seed': 0, 'dtype': dtype} attrs = {'low': low, 'high': high, 'seed': 0, 'dtype': dtype}
utils._get_shape_tensor_inputs( utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='randint') inputs=inputs, attrs=attrs, shape=shape, op_type='randint')
helper = LayerHelper("randint", **locals()) helper = LayerHelper("randint", **locals())
...@@ -560,21 +525,17 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): ...@@ -560,21 +525,17 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
return out return out
@templatedoc()
def randperm(n, dtype="int64", name=None): def randperm(n, dtype="int64", name=None):
""" """
:alias_main: paddle.randperm
:alias: paddle.tensor.randperm, paddle.tensor.random.randperm
This OP returns a 1-D Tensor filled with random permutation values from 0 This OP returns a 1-D Tensor filled with random permutation values from 0
to n-1, with ``dtype``. to n-1, with ``dtype``.
Args: Args:
n(int): The upper bound (exclusive), and it should be greater than 0. n (int): The upper bound (exclusive), and it should be greater than 0.
dtype(str|np.dtype, optional): The data type of dtype (str|np.dtype, optional): The data type of
the output Tensor. Supported data types: int32, int64, float32, the output Tensor. Supported data types: int32, int64, float32,
float64. Default is int64. float64. Default is int64.
name(str, optional): The default value is None. Normally there is no name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
...@@ -582,10 +543,6 @@ def randperm(n, dtype="int64", name=None): ...@@ -582,10 +543,6 @@ def randperm(n, dtype="int64", name=None):
Tensor: A 1-D Tensor filled with random permutation values from 0 Tensor: A 1-D Tensor filled with random permutation values from 0
to n-1, with ``dtype``. to n-1, with ``dtype``.
Raises:
ValueError: If ``n`` is not greater than 0.
TypeError: If ``dtype`` is not int32, int64, float32, float64.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -593,10 +550,10 @@ def randperm(n, dtype="int64", name=None): ...@@ -593,10 +550,10 @@ def randperm(n, dtype="int64", name=None):
paddle.disable_static() paddle.disable_static()
result_1 = paddle.randperm(5) out1 = paddle.randperm(5)
# [4, 1, 2, 3, 0] # random # [4, 1, 2, 3, 0] # random
result_2 = paddle.randperm(7, 'int32') out2 = paddle.randperm(7, 'int32')
# [1, 6, 2, 0, 4, 3, 5] # random # [1, 6, 2, 0, 4, 3, 5] # random
""" """
...@@ -622,32 +579,20 @@ def randperm(n, dtype="int64", name=None): ...@@ -622,32 +579,20 @@ def randperm(n, dtype="int64", name=None):
def rand(shape, dtype=None, name=None): def rand(shape, dtype=None, name=None):
""" """
:alias_main: paddle.rand
:alias: paddle.tensor.rand, paddle.tensor.random.rand
This OP returns a Tensor filled with random values sampled from a uniform This OP returns a Tensor filled with random values sampled from a uniform
distribution in the range [0, 1), with ``shape`` and ``dtype``. distribution in the range [0, 1), with ``shape`` and ``dtype``.
Examples:
::
Input:
shape = [1, 2]
Output:
result=[[0.8505902, 0.8397286]]
Args: Args:
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape`` shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape`` (with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64). int64).
dtype(str|np.dtype, optional): The data type of the output Tensor. dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64. Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype`` Default is None, use global default dtype (see ``get_default_dtype``
for details). for details).
name(str, optional): The default value is None. Normally there is no name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`. refer to :ref:`api_guide_Name`.
...@@ -655,10 +600,6 @@ def rand(shape, dtype=None, name=None): ...@@ -655,10 +600,6 @@ def rand(shape, dtype=None, name=None):
Tensor: A Tensor filled with random values sampled from a uniform Tensor: A Tensor filled with random values sampled from a uniform
distribution in the range [0, 1), with ``shape`` and ``dtype``. distribution in the range [0, 1), with ``shape`` and ``dtype``.
Raises:
TypeError: If ``shape`` is not list, tuple, Tensor.
ValueError: If ``dtype`` is not float32, float64.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -667,14 +608,14 @@ def rand(shape, dtype=None, name=None): ...@@ -667,14 +608,14 @@ def rand(shape, dtype=None, name=None):
paddle.disable_static() paddle.disable_static()
# example 1: attr shape is a list which doesn't contain Tensor. # example 1: attr shape is a list which doesn't contain Tensor.
result_1 = paddle.rand(shape=[2, 3]) out1 = paddle.rand(shape=[2, 3])
# [[0.451152 , 0.55825245, 0.403311 ], # random # [[0.451152 , 0.55825245, 0.403311 ], # random
# [0.22550228, 0.22106001, 0.7877319 ]] # random # [0.22550228, 0.22106001, 0.7877319 ]] # random
# example 2: attr shape is a list which contains Tensor. # example 2: attr shape is a list which contains Tensor.
dim_1 = paddle.fill_constant([1], "int64", 2) dim1 = paddle.full([1], 2, "int64")
dim_2 = paddle.fill_constant([1], "int32", 3) dim2 = paddle.full([1], 3, "int32")
result_2 = paddle.rand(shape=[dim_1, dim_2, 2]) out2 = paddle.rand(shape=[dim1, dim2, 2])
# [[[0.8879919 , 0.25788337], # random # [[[0.8879919 , 0.25788337], # random
# [0.28826773, 0.9712097 ], # random # [0.28826773, 0.9712097 ], # random
# [0.26438272, 0.01796806]], # random # [0.26438272, 0.01796806]], # random
...@@ -683,19 +624,10 @@ def rand(shape, dtype=None, name=None): ...@@ -683,19 +624,10 @@ def rand(shape, dtype=None, name=None):
# [0.870881 , 0.2984597 ]]] # random # [0.870881 , 0.2984597 ]]] # random
# example 3: attr shape is a Tensor, the data type must be int64 or int32. # example 3: attr shape is a Tensor, the data type must be int64 or int32.
var_shape = paddle.to_variable(np.array([2, 3])) shape_tensor = paddle.to_tensor(np.array([2, 3]))
result_3 = paddle.rand(var_shape) out2 = paddle.rand(shape_tensor)
# [[0.22920267, 0.841956 , 0.05981819], # random # [[0.22920267, 0.841956 , 0.05981819], # random
# [0.4836288 , 0.24573246, 0.7516129 ]] # random # [0.4836288 , 0.24573246, 0.7516129 ]] # random
""" """
if dtype is None: return uniform(shape, dtype, min=0.0, max=1.0, name=name)
dtype = paddle.framework.get_default_dtype()
if dtype not in ['float32', 'float64']:
raise TypeError(
"rand only supports [float32, float64], but the default dtype is %s"
% dtype)
out = uniform(shape, dtype, min=0.0, max=1.0, name=name)
out.stop_gradient = True
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册