未验证 提交 eb3173e2 编写于 作者: Z zhupengyang 提交者: GitHub

rand API: remove out, device, stop_gradient; add name (#25246)

上级 22720a15
...@@ -98,7 +98,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -98,7 +98,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
return; return;
} }
if (!(ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList"))) { if (!ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList")) {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
shape.size(), 0UL, shape.size(), 0UL,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -10487,29 +10487,24 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'): ...@@ -10487,29 +10487,24 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
# [2.8675377 , 2.2279181 , 0.79029655, 2.8447366 ]], dtype=float32) # [2.8675377 , 2.2279181 , 0.79029655, 2.8447366 ]], dtype=float32)
""" """
helper = LayerHelper('gaussian_random', **locals()) check_type(shape, 'shape', (list, tuple, Variable), 'gaussian_random')
out = helper.create_variable_for_type_inference(dtype) if not isinstance(dtype, core.VarDesc.VarType):
if not isinstance(shape, (list, tuple, Variable)): dtype = convert_np_dtype_to_dtype_(dtype)
raise TypeError( check_dtype(dtype, 'dtype', ['float32', 'float64'], 'gaussian_random')
"The type of 'shape' in fill_constant must be Variable, list or tuple, but "
"received %s." % (type(shape))) inputs = {}
c_dtype = convert_np_dtype_to_dtype_(dtype)
attrs = { attrs = {
'mean': mean, 'mean': mean,
'std': std, 'std': std,
'seed': seed, 'seed': seed,
'dtype': c_dtype, 'dtype': dtype,
'use_mkldnn': False 'use_mkldnn': False
} }
inputs = {}
utils._get_shape_tensor_inputs( utils._get_shape_tensor_inputs(
inputs=inputs, inputs=inputs, attrs=attrs, shape=shape, op_type='gaussian_random')
helper=helper,
attrs=attrs,
shape=shape,
op_type='gaussian_random')
helper = LayerHelper('gaussian_random', **locals())
out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='gaussian_random', type='gaussian_random',
inputs=inputs, inputs=inputs,
...@@ -14937,7 +14932,8 @@ def gather_tree(ids, parents): ...@@ -14937,7 +14932,8 @@ def gather_tree(ids, parents):
@templatedoc() @templatedoc()
def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0,
name=None):
""" """
This OP initializes a variable with random values sampled from a This OP initializes a variable with random values sampled from a
uniform distribution in the range [min, max). uniform distribution in the range [min, max).
...@@ -14952,18 +14948,24 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): ...@@ -14952,18 +14948,24 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
result=[[0.8505902, 0.8397286]] result=[[0.8505902, 0.8397286]]
Args: Args:
shape (list|tuple|Variable): The shape of the output Tensor, if the shape is a list or tuple, shape (list|tuple|Variable): The shape of the output Tensor, if the
its elements can be an integer shape is a list or tuple, its elements can be an integer or a
or a Tensor with the shape [1], and the type of the Tensor must be int32 or int64. Tensor with the shape [1], and the type of the Tensor must be
If the shape is a Variable, it is a 1-D Tensor, and the type of the Tensor must be int32 or int64. int32 or int64. If the shape is a Variable, it is a 1-D Tensor, and
dtype(np.dtype|core.VarDesc.VarType|str, optional): The type of the output Tensor. Supported data types: float32, float64. the type of the Tensor must be int32 or int64.
Default: float32. dtype(np.dtype|core.VarDesc.VarType|str, optional): The type of the
min (float, optional): The lower bound on the range of random values to generate, the min is included in the range. Default -1.0. output Tensor. Supported data types: float32, float64. Default: float32.
max (float, optional): The upper bound on the range of random values to generate, the max is excluded in the range. Default 1.0. min (float, optional): The lower bound on the range of random values
seed (int, optional): Random seed used for generating samples. 0 means use a to generate, the min is included in the range. Default -1.0.
seed generated by the system. Note that if seed is not 0, this max (float, optional): The upper bound on the range of random values
operator will always generate the same random numbers every time. to generate, the max is excluded in the range. Default 1.0.
Default 0. seed (int, optional): Random seed used for generating samples. 0 means
use a seed generated by the system. Note that if seed is not 0,
this operator will always generate the same random numbers every
time. Default 0.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable: A Tensor of the specified shape filled with uniform_random values. Variable: A Tensor of the specified shape filled with uniform_random values.
...@@ -14993,62 +14995,30 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): ...@@ -14993,62 +14995,30 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0):
var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32") var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32")
result_4 = fluid.layers.uniform_random(var_shape_int32) result_4 = fluid.layers.uniform_random(var_shape_int32)
""" """
check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random')
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random')
def get_new_shape_tensor(list_shape): if in_dygraph_mode():
new_shape_tensor = [] shape = utils._convert_shape_to_list(shape)
for dim in list_shape: return core.ops.uniform_random('shape', shape, 'min',
if isinstance(dim, Variable): float(min), 'max',
dim.stop_gradient = True float(max), 'seed', seed, 'dtype', dtype)
new_shape_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = helper.create_variable_for_type_inference('int64')
fill_constant([1], 'int64', dim, force_cpu=True, out=temp_out)
new_shape_tensor.append(temp_out)
return new_shape_tensor
def get_attr_shape(list_shape): check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random')
unk_dim_idx = -1 check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random')
attrs_shape = []
for dim_idx, dim_size in enumerate(list_shape):
if isinstance(dim_size, Variable):
attrs_shape.append(-1)
else:
attrs_shape.append(dim_size)
assert dim_size > 0, (
"Each dimension size given in shape must not be negative "
"except one unknown dimension.")
return attrs_shape
helper = LayerHelper("uniform_random", **locals())
inputs = dict() inputs = dict()
attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype} attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype}
if in_dygraph_mode(): utils._get_shape_tensor_inputs(
attrs['shape'] = shape inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random')
else:
if isinstance(shape, Variable):
shape.stop_gradient = True
inputs["ShapeTensor"] = shape
elif isinstance(shape, (list, tuple)):
assert len(shape) > 0, (
"The size of argument(shape) can't be zero.")
attrs["shape"] = get_attr_shape(shape)
if utils._contain_var(shape):
inputs['ShapeTensorList'] = get_new_shape_tensor(shape)
helper = LayerHelper("uniform_random", **locals())
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="uniform_random", inputs=inputs, attrs=attrs, type="uniform_random", inputs=inputs, attrs=attrs,
outputs={"Out": out}) outputs={"Out": out})
return out
return helper.append_activation(out)
def unbind(input, axis=0): def unbind(input, axis=0):
......
...@@ -685,12 +685,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -685,12 +685,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
attrs['str_value'] = str(float(value)) attrs['str_value'] = str(float(value))
if in_dygraph_mode(): if in_dygraph_mode():
if isinstance(shape, (list, tuple)): shape = utils._convert_shape_to_list(shape)
shape = list(
map(lambda x: x.numpy()[0] if isinstance(x, Variable) else x,
shape))
else:
shape = list(shape.numpy().astype(int))
if out is None: if out is None:
out = _varbase_creator(dtype=dtype) out = _varbase_creator(dtype=dtype)
...@@ -719,12 +714,8 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -719,12 +714,8 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
'fill_constant') 'fill_constant')
helper = LayerHelper("fill_constant", **locals()) helper = LayerHelper("fill_constant", **locals())
inputs = utils._get_shape_tensor_inputs( utils._get_shape_tensor_inputs(
inputs=inputs, inputs=inputs, attrs=attrs, shape=shape, op_type='fill_constant')
helper=helper,
attrs=attrs,
shape=shape,
op_type='fill_constant')
if out is None: if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
......
...@@ -282,7 +282,7 @@ def _contain_var(list_or_tuple): ...@@ -282,7 +282,7 @@ def _contain_var(list_or_tuple):
return False return False
def _get_shape_tensor_inputs(inputs, helper, attrs, shape, op_type): def _get_shape_tensor_inputs(inputs, attrs, shape, op_type):
from .tensor import fill_constant, cast from .tensor import fill_constant, cast
def _get_attr_shape(list_shape): def _get_attr_shape(list_shape):
...@@ -295,7 +295,7 @@ def _get_shape_tensor_inputs(inputs, helper, attrs, shape, op_type): ...@@ -295,7 +295,7 @@ def _get_shape_tensor_inputs(inputs, helper, attrs, shape, op_type):
return attr_shape return attr_shape
def _get_shape_tensor(list_shape): def _get_shape_tensor(list_shape):
new_shape_tensor = [] shape_tensor_list = []
for idx, dim in enumerate(list_shape): for idx, dim in enumerate(list_shape):
if isinstance(dim, Variable): if isinstance(dim, Variable):
dim.stop_gradient = True dim.stop_gradient = True
...@@ -305,11 +305,11 @@ def _get_shape_tensor_inputs(inputs, helper, attrs, shape, op_type): ...@@ -305,11 +305,11 @@ def _get_shape_tensor_inputs(inputs, helper, attrs, shape, op_type):
'(When type of shape in' + op_type + 'is list or tuple.)') '(When type of shape in' + op_type + 'is list or tuple.)')
if convert_dtype(dim.dtype) == 'int64': if convert_dtype(dim.dtype) == 'int64':
dim = cast(x=dim, dtype='int32') dim = cast(x=dim, dtype='int32')
new_shape_tensor.append(dim) shape_tensor_list.append(dim)
else: else:
temp_out = fill_constant([1], 'int32', dim, force_cpu=True) temp_out = fill_constant([1], 'int32', dim, force_cpu=True)
new_shape_tensor.append(temp_out) shape_tensor_list.append(temp_out)
return new_shape_tensor return shape_tensor_list
if isinstance(shape, Variable): if isinstance(shape, Variable):
shape.stop_gradient = True shape.stop_gradient = True
...@@ -325,8 +325,8 @@ def _get_shape_tensor_inputs(inputs, helper, attrs, shape, op_type): ...@@ -325,8 +325,8 @@ def _get_shape_tensor_inputs(inputs, helper, attrs, shape, op_type):
attrs["shape"] = _get_attr_shape(shape) attrs["shape"] = _get_attr_shape(shape)
if _contain_var(shape): if _contain_var(shape):
inputs['ShapeTensorList'] = _get_shape_tensor(shape) inputs['ShapeTensorList'] = _get_shape_tensor(shape)
else:
return inputs raise TypeError("Shape only supports Variable, or list, or tuple.")
def _convert_to_tensor_list(old_list, dtype="int32"): def _convert_to_tensor_list(old_list, dtype="int32"):
...@@ -345,3 +345,16 @@ def _convert_to_tensor_list(old_list, dtype="int32"): ...@@ -345,3 +345,16 @@ def _convert_to_tensor_list(old_list, dtype="int32"):
temp_out = fill_constant([1], dtype, ele, force_cpu=True) temp_out = fill_constant([1], dtype, ele, force_cpu=True)
new_list_tensor.append(temp_out) new_list_tensor.append(temp_out)
return new_list_tensor return new_list_tensor
def _convert_shape_to_list(shape):
"""
Convert shape(list, tuple, variable) to list in imperative mode
"""
if isinstance(shape, (list, tuple)):
shape = list(
map(lambda x: x.numpy()[0] if isinstance(x, Variable) else x,
shape))
else:
shape = list(shape.numpy().astype(int))
return shape
...@@ -47,71 +47,73 @@ class TestRandOpError(unittest.TestCase): ...@@ -47,71 +47,73 @@ class TestRandOpError(unittest.TestCase):
self.assertRaises(TypeError, test_dtype) self.assertRaises(TypeError, test_dtype)
def test_shape_list():
rand(shape=[2.])
self.assertRaises(TypeError, test_shape_list)
def test_shape_list2():
rand(shape=[2, 3.])
self.assertRaises(TypeError, test_shape_list2)
def test_device():
rand(shape=[3, 4], device='device')
self.assertRaises(ValueError, test_device)
class TestRandOp(unittest.TestCase): class TestRandOp(unittest.TestCase):
""" """
This class test the common usages of randop. This class test the common usages of randop.
""" """
def test_run(self): def run_net(self, use_cuda=False):
use_cuda = False
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
train_program = fluid.Program() train_program = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
result_1 = rand(shape=[3, 4]) result_0 = rand([3, 4])
result_1 = rand([3, 4], 'float64')
dim_1 = fluid.layers.fill_constant([1], "int64", 3) dim_1 = fluid.layers.fill_constant([1], "int64", 3)
dim_2 = fluid.layers.fill_constant([1], "int32", 5) dim_2 = fluid.layers.fill_constant([1], "int32", 5)
result_2 = rand(shape=[dim_1, dim_2]) result_2 = rand(shape=[dim_1, dim_2])
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64") var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64")
result_3 = rand(var_shape) result_3 = rand(var_shape)
var_shape_int32 = fluid.data( var_shape_int32 = fluid.data(
name='var_shape_int32', shape=[2], dtype="int32") name='var_shape_int32', shape=[2], dtype="int32")
result_4 = rand(var_shape_int32) result_4 = rand(var_shape_int32)
exe.run(startup_program) exe.run(startup_program)
x1 = np.array([3, 2]).astype('int64') x1 = np.array([3, 2]).astype('int64')
x2 = np.array([4, 3]).astype('int32') x2 = np.array([4, 3]).astype('int32')
ret = exe.run(train_program, ret = exe.run(
feed={"var_shape": x1, train_program,
"var_shape_int32": x2}, feed={"var_shape": x1,
fetch_list=[result_1, result_2, result_3, result_4]) "var_shape_int32": x2},
fetch_list=[result_1, result_1, result_2, result_3, result_4])
def test_run(self):
self.run_net(False)
if core.is_compiled_with_cuda():
self.run_net(True)
class TestRandOpForDygraph(unittest.TestCase): class TestRandOpForDygraph(unittest.TestCase):
""" """
This class test the common usages of randop. This class test the common usages of randop.
""" """
def test_run(self): def run_net(self, use_cuda=False):
use_cuda = False place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
with fluid.dygraph.guard(): with fluid.dygraph.guard(place):
rand(shape=[3, 4]) rand([3, 4])
rand([3, 4], 'float64')
dim_1 = fluid.layers.fill_constant([1], "int64", 3) dim_1 = fluid.layers.fill_constant([1], "int64", 3)
dim_2 = fluid.layers.fill_constant([1], "int32", 5) dim_2 = fluid.layers.fill_constant([1], "int32", 5)
rand(shape=[dim_1, dim_2]) rand(shape=[dim_1, dim_2])
var_shape = fluid.dygraph.to_variable(np.array([3, 4])) var_shape = fluid.dygraph.to_variable(np.array([3, 4]))
rand(var_shape) rand(var_shape)
def test_run(self):
self.run_net(False)
if core.is_compiled_with_cuda():
self.run_net(True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -406,7 +406,7 @@ def randperm(n, ...@@ -406,7 +406,7 @@ def randperm(n,
return out return out
def rand(shape, out=None, dtype=None, device=None, stop_gradient=True): def rand(shape, dtype=None, name=None):
""" """
:alias_main: paddle.rand :alias_main: paddle.rand
:alias: paddle.rand,paddle.tensor.rand,paddle.tensor.random.rand :alias: paddle.rand,paddle.tensor.rand,paddle.tensor.random.rand
...@@ -424,22 +424,19 @@ def rand(shape, out=None, dtype=None, device=None, stop_gradient=True): ...@@ -424,22 +424,19 @@ def rand(shape, out=None, dtype=None, device=None, stop_gradient=True):
result=[[0.8505902, 0.8397286]] result=[[0.8505902, 0.8397286]]
Args: Args:
shape(list|tuple|Variable): Shape of the Tensor to be created. shape(list|tuple|Variable): Shape of the Tensor to be created. The data
The data type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple, type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple,
the elements of it should be integers or Tensors with shape [1]. the elements of it should be integers or Tensors with shape [1]. If
If ``shape`` is a Variable, it should be an 1-D Tensor . ``shape`` is a Variable, it should be an 1-D Tensor .
out(Variable, optional): Optional output which can be any created dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the
Variable that meets the requirements to store the result of operation. output tensor which can be float32, float64, if dytpe is `None`,
if out is None, a new Varibale will be create to store the result. the data type of created tensor is `float32`
dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output tensor name(str, optional): The default value is None. Normally there is no
which can be float32, float64, if dytpe is `None`, the data need for user to set this property. For more information, please
type of created tensor is `float32` refer to :ref:`api_guide_Name`.
device(str, optional): This parameter specifies that the Tensor is created
on the GPU or CPU.
stop_gradient(bool, optional): Indicating if we stop gradient from current(out) Variable,
default value is True.
Returns: Returns:
Variable: A Tensor of the specified shape filled with random numbers from a uniform distribution on the interval [0, 1). Variable: A Tensor of the specified shape filled with random numbers
from a uniform distribution on the interval [0, 1).
Raises: Raises:
TypeError: The shape type should be list or tupple or Variable. TypeError: The shape type should be list or tupple or Variable.
...@@ -447,54 +444,33 @@ def rand(shape, out=None, dtype=None, device=None, stop_gradient=True): ...@@ -447,54 +444,33 @@ def rand(shape, out=None, dtype=None, device=None, stop_gradient=True):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
import paddle.fluid as fluid import numpy as np
# example 1: paddle.enable_imperative()
# attr shape is a list which doesn't contain tensor Variable. # example 1: attr shape is a list which doesn't contain tensor Variable.
result_1 = paddle.rand(shape=[3, 4]) result_1 = paddle.rand(shape=[2, 3])
# [[0.451152 , 0.55825245, 0.403311 ],
# example 2: # [0.22550228, 0.22106001, 0.7877319 ]]
# attr shape is a list which contains tensor Variable.
dim_1 = fluid.layers.fill_constant([1],"int64",3) # example 2: attr shape is a list which contains tensor Variable.
dim_2 = fluid.layers.fill_constant([1],"int32",5) dim_1 = paddle.fill_constant([1], "int64", 2)
result_2 = paddle.rand(shape=[dim_1, dim_2]) dim_2 = paddle.fill_constant([1], "int32", 3)
result_2 = paddle.rand(shape=[dim_1, dim_2, 2])
# [[[0.8879919 0.25788337]
# [0.28826773 0.9712097 ]
# [0.26438272 0.01796806]]
# [[0.33633623 0.28654453]
# [0.79109055 0.7305809 ]
# [0.870881 0.2984597 ]]]
# example 3: attr shape is a Variable, the data type must be int64 or int32.
var_shape = paddle.imperative.to_variable(np.array([2, 3]))
result_3 = paddle.rand(var_shape)
# [[0.22920267 0.841956 0.05981819]
# [0.4836288 0.24573246 0.7516129 ]]
# example 3:
# attr shape is a Variable, the data type must be int64 or int32.
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64")
result_3 = paddle.rand(var_shape)
var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32")
result_4 = paddle.rand(var_shape_int32)
""" """
if dtype is None: if dtype is None:
dtype = 'float32' dtype = 'float32'
return uniform_random(shape, dtype, min=0.0, max=1.0, name=name)
check_dtype(dtype, 'dtype', ['float32', 'float64'], 'rand')
check_type(shape, 'shape', (Variable, list, tuple), 'rand')
if isinstance(shape, Variable):
check_variable_and_dtype(shape, 'shape', ['int32', 'int64'], 'rand')
elif isinstance(shape, (list, tuple)):
for i, _shape in enumerate(shape):
if not isinstance(_shape, Variable):
check_type(_shape, '_shape', (int), 'rand')
else:
check_variable_and_dtype(_shape, 'shape[' + str(i) + ']',
['int32', 'int64'], 'rand')
if device not in [None, 'cpu', 'gpu']:
raise ValueError(
"The input device should in [None, 'cpu', 'gpu'], but received {}".
format(device))
helper = LayerHelper("rand", **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
else:
check_variable_and_dtype(out, 'out', [dtype], 'rand')
out.stop_gradient = stop_gradient
with device_guard(device):
out = uniform_random(shape, dtype, min=0., max=1.0)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册