未验证 提交 22501202 编写于 作者: Z zhupengyang 提交者: GitHub

randint API: remove out, devive, stop_gradient; add name (#25433)

* randint API: remove out, devive, stop_gradient; add name; test=develop

* test=develop

* test=develop

* test=develop
上级 ca725c82
...@@ -17,12 +17,9 @@ from __future__ import print_function ...@@ -17,12 +17,9 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle import paddle
from paddle.fluid import core
from paddle import Program, program_guard
def output_hist(out): def output_hist(out):
...@@ -56,25 +53,10 @@ class TestRandintOp(OpTest): ...@@ -56,25 +53,10 @@ class TestRandintOp(OpTest):
class TestRandintOpError(unittest.TestCase): class TestRandintOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
main_prog = Program() with program_guard(Program(), Program()):
start_prog = Program() self.assertRaises(TypeError, paddle.randint, 5, shape=np.array([2]))
with program_guard(main_prog, start_prog): self.assertRaises(TypeError, paddle.randint, 5, dtype='float32')
self.assertRaises(ValueError, paddle.randint, 5, 5)
def test_shape():
shape = np.array([2, 3])
paddle.randint(5, shape=shape, dtype='int32')
self.assertRaises(TypeError, test_shape)
def test_dtype():
paddle.randint(5, shape=[32, 32], dtype='float32')
self.assertRaises(TypeError, test_dtype)
def test_low_high():
paddle.randint(low=5, high=5, shape=[32, 32], dtype='int32')
self.assertRaises(ValueError, test_low_high)
class TestRandintOp_attr_tensorlist(OpTest): class TestRandintOp_attr_tensorlist(OpTest):
...@@ -127,46 +109,44 @@ class TestRandint_attr_tensor(OpTest): ...@@ -127,46 +109,44 @@ class TestRandint_attr_tensor(OpTest):
# Test python API # Test python API
class TestRandintAPI(unittest.TestCase): class TestRandintAPI(unittest.TestCase):
def test_api(self): def test_api(self):
startup_program = fluid.Program() with program_guard(Program(), Program()):
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
# results are from [0, 5). # results are from [0, 5).
output1 = paddle.randint(5) out1 = paddle.randint(5)
# shape is a list and dtype is 'int32' # shape is a list and dtype is 'int32'
output2 = paddle.randint( out2 = paddle.randint(
low=-100, high=100, shape=[64, 64], dtype='int32') low=-100, high=100, shape=[64, 64], dtype='int32')
# shape is a tuple and dtype is 'int64' # shape is a tuple and dtype is 'int64'
output3 = paddle.randint( out3 = paddle.randint(
low=-100, high=100, shape=(32, 32, 3), dtype='int64') low=-100, high=100, shape=(32, 32, 3), dtype='int64')
# shape is a tensorlist and dtype is 'float32' # shape is a tensorlist and dtype is 'float32'
dim_1 = fluid.layers.fill_constant([1], "int64", 32) dim_1 = paddle.fill_constant([1], "int64", 32)
dim_2 = fluid.layers.fill_constant([1], "int32", 50) dim_2 = paddle.fill_constant([1], "int32", 50)
output4 = paddle.randint( out4 = paddle.randint(
low=-100, high=100, shape=[dim_1, 5], dtype='int32') low=-100, high=100, shape=[dim_1, 5, dim_2], dtype='int32')
# shape is a tensor and dtype is 'float64' # shape is a tensor and dtype is 'float64'
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64") var_shape = paddle.nn.data(
output5 = paddle.randint( name='var_shape', shape=[2], dtype="int64")
out5 = paddle.randint(
low=1, high=1000, shape=var_shape, dtype='int64') low=1, high=1000, shape=var_shape, dtype='int64')
place = fluid.CPUPlace() place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
if fluid.core.is_compiled_with_cuda(): ) else paddle.CPUPlace()
place = fluid.CUDAPlace(0) exe = paddle.Executor(place)
exe = fluid.Executor(place)
exe.run(startup_program)
outs = exe.run( outs = exe.run(
train_program,
feed={'var_shape': np.array([100, 100]).astype('int64')}, feed={'var_shape': np.array([100, 100]).astype('int64')},
fetch_list=[output1, output2, output3, output4, output5]) fetch_list=[out1, out2, out3, out4, out5])
class TestRandintDygraphMode(unittest.TestCase): class TestRandintImperative(unittest.TestCase):
def test_check_output(self): def test_api(self):
with fluid.dygraph.guard(): n = 10
x = paddle.randint(10, shape=[10], dtype="int32") with paddle.imperative.guard():
x_np = x.numpy() x1 = paddle.randint(n, shape=[10], dtype="int32")
for i in range(10): x2 = paddle.tensor.randint(n)
self.assertTrue((x_np[i] >= 0 and x_np[i] < 10)) x3 = paddle.tensor.random.randint(n)
for i in [x1, x2, x3]:
for j in i.numpy().tolist():
self.assertTrue((j >= 0 and j < n))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -37,172 +37,111 @@ __all__ = [ ...@@ -37,172 +37,111 @@ __all__ = [
] ]
def randint(low, def randint(low=0, high=None, shape=[1], dtype=None, name=None):
high=None,
shape=None,
out=None,
dtype=None,
device=None,
stop_gradient=False,
seed=0,
name=None):
""" """
:alias_main: paddle.randint :alias_main: paddle.randint
:alias: paddle.randint,paddle.tensor.randint,paddle.tensor.random.randint :alias: paddle.randint,paddle.tensor.randint,paddle.tensor.random.randint
This function returns a Tensor filled with random integers from the "discrete uniform" distribution of the This function returns a Tensor filled with random integers from the
specified data type in the interval [low, high). If high is None (the default), then results are from [0, low). "discrete uniform" distribution of the specified data type in the interval
[low, high). If high is None (the default), then results are from [0, low).
Args: Args:
low (int): The lower bound on the range of random values to generate, the low is included in the range. low (int): The lower bound on the range of random values to generate,
(unless high=None, in which case this parameter is one above the highest such integer). the low is included in the range.(unless high=None, in which case
high (int, optional): The upper bound on the range of random values to generate, the high is excluded this parameter is one above the highest such integer). Default is 0.
in the range. Default None(see above for behavior if high=None). high (int, optional): The upper bound on the range of random values to
shape (list|tuple|Variable, optional): The shape of the output Tensor, if the shape is a list or tuple, generate, the high is excluded in the range. Default is None(see
its elements can be an integer above for behavior if high=None).
or a Tensor with the shape [1], and the type of the Tensor must be int32 or int64. shape (list|tuple|Variable, optional): The shape of the output Tensor,
If the shape is a Variable, it is a 1-D Tensor, and the type of the Tensor must be if the shape is a list or tuple, its elements can be an integer or
int32 or int64. Default is None, in which case the shape is [1]. a Tensor with the shape [1], and the type of the Tensor must be
out(Variable, optional): Optional output which can be any created int32 or int64. If the shape is a Variable, it is a 1-D Tensor,
Variable that meets the requirements to store the result of operation. and the type of the Tensor must be int32 or int64. Default is None.
if out is None, a new Varibale will be create to store the result. dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the
dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output Tensor output Tensor which can be int32, int64. If dtype is `None`, the
which can be int32, int64, if dytpe is `None`, the data data type of created Tensor is `int64`
type of created Tensor is `int64` name(str, optional): The default value is None. Normally there is no
device(str, optional): This parameter specifies that the Tensor is created need for user to set this property. For more information, please
on the GPU or CPU. refer to :ref:`api_guide_Name`.
stop_gradient(bool, optional): Indicating if we stop gradient from current(out) Variable,
default value is False.
seed (int, optional): Random seed used for permute samples. If seed is
equal to 0, it means use a seed generated by the system. Note that
if seed is not 0, this operator will always generate the same random
permutation every time. Default: 0.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable: A Tensor of the specified shape filled with random integers. Variable: A Tensor of the specified shape filled with random integers.
Raises: Raises:
TypeError: Randint's low must less then high. TypeError: If shape's type is not list, tuple or Variable.
TypeError: If dtype is not int32 or int64.
ValueError: If low is not large then high; If low is 0, and high is None.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.fluid as fluid import numpy as np
# example 1:
# attr shape is a list which doesn't contain tensor Variable.
result_1 = paddle.randint(low=-5, high=5, shape=[3, 4], dtype="int64")
# example 2:
# attr shape is a list which contains tensor Variable.
dim_1 = fluid.layers.fill_constant([1],"int64",3)
dim_2 = fluid.layers.fill_constant([1],"int32",5)
result_2 = paddle.randint(low=-5, high=5, shape=[dim_1, dim_2], dtype="int32")
# example 3:
# attr shape is a Variable, the data type must be int64 or int32.
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64")
result_3 = paddle.randint(low=-5, high=5, shape=var_shape, dtype="int32")
var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32")
result_4 = paddle.randint(low=-5, high=5, shape=var_shape_int32, dtype="int64")
# example 4:
# Input only one parameter
# low=0, high=10, shape=[1], dtype='int64'
result_4 = paddle.randint(10)
"""
def get_new_shape_tensor(list_shape):
new_shape_tensor = []
for dim in list_shape:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_shape_tensor.append(dim)
else:
assert isinstance(dim, int) or isinstance(dim, long)
temp_out = helper.create_variable_for_type_inference('int64')
fill_constant([1], 'int64', dim, force_cpu=True, out=temp_out)
new_shape_tensor.append(temp_out)
return new_shape_tensor
def get_attr_shape(list_shape):
unk_dim_idx = -1
attrs_shape = []
for dim_idx, dim_size in enumerate(list_shape):
if isinstance(dim_size, Variable):
attrs_shape.append(-1)
else:
attrs_shape.append(dim_size)
assert dim_size > 0, (
"Each dimension size given in shape must not be negative "
"except one unknown dimension.")
return attrs_shape
if dtype is None:
dtype = 'int64'
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint')
inputs = dict()
attrs = dict()
if shape is None:
shape = [1]
assert len(shape) > 0, ("The size of argument(shape) can't be zero.")
helper = LayerHelper("randint", **locals()) paddle.enable_imperative()
if in_dygraph_mode(): # example 1:
attrs['shape'] = shape # attr shape is a list which doesn't contain tensor Variable.
else: result_1 = paddle.randint(low=-5, high=5, shape=[3])
if isinstance(shape, Variable): # [0 -3 2]
shape.stop_gradient = True
inputs["ShapeTensor"] = shape # example 2:
elif isinstance(shape, (list, tuple)): # attr shape is a list which contains tensor Variable.
assert len(shape) > 0, ( dim_1 = paddle.fill_constant([1],"int64",2)
"The size of argument(shape) can't be zero.") dim_2 = paddle.fill_constant([1],"int32",3)
if utils._contain_var(shape): result_2 = paddle.randint(low=-5, high=5, shape=[dim_1, dim_2], dtype="int32")
inputs['ShapeTensorList'] = get_new_shape_tensor(shape) print(result_2.numpy())
else: # [[ 0 -1 -3]
attrs["shape"] = get_attr_shape(shape) # [ 4 -2 0]]
check_type(shape, 'shape', (list, tuple, Variable), 'randint')
# example 3:
# attr shape is a Variable
var_shape = paddle.imperative.to_variable(np.array([3]))
result_3 = paddle.randint(low=-5, high=5, shape=var_shape)
# [-2 2 3]
# example 4:
# data type is int32
result_4 = paddle.randint(low=-5, high=5, shape=[3], dtype='int32')
# [-5 4 -4]
# example 5:
# Input only one parameter
# low=0, high=10, shape=[1], dtype='int64'
result_5 = paddle.randint(10)
# [7]
"""
if high is None: if high is None:
high = low high = low
low = 0 low = 0
attrs['low'] = low if dtype is None:
attrs['high'] = high dtype = 'int64'
attrs['seed'] = seed if not isinstance(dtype, core.VarDesc.VarType):
if (low >= high): dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape)
return core.ops.randint('shape', shape, 'low', low, 'high', high,
'seed', 0, 'dtype', dtype)
check_type(shape, 'shape', (list, tuple, Variable), 'randint')
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint')
if low >= high:
raise ValueError( raise ValueError(
"randint's low must less then high, but received low = {0}, " "randint's low must less then high, but received low = {0}, "
"high = {1}".format(low, high)) "high = {1}".format(low, high))
if out is None: inputs = dict()
if name is None: attrs = {'low': low, 'high': high, 'seed': 0, 'dtype': dtype}
out = helper.create_variable_for_type_inference(dtype=dtype) utils._get_shape_tensor_inputs(
else: inputs=inputs, attrs=attrs, shape=shape, op_type='randint')
out = helper.create_variable(
name=name, dtype=dtype, persistable=False) helper = LayerHelper("randint", **locals())
else: out = helper.create_variable_for_type_inference(dtype=dtype)
check_dtype(dtype, 'dtype', helper.append_op(
convert_dtype(out.dtype), 'randint', type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs)
"(The dtype in randint must be the same with out's dtype.)")
attrs['dtype'] = out.dtype
out.stop_gradient = stop_gradient
if device is None:
helper.append_op(
type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs)
else:
with device_guard(device):
helper.append_op(
type='randint',
inputs=inputs,
outputs={'Out': out},
attrs=attrs)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册