未验证 提交 bdac6bc8 编写于 作者: P pangyoki 提交者: GitHub

Rename uniform_random API (#26347)

* Rename uniform and gaussian APIs

* rename uniform_random API as uniform

* Fixed unittest

* fixed unittest AttributeError

* Fixed unittest

* Add uniform function rather than alias

* remove templatedoc

* solve conflict and fix doc code

* fix doc code
上级 f6d20d56
......@@ -190,7 +190,7 @@ from .tensor.math import trace #DEFINE_ALIAS
from .tensor.math import kron #DEFINE_ALIAS
from .tensor.math import prod #DEFINE_ALIAS
# from .tensor.random import gaussin #DEFINE_ALIAS
# from .tensor.random import uniform #DEFINE_ALIAS
from .tensor.random import uniform #DEFINE_ALIAS
from .tensor.random import shuffle #DEFINE_ALIAS
from .tensor.random import randn #DEFINE_ALIAS
from .tensor.random import rand #DEFINE_ALIAS
......
......@@ -26,8 +26,8 @@ import paddle
class TestDirectory(unittest.TestCase):
def get_import_command(self, module):
paths = module.split('.')
if len(paths) <= 1:
return module
if len(paths) == 1:
return 'import {}'.format(module)
package = '.'.join(paths[:-1])
func = paths[-1]
cmd = 'from {} import {}'.format(package, func)
......
......@@ -14,9 +14,12 @@
from __future__ import print_function
import sys
import subprocess
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
......@@ -472,5 +475,61 @@ class TestUniformRandomBatchSizeLikeOpError(unittest.TestCase):
self.assertRaises(TypeError, test_dtype)
class TestUniformAlias(unittest.TestCase):
def test_alias(self):
paddle.uniform([2, 3], min=-5.0, max=5.0)
paddle.tensor.uniform([2, 3], min=-5.0, max=5.0)
paddle.tensor.random.uniform([2, 3], min=-5.0, max=5.0)
def test_uniform_random():
paddle.tensor.random.uniform_random([2, 3], min=-5.0, max=5.0)
self.assertRaises(AttributeError, test_uniform_random)
class TestUniformOpError(unittest.TestCase):
def test_errors(self):
main_prog = Program()
start_prog = Program()
with program_guard(main_prog, start_prog):
def test_Variable():
x1 = fluid.create_lod_tensor(
np.zeros((4, 784)), [[1, 1, 1, 1]], fluid.CPUPlace())
paddle.tensor.random.uniform(x1)
self.assertRaises(TypeError, test_Variable)
def test_Variable2():
x1 = np.zeros((4, 784))
paddle.tensor.random.uniform(x1)
self.assertRaises(TypeError, test_Variable2)
def test_dtype():
x2 = fluid.layers.data(
name='x2', shape=[4, 784], dtype='float32')
paddle.tensor.random.uniform(x2, 'int32')
self.assertRaises(TypeError, test_dtype)
def test_out_dtype():
out = paddle.tensor.random.uniform(
shape=[3, 4], dtype='float64')
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64)
test_out_dtype()
class TestUniformDygraphMode(unittest.TestCase):
def test_check_output(self):
with fluid.dygraph.guard():
x = paddle.tensor.random.uniform(
[10], dtype="float32", min=0.0, max=1.0)
x_np = x.numpy()
for i in range(10):
self.assertTrue((x_np[i] > 0 and x_np[i] < 1.0))
if __name__ == "__main__":
unittest.main()
......@@ -159,7 +159,7 @@ from .math import trace #DEFINE_ALIAS
from .math import kron #DEFINE_ALIAS
from .math import prod #DEFINE_ALIAS
# from .random import gaussin #DEFINE_ALIAS
# from .random import uniform #DEFINE_ALIAS
from .random import uniform #DEFINE_ALIAS
from .random import shuffle #DEFINE_ALIAS
from .random import randn #DEFINE_ALIAS
from .random import rand #DEFINE_ALIAS
......
......@@ -21,14 +21,14 @@ from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator, V
from ..fluid.layers.layer_function_generator import templatedoc
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..fluid.layers import utils, uniform_random, gaussian_random
from ..fluid.layers import utils, gaussian_random
from ..fluid.layers.tensor import fill_constant
from ..fluid.io import shuffle #DEFINE_ALIAS
__all__ = [
# 'gaussin',
# 'uniform',
'uniform',
'shuffle',
'randn',
'rand',
......@@ -37,6 +37,111 @@ __all__ = [
]
def uniform(shape, dtype='float32', min=-1.0, max=1.0, seed=0, name=None):
"""
This OP returns a Tensor filled with random values sampled from a uniform
distribution in the range [``min``, ``max``), with ``shape`` and ``dtype``.
Examples:
::
Input:
shape = [1, 2]
Output:
result=[[0.8505902, 0.8397286]]
Args:
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
dtype(str|np.dtype|core.VarDesc.VarType, optional): The data type of
the output Tensor. Supported data types: float32, float64.
Default is float32.
min(float|int, optional): The lower bound on the range of random values
to generate, ``min`` is included in the range. Default is -1.0.
max(float|int, optional): The upper bound on the range of random values
to generate, ``max`` is excluded in the range. Default is 1.0.
seed(int, optional): Random seed used for generating samples. 0 means
use a seed generated by the system. Note that if seed is not 0,
this operator will always generate the same random numbers every
time. Default is 0.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: A Tensor filled with random values sampled from a uniform
distribution in the range [``min``, ``max``), with ``shape`` and ``dtype``.
Raises:
TypeError: If ``shape`` is not list, tuple, Tensor.
TypeError: If ``dtype`` is not float32, float64.
Examples:
.. code-block:: python
import numpy as np
import paddle
paddle.disable_static()
# example 1:
# attr shape is a list which doesn't contain Tensor.
result_1 = paddle.tensor.random.uniform(shape=[3, 4])
# [[ 0.84524226, 0.6921872, 0.56528175, 0.71690357],
# [-0.34646994, -0.45116323, -0.09902662, -0.11397249],
# [ 0.433519, 0.39483607, -0.8660099, 0.83664286]]
# example 2:
# attr shape is a list which contains Tensor.
dim_1 = paddle.fill_constant([1], "int64", 2)
dim_2 = paddle.fill_constant([1], "int32", 3)
result_2 = paddle.tensor.random.uniform(shape=[dim_1, dim_2])
# [[-0.9951253, 0.30757582, 0.9899647 ],
# [ 0.5864527, 0.6607096, -0.8886161 ]]
# example 3:
# attr shape is a Tensor, the data type must be int64 or int32.
shape = np.array([2, 3])
shape_tensor = paddle.to_tensor(shape)
result_3 = paddle.tensor.random.uniform(shape_tensor)
# if shape_tensor's value is [2, 3]
# result_3 is:
# [[-0.8517412, -0.4006908, 0.2551912 ],
# [ 0.3364414, 0.36278176, -0.16085452]]
paddle.enable_static()
"""
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape)
return core.ops.uniform_random('shape', shape, 'min',
float(min), 'max',
float(max), 'seed', seed, 'dtype', dtype)
check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random/rand')
check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random/rand')
inputs = dict()
attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype}
utils._get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random/rand')
helper = LayerHelper("uniform_random", **locals())
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="uniform_random", inputs=inputs, attrs=attrs,
outputs={"Out": out})
return out
def randint(low=0, high=None, shape=[1], dtype=None, name=None):
"""
:alias_main: paddle.randint
......@@ -352,6 +457,6 @@ def rand(shape, dtype=None, name=None):
if dtype is None:
dtype = 'float32'
out = uniform_random(shape, dtype, min=0.0, max=1.0, name=name)
out = uniform(shape, dtype, min=0.0, max=1.0, name=name)
out.stop_gradient = True
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册