From bdac6bc880efdcbd5a08579366ea6dc92b2b6822 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Fri, 21 Aug 2020 00:13:01 -0500 Subject: [PATCH] Rename uniform_random API (#26347) * Rename uniform and gaussian APIs * rename uniform_random API as uniform * Fixed unittest * fixed unittest AttributeError * Fixed unittest * Add uniform function rather than alias * remove templatedoc * solve conflict and fix doc code * fix doc code --- python/paddle/__init__.py | 2 +- .../unittests/test_directory_migration.py | 4 +- .../tests/unittests/test_uniform_random_op.py | 59 ++++++++++ python/paddle/tensor/__init__.py | 2 +- python/paddle/tensor/random.py | 111 +++++++++++++++++- 5 files changed, 171 insertions(+), 7 deletions(-) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 7e0c4cd283..1597a1a78e 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -190,7 +190,7 @@ from .tensor.math import trace #DEFINE_ALIAS from .tensor.math import kron #DEFINE_ALIAS from .tensor.math import prod #DEFINE_ALIAS # from .tensor.random import gaussin #DEFINE_ALIAS -# from .tensor.random import uniform #DEFINE_ALIAS +from .tensor.random import uniform #DEFINE_ALIAS from .tensor.random import shuffle #DEFINE_ALIAS from .tensor.random import randn #DEFINE_ALIAS from .tensor.random import rand #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_directory_migration.py b/python/paddle/fluid/tests/unittests/test_directory_migration.py index 4dc2c92ad9..bc85882805 100644 --- a/python/paddle/fluid/tests/unittests/test_directory_migration.py +++ b/python/paddle/fluid/tests/unittests/test_directory_migration.py @@ -26,8 +26,8 @@ import paddle class TestDirectory(unittest.TestCase): def get_import_command(self, module): paths = module.split('.') - if len(paths) <= 1: - return module + if len(paths) == 1: + return 'import {}'.format(module) package = '.'.join(paths[:-1]) func = paths[-1] cmd = 'from {} import {}'.format(package, func) diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py index 9a64dd1dee..158462a1e6 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py @@ -14,9 +14,12 @@ from __future__ import print_function +import sys +import subprocess import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid @@ -472,5 +475,61 @@ class TestUniformRandomBatchSizeLikeOpError(unittest.TestCase): self.assertRaises(TypeError, test_dtype) +class TestUniformAlias(unittest.TestCase): + def test_alias(self): + paddle.uniform([2, 3], min=-5.0, max=5.0) + paddle.tensor.uniform([2, 3], min=-5.0, max=5.0) + paddle.tensor.random.uniform([2, 3], min=-5.0, max=5.0) + + def test_uniform_random(): + paddle.tensor.random.uniform_random([2, 3], min=-5.0, max=5.0) + + self.assertRaises(AttributeError, test_uniform_random) + + +class TestUniformOpError(unittest.TestCase): + def test_errors(self): + main_prog = Program() + start_prog = Program() + with program_guard(main_prog, start_prog): + + def test_Variable(): + x1 = fluid.create_lod_tensor( + np.zeros((4, 784)), [[1, 1, 1, 1]], fluid.CPUPlace()) + paddle.tensor.random.uniform(x1) + + self.assertRaises(TypeError, test_Variable) + + def test_Variable2(): + x1 = np.zeros((4, 784)) + paddle.tensor.random.uniform(x1) + + self.assertRaises(TypeError, test_Variable2) + + def test_dtype(): + x2 = fluid.layers.data( + name='x2', shape=[4, 784], dtype='float32') + paddle.tensor.random.uniform(x2, 'int32') + + self.assertRaises(TypeError, test_dtype) + + def test_out_dtype(): + out = paddle.tensor.random.uniform( + shape=[3, 4], dtype='float64') + self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64) + + test_out_dtype() + + +class TestUniformDygraphMode(unittest.TestCase): + def test_check_output(self): + with fluid.dygraph.guard(): + x = paddle.tensor.random.uniform( + [10], dtype="float32", min=0.0, max=1.0) + x_np = x.numpy() + for i in range(10): + self.assertTrue((x_np[i] > 0 and x_np[i] < 1.0)) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index f6ac029f43..ad01caf0bb 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -159,7 +159,7 @@ from .math import trace #DEFINE_ALIAS from .math import kron #DEFINE_ALIAS from .math import prod #DEFINE_ALIAS # from .random import gaussin #DEFINE_ALIAS -# from .random import uniform #DEFINE_ALIAS +from .random import uniform #DEFINE_ALIAS from .random import shuffle #DEFINE_ALIAS from .random import randn #DEFINE_ALIAS from .random import rand #DEFINE_ALIAS diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index d26003fd82..353a22f87f 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -21,14 +21,14 @@ from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator, V from ..fluid.layers.layer_function_generator import templatedoc from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype -from ..fluid.layers import utils, uniform_random, gaussian_random +from ..fluid.layers import utils, gaussian_random from ..fluid.layers.tensor import fill_constant from ..fluid.io import shuffle #DEFINE_ALIAS __all__ = [ # 'gaussin', - # 'uniform', + 'uniform', 'shuffle', 'randn', 'rand', @@ -37,6 +37,111 @@ __all__ = [ ] +def uniform(shape, dtype='float32', min=-1.0, max=1.0, seed=0, name=None): + """ + This OP returns a Tensor filled with random values sampled from a uniform + distribution in the range [``min``, ``max``), with ``shape`` and ``dtype``. + + Examples: + :: + + Input: + shape = [1, 2] + + Output: + result=[[0.8505902, 0.8397286]] + + Args: + shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape`` + is a list or tuple, the elements of it should be integers or Tensors + (with the shape [1], and the data type int32 or int64). If ``shape`` + is a Tensor, it should be a 1-D Tensor(with the data type int32 or + int64). + dtype(str|np.dtype|core.VarDesc.VarType, optional): The data type of + the output Tensor. Supported data types: float32, float64. + Default is float32. + min(float|int, optional): The lower bound on the range of random values + to generate, ``min`` is included in the range. Default is -1.0. + max(float|int, optional): The upper bound on the range of random values + to generate, ``max`` is excluded in the range. Default is 1.0. + seed(int, optional): Random seed used for generating samples. 0 means + use a seed generated by the system. Note that if seed is not 0, + this operator will always generate the same random numbers every + time. Default is 0. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A Tensor filled with random values sampled from a uniform + distribution in the range [``min``, ``max``), with ``shape`` and ``dtype``. + + Raises: + TypeError: If ``shape`` is not list, tuple, Tensor. + TypeError: If ``dtype`` is not float32, float64. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + paddle.disable_static() + + # example 1: + # attr shape is a list which doesn't contain Tensor. + result_1 = paddle.tensor.random.uniform(shape=[3, 4]) + # [[ 0.84524226, 0.6921872, 0.56528175, 0.71690357], + # [-0.34646994, -0.45116323, -0.09902662, -0.11397249], + # [ 0.433519, 0.39483607, -0.8660099, 0.83664286]] + + # example 2: + # attr shape is a list which contains Tensor. + dim_1 = paddle.fill_constant([1], "int64", 2) + dim_2 = paddle.fill_constant([1], "int32", 3) + result_2 = paddle.tensor.random.uniform(shape=[dim_1, dim_2]) + # [[-0.9951253, 0.30757582, 0.9899647 ], + # [ 0.5864527, 0.6607096, -0.8886161 ]] + + # example 3: + # attr shape is a Tensor, the data type must be int64 or int32. + shape = np.array([2, 3]) + shape_tensor = paddle.to_tensor(shape) + + result_3 = paddle.tensor.random.uniform(shape_tensor) + # if shape_tensor's value is [2, 3] + # result_3 is: + # [[-0.8517412, -0.4006908, 0.2551912 ], + # [ 0.3364414, 0.36278176, -0.16085452]] + + paddle.enable_static() + + """ + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + + if in_dygraph_mode(): + shape = utils._convert_shape_to_list(shape) + return core.ops.uniform_random('shape', shape, 'min', + float(min), 'max', + float(max), 'seed', seed, 'dtype', dtype) + + check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random/rand') + check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random/rand') + + inputs = dict() + attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype} + utils._get_shape_tensor_inputs( + inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random/rand') + + helper = LayerHelper("uniform_random", **locals()) + out = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type="uniform_random", inputs=inputs, attrs=attrs, + outputs={"Out": out}) + return out + + def randint(low=0, high=None, shape=[1], dtype=None, name=None): """ :alias_main: paddle.randint @@ -352,6 +457,6 @@ def rand(shape, dtype=None, name=None): if dtype is None: dtype = 'float32' - out = uniform_random(shape, dtype, min=0.0, max=1.0, name=name) + out = uniform(shape, dtype, min=0.0, max=1.0, name=name) out.stop_gradient = True return out -- GitLab