未验证 提交 25029254 编写于 作者: Z zhupengyang 提交者: GitHub

randn API: remove out, devive, stop_gradient; add name (#25409)

上级 41d22472
...@@ -10416,20 +10416,28 @@ def uniform_random_batch_size_like(input, ...@@ -10416,20 +10416,28 @@ def uniform_random_batch_size_like(input,
@templatedoc() @templatedoc()
def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'): def gaussian_random(shape,
mean=0.0,
std=1.0,
seed=0,
dtype='float32',
name=None):
""" """
Generate a random tensor whose data is drawn from a Gaussian distribution. Generate a random tensor whose data is drawn from a Gaussian distribution.
Args: Args:
shape (tuple[int] | list[int] | Variable | list[Variable]): Shape of the generated random tensor. shape(list|tuple|Variable): Shape of the Tensor to be created. The data
type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple,
mean (float): Mean of the random tensor, defaults to 0.0. the elements of it should be integers or Tensors with shape [1]. If
``shape`` is a Variable, it should be an 1-D Tensor .
std (float): Standard deviation of the random tensor, defaults to 1.0. mean(float): Mean of the random tensor, defaults to 0.0.
std(float): Standard deviation of the random tensor, defaults to 1.0.
seed (int): ${seed_comment} seed(int): ${seed_comment}
dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output
dtype(np.dtype | core.VarDesc.VarType | str): Output data type, float32 or float64. tensor, which can be float32, float64. Default is float32.
name(str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Default is None.
Returns: Returns:
Variable: Random tensor whose data is drawn from a Gaussian distribution, dtype: flaot32 or float64 as specified. Variable: Random tensor whose data is drawn from a Gaussian distribution, dtype: flaot32 or float64 as specified.
...@@ -10492,11 +10500,16 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'): ...@@ -10492,11 +10500,16 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
# array([[2.3060477 , 2.676496 , 3.9911983 , 0.9990833 ], # array([[2.3060477 , 2.676496 , 3.9911983 , 0.9990833 ],
# [2.8675377 , 2.2279181 , 0.79029655, 2.8447366 ]], dtype=float32) # [2.8675377 , 2.2279181 , 0.79029655, 2.8447366 ]], dtype=float32)
""" """
check_type(shape, 'shape', (list, tuple, Variable), 'gaussian_random')
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(dtype, 'dtype', ['float32', 'float64'], 'gaussian_random')
if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape)
return core.ops.gaussian_random('shape', shape, 'mean', mean, 'std',
std, 'seed', seed, 'dtype', dtype)
check_type(shape, 'shape', (list, tuple, Variable), 'gaussian_random/randn')
check_dtype(dtype, 'dtype', ['float32', 'float64'], 'gaussian_random/randn')
inputs = {} inputs = {}
attrs = { attrs = {
...@@ -10507,7 +10520,10 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'): ...@@ -10507,7 +10520,10 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
'use_mkldnn': False 'use_mkldnn': False
} }
utils._get_shape_tensor_inputs( utils._get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='gaussian_random') inputs=inputs,
attrs=attrs,
shape=shape,
op_type='gaussian_random/randn')
helper = LayerHelper('gaussian_random', **locals()) helper = LayerHelper('gaussian_random', **locals())
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
...@@ -15011,13 +15027,13 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0, ...@@ -15011,13 +15027,13 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0,
float(min), 'max', float(min), 'max',
float(max), 'seed', seed, 'dtype', dtype) float(max), 'seed', seed, 'dtype', dtype)
check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random') check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random/rand')
check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random') check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random/rand')
inputs = dict() inputs = dict()
attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype} attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype}
utils._get_shape_tensor_inputs( utils._get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random') inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random/rand')
helper = LayerHelper("uniform_random", **locals()) helper = LayerHelper("uniform_random", **locals())
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
......
...@@ -17,92 +17,71 @@ from __future__ import print_function ...@@ -17,92 +17,71 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid import Program, program_guard from paddle import Program, program_guard
class TestRandnOp(unittest.TestCase): class TestRandnOp(unittest.TestCase):
def test_api(self): def test_api(self):
x1 = paddle.randn(shape=[1000, 784], dtype='float32') shape = [1000, 784]
x2 = paddle.randn(shape=[1000, 784], dtype='float64') train_program = Program()
x3 = fluid.layers.fill_constant( startup_program = Program()
shape=[1000, 784], dtype='float32', value=0) with program_guard(train_program, startup_program):
paddle.randn(shape=[1000, 784], out=x3, dtype='float32') x1 = paddle.randn(shape, 'float32')
x4 = paddle.randn(shape=[1000, 784], dtype='float32', device='cpu') x2 = paddle.randn(shape, 'float64')
x5 = paddle.randn(shape=[1000, 784], dtype='float32', device='gpu')
x6 = paddle.randn( dim_1 = paddle.fill_constant([1], "int64", 20)
shape=[1000, 784], dim_2 = paddle.fill_constant([1], "int32", 50)
dtype='float32', x3 = paddle.randn([dim_1, dim_2, 784])
device='gpu',
stop_gradient=False) var_shape = paddle.nn.data('X', [2], 'int32')
x4 = paddle.randn(var_shape)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace() place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
exe = fluid.Executor(place) ) else paddle.CPUPlace()
res = exe.run(fluid.default_main_program(), exe = paddle.Executor(place)
feed={}, res = exe.run(train_program,
fetch_list=[x1, x2, x3, x4, x5, x6]) feed={'X': np.array(
shape, dtype='int32')},
self.assertAlmostEqual(np.mean(res[0]), .0, delta=0.1) fetch_list=[x1, x2, x3, x4])
self.assertAlmostEqual(np.std(res[0]), 1., delta=0.1)
self.assertAlmostEqual(np.mean(res[1]), .0, delta=0.1) for out in res:
self.assertAlmostEqual(np.std(res[1]), 1., delta=0.1) self.assertAlmostEqual(np.mean(out), .0, delta=0.1)
self.assertAlmostEqual(np.mean(res[2]), .0, delta=0.1) self.assertAlmostEqual(np.std(out), 1., delta=0.1)
self.assertAlmostEqual(np.std(res[2]), 1., delta=0.1)
self.assertAlmostEqual(np.mean(res[3]), .0, delta=0.1)
self.assertAlmostEqual(np.std(res[3]), 1., delta=0.1) class TestRandnOpForDygraph(unittest.TestCase):
self.assertAlmostEqual(np.mean(res[4]), .0, delta=0.1) def test_api(self):
self.assertAlmostEqual(np.std(res[4]), 1., delta=0.1) shape = [1000, 784]
self.assertAlmostEqual(np.mean(res[5]), .0, delta=0.1) place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
self.assertAlmostEqual(np.std(res[5]), 1., delta=0.1) ) else paddle.CPUPlace()
with paddle.imperative.guard(place):
x1 = paddle.randn(shape, 'float32')
x2 = paddle.randn(shape, 'float64')
dim_1 = paddle.fill_constant([1], "int64", 20)
dim_2 = paddle.fill_constant([1], "int32", 50)
x3 = paddle.randn(shape=[dim_1, dim_2, 784])
var_shape = paddle.imperative.to_variable(np.array(shape))
x4 = paddle.randn(var_shape)
for out in [x1, x2, x3, x4]:
self.assertAlmostEqual(np.mean(out.numpy()), .0, delta=0.1)
self.assertAlmostEqual(np.std(out.numpy()), 1., delta=0.1)
class TestRandnOpError(unittest.TestCase): class TestRandnOpError(unittest.TestCase):
def test_error(self): def test_error(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
# The argument shape's size of randn_op should not be 0. # The argument shape's size of randn_op should not be 0.
def test_shape_size(): self.assertRaises(AssertionError, paddle.randn, [])
out = paddle.randn(shape=[])
self.assertRaises(AssertionError, test_shape_size)
# The argument shape's type of randn_op should be list or tuple. # The argument shape's type of randn_op should be list or tuple.
def test_shape_type(): self.assertRaises(TypeError, paddle.randn, 1)
out = paddle.randn(shape=1)
self.assertRaises(TypeError, test_shape_type)
# The argument dtype of randn_op should be float32 or float64.
def test_dtype_float16():
out = paddle.randn(shape=[1, 2], dtype='float16')
self.assertRaises(TypeError, test_dtype_float16)
# The argument dtype of randn_op should be float32 or float64. # The argument dtype of randn_op should be float32 or float64.
def test_dtype_int32(): self.assertRaises(TypeError, paddle.randn, [1, 2], 'int32')
out = paddle.randn(shape=[1, 2], dtype='int32')
self.assertRaises(TypeError, test_dtype_int32)
# The argument dtype of randn_op should be float32 or float64.
def test_dtype_int64():
out = paddle.randn(shape=[1, 2], dtype='int64')
self.assertRaises(TypeError, test_dtype_int64)
# The argument dtype of randn_op should be float32 or float64.
def test_dtype_uint8():
out = paddle.randn(shape=[1, 2], dtype='uint8')
self.assertRaises(TypeError, test_dtype_uint8)
# The argument dtype of randn_op should be float32 or float64.
def test_dtype_bool():
out = paddle.randn(shape=[1, 2], dtype='bool')
self.assertRaises(TypeError, test_dtype_bool)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -21,7 +21,7 @@ from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator, V ...@@ -21,7 +21,7 @@ from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator, V
from ..fluid.layers.layer_function_generator import templatedoc from ..fluid.layers.layer_function_generator import templatedoc
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..fluid.layers import uniform_random, utils from ..fluid.layers import utils, uniform_random, gaussian_random
from ..fluid.layers.tensor import fill_constant from ..fluid.layers.tensor import fill_constant
from ..fluid.io import shuffle #DEFINE_ALIAS from ..fluid.io import shuffle #DEFINE_ALIAS
...@@ -206,36 +206,23 @@ def randint(low, ...@@ -206,36 +206,23 @@ def randint(low,
return out return out
def randn(shape, def randn(shape, dtype=None, name=None):
out=None,
dtype=None,
device=None,
stop_gradient=True,
name=None):
""" """
:alias_main: paddle.randn :alias_main: paddle.randn
:alias: paddle.randn,paddle.tensor.randn,paddle.tensor.random.randn :alias: paddle.randn,paddle.tensor.randn,paddle.tensor.random.randn
This function returns a tensor filled with random numbers from a normal This function returns a tensor filled with random numbers from a normal
distribution with mean 0 and variance 1 (also called the standard normal distribution with mean 0 and standard deviation 1 (also called the standard normal
distribution). distribution).
Args: Args:
shape(list|tuple): Shape of the generated random tensor. shape(list|tuple|Variable): Shape of the Tensor to be created. The data
out(Variable, optional): Optional output which can be any created Variable type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple,
that meets the requirements to store the result of operation. If the the elements of it should be integers or Tensors with shape [1]. If
out is `None`, a new Variable will be returned to store the result. ``shape`` is a Variable, it should be an 1-D Tensor .
Default is None. dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output
dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output tensor, which can be float32, float64. If dtype is `None` , the data
tensor, which can be float32, float64. if dtype is `None` , the data type of output tensor is `float32` . Default is None.
type of output tensor is `float32` .
Default is None.
device(str, optional): Specific the output variable to be saved in cpu
or gpu memory. Supported None, 'cpu', 'gpu'. If it is None, the output
variable will be automatically assigned devices.
Default: None.
stop_gradient(bool, optional): Indicating if we stop gradient from current(out)
Variable. Default is True.
name(str, optional): Normally there is no need for user to set this property. name(str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` . For more information, please refer to :ref:`api_guide_Name` .
Default is None. Default is None.
...@@ -244,75 +231,50 @@ def randn(shape, ...@@ -244,75 +231,50 @@ def randn(shape,
Random tensor whose data is drawn from a standard normal distribution, Random tensor whose data is drawn from a standard normal distribution,
dtype: flaot32 or float64 as specified. dtype: flaot32 or float64 as specified.
Return type: Return type: Variable
Variable
Raises: Raises:
TypeError: If the type of `shape` is not list or tuple. TypeError: If the type of `shape` is not Variable, list or tuple.
TypeError: If the data type of `dtype` is not float32 or float64. TypeError: If the data type of `dtype` is not float32 or float64.
ValueError: If the length of `shape` is not bigger than 0. ValueError: If the length of `shape` is not bigger than 0.
Examples: Examples:
.. code-block:: python .. code-block:: python
# declarative mode import paddle
import paddle import numpy as np
import paddle.fluid as fluid
data = paddle.randn([2, 4]) paddle.enable_imperative()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
res, = exe.run(fluid.default_main_program(), feed={}, fetch_list=[data])
print(res)
# [[-1.4187592 0.7368311 -0.53748125 -0.0146909 ]
# [-0.66294265 -1.3090698 0.1898754 -0.14065823]]
.. code-block:: python # example 1: attr shape is a list which doesn't contain tensor Variable.
result_1 = paddle.randn(shape=[2, 3])
# [[-2.923464 0.11934398 -0.51249987]
# [ 0.39632758 0.08177969 0.2692008 ]]
# imperative mode # example 2: attr shape is a list which contains tensor Variable.
import paddle dim_1 = paddle.fill_constant([1], "int64", 2)
import paddle.fluid as fluid dim_2 = paddle.fill_constant([1], "int32", 3)
import paddle.fluid.dygraph as dg result_2 = paddle.randn(shape=[dim_1, dim_2, 2])
# [[[-2.8852394 -0.25898588]
place = fluid.CPUPlace() # [-0.47420555 0.17683524]
with dg.guard(place) as g: # [-0.7989969 0.00754541]]
x = paddle.randn([2, 4]) # [[ 0.85201347 0.32320443]
x_np = x.numpy() # [ 1.1399018 0.48336947]
print(x_np) # [ 0.8086993 0.6868893 ]]]
# [[ 1.5149173 -0.26234224 -0.592486 1.4523455 ]
# [ 0.04581212 -0.85345626 1.1687907 -0.02512913]] # example 3: attr shape is a Variable, the data type must be int64 or int32.
""" var_shape = paddle.imperative.to_variable(np.array([2, 3]))
helper = LayerHelper("randn", **locals()) result_3 = paddle.randn(var_shape)
check_type(shape, 'shape', (list, tuple), 'randn') # [[-2.878077 0.17099959 0.05111201]
assert len(shape) > 0, ("The size of argument(shape) can't be zero.") # [-0.3761474 -1.044801 1.1870178 ]]
"""
if dtype is None: if dtype is None:
dtype = 'float32' dtype = 'float32'
check_dtype(dtype, 'create data type', ['float32', 'float64'], 'randn') out = gaussian_random(
shape=shape, mean=0.0, std=1.0, seed=0, dtype=dtype, name=name)
if out is None: out.stop_gradient = True
out = helper.create_variable_for_type_inference(dtype=dtype)
else:
check_variable_and_dtype(out, 'out', [dtype], 'randn')
out.stop_gradient = stop_gradient
dtype = convert_np_dtype_to_dtype_(dtype)
seed = np.random.randint(0, 100)
with device_guard(device):
helper.append_op(
type='gaussian_random',
outputs={'Out': out},
attrs={
'shape': shape,
'mean': 0.0,
'std': 1.0,
'seed': seed,
'dtype': dtype,
'use_mkldnn': False
})
return out return out
...@@ -369,6 +331,7 @@ def randperm(n, dtype="int64", name=None): ...@@ -369,6 +331,7 @@ def randperm(n, dtype="int64", name=None):
attrs = {'n': n, 'dtype': dtype, 'seed': 0} attrs = {'n': n, 'dtype': dtype, 'seed': 0}
helper.append_op( helper.append_op(
type='randperm', inputs={}, outputs={'Out': out}, attrs=attrs) type='randperm', inputs={}, outputs={'Out': out}, attrs=attrs)
out.stop_gradient = True
return out return out
...@@ -439,4 +402,7 @@ def rand(shape, dtype=None, name=None): ...@@ -439,4 +402,7 @@ def rand(shape, dtype=None, name=None):
""" """
if dtype is None: if dtype is None:
dtype = 'float32' dtype = 'float32'
return uniform_random(shape, dtype, min=0.0, max=1.0, name=name)
out = uniform_random(shape, dtype, min=0.0, max=1.0, name=name)
out.stop_gradient = True
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册