未验证 提交 e0be4b94 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] support input 0D Tensor as scalar attribute for some api (#47689)

* [Zero-Dim] support input 0D Tensor as scalar attribute for some api

* fix doc
上级 1a145aab
......@@ -259,48 +259,24 @@ void ArangeInferMeta(const MetaTensor& start,
const MetaTensor& end,
const MetaTensor& step,
MetaTensor* out) {
auto start_dims = start.dims();
auto end_dims = end.dims();
auto step_dims = step.dims();
PADDLE_ENFORCE_EQ(
start_dims.size(),
1,
phi::errors::InvalidArgument(
"The dim of the shape of Input(Start) should be 1, but got %d",
start_dims.size()));
PADDLE_ENFORCE_EQ(start_dims[0],
PADDLE_ENFORCE_EQ(phi::product(start.dims()),
1,
phi::errors::InvalidArgument(
"The first dim of the shape of Input(Start) should "
"be 1, but got %d",
start_dims[0]));
PADDLE_ENFORCE_EQ(
end_dims.size(),
1,
phi::errors::InvalidArgument(
"The dim of the shape of Input(End) should be 1, but got %d",
end_dims.size()));
"The numel of Input(start) should be 1, but got %d",
phi::product(start.dims())));
PADDLE_ENFORCE_EQ(
end_dims[0],
1,
phi::errors::InvalidArgument("The first dim of the shape of "
"Input(End) should be 1, but got %d",
end_dims[0]));
PADDLE_ENFORCE_EQ(
step_dims.size(),
1,
phi::errors::InvalidArgument(
"The dim of the shape of Input(Step) should be 1, but got %d",
step_dims.size()));
PADDLE_ENFORCE_EQ(phi::product(end.dims()),
1,
phi::errors::InvalidArgument(
"The numel of Input(end) should be 1, but got %d",
phi::product(end.dims())));
PADDLE_ENFORCE_EQ(step_dims[0],
PADDLE_ENFORCE_EQ(phi::product(step.dims()),
1,
phi::errors::InvalidArgument(
"The first dim of the shape of Input(Step) should "
"be 1, but got %d",
step_dims[0]));
"The numel of Input(step) should be 1, but got %d",
phi::product(step.dims())));
out->set_dims({-1});
out->set_dtype(start.dtype());
}
......@@ -635,27 +611,27 @@ void LinspaceRawInferMeta(const MetaTensor& start,
const MetaTensor& stop,
const MetaTensor& number,
MetaTensor* out) {
auto s_dims = start.dims();
PADDLE_ENFORCE_EQ(
(s_dims.size() == 1) && (s_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Start) must be [1],"
"but received input shape is [%s].",
s_dims));
auto e_dims = stop.dims();
phi::product(start.dims()),
1,
phi::errors::InvalidArgument("The size of Input(start) should be 1,"
"but got %d.",
phi::product(start.dims())));
PADDLE_ENFORCE_EQ(
(e_dims.size() == 1) && (e_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Stop) must be [1],"
"but received input shape is [%s].",
e_dims));
auto step_dims = number.dims();
phi::product(stop.dims()),
1,
phi::errors::InvalidArgument("The size of Input(stop) should be 1,"
"but got %d.",
phi::product(stop.dims())));
PADDLE_ENFORCE_EQ(
(step_dims.size() == 1) && (step_dims[0] == 1),
true,
phi::errors::InvalidArgument("The shape of Input(Num) must be [1],"
"but received input shape is [%s].",
step_dims));
phi::product(number.dims()),
1,
phi::errors::InvalidArgument("The size of Input(number) should be 1,"
"but got %d.",
phi::product(number.dims())));
out->set_dims(phi::make_ddim({-1}));
out->set_dtype(start.dtype());
}
......
......@@ -918,17 +918,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
if force_cpu:
place = core.CPUPlace()
if isinstance(shape, (list, tuple)):
for item in shape:
if not isinstance(item, Variable):
shape = list(
map(
lambda x: x.numpy().flat[0]
if isinstance(x, Variable)
else x,
shape,
)
)
break
shape = utils.convert_shape_to_list(shape)
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
......
......@@ -498,5 +498,304 @@ class TestSundryAPI(unittest.TestCase):
np.testing.assert_array_equal(out.numpy(), np.array([]))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.shape = [
paddle.full([], 2, 'int32'),
paddle.full([], 3, 'int32'),
paddle.full([], 4, 'int32'),
]
def test_slice(self):
starts = [paddle.full([], 1, 'int32'), paddle.full([], 1, 'int32')]
ends = [paddle.full([], 3, 'int32'), paddle.full([], 3, 'int32')]
x = paddle.rand([5, 3, 3])
out = paddle.slice(x, [1, 2], starts, ends)
self.assertEqual(out.shape, [5, 2, 2])
def test_strided_slice(self):
starts = [paddle.full([], 0, 'int32'), paddle.full([], 0, 'int32')]
ends = [paddle.full([], 4, 'int32'), paddle.full([], 4, 'int32')]
strides = [paddle.full([], 2, 'int32'), paddle.full([], 2, 'int32')]
x = paddle.rand([5, 5, 5])
out = paddle.strided_slice(x, [1, 2], starts, ends, strides)
self.assertEqual(out.shape, [5, 2, 2])
def test_linspace(self):
start = paddle.full([], 1.0)
stop = paddle.full([], 5.0)
num = paddle.full([], 5, 'int32')
out = paddle.linspace(start, stop, num)
np.testing.assert_array_equal(out.numpy(), [1.0, 2.0, 3.0, 4.0, 5.0])
def test_arange(self):
start = paddle.full([], 1.0)
stop = paddle.full([], 6.0)
step = paddle.full([], 1.0)
out = paddle.arange(start, stop, step)
np.testing.assert_array_equal(out.numpy(), [1.0, 2.0, 3.0, 4.0, 5.0])
def test_normal(self):
mean = paddle.full([], 0.0)
std = paddle.full([], 0.0)
out = paddle.normal(mean, std)
self.assertEqual(out.shape, [])
out = paddle.normal(0.0, 1.0, [])
self.assertEqual(out.shape, [])
out = paddle.normal(0.0, 1.0, self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_rand(self):
out = paddle.rand([])
self.assertEqual(out.shape, [])
out = paddle.rand(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_randn(self):
out = paddle.randn([])
self.assertEqual(out.shape, [])
out = paddle.randn(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_randint_and_randint_like(self):
out = paddle.randint(-10, 10, [])
self.assertEqual(out.shape, [])
out = paddle.randint_like(out, -10, 10)
self.assertEqual(out.shape, [])
out = paddle.randint(-10, 10, self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_standard_normal(self):
out = paddle.standard_normal([])
self.assertEqual(out.shape, [])
out = paddle.standard_normal(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_uniform(self):
out = paddle.uniform([])
self.assertEqual(out.shape, [])
out = paddle.uniform(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_empty_and_empty_like(self):
out = paddle.empty([])
self.assertEqual(out.shape, [])
out = paddle.empty_like(out)
self.assertEqual(out.shape, [])
out = paddle.empty(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_full_and_full_like(self):
out = paddle.full([], 0.5)
self.assertEqual(out.shape, [])
out = paddle.full_like(out, 0.5)
self.assertEqual(out.shape, [])
out = paddle.full(self.shape, 0.5)
self.assertEqual(out.shape, [2, 3, 4])
def test_ones_and_ones_like(self):
out = paddle.ones([])
self.assertEqual(out.shape, [])
out = paddle.ones_like(out)
self.assertEqual(out.shape, [])
out = paddle.ones(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_zeros_and_zeros_like(self):
out = paddle.zeros([])
self.assertEqual(out.shape, [])
out = paddle.zeros_like(out)
self.assertEqual(out.shape, [])
out = paddle.zeros(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
class TestNoBackwardAPIStatic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.exe = paddle.static.Executor()
self.shape = [
paddle.full([], 2, 'int32'),
paddle.full([], 3, 'int32'),
paddle.full([], 4, 'int32'),
]
def test_slice(self):
starts = [paddle.full([], 1, 'int32'), paddle.full([], 1, 'int32')]
ends = [paddle.full([], 3, 'int32'), paddle.full([], 3, 'int32')]
x = paddle.rand([5, 3, 3])
out = paddle.slice(x, [1, 2], starts, ends)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out]
)[0]
self.assertEqual(res.shape, (5, 2, 2))
def test_strided_slice(self):
starts = [paddle.full([], 0, 'int32'), paddle.full([], 0, 'int32')]
ends = [paddle.full([], 4, 'int32'), paddle.full([], 4, 'int32')]
strides = [paddle.full([], 2, 'int32'), paddle.full([], 2, 'int32')]
x = paddle.rand([5, 5, 5])
out = paddle.strided_slice(x, [1, 2], starts, ends, strides)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out]
)[0]
self.assertEqual(res.shape, (5, 2, 2))
def test_linspace(self):
start = paddle.full([], 1.0)
stop = paddle.full([], 5.0)
num = paddle.full([], 5, 'int32')
out = paddle.linspace(start, stop, num)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out]
)[0]
np.testing.assert_array_equal(res, [1.0, 2.0, 3.0, 4.0, 5.0])
def test_arange(self):
start = paddle.full([], 1.0)
stop = paddle.full([], 6.0)
step = paddle.full([], 1.0)
out = paddle.arange(start, stop, step)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out]
)[0]
np.testing.assert_array_equal(res, [1.0, 2.0, 3.0, 4.0, 5.0])
def test_normal(self):
mean = paddle.full([], 0.0)
std = paddle.full([], 0.0)
out1 = paddle.normal(mean, std)
out2 = paddle.normal(0.0, 1.0, [])
out3 = paddle.normal(0.0, 1.0, self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
def test_rand(self):
out1 = paddle.rand([])
out2 = paddle.rand(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))
def test_randn(self):
out1 = paddle.randn([])
out2 = paddle.randn(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))
def test_randint_and_randint_like(self):
out1 = paddle.randint(-10, 10, [])
out2 = paddle.randint_like(out1, -10, 10)
out3 = paddle.randint(-10, 10, self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
def test_standard_normal(self):
out1 = paddle.standard_normal([])
out2 = paddle.standard_normal(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))
def test_uniform(self):
out1 = paddle.uniform([])
out2 = paddle.uniform(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))
def test_empty_and_empty_like(self):
out1 = paddle.empty([])
out2 = paddle.empty_like(out1)
out3 = paddle.empty(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
def test_full_and_full_like(self):
out1 = paddle.full([], 0.5)
out2 = paddle.full_like(out1, 0.5)
out3 = paddle.full(self.shape, 0.5)
out4 = paddle.full(self.shape, paddle.full([], 0.5))
res = self.exe.run(
paddle.static.default_main_program(),
fetch_list=[out1, out2, out3, out4],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
self.assertEqual(res[3].shape, (2, 3, 4))
def test_ones_and_ones_like(self):
out1 = paddle.ones([])
out2 = paddle.ones_like(out1)
out3 = paddle.ones(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
def test_zeros_and_zeros_like(self):
out1 = paddle.zeros([])
out2 = paddle.zeros_like(out1)
out3 = paddle.zeros(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
if __name__ == "__main__":
unittest.main()
......@@ -65,12 +65,12 @@ def linspace(start, stop, num, dtype=None, name=None):
Return fixed number of evenly spaced values within a given interval.
Args:
start(int|float|Tensor): The input :attr:`start` is start variable of range. It is a scalar, \
or a Tensor of shape [1] with input data type int32, int64, float32 or float64.
stop(int|float|Tensor): The input :attr:`stop` is start variable of range. It is a scalar, \
or a Tensor of shape [1] with input data type int32, int64, float32 or float64.
num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int scalar, \
or a Tensor of shape [1] with data type int32.
start(int|float|Tensor): The input :attr:`start` is start of range. It is a int, float, \
or a 0-D Tensor with data type int32, int64, float32 or float64.
stop(int|float|Tensor): The input :attr:`stop` is start variable of range. It is a int, float, \
or a 0-D Tensor with data type int32, int64, float32 or float64.
num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int, \
or a 0-D Tensor with data type int32.
dtype(np.dtype|str, optional): The data type of output tensor, it could be
int32, int64, float32 and float64. Default: if None, the data type is float32.
name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
......@@ -620,7 +620,9 @@ def ones(shape, dtype=None, name=None):
Create a Tensor of specified :attr:`shape` and :attr:`dtype` and fill it with 1.
Args:
shape (tuple|list|Tensor): Shape of the Tensor to be created, the data type of shape should be int32 or int64.
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, the elements of it should be integers or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
dtype (np.dtype|str, optional): Data type of output Tensor, it should be one of
bool, float16, float32, float64, int32 and int64. If it is set to None, the data type will be float32.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
......@@ -633,21 +635,25 @@ def ones(shape, dtype=None, name=None):
import paddle
# default dtype for ones OP
# shape is a list/tuple
data1 = paddle.ones(shape=[3, 2])
# [[1. 1.]
# [1. 1.]
# [1. 1.]]
data2 = paddle.ones(shape=[2, 2], dtype='int32')
# [[1 1]
# [1 1]]
# shape is a Tensor
shape = paddle.full(shape=[2], dtype='int32', fill_value=2)
data3 = paddle.ones(shape=shape, dtype='int32')
# [[1 1]
# [1 1]]
shape = paddle.to_tensor([3, 2])
data2 = paddle.ones(shape=shape)
# [[1. 1.]
# [1. 1.]
# [1. 1.]]
# shape is a Tensor List
shape = [paddle.to_tensor(3), paddle.to_tensor(2)]
data3 = paddle.ones(shape=shape)
# [[1. 1.]
# [1. 1.]
# [1. 1.]]
"""
if dtype is None:
dtype = 'float32'
......@@ -690,7 +696,9 @@ def zeros(shape, dtype=None, name=None):
Creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 0.
Args:
shape(tuple|list|Tensor): Shape of the Tensor to be created, the data type of ``shape`` is int32 or int64.
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
dtype(np.dtype|str, optional): Data type of output Tensor, it supports
bool, float16, float32, float64, int32 and int64. Default: if None, the date type is float32.
name(str, optional): The default value is None. Normally there is no need for user to set this
......@@ -702,21 +710,27 @@ def zeros(shape, dtype=None, name=None):
Examples:
.. code-block:: python
import paddle
import paddle
data = paddle.zeros(shape=[3, 2], dtype='float32')
# [[0. 0.]
# [0. 0.]
# [0. 0.]]
data = paddle.zeros(shape=[2, 2])
# [[0. 0.]
# [0. 0.]]
# shape is a Tensor
shape = paddle.full(shape=[2], dtype='int32', fill_value=2)
data3 = paddle.zeros(shape=shape, dtype='int32')
# [[0 0]
# [0 0]]
# shape is a list/tuple
data1 = paddle.zeros(shape=[3, 2])
# [[0. 0.]
# [0. 0.]
# [0. 0.]]
# shape is a Tensor
shape = paddle.to_tensor([3, 2])
data2 = paddle.zeros(shape=shape)
# [[0. 0.]
# [0. 0.]
# [0. 0.]]
# shape is a Tensor List
shape = [paddle.to_tensor(3), paddle.to_tensor(2)]
data3 = paddle.zeros(shape=shape)
# [[0. 0.]
# [0. 0.]
# [0. 0.]]
"""
if dtype is None:
dtype = 'float32'
......@@ -844,12 +858,11 @@ def full(shape, fill_value, dtype=None, name=None):
Return a Tensor with the ``fill_value`` which size is same as ``shape``.
Args:
shape(list|tuple|Tensor): Shape of the Tensor to be created.
The data type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple,
the elements of it should be integers or Tensors with shape [1].
If ``shape`` is an Tensor, it should be an 1-D Tensor.
fill_value(bool|float|int|Tensor): The constant value
used to initialize the Tensor to be created. If ``fill_value`` is an Tensor, it must be an 1-D Tensor.
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
fill_value(bool|float|int|Tensor): The constant value used to initialize the Tensor to be created.
If ``fill_value`` is an Tensor, it shoule be an 0-D Tensor which represents a scalar.
dtype(np.dtype|str, optional): Data type of the output Tensor
which can be float16, float32, float64, int32, int64, if dytpe is `None`, the data
type of created Tensor is `float32`.
......@@ -863,26 +876,32 @@ def full(shape, fill_value, dtype=None, name=None):
import paddle
data1 = paddle.full(shape=[2,1], fill_value=0, dtype='int64')
#[[0]
# [0]]
# attr shape is a list which contains Tensor.
positive_2 = paddle.full([1], 2, "int32")
data3 = paddle.full(shape=[1, positive_2], dtype='float32', fill_value=1.5)
# [[1.5 1.5]]
# attr shape is a Tensor.
shape = paddle.full([2], 2, "int32")
data4 = paddle.full(shape=shape, dtype='bool', fill_value=True)
# [[True True]
# [True True]]
# attr fill_value is a Tensor.
val = paddle.full([1], 2.0, "float32")
data5 = paddle.full(shape=[2,1], fill_value=val, dtype='float32')
# [[2.0]
# [2.0]]
# shape is a list/tuple
data1 = paddle.full(shape=[3, 2], fill_value=1.)
# [[1. 1.]
# [1. 1.]
# [1. 1.]]
# shape is a Tensor
shape = paddle.to_tensor([3, 2])
data2 = paddle.full(shape=shape, fill_value=2.)
# [[2. 2.]
# [2. 2.]
# [2. 2.]]
# shape is a Tensor List
shape = [paddle.to_tensor(3), paddle.to_tensor(2)]
data3 = paddle.full(shape=shape, fill_value=3.)
# [[3. 3.]
# [3. 3.]
# [3. 3.]]
# fill_value is a Tensor.
val = paddle.full([], 2.0, "float32")
data5 = paddle.full(shape=[3, 2], fill_value=val)
# [[2. 2.]
# [2. 2.]
# [2. 2.]]
"""
if dtype is None:
......@@ -904,16 +923,17 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
Parameters:
start(float|int|Tensor): Start of interval. The interval includes this
value. If ``end`` is None, the half-open interval is [0, ``start``).
If ``start`` is a Tensor, it is a 1-D Tensor with shape [1], with
data type int32, int64, float32, float64. Default is 0.
If ``start`` is a Tensor, it is a 0-D Tensor which represents a scalar
and data type is int32, int64, float32, float64. Default is 0.
end(float|int|Tensor, optional): End of interval. The interval does not
include this value. If ``end`` is a Tensor, it is a 1-D Tensor with
shape [1], with data type int32, int64, float32, float64. If ``end``
is None, the half-open interval is [0, ``start``). Default is None.
include this value. If ``end`` is a Tensor, it is a 0-D Tensor which
represents a scalar and data type is int32, int64, float32, float64.
If ``end`` is None, the half-open interval is [0, ``start``).
Default is None.
step(float|int|Tensor, optional): Spacing between values. For any out,
it is the istance between two adjacent values, out[i+1] - out[i].
If ``step`` is a Tensor, it is a 1-D Tensor with shape [1], with
data type int32, int64, float32, float64. Default is 1.
If ``step`` is a Tensor, it is a 0-D Tensor which represents a scalar
and data type is int32, int64, float32, float64. . Default is 1.
dtype(str|np.dtype, optional): The data type of the
output tensor. Supported data types: int32, int64, float32, float64.
If ``dytpe`` is None, the data type is float32. Default is None.
......@@ -939,7 +959,7 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
out3 = paddle.arange(4.999, dtype='float32')
# [0., 1., 2., 3., 4.]
start_var = paddle.to_tensor([3])
start_var = paddle.to_tensor(3)
out4 = paddle.arange(start_var, 7)
# [3, 4, 5, 6]
......@@ -1501,10 +1521,9 @@ def empty(shape, dtype=None, name=None):
Returns a Tensor with uninitialized data which size is same as ``shape``.
Args:
shape(list|tuple|Tensor): Shape of the Tensor to be created.
The data type of dimension of shape is ``int32`` or ``int64`` . If ``shape`` is a list or tuple,
the elements of it should be integers or Tensors with shape [1].
If ``shape`` is an Tensor, it should be an 1-D Tensor.
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
dtype(np.dtype|str, optional): Data type of the output Tensor
which can be bool, float16, float32, float64, int32, int64, if dytpe is `None`, the data
type of created Tensor use global default dtype (see ``get_default_dtype``
......@@ -1519,30 +1538,25 @@ def empty(shape, dtype=None, name=None):
import paddle
paddle.set_device("cpu") # and use cpu device
# example 1: argument ``shape`` is a list which doesn't contain Tensor.
data1 = paddle.empty(shape=[2, 3], dtype='float32')
print(data1)
# Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[0.00000000, 0. , 0.00000000],
# [0. , 0.29652897, 0.09356152]]) # uninitialized
# shape is a list/tuple
data1 = paddle.empty(shape=[3, 2])
# [[1. 1.]
# [1. 1.]
# [1. 1.]]
# example 2: argument ``shape`` is a Tensor, the data type must be int64 or int32.
shape_data = paddle.to_tensor([2, 3]).astype('int32')
data2 = paddle.empty(shape=shape_data, dtype='float32')
print(data2)
# Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[-0.50543123, -0.09872390, -0.92634487],
# [-0.51007903, -0.02454148, 1.29315734]]) # uninitialized
# shape is a Tensor
shape = paddle.to_tensor([3, 2])
data2 = paddle.empty(shape=shape)
# [[1. 1.]
# [1. 1.]
# [1. 1.]]
# example 3: argument ``shape`` is a list which contains Tensor.
dim2 = paddle.to_tensor([3]).astype('int32')
data3 = paddle.empty(shape=[2, dim2], dtype='float32')
print(data3)
# Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[ 0.00000000, 0. , -0.92634487],
# [-0.51007903, -0.02454148, 1.29315734]]) # uninitialized
# shape is a Tensor List
shape = [paddle.to_tensor(3), paddle.to_tensor(2)]
data3 = paddle.empty(shape=shape)
# [[1. 1.]
# [1. 1.]
# [1. 1.]]
"""
if dtype is None:
......
......@@ -221,11 +221,9 @@ def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None):
distribution, with ``shape`` and ``dtype``.
Args:
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
mean (float|int, optional): Mean of the output tensor, default is 0.0.
std (float|int, optional): Standard deviation of the output tensor, default
is 1.0.
......@@ -307,11 +305,9 @@ def standard_normal(shape, dtype=None, name=None):
and ``dtype``.
Args:
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype``
......@@ -335,8 +331,8 @@ def standard_normal(shape, dtype=None, name=None):
# [ 0.39632758, 0.08177969, 0.2692008 ]] # random
# example 2: attr shape is a list which contains Tensor.
dim1 = paddle.to_tensor([2], 'int64')
dim2 = paddle.to_tensor([3], 'int32')
dim1 = paddle.to_tensor(2, 'int64')
dim2 = paddle.to_tensor(3, 'int32')
out2 = paddle.standard_normal(shape=[dim1, dim2, 2])
# [[[-2.8852394 , -0.25898588], # random
# [-0.47420555, 0.17683524], # random
......@@ -362,11 +358,9 @@ def randn(shape, dtype=None, name=None):
and ``dtype``.
Args:
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype``
......@@ -390,8 +384,8 @@ def randn(shape, dtype=None, name=None):
# [ 0.39632758, 0.08177969, 0.2692008 ]] # random
# example 2: attr shape is a list which contains Tensor.
dim1 = paddle.to_tensor([2], 'int64')
dim2 = paddle.to_tensor([3], 'int32')
dim1 = paddle.to_tensor(2, 'int64')
dim2 = paddle.to_tensor(3, 'int32')
out2 = paddle.randn(shape=[dim1, dim2, 2])
# [[[-2.8852394 , -0.25898588], # random
# [-0.47420555, 0.17683524], # random
......@@ -429,12 +423,10 @@ def normal(mean=0.0, std=1.0, shape=None, name=None):
If ``std`` is float, all elements of the output Tensor shared the same standard deviation.
If ``std`` is a Tensor(data type supports float32, float64), it has per-element standard deviations.
Defaule is 1.0
shape (list|tuple|Tensor, optional): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64). If ``mean`` or ``std`` is a Tensor, the shape of the output
Tensor is the same as ``mean`` or ``std`` , attr ``shape`` is ignored.
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list. If ``mean`` or ``std``
is a Tensor, the shape of the output Tensor is the same as ``mean`` or ``std`` , attr ``shape`` is ignored.
Default is None
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
......@@ -518,11 +510,9 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
result=[[0.8505902, 0.8397286]]
Args:
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
dtype(str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype``
......@@ -557,8 +547,8 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
# example 2:
# attr shape is a list which contains Tensor.
dim1 = paddle.to_tensor([2], 'int64')
dim2 = paddle.to_tensor([3], 'int32')
dim1 = paddle.to_tensor(2, 'int64')
dim2 = paddle.to_tensor(3, 'int32')
out2 = paddle.uniform(shape=[dim1, dim2])
# [[-0.9951253, 0.30757582, 0.9899647 ], # random
# [ 0.5864527, 0.6607096, -0.8886161]] # random
......@@ -684,11 +674,9 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
high (int, optional): The upper bound on the range of random values to
generate, the ``high`` is excluded in the range. Default is None
(see above for behavior if high = None). Default is None.
shape (list|tuple|Tensor, optional): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64). Default is [1].
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list. Default is [1].
dtype (str|np.dtype, optional): The data type of the
output tensor. Supported data types: int32, int64. If ``dytpe``
is None, the data type is int64. Default is None.
......@@ -707,22 +695,23 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
# example 1:
# attr shape is a list which doesn't contain Tensor.
out1 = paddle.randint(low=-5, high=5, shape=[3])
out1 = paddle.randint(low=-5, high=5, shape=[2, 3])
# [0, -3, 2] # random
# example 2:
# attr shape is a list which contains Tensor.
dim1 = paddle.to_tensor([2], 'int64')
dim2 = paddle.to_tensor([3], 'int32')
dim1 = paddle.to_tensor(2, 'int64')
dim2 = paddle.to_tensor(3, 'int32')
out2 = paddle.randint(low=-5, high=5, shape=[dim1, dim2])
# [[0, -1, -3], # random
# [4, -2, 0]] # random
# example 3:
# attr shape is a Tensor
shape_tensor = paddle.to_tensor(3)
shape_tensor = paddle.to_tensor([2, 3])
out3 = paddle.randint(low=-5, high=5, shape=shape_tensor)
# [-2, 2, 3] # random
# [[ 2, -3, -1], # random
# [-3, -2, 1]]) # random
# example 4:
# data type is int32
......@@ -1033,11 +1022,9 @@ def rand(shape, dtype=None, name=None):
distribution in the range [0, 1), with ``shape`` and ``dtype``.
Args:
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape``
is a list or tuple, the elements of it should be integers or Tensors
(with the shape [1], and the data type int32 or int64). If ``shape``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype``
......@@ -1061,8 +1048,8 @@ def rand(shape, dtype=None, name=None):
# [0.22550228, 0.22106001, 0.7877319 ]] # random
# example 2: attr shape is a list which contains Tensor.
dim1 = paddle.to_tensor([2], 'int64')
dim2 = paddle.to_tensor([3], 'int32')
dim1 = paddle.to_tensor(2, 'int64')
dim2 = paddle.to_tensor(3, 'int32')
out2 = paddle.rand(shape=[dim1, dim2, 2])
# [[[0.8879919 , 0.25788337], # random
# [0.28826773, 0.9712097 ], # random
......@@ -1076,7 +1063,6 @@ def rand(shape, dtype=None, name=None):
out3 = paddle.rand(shape_tensor)
# [[0.22920267, 0.841956 , 0.05981819], # random
# [0.4836288 , 0.24573246, 0.7516129 ]] # random
"""
return uniform(shape, dtype, min=0.0, max=1.0, name=name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册