You need to sign in or sign up before continuing.
未验证 提交 e0be4b94 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Zero-Dim] support input 0D Tensor as scalar attribute for some api (#47689)

* [Zero-Dim] support input 0D Tensor as scalar attribute for some api

* fix doc
上级 1a145aab
...@@ -259,48 +259,24 @@ void ArangeInferMeta(const MetaTensor& start, ...@@ -259,48 +259,24 @@ void ArangeInferMeta(const MetaTensor& start,
const MetaTensor& end, const MetaTensor& end,
const MetaTensor& step, const MetaTensor& step,
MetaTensor* out) { MetaTensor* out) {
auto start_dims = start.dims(); PADDLE_ENFORCE_EQ(phi::product(start.dims()),
auto end_dims = end.dims();
auto step_dims = step.dims();
PADDLE_ENFORCE_EQ(
start_dims.size(),
1,
phi::errors::InvalidArgument(
"The dim of the shape of Input(Start) should be 1, but got %d",
start_dims.size()));
PADDLE_ENFORCE_EQ(start_dims[0],
1, 1,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The first dim of the shape of Input(Start) should " "The numel of Input(start) should be 1, but got %d",
"be 1, but got %d", phi::product(start.dims())));
start_dims[0]));
PADDLE_ENFORCE_EQ(
end_dims.size(),
1,
phi::errors::InvalidArgument(
"The dim of the shape of Input(End) should be 1, but got %d",
end_dims.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(phi::product(end.dims()),
end_dims[0], 1,
1, phi::errors::InvalidArgument(
phi::errors::InvalidArgument("The first dim of the shape of " "The numel of Input(end) should be 1, but got %d",
"Input(End) should be 1, but got %d", phi::product(end.dims())));
end_dims[0]));
PADDLE_ENFORCE_EQ(
step_dims.size(),
1,
phi::errors::InvalidArgument(
"The dim of the shape of Input(Step) should be 1, but got %d",
step_dims.size()));
PADDLE_ENFORCE_EQ(step_dims[0], PADDLE_ENFORCE_EQ(phi::product(step.dims()),
1, 1,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The first dim of the shape of Input(Step) should " "The numel of Input(step) should be 1, but got %d",
"be 1, but got %d", phi::product(step.dims())));
step_dims[0]));
out->set_dims({-1}); out->set_dims({-1});
out->set_dtype(start.dtype()); out->set_dtype(start.dtype());
} }
...@@ -635,27 +611,27 @@ void LinspaceRawInferMeta(const MetaTensor& start, ...@@ -635,27 +611,27 @@ void LinspaceRawInferMeta(const MetaTensor& start,
const MetaTensor& stop, const MetaTensor& stop,
const MetaTensor& number, const MetaTensor& number,
MetaTensor* out) { MetaTensor* out) {
auto s_dims = start.dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(s_dims.size() == 1) && (s_dims[0] == 1), phi::product(start.dims()),
true, 1,
phi::errors::InvalidArgument("The shape of Input(Start) must be [1]," phi::errors::InvalidArgument("The size of Input(start) should be 1,"
"but received input shape is [%s].", "but got %d.",
s_dims)); phi::product(start.dims())));
auto e_dims = stop.dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(e_dims.size() == 1) && (e_dims[0] == 1), phi::product(stop.dims()),
true, 1,
phi::errors::InvalidArgument("The shape of Input(Stop) must be [1]," phi::errors::InvalidArgument("The size of Input(stop) should be 1,"
"but received input shape is [%s].", "but got %d.",
e_dims)); phi::product(stop.dims())));
auto step_dims = number.dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(step_dims.size() == 1) && (step_dims[0] == 1), phi::product(number.dims()),
true, 1,
phi::errors::InvalidArgument("The shape of Input(Num) must be [1]," phi::errors::InvalidArgument("The size of Input(number) should be 1,"
"but received input shape is [%s].", "but got %d.",
step_dims)); phi::product(number.dims())));
out->set_dims(phi::make_ddim({-1})); out->set_dims(phi::make_ddim({-1}));
out->set_dtype(start.dtype()); out->set_dtype(start.dtype());
} }
......
...@@ -918,17 +918,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -918,17 +918,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
if force_cpu: if force_cpu:
place = core.CPUPlace() place = core.CPUPlace()
if isinstance(shape, (list, tuple)): if isinstance(shape, (list, tuple)):
for item in shape: shape = utils.convert_shape_to_list(shape)
if not isinstance(item, Variable):
shape = list(
map(
lambda x: x.numpy().flat[0]
if isinstance(x, Variable)
else x,
shape,
)
)
break
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
......
...@@ -498,5 +498,304 @@ class TestSundryAPI(unittest.TestCase): ...@@ -498,5 +498,304 @@ class TestSundryAPI(unittest.TestCase):
np.testing.assert_array_equal(out.numpy(), np.array([])) np.testing.assert_array_equal(out.numpy(), np.array([]))
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.shape = [
paddle.full([], 2, 'int32'),
paddle.full([], 3, 'int32'),
paddle.full([], 4, 'int32'),
]
def test_slice(self):
starts = [paddle.full([], 1, 'int32'), paddle.full([], 1, 'int32')]
ends = [paddle.full([], 3, 'int32'), paddle.full([], 3, 'int32')]
x = paddle.rand([5, 3, 3])
out = paddle.slice(x, [1, 2], starts, ends)
self.assertEqual(out.shape, [5, 2, 2])
def test_strided_slice(self):
starts = [paddle.full([], 0, 'int32'), paddle.full([], 0, 'int32')]
ends = [paddle.full([], 4, 'int32'), paddle.full([], 4, 'int32')]
strides = [paddle.full([], 2, 'int32'), paddle.full([], 2, 'int32')]
x = paddle.rand([5, 5, 5])
out = paddle.strided_slice(x, [1, 2], starts, ends, strides)
self.assertEqual(out.shape, [5, 2, 2])
def test_linspace(self):
start = paddle.full([], 1.0)
stop = paddle.full([], 5.0)
num = paddle.full([], 5, 'int32')
out = paddle.linspace(start, stop, num)
np.testing.assert_array_equal(out.numpy(), [1.0, 2.0, 3.0, 4.0, 5.0])
def test_arange(self):
start = paddle.full([], 1.0)
stop = paddle.full([], 6.0)
step = paddle.full([], 1.0)
out = paddle.arange(start, stop, step)
np.testing.assert_array_equal(out.numpy(), [1.0, 2.0, 3.0, 4.0, 5.0])
def test_normal(self):
mean = paddle.full([], 0.0)
std = paddle.full([], 0.0)
out = paddle.normal(mean, std)
self.assertEqual(out.shape, [])
out = paddle.normal(0.0, 1.0, [])
self.assertEqual(out.shape, [])
out = paddle.normal(0.0, 1.0, self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_rand(self):
out = paddle.rand([])
self.assertEqual(out.shape, [])
out = paddle.rand(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_randn(self):
out = paddle.randn([])
self.assertEqual(out.shape, [])
out = paddle.randn(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_randint_and_randint_like(self):
out = paddle.randint(-10, 10, [])
self.assertEqual(out.shape, [])
out = paddle.randint_like(out, -10, 10)
self.assertEqual(out.shape, [])
out = paddle.randint(-10, 10, self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_standard_normal(self):
out = paddle.standard_normal([])
self.assertEqual(out.shape, [])
out = paddle.standard_normal(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_uniform(self):
out = paddle.uniform([])
self.assertEqual(out.shape, [])
out = paddle.uniform(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_empty_and_empty_like(self):
out = paddle.empty([])
self.assertEqual(out.shape, [])
out = paddle.empty_like(out)
self.assertEqual(out.shape, [])
out = paddle.empty(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_full_and_full_like(self):
out = paddle.full([], 0.5)
self.assertEqual(out.shape, [])
out = paddle.full_like(out, 0.5)
self.assertEqual(out.shape, [])
out = paddle.full(self.shape, 0.5)
self.assertEqual(out.shape, [2, 3, 4])
def test_ones_and_ones_like(self):
out = paddle.ones([])
self.assertEqual(out.shape, [])
out = paddle.ones_like(out)
self.assertEqual(out.shape, [])
out = paddle.ones(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
def test_zeros_and_zeros_like(self):
out = paddle.zeros([])
self.assertEqual(out.shape, [])
out = paddle.zeros_like(out)
self.assertEqual(out.shape, [])
out = paddle.zeros(self.shape)
self.assertEqual(out.shape, [2, 3, 4])
class TestNoBackwardAPIStatic(unittest.TestCase):
def setUp(self):
paddle.enable_static()
self.exe = paddle.static.Executor()
self.shape = [
paddle.full([], 2, 'int32'),
paddle.full([], 3, 'int32'),
paddle.full([], 4, 'int32'),
]
def test_slice(self):
starts = [paddle.full([], 1, 'int32'), paddle.full([], 1, 'int32')]
ends = [paddle.full([], 3, 'int32'), paddle.full([], 3, 'int32')]
x = paddle.rand([5, 3, 3])
out = paddle.slice(x, [1, 2], starts, ends)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out]
)[0]
self.assertEqual(res.shape, (5, 2, 2))
def test_strided_slice(self):
starts = [paddle.full([], 0, 'int32'), paddle.full([], 0, 'int32')]
ends = [paddle.full([], 4, 'int32'), paddle.full([], 4, 'int32')]
strides = [paddle.full([], 2, 'int32'), paddle.full([], 2, 'int32')]
x = paddle.rand([5, 5, 5])
out = paddle.strided_slice(x, [1, 2], starts, ends, strides)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out]
)[0]
self.assertEqual(res.shape, (5, 2, 2))
def test_linspace(self):
start = paddle.full([], 1.0)
stop = paddle.full([], 5.0)
num = paddle.full([], 5, 'int32')
out = paddle.linspace(start, stop, num)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out]
)[0]
np.testing.assert_array_equal(res, [1.0, 2.0, 3.0, 4.0, 5.0])
def test_arange(self):
start = paddle.full([], 1.0)
stop = paddle.full([], 6.0)
step = paddle.full([], 1.0)
out = paddle.arange(start, stop, step)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out]
)[0]
np.testing.assert_array_equal(res, [1.0, 2.0, 3.0, 4.0, 5.0])
def test_normal(self):
mean = paddle.full([], 0.0)
std = paddle.full([], 0.0)
out1 = paddle.normal(mean, std)
out2 = paddle.normal(0.0, 1.0, [])
out3 = paddle.normal(0.0, 1.0, self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
def test_rand(self):
out1 = paddle.rand([])
out2 = paddle.rand(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))
def test_randn(self):
out1 = paddle.randn([])
out2 = paddle.randn(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))
def test_randint_and_randint_like(self):
out1 = paddle.randint(-10, 10, [])
out2 = paddle.randint_like(out1, -10, 10)
out3 = paddle.randint(-10, 10, self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
def test_standard_normal(self):
out1 = paddle.standard_normal([])
out2 = paddle.standard_normal(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))
def test_uniform(self):
out1 = paddle.uniform([])
out2 = paddle.uniform(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 3, 4))
def test_empty_and_empty_like(self):
out1 = paddle.empty([])
out2 = paddle.empty_like(out1)
out3 = paddle.empty(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
def test_full_and_full_like(self):
out1 = paddle.full([], 0.5)
out2 = paddle.full_like(out1, 0.5)
out3 = paddle.full(self.shape, 0.5)
out4 = paddle.full(self.shape, paddle.full([], 0.5))
res = self.exe.run(
paddle.static.default_main_program(),
fetch_list=[out1, out2, out3, out4],
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
self.assertEqual(res[3].shape, (2, 3, 4))
def test_ones_and_ones_like(self):
out1 = paddle.ones([])
out2 = paddle.ones_like(out1)
out3 = paddle.ones(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
def test_zeros_and_zeros_like(self):
out1 = paddle.zeros([])
out2 = paddle.zeros_like(out1)
out3 = paddle.zeros(self.shape)
res = self.exe.run(
paddle.static.default_main_program(), fetch_list=[out1, out2, out3]
)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (2, 3, 4))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -65,12 +65,12 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -65,12 +65,12 @@ def linspace(start, stop, num, dtype=None, name=None):
Return fixed number of evenly spaced values within a given interval. Return fixed number of evenly spaced values within a given interval.
Args: Args:
start(int|float|Tensor): The input :attr:`start` is start variable of range. It is a scalar, \ start(int|float|Tensor): The input :attr:`start` is start of range. It is a int, float, \
or a Tensor of shape [1] with input data type int32, int64, float32 or float64. or a 0-D Tensor with data type int32, int64, float32 or float64.
stop(int|float|Tensor): The input :attr:`stop` is start variable of range. It is a scalar, \ stop(int|float|Tensor): The input :attr:`stop` is start variable of range. It is a int, float, \
or a Tensor of shape [1] with input data type int32, int64, float32 or float64. or a 0-D Tensor with data type int32, int64, float32 or float64.
num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int scalar, \ num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int, \
or a Tensor of shape [1] with data type int32. or a 0-D Tensor with data type int32.
dtype(np.dtype|str, optional): The data type of output tensor, it could be dtype(np.dtype|str, optional): The data type of output tensor, it could be
int32, int64, float32 and float64. Default: if None, the data type is float32. int32, int64, float32 and float64. Default: if None, the data type is float32.
name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
...@@ -620,7 +620,9 @@ def ones(shape, dtype=None, name=None): ...@@ -620,7 +620,9 @@ def ones(shape, dtype=None, name=None):
Create a Tensor of specified :attr:`shape` and :attr:`dtype` and fill it with 1. Create a Tensor of specified :attr:`shape` and :attr:`dtype` and fill it with 1.
Args: Args:
shape (tuple|list|Tensor): Shape of the Tensor to be created, the data type of shape should be int32 or int64. shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, the elements of it should be integers or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
dtype (np.dtype|str, optional): Data type of output Tensor, it should be one of dtype (np.dtype|str, optional): Data type of output Tensor, it should be one of
bool, float16, float32, float64, int32 and int64. If it is set to None, the data type will be float32. bool, float16, float32, float64, int32 and int64. If it is set to None, the data type will be float32.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
...@@ -633,21 +635,25 @@ def ones(shape, dtype=None, name=None): ...@@ -633,21 +635,25 @@ def ones(shape, dtype=None, name=None):
import paddle import paddle
# default dtype for ones OP # shape is a list/tuple
data1 = paddle.ones(shape=[3, 2]) data1 = paddle.ones(shape=[3, 2])
# [[1. 1.] # [[1. 1.]
# [1. 1.] # [1. 1.]
# [1. 1.]] # [1. 1.]]
data2 = paddle.ones(shape=[2, 2], dtype='int32')
# [[1 1]
# [1 1]]
# shape is a Tensor # shape is a Tensor
shape = paddle.full(shape=[2], dtype='int32', fill_value=2) shape = paddle.to_tensor([3, 2])
data3 = paddle.ones(shape=shape, dtype='int32') data2 = paddle.ones(shape=shape)
# [[1 1] # [[1. 1.]
# [1 1]] # [1. 1.]
# [1. 1.]]
# shape is a Tensor List
shape = [paddle.to_tensor(3), paddle.to_tensor(2)]
data3 = paddle.ones(shape=shape)
# [[1. 1.]
# [1. 1.]
# [1. 1.]]
""" """
if dtype is None: if dtype is None:
dtype = 'float32' dtype = 'float32'
...@@ -690,7 +696,9 @@ def zeros(shape, dtype=None, name=None): ...@@ -690,7 +696,9 @@ def zeros(shape, dtype=None, name=None):
Creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 0. Creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 0.
Args: Args:
shape(tuple|list|Tensor): Shape of the Tensor to be created, the data type of ``shape`` is int32 or int64. shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
dtype(np.dtype|str, optional): Data type of output Tensor, it supports dtype(np.dtype|str, optional): Data type of output Tensor, it supports
bool, float16, float32, float64, int32 and int64. Default: if None, the date type is float32. bool, float16, float32, float64, int32 and int64. Default: if None, the date type is float32.
name(str, optional): The default value is None. Normally there is no need for user to set this name(str, optional): The default value is None. Normally there is no need for user to set this
...@@ -702,21 +710,27 @@ def zeros(shape, dtype=None, name=None): ...@@ -702,21 +710,27 @@ def zeros(shape, dtype=None, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
data = paddle.zeros(shape=[3, 2], dtype='float32') # shape is a list/tuple
# [[0. 0.] data1 = paddle.zeros(shape=[3, 2])
# [0. 0.] # [[0. 0.]
# [0. 0.]] # [0. 0.]
data = paddle.zeros(shape=[2, 2]) # [0. 0.]]
# [[0. 0.]
# [0. 0.]] # shape is a Tensor
shape = paddle.to_tensor([3, 2])
# shape is a Tensor data2 = paddle.zeros(shape=shape)
shape = paddle.full(shape=[2], dtype='int32', fill_value=2) # [[0. 0.]
data3 = paddle.zeros(shape=shape, dtype='int32') # [0. 0.]
# [[0 0] # [0. 0.]]
# [0 0]]
# shape is a Tensor List
shape = [paddle.to_tensor(3), paddle.to_tensor(2)]
data3 = paddle.zeros(shape=shape)
# [[0. 0.]
# [0. 0.]
# [0. 0.]]
""" """
if dtype is None: if dtype is None:
dtype = 'float32' dtype = 'float32'
...@@ -844,12 +858,11 @@ def full(shape, fill_value, dtype=None, name=None): ...@@ -844,12 +858,11 @@ def full(shape, fill_value, dtype=None, name=None):
Return a Tensor with the ``fill_value`` which size is same as ``shape``. Return a Tensor with the ``fill_value`` which size is same as ``shape``.
Args: Args:
shape(list|tuple|Tensor): Shape of the Tensor to be created. shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
The data type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple, If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
the elements of it should be integers or Tensors with shape [1]. If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
If ``shape`` is an Tensor, it should be an 1-D Tensor. fill_value(bool|float|int|Tensor): The constant value used to initialize the Tensor to be created.
fill_value(bool|float|int|Tensor): The constant value If ``fill_value`` is an Tensor, it shoule be an 0-D Tensor which represents a scalar.
used to initialize the Tensor to be created. If ``fill_value`` is an Tensor, it must be an 1-D Tensor.
dtype(np.dtype|str, optional): Data type of the output Tensor dtype(np.dtype|str, optional): Data type of the output Tensor
which can be float16, float32, float64, int32, int64, if dytpe is `None`, the data which can be float16, float32, float64, int32, int64, if dytpe is `None`, the data
type of created Tensor is `float32`. type of created Tensor is `float32`.
...@@ -863,26 +876,32 @@ def full(shape, fill_value, dtype=None, name=None): ...@@ -863,26 +876,32 @@ def full(shape, fill_value, dtype=None, name=None):
import paddle import paddle
data1 = paddle.full(shape=[2,1], fill_value=0, dtype='int64') # shape is a list/tuple
#[[0] data1 = paddle.full(shape=[3, 2], fill_value=1.)
# [0]] # [[1. 1.]
# [1. 1.]
# attr shape is a list which contains Tensor. # [1. 1.]]
positive_2 = paddle.full([1], 2, "int32")
data3 = paddle.full(shape=[1, positive_2], dtype='float32', fill_value=1.5) # shape is a Tensor
# [[1.5 1.5]] shape = paddle.to_tensor([3, 2])
data2 = paddle.full(shape=shape, fill_value=2.)
# attr shape is a Tensor. # [[2. 2.]
shape = paddle.full([2], 2, "int32") # [2. 2.]
data4 = paddle.full(shape=shape, dtype='bool', fill_value=True) # [2. 2.]]
# [[True True]
# [True True]] # shape is a Tensor List
shape = [paddle.to_tensor(3), paddle.to_tensor(2)]
# attr fill_value is a Tensor. data3 = paddle.full(shape=shape, fill_value=3.)
val = paddle.full([1], 2.0, "float32") # [[3. 3.]
data5 = paddle.full(shape=[2,1], fill_value=val, dtype='float32') # [3. 3.]
# [[2.0] # [3. 3.]]
# [2.0]]
# fill_value is a Tensor.
val = paddle.full([], 2.0, "float32")
data5 = paddle.full(shape=[3, 2], fill_value=val)
# [[2. 2.]
# [2. 2.]
# [2. 2.]]
""" """
if dtype is None: if dtype is None:
...@@ -904,16 +923,17 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): ...@@ -904,16 +923,17 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
Parameters: Parameters:
start(float|int|Tensor): Start of interval. The interval includes this start(float|int|Tensor): Start of interval. The interval includes this
value. If ``end`` is None, the half-open interval is [0, ``start``). value. If ``end`` is None, the half-open interval is [0, ``start``).
If ``start`` is a Tensor, it is a 1-D Tensor with shape [1], with If ``start`` is a Tensor, it is a 0-D Tensor which represents a scalar
data type int32, int64, float32, float64. Default is 0. and data type is int32, int64, float32, float64. Default is 0.
end(float|int|Tensor, optional): End of interval. The interval does not end(float|int|Tensor, optional): End of interval. The interval does not
include this value. If ``end`` is a Tensor, it is a 1-D Tensor with include this value. If ``end`` is a Tensor, it is a 0-D Tensor which
shape [1], with data type int32, int64, float32, float64. If ``end`` represents a scalar and data type is int32, int64, float32, float64.
is None, the half-open interval is [0, ``start``). Default is None. If ``end`` is None, the half-open interval is [0, ``start``).
Default is None.
step(float|int|Tensor, optional): Spacing between values. For any out, step(float|int|Tensor, optional): Spacing between values. For any out,
it is the istance between two adjacent values, out[i+1] - out[i]. it is the istance between two adjacent values, out[i+1] - out[i].
If ``step`` is a Tensor, it is a 1-D Tensor with shape [1], with If ``step`` is a Tensor, it is a 0-D Tensor which represents a scalar
data type int32, int64, float32, float64. Default is 1. and data type is int32, int64, float32, float64. . Default is 1.
dtype(str|np.dtype, optional): The data type of the dtype(str|np.dtype, optional): The data type of the
output tensor. Supported data types: int32, int64, float32, float64. output tensor. Supported data types: int32, int64, float32, float64.
If ``dytpe`` is None, the data type is float32. Default is None. If ``dytpe`` is None, the data type is float32. Default is None.
...@@ -939,7 +959,7 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): ...@@ -939,7 +959,7 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
out3 = paddle.arange(4.999, dtype='float32') out3 = paddle.arange(4.999, dtype='float32')
# [0., 1., 2., 3., 4.] # [0., 1., 2., 3., 4.]
start_var = paddle.to_tensor([3]) start_var = paddle.to_tensor(3)
out4 = paddle.arange(start_var, 7) out4 = paddle.arange(start_var, 7)
# [3, 4, 5, 6] # [3, 4, 5, 6]
...@@ -1501,10 +1521,9 @@ def empty(shape, dtype=None, name=None): ...@@ -1501,10 +1521,9 @@ def empty(shape, dtype=None, name=None):
Returns a Tensor with uninitialized data which size is same as ``shape``. Returns a Tensor with uninitialized data which size is same as ``shape``.
Args: Args:
shape(list|tuple|Tensor): Shape of the Tensor to be created. shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
The data type of dimension of shape is ``int32`` or ``int64`` . If ``shape`` is a list or tuple, If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
the elements of it should be integers or Tensors with shape [1]. If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
If ``shape`` is an Tensor, it should be an 1-D Tensor.
dtype(np.dtype|str, optional): Data type of the output Tensor dtype(np.dtype|str, optional): Data type of the output Tensor
which can be bool, float16, float32, float64, int32, int64, if dytpe is `None`, the data which can be bool, float16, float32, float64, int32, int64, if dytpe is `None`, the data
type of created Tensor use global default dtype (see ``get_default_dtype`` type of created Tensor use global default dtype (see ``get_default_dtype``
...@@ -1519,30 +1538,25 @@ def empty(shape, dtype=None, name=None): ...@@ -1519,30 +1538,25 @@ def empty(shape, dtype=None, name=None):
import paddle import paddle
paddle.set_device("cpu") # and use cpu device # shape is a list/tuple
data1 = paddle.empty(shape=[3, 2])
# example 1: argument ``shape`` is a list which doesn't contain Tensor. # [[1. 1.]
data1 = paddle.empty(shape=[2, 3], dtype='float32') # [1. 1.]
print(data1) # [1. 1.]]
# Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[0.00000000, 0. , 0.00000000],
# [0. , 0.29652897, 0.09356152]]) # uninitialized
# example 2: argument ``shape`` is a Tensor, the data type must be int64 or int32. # shape is a Tensor
shape_data = paddle.to_tensor([2, 3]).astype('int32') shape = paddle.to_tensor([3, 2])
data2 = paddle.empty(shape=shape_data, dtype='float32') data2 = paddle.empty(shape=shape)
print(data2) # [[1. 1.]
# Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, # [1. 1.]
# [[-0.50543123, -0.09872390, -0.92634487], # [1. 1.]]
# [-0.51007903, -0.02454148, 1.29315734]]) # uninitialized
# example 3: argument ``shape`` is a list which contains Tensor. # shape is a Tensor List
dim2 = paddle.to_tensor([3]).astype('int32') shape = [paddle.to_tensor(3), paddle.to_tensor(2)]
data3 = paddle.empty(shape=[2, dim2], dtype='float32') data3 = paddle.empty(shape=shape)
print(data3) # [[1. 1.]
# Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, # [1. 1.]
# [[ 0.00000000, 0. , -0.92634487], # [1. 1.]]
# [-0.51007903, -0.02454148, 1.29315734]]) # uninitialized
""" """
if dtype is None: if dtype is None:
......
...@@ -221,11 +221,9 @@ def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None): ...@@ -221,11 +221,9 @@ def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None):
distribution, with ``shape`` and ``dtype``. distribution, with ``shape`` and ``dtype``.
Args: Args:
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape`` shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
is a list or tuple, the elements of it should be integers or Tensors If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
(with the shape [1], and the data type int32 or int64). If ``shape`` If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
mean (float|int, optional): Mean of the output tensor, default is 0.0. mean (float|int, optional): Mean of the output tensor, default is 0.0.
std (float|int, optional): Standard deviation of the output tensor, default std (float|int, optional): Standard deviation of the output tensor, default
is 1.0. is 1.0.
...@@ -307,11 +305,9 @@ def standard_normal(shape, dtype=None, name=None): ...@@ -307,11 +305,9 @@ def standard_normal(shape, dtype=None, name=None):
and ``dtype``. and ``dtype``.
Args: Args:
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape`` shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
is a list or tuple, the elements of it should be integers or Tensors If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
(with the shape [1], and the data type int32 or int64). If ``shape`` If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
dtype (str|np.dtype, optional): The data type of the output Tensor. dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64. Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype`` Default is None, use global default dtype (see ``get_default_dtype``
...@@ -335,8 +331,8 @@ def standard_normal(shape, dtype=None, name=None): ...@@ -335,8 +331,8 @@ def standard_normal(shape, dtype=None, name=None):
# [ 0.39632758, 0.08177969, 0.2692008 ]] # random # [ 0.39632758, 0.08177969, 0.2692008 ]] # random
# example 2: attr shape is a list which contains Tensor. # example 2: attr shape is a list which contains Tensor.
dim1 = paddle.to_tensor([2], 'int64') dim1 = paddle.to_tensor(2, 'int64')
dim2 = paddle.to_tensor([3], 'int32') dim2 = paddle.to_tensor(3, 'int32')
out2 = paddle.standard_normal(shape=[dim1, dim2, 2]) out2 = paddle.standard_normal(shape=[dim1, dim2, 2])
# [[[-2.8852394 , -0.25898588], # random # [[[-2.8852394 , -0.25898588], # random
# [-0.47420555, 0.17683524], # random # [-0.47420555, 0.17683524], # random
...@@ -362,11 +358,9 @@ def randn(shape, dtype=None, name=None): ...@@ -362,11 +358,9 @@ def randn(shape, dtype=None, name=None):
and ``dtype``. and ``dtype``.
Args: Args:
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape`` shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
is a list or tuple, the elements of it should be integers or Tensors If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
(with the shape [1], and the data type int32 or int64). If ``shape`` If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
dtype (str|np.dtype, optional): The data type of the output Tensor. dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64. Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype`` Default is None, use global default dtype (see ``get_default_dtype``
...@@ -390,8 +384,8 @@ def randn(shape, dtype=None, name=None): ...@@ -390,8 +384,8 @@ def randn(shape, dtype=None, name=None):
# [ 0.39632758, 0.08177969, 0.2692008 ]] # random # [ 0.39632758, 0.08177969, 0.2692008 ]] # random
# example 2: attr shape is a list which contains Tensor. # example 2: attr shape is a list which contains Tensor.
dim1 = paddle.to_tensor([2], 'int64') dim1 = paddle.to_tensor(2, 'int64')
dim2 = paddle.to_tensor([3], 'int32') dim2 = paddle.to_tensor(3, 'int32')
out2 = paddle.randn(shape=[dim1, dim2, 2]) out2 = paddle.randn(shape=[dim1, dim2, 2])
# [[[-2.8852394 , -0.25898588], # random # [[[-2.8852394 , -0.25898588], # random
# [-0.47420555, 0.17683524], # random # [-0.47420555, 0.17683524], # random
...@@ -429,12 +423,10 @@ def normal(mean=0.0, std=1.0, shape=None, name=None): ...@@ -429,12 +423,10 @@ def normal(mean=0.0, std=1.0, shape=None, name=None):
If ``std`` is float, all elements of the output Tensor shared the same standard deviation. If ``std`` is float, all elements of the output Tensor shared the same standard deviation.
If ``std`` is a Tensor(data type supports float32, float64), it has per-element standard deviations. If ``std`` is a Tensor(data type supports float32, float64), it has per-element standard deviations.
Defaule is 1.0 Defaule is 1.0
shape (list|tuple|Tensor, optional): The shape of the output Tensor. If ``shape`` shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
is a list or tuple, the elements of it should be integers or Tensors If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
(with the shape [1], and the data type int32 or int64). If ``shape`` If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list. If ``mean`` or ``std``
is a Tensor, it should be a 1-D Tensor(with the data type int32 or is a Tensor, the shape of the output Tensor is the same as ``mean`` or ``std`` , attr ``shape`` is ignored.
int64). If ``mean`` or ``std`` is a Tensor, the shape of the output
Tensor is the same as ``mean`` or ``std`` , attr ``shape`` is ignored.
Default is None Default is None
name (str, optional): Name for the operation (optional, default is None). name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`. For more information, please refer to :ref:`api_guide_Name`.
...@@ -518,11 +510,9 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): ...@@ -518,11 +510,9 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
result=[[0.8505902, 0.8397286]] result=[[0.8505902, 0.8397286]]
Args: Args:
shape(list|tuple|Tensor): The shape of the output Tensor. If ``shape`` shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
is a list or tuple, the elements of it should be integers or Tensors If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
(with the shape [1], and the data type int32 or int64). If ``shape`` If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
dtype(str|np.dtype, optional): The data type of the output Tensor. dtype(str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64. Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype`` Default is None, use global default dtype (see ``get_default_dtype``
...@@ -557,8 +547,8 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None): ...@@ -557,8 +547,8 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
# example 2: # example 2:
# attr shape is a list which contains Tensor. # attr shape is a list which contains Tensor.
dim1 = paddle.to_tensor([2], 'int64') dim1 = paddle.to_tensor(2, 'int64')
dim2 = paddle.to_tensor([3], 'int32') dim2 = paddle.to_tensor(3, 'int32')
out2 = paddle.uniform(shape=[dim1, dim2]) out2 = paddle.uniform(shape=[dim1, dim2])
# [[-0.9951253, 0.30757582, 0.9899647 ], # random # [[-0.9951253, 0.30757582, 0.9899647 ], # random
# [ 0.5864527, 0.6607096, -0.8886161]] # random # [ 0.5864527, 0.6607096, -0.8886161]] # random
...@@ -684,11 +674,9 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): ...@@ -684,11 +674,9 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
high (int, optional): The upper bound on the range of random values to high (int, optional): The upper bound on the range of random values to
generate, the ``high`` is excluded in the range. Default is None generate, the ``high`` is excluded in the range. Default is None
(see above for behavior if high = None). Default is None. (see above for behavior if high = None). Default is None.
shape (list|tuple|Tensor, optional): The shape of the output Tensor. If ``shape`` shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
is a list or tuple, the elements of it should be integers or Tensors If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
(with the shape [1], and the data type int32 or int64). If ``shape`` If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list. Default is [1].
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64). Default is [1].
dtype (str|np.dtype, optional): The data type of the dtype (str|np.dtype, optional): The data type of the
output tensor. Supported data types: int32, int64. If ``dytpe`` output tensor. Supported data types: int32, int64. If ``dytpe``
is None, the data type is int64. Default is None. is None, the data type is int64. Default is None.
...@@ -707,22 +695,23 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None): ...@@ -707,22 +695,23 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
# example 1: # example 1:
# attr shape is a list which doesn't contain Tensor. # attr shape is a list which doesn't contain Tensor.
out1 = paddle.randint(low=-5, high=5, shape=[3]) out1 = paddle.randint(low=-5, high=5, shape=[2, 3])
# [0, -3, 2] # random # [0, -3, 2] # random
# example 2: # example 2:
# attr shape is a list which contains Tensor. # attr shape is a list which contains Tensor.
dim1 = paddle.to_tensor([2], 'int64') dim1 = paddle.to_tensor(2, 'int64')
dim2 = paddle.to_tensor([3], 'int32') dim2 = paddle.to_tensor(3, 'int32')
out2 = paddle.randint(low=-5, high=5, shape=[dim1, dim2]) out2 = paddle.randint(low=-5, high=5, shape=[dim1, dim2])
# [[0, -1, -3], # random # [[0, -1, -3], # random
# [4, -2, 0]] # random # [4, -2, 0]] # random
# example 3: # example 3:
# attr shape is a Tensor # attr shape is a Tensor
shape_tensor = paddle.to_tensor(3) shape_tensor = paddle.to_tensor([2, 3])
out3 = paddle.randint(low=-5, high=5, shape=shape_tensor) out3 = paddle.randint(low=-5, high=5, shape=shape_tensor)
# [-2, 2, 3] # random # [[ 2, -3, -1], # random
# [-3, -2, 1]]) # random
# example 4: # example 4:
# data type is int32 # data type is int32
...@@ -1033,11 +1022,9 @@ def rand(shape, dtype=None, name=None): ...@@ -1033,11 +1022,9 @@ def rand(shape, dtype=None, name=None):
distribution in the range [0, 1), with ``shape`` and ``dtype``. distribution in the range [0, 1), with ``shape`` and ``dtype``.
Args: Args:
shape (list|tuple|Tensor): The shape of the output Tensor. If ``shape`` shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
is a list or tuple, the elements of it should be integers or Tensors If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
(with the shape [1], and the data type int32 or int64). If ``shape`` If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
is a Tensor, it should be a 1-D Tensor(with the data type int32 or
int64).
dtype (str|np.dtype, optional): The data type of the output Tensor. dtype (str|np.dtype, optional): The data type of the output Tensor.
Supported data types: float32, float64. Supported data types: float32, float64.
Default is None, use global default dtype (see ``get_default_dtype`` Default is None, use global default dtype (see ``get_default_dtype``
...@@ -1061,8 +1048,8 @@ def rand(shape, dtype=None, name=None): ...@@ -1061,8 +1048,8 @@ def rand(shape, dtype=None, name=None):
# [0.22550228, 0.22106001, 0.7877319 ]] # random # [0.22550228, 0.22106001, 0.7877319 ]] # random
# example 2: attr shape is a list which contains Tensor. # example 2: attr shape is a list which contains Tensor.
dim1 = paddle.to_tensor([2], 'int64') dim1 = paddle.to_tensor(2, 'int64')
dim2 = paddle.to_tensor([3], 'int32') dim2 = paddle.to_tensor(3, 'int32')
out2 = paddle.rand(shape=[dim1, dim2, 2]) out2 = paddle.rand(shape=[dim1, dim2, 2])
# [[[0.8879919 , 0.25788337], # random # [[[0.8879919 , 0.25788337], # random
# [0.28826773, 0.9712097 ], # random # [0.28826773, 0.9712097 ], # random
...@@ -1076,7 +1063,6 @@ def rand(shape, dtype=None, name=None): ...@@ -1076,7 +1063,6 @@ def rand(shape, dtype=None, name=None):
out3 = paddle.rand(shape_tensor) out3 = paddle.rand(shape_tensor)
# [[0.22920267, 0.841956 , 0.05981819], # random # [[0.22920267, 0.841956 , 0.05981819], # random
# [0.4836288 , 0.24573246, 0.7516129 ]] # random # [0.4836288 , 0.24573246, 0.7516129 ]] # random
""" """
return uniform(shape, dtype, min=0.0, max=1.0, name=name) return uniform(shape, dtype, min=0.0, max=1.0, name=name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册