未验证 提交 6a09b8f1 编写于 作者: Z zhupengyang 提交者: GitHub

erase Raises and refine doce of random functions (#26901)

上级 559d9f2b
......@@ -132,6 +132,28 @@ def check_dtype(input_dtype,
extra_message))
def check_shape(shape,
op_name,
expected_shape_type=(list, tuple, Variable),
expected_element_type=(int, Variable),
expected_tensor_dtype=('int32', 'int64')):
# See NOTE [ Why skip dynamic graph check ]
if in_dygraph_mode():
return
check_type(shape, 'shape', expected_shape_type, op_name)
if expected_element_type is not None and not isinstance(shape, Variable):
for item in shape:
check_type(item, 'element of shape', expected_element_type, op_name)
if expected_tensor_dtype is not None and isinstance(item, Variable):
check_dtype(
item.dtype, 'element of shape', expected_tensor_dtype,
op_name,
'If element of shape is Tensor, its data type should be {}'.
format(', '.join(expected_tensor_dtype)))
if expected_tensor_dtype is not None and isinstance(shape, Variable):
check_dtype(shape.dtype, 'shape', expected_tensor_dtype, op_name)
class DataToLoDTensorConverter(object):
def __init__(self, place, lod_level, shape, dtype):
self.place = place
......
......@@ -10610,7 +10610,7 @@ def gaussian_random(shape,
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape)
shape = utils.convert_shape_to_list(shape)
return core.ops.gaussian_random('shape', shape, 'mean',
float(mean), 'std',
float(std), 'seed', seed, 'dtype',
......@@ -10627,7 +10627,7 @@ def gaussian_random(shape,
'dtype': dtype,
'use_mkldnn': False
}
utils._get_shape_tensor_inputs(
utils.get_shape_tensor_inputs(
inputs=inputs,
attrs=attrs,
shape=shape,
......@@ -15116,7 +15116,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0,
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape)
shape = utils.convert_shape_to_list(shape)
return core.ops.uniform_random('shape', shape, 'min',
float(min), 'max',
float(max), 'seed', seed, 'dtype', dtype)
......@@ -15126,7 +15126,7 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0,
inputs = dict()
attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype}
utils._get_shape_tensor_inputs(
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random/rand')
helper = LayerHelper("uniform_random", **locals())
......
......@@ -694,7 +694,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
attrs['str_value'] = str(float(value))
if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape)
shape = utils.convert_shape_to_list(shape)
if out is None:
out = _varbase_creator(dtype=dtype)
......@@ -731,7 +731,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
'fill_constant')
helper = LayerHelper("fill_constant", **locals())
utils._get_shape_tensor_inputs(
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='fill_constant')
if out is None:
......
......@@ -282,7 +282,7 @@ def _contain_var(list_or_tuple):
return False
def _get_shape_tensor_inputs(inputs, attrs, shape, op_type):
def get_shape_tensor_inputs(inputs, attrs, shape, op_type):
from .tensor import fill_constant, cast
def _get_attr_shape(list_shape):
......@@ -347,7 +347,7 @@ def _convert_to_tensor_list(old_list, dtype="int32"):
return new_list_tensor
def _convert_shape_to_list(shape):
def convert_shape_to_list(shape):
"""
Convert shape(list, tuple, variable) to list in imperative mode
"""
......
......@@ -241,18 +241,18 @@ class TestGaussianRandomAPI(unittest.TestCase):
def test_default_fp_16():
paddle.framework.set_default_dtype('float16')
paddle.tensor.random.gaussian_random([2, 3])
paddle.tensor.random.gaussian([2, 3])
self.assertRaises(TypeError, test_default_fp_16)
def test_default_fp_32():
paddle.framework.set_default_dtype('float32')
out = paddle.tensor.random.gaussian_random([2, 3])
out = paddle.tensor.random.gaussian([2, 3])
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP32)
def test_default_fp_64():
paddle.framework.set_default_dtype('float64')
out = paddle.tensor.random.gaussian_random([2, 3])
out = paddle.tensor.random.gaussian([2, 3])
self.assertEqual(out.dtype, fluid.core.VarDesc.VarType.FP64)
test_default_fp_64()
......
......@@ -58,6 +58,11 @@ class TestRandintOpError(unittest.TestCase):
self.assertRaises(TypeError, paddle.randint, 5, dtype='float32')
self.assertRaises(ValueError, paddle.randint, 5, 5)
self.assertRaises(ValueError, paddle.randint, -5)
self.assertRaises(TypeError, paddle.randint, 5, shape=['2'])
shape_tensor = paddle.static.data('X', [1])
self.assertRaises(TypeError, paddle.randint, 5, shape=shape_tensor)
self.assertRaises(
TypeError, paddle.randint, 5, shape=[shape_tensor])
class TestRandintOp_attr_tensorlist(OpTest):
......
......@@ -14,17 +14,12 @@
# TODO: define random functions
import numpy as np
from ..fluid import core
from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator, Variable, convert_np_dtype_to_dtype_
from ..fluid.layers.layer_function_generator import templatedoc
from ..fluid.framework import in_dygraph_mode, Variable, convert_np_dtype_to_dtype_
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, check_shape
from ..fluid.layers import utils
from ..fluid.layers.tensor import fill_constant
import paddle
import warnings
from ..fluid.io import shuffle #DEFINE_ALIAS
......@@ -94,26 +89,26 @@ def bernoulli(x, name=None):
return out
def gaussian_random(shape, mean=0.0, std=1.0, dtype=None, name=None):
def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None):
"""
This OP returns a Tensor filled with random values sampled from a Gaussian
distribution, with ``shape`` and ``dtype``.
Args:
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape``
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
mean(float|int, optional): Mean of the output tensor, default is 0.0.
std(float|int, optional): Standard deviation of the output tensor, default
mean (float|int, optional): Mean of the output tensor, default is 0.0.
std (float|int, optional): Standard deviation of the output tensor, default
is 1.0.
seed(int, optional): ${seed_comment}
dtype(str|np.dtype, optional): The data type of the output Tensor.
seed (int, optional): Random seed of generator.
dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype``
for details).
name(str, optional): The default value is None. Normally there is no
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
......@@ -121,26 +116,26 @@ def gaussian_random(shape, mean=0.0, std=1.0, dtype=None, name=None):
Tensor: A Tensor filled with random values sampled from a Gaussian
distribution, with ``shape`` and ``dtype``.
"""
op_type_for_check = 'gaussian/standard_normal/randn/normal'
seed = 0
if dtype is None:
dtype = paddle.framework.get_default_dtype()
if dtype not in ['float32', 'float64']:
raise TypeError(
"gaussian_random only supports [float32, float64], but the default dtype is %s"
% dtype)
"{} only supports [float32, float64], but the default dtype is {}"
.format(op_type_for_check, dtype))
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
seed = 0
op_type_for_check = 'gaussian_random/standard_normal/randn/normal'
if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape)
shape = utils.convert_shape_to_list(shape)
return core.ops.gaussian_random('shape', shape, 'mean',
float(mean), 'std',
float(std), 'seed', seed, 'dtype',
dtype)
check_type(shape, 'shape', (list, tuple, Variable), op_type_for_check)
check_shape(shape, op_type_for_check)
check_dtype(dtype, 'dtype', ['float32', 'float64'], op_type_for_check)
inputs = {}
......@@ -151,10 +146,10 @@ def gaussian_random(shape, mean=0.0, std=1.0, dtype=None, name=None):
'dtype': dtype,
'use_mkldnn': False
}
utils._get_shape_tensor_inputs(
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type=op_type_for_check)
helper = LayerHelper('gaussian_random', **locals())
helper = LayerHelper('gaussian', **locals())
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='gaussian_random',
......@@ -172,12 +167,12 @@ def standard_normal(shape, dtype=None, name=None):
and ``dtype``.
Args:
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape``
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
dtype(str|np.dtype, optional): The data type of the output Tensor.
dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype``
for details).
......@@ -189,10 +184,6 @@ def standard_normal(shape, dtype=None, name=None):
normal distribution with mean 0 and standard deviation 1, with
``shape`` and ``dtype``.
Raises:
TypeError: If ``shape`` is not list, tuple, Tensor.
TypeError: If ``dtype`` is not float32, float64.
Examples:
.. code-block:: python
......@@ -202,14 +193,14 @@ def standard_normal(shape, dtype=None, name=None):
paddle.disable_static()
# example 1: attr shape is a list which doesn't contain Tensor.
result_1 = paddle.standard_normal(shape=[2, 3])
out1 = paddle.standard_normal(shape=[2, 3])
# [[-2.923464 , 0.11934398, -0.51249987], # random
# [ 0.39632758, 0.08177969, 0.2692008 ]] # random
# example 2: attr shape is a list which contains Tensor.
dim_1 = paddle.fill_constant([1], "int64", 2)
dim_2 = paddle.fill_constant([1], "int32", 3)
result_2 = paddle.standard_normal(shape=[dim_1, dim_2, 2])
dim1 = paddle.full([1], 2, "int64")
dim2 = paddle.full([1], 3, "int32")
out2 = paddle.standard_normal(shape=[dim1, dim2, 2])
# [[[-2.8852394 , -0.25898588], # random
# [-0.47420555, 0.17683524], # random
# [-0.7989969 , 0.00754541]], # random
......@@ -218,21 +209,13 @@ def standard_normal(shape, dtype=None, name=None):
# [ 0.8086993 , 0.6868893 ]]] # random
# example 3: attr shape is a Tensor, the data type must be int64 or int32.
var_shape = paddle.to_tensor(np.array([2, 3]))
result_3 = paddle.standard_normal(var_shape)
shape_tensor = paddle.to_tensor(np.array([2, 3]))
out3 = paddle.standard_normal(shape_tensor)
# [[-2.878077 , 0.17099959, 0.05111201] # random
# [-0.3761474, -1.044801 , 1.1870178 ]] # random
"""
if dtype is None:
dtype = paddle.framework.get_default_dtype()
if dtype not in ['float32', 'float64']:
raise TypeError(
"standard_normal only supports [float32, float64], but the default dtype is %s"
% dtype)
return gaussian_random(
shape=shape, mean=0.0, std=1.0, dtype=dtype, name=name)
return gaussian(shape=shape, mean=0.0, std=1.0, dtype=dtype, name=name)
randn = standard_normal
......@@ -306,16 +289,7 @@ def normal(mean=0.0, std=1.0, shape=None, name=None):
"If std is Tensor, it's data type only support float32, float64."
)
if shape is not None:
if isinstance(shape, (list, tuple)):
for item in shape:
check_type(item, 'shape', (int), 'normal',
'Elements of shape should be int.')
elif isinstance(shape, Variable):
check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'normal')
else:
assert TypeError(
'If mean and std are all not Tensor, shape should be list, tuple, Tensor.'
)
check_shape(shape, 'normal')
if isinstance(mean, Variable):
if isinstance(std, Variable):
......@@ -330,7 +304,7 @@ def normal(mean=0.0, std=1.0, shape=None, name=None):
mean = float(mean)
out = standard_normal(paddle.shape(std), std.dtype, name)
else:
return gaussian_random(shape=shape, mean=mean, std=std, name=name)
return gaussian(shape=shape, mean=mean, std=std, name=name)
out = out * std + mean
if not in_dygraph_mode():
......@@ -426,7 +400,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape)
shape = utils.convert_shape_to_list(shape)
return core.ops.uniform_random('shape', shape, 'min',
float(min), 'max',
float(max), 'seed', seed, 'dtype', dtype)
......@@ -436,7 +410,7 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
inputs = dict()
attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype}
utils._get_shape_tensor_inputs(
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random/rand')
helper = LayerHelper("uniform_random", **locals())
......@@ -449,29 +423,26 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
def randint(low=0, high=None, shape=[1], dtype=None, name=None):
"""
:alias_main: paddle.randint
:alias: paddle.tensor.randint, paddle.tensor.random.randint
This OP returns a Tensor filled with random integers from a discrete uniform
distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``.
If ``high`` is None (the default), the range is [0, ``low``).
Args:
low(int): The lower bound on the range of random values to generate.
low (int): The lower bound on the range of random values to generate.
The ``low`` is included in the range. If ``high`` is None, the
range is [0, ``low``). Default is 0.
high(int, optional): The upper bound on the range of random values to
high (int, optional): The upper bound on the range of random values to
generate, the ``high`` is excluded in the range. Default is None
(see above for behavior if high = None). Default is None.
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape``
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64). Default is [1].
dtype(str|np.dtype, optional): The data type of the
dtype (str|np.dtype, optional): The data type of the
output tensor. Supported data types: int32, int64. If ``dytpe``
is None, the data type is int64. Default is None.
name(str, optional): The default value is None. Normally there is no
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
......@@ -479,12 +450,6 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
Tensor: A Tensor filled with random integers from a discrete uniform
distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``.
Raises:
TypeError: If ``shape`` is not list, tuple, Tensor.
TypeError: If ``dtype`` is not int32, int64.
ValueError: If ``high`` is not greater then ``low``; If ``high`` is
None, and ``low`` is not greater than 0.
Examples:
.. code-block:: python
......@@ -495,32 +460,32 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
# example 1:
# attr shape is a list which doesn't contain Tensor.
result_1 = paddle.randint(low=-5, high=5, shape=[3])
out1 = paddle.randint(low=-5, high=5, shape=[3])
# [0, -3, 2] # random
# example 2:
# attr shape is a list which contains Tensor.
dim_1 = paddle.fill_constant([1], "int64", 2)
dim_2 = paddle.fill_constant([1], "int32", 3)
result_2 = paddle.randint(low=-5, high=5, shape=[dim_1, dim_2], dtype="int32")
dim1 = paddle.full([1], 2, "int64")
dim2 = paddle.full([1], 3, "int32")
out2 = paddle.randint(low=-5, high=5, shape=[dim1, dim2], dtype="int32")
# [[0, -1, -3], # random
# [4, -2, 0]] # random
# example 3:
# attr shape is a Tensor
var_shape = paddle.to_variable(np.array([3]))
result_3 = paddle.randint(low=-5, high=5, shape=var_shape)
shape_tensor = paddle.to_tensor(np.array([3]))
out3 = paddle.randint(low=-5, high=5, shape=shape_tensor)
# [-2, 2, 3] # random
# example 4:
# data type is int32
result_4 = paddle.randint(low=-5, high=5, shape=[3], dtype='int32')
out4 = paddle.randint(low=-5, high=5, shape=[3], dtype='int32')
# [-5, 4, -4] # random
# example 5:
# Input only one parameter
# low=0, high=10, shape=[1], dtype='int64'
result_5 = paddle.randint(10)
out5 = paddle.randint(10)
# [7] # random
"""
......@@ -537,11 +502,11 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape)
shape = utils.convert_shape_to_list(shape)
return core.ops.randint('shape', shape, 'low', low, 'high', high,
'seed', 0, 'dtype', dtype)
check_type(shape, 'shape', (list, tuple, Variable), 'randint')
check_shape(shape, 'randint')
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint')
if low >= high:
raise ValueError(
......@@ -550,7 +515,7 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
inputs = dict()
attrs = {'low': low, 'high': high, 'seed': 0, 'dtype': dtype}
utils._get_shape_tensor_inputs(
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='randint')
helper = LayerHelper("randint", **locals())
......@@ -560,21 +525,17 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
return out
@templatedoc()
def randperm(n, dtype="int64", name=None):
"""
:alias_main: paddle.randperm
:alias: paddle.tensor.randperm, paddle.tensor.random.randperm
This OP returns a 1-D Tensor filled with random permutation values from 0
to n-1, with ``dtype``.
Args:
n(int): The upper bound (exclusive), and it should be greater than 0.
dtype(str|np.dtype, optional): The data type of
n (int): The upper bound (exclusive), and it should be greater than 0.
dtype (str|np.dtype, optional): The data type of
the output Tensor. Supported data types: int32, int64, float32,
float64. Default is int64.
name(str, optional): The default value is None. Normally there is no
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
......@@ -582,10 +543,6 @@ def randperm(n, dtype="int64", name=None):
Tensor: A 1-D Tensor filled with random permutation values from 0
to n-1, with ``dtype``.
Raises:
ValueError: If ``n`` is not greater than 0.
TypeError: If ``dtype`` is not int32, int64, float32, float64.
Examples:
.. code-block:: python
......@@ -593,10 +550,10 @@ def randperm(n, dtype="int64", name=None):
paddle.disable_static()
result_1 = paddle.randperm(5)
out1 = paddle.randperm(5)
# [4, 1, 2, 3, 0] # random
result_2 = paddle.randperm(7, 'int32')
out2 = paddle.randperm(7, 'int32')
# [1, 6, 2, 0, 4, 3, 5] # random
"""
......@@ -622,32 +579,20 @@ def randperm(n, dtype="int64", name=None):
def rand(shape, dtype=None, name=None):
"""
:alias_main: paddle.rand
:alias: paddle.tensor.rand, paddle.tensor.random.rand
This OP returns a Tensor filled with random values sampled from a uniform
distribution in the range [0, 1), with ``shape`` and ``dtype``.
Examples:
::
Input:
shape = [1, 2]
Output:
result=[[0.8505902, 0.8397286]]
Args:
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape``
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
dtype(str|np.dtype, optional): The data type of the output Tensor.
dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype``
for details).
name(str, optional): The default value is None. Normally there is no
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
......@@ -655,10 +600,6 @@ def rand(shape, dtype=None, name=None):
Tensor: A Tensor filled with random values sampled from a uniform
distribution in the range [0, 1), with ``shape`` and ``dtype``.
Raises:
TypeError: If ``shape`` is not list, tuple, Tensor.
ValueError: If ``dtype`` is not float32, float64.
Examples:
.. code-block:: python
......@@ -667,14 +608,14 @@ def rand(shape, dtype=None, name=None):
paddle.disable_static()
# example 1: attr shape is a list which doesn't contain Tensor.
result_1 = paddle.rand(shape=[2, 3])
out1 = paddle.rand(shape=[2, 3])
# [[0.451152 , 0.55825245, 0.403311 ], # random
# [0.22550228, 0.22106001, 0.7877319 ]] # random
# example 2: attr shape is a list which contains Tensor.
dim_1 = paddle.fill_constant([1], "int64", 2)
dim_2 = paddle.fill_constant([1], "int32", 3)
result_2 = paddle.rand(shape=[dim_1, dim_2, 2])
dim1 = paddle.full([1], 2, "int64")
dim2 = paddle.full([1], 3, "int32")
out2 = paddle.rand(shape=[dim1, dim2, 2])
# [[[0.8879919 , 0.25788337], # random
# [0.28826773, 0.9712097 ], # random
# [0.26438272, 0.01796806]], # random
......@@ -683,19 +624,10 @@ def rand(shape, dtype=None, name=None):
# [0.870881 , 0.2984597 ]]] # random
# example 3: attr shape is a Tensor, the data type must be int64 or int32.
var_shape = paddle.to_variable(np.array([2, 3]))
result_3 = paddle.rand(var_shape)
shape_tensor = paddle.to_tensor(np.array([2, 3]))
out2 = paddle.rand(shape_tensor)
# [[0.22920267, 0.841956 , 0.05981819], # random
# [0.4836288 , 0.24573246, 0.7516129 ]] # random
"""
if dtype is None:
dtype = paddle.framework.get_default_dtype()
if dtype not in ['float32', 'float64']:
raise TypeError(
"rand only supports [float32, float64], but the default dtype is %s"
% dtype)
out = uniform(shape, dtype, min=0.0, max=1.0, name=name)
out.stop_gradient = True
return out
return uniform(shape, dtype, min=0.0, max=1.0, name=name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册