未验证 提交 2092660c 编写于 作者: W wangchaochaohu 提交者: GitHub

Ones op for API 2.0: remove the device and out parameters (#25497)

上级 4a44ffdd
......@@ -650,9 +650,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
Returns:
Variable: Tensor which is created according to shape and dtype.
Raise:
Raises:
TypeError: The dtype must be one of bool, float16, float32, float64, int32 and int64
and the data type of out Tensor must be the same as the dtype.
TypeError: The shape must be one of list, tuple and Variable.
Examples:
.. code-block:: python
......@@ -665,7 +666,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
# attr shape is a list which contains Variable Tensor.
positive_2 = fluid.layers.fill_constant([1], "int32", 2)
data3 = fluid.layers.fill_constant(shape=[1, positive_2], dtype='float32', value=1.5) # data3=[1.5, 1.5]
data3 = fluid.layers.fill_constant(shape=[1, positive_2], dtype='float32', value=1.5) # data3=[[1.5, 1.5]]
# attr shape is an Variable Tensor.
shape = fluid.layers.fill_constant([2], "int32", 2) # shape=[2,2]
......@@ -1424,6 +1425,12 @@ def linspace(start, stop, num, dtype=None, name=None):
the data shape of this tensor is :math:`[num]` . If the :attr:`num` is set 1, the output tensor just has \
the value with input :attr:`start`.
Raises:
TypeError: The dtype must be one of float32 and float64.
TypeError: The dtype of `start` and `stop` must be one of float32 and float64.
TypeError: The dtype of `num` must be one of int32 and int64.
Examples:
.. code-block:: python
......
......@@ -83,26 +83,6 @@ class TestFillConstantOp4(OpTest):
self.check_output()
class TestFillConstantOp5(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
out_np = np.zeros(shape=(1), dtype='float32')
out = paddle.zeros(shape=[1], dtype="float32")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(fetch_list=[out])
self.assertEqual((result == out_np).all(), True)
with program_guard(Program()):
data = fluid.data(name="X", shape=[1], dtype="float32")
out = paddle.ones(shape=[1], out=data, dtype="float32")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(feed={"X": np.array(
[0.1], dtype="float32")},
fetch_list=[data, out])
self.assertEqual(result[0], result[1])
class TestFillConstantOpWithSelectedRows(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
......
......@@ -26,27 +26,36 @@ import numpy as np
class ApiOnesTest(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
def test_paddle_ones(self):
with paddle.program_guard(paddle.Program()):
ones = paddle.ones(shape=[10], dtype="float64")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.Executor(place)
result, = exe.run(fetch_list=[ones])
expected_result = np.ones(10, dtype="float64")
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
with paddle.program_guard(paddle.Program()):
ones = paddle.ones(shape=[10], dtype="float64")
place = paddle.CPUPlace()
exe = paddle.Executor(place)
result, = exe.run(fetch_list=[ones])
expected_result = np.ones(10, dtype="float64")
self.assertEqual((result == expected_result).all(), True)
with paddle.program_guard(paddle.Program()):
ones = paddle.ones(shape=[10], dtype="int64")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.Executor(place)
result, = exe.run(fetch_list=[ones])
expected_result = np.ones(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()):
ones = paddle.ones(shape=[10], dtype="int64", device="cpu")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
def test_fluid_ones(self):
with paddle.program_guard(paddle.Program()):
ones = fluid.layers.ones(shape=[10], dtype="int64")
place = paddle.CPUPlace()
exe = paddle.Executor(place)
result, = exe.run(fetch_list=[ones])
expected_result = np.ones(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True)
......@@ -55,25 +64,25 @@ class ApiOnesTest(unittest.TestCase):
class ApiOnesZerosError(unittest.TestCase):
def test_errors(self):
def test_error1():
with fluid.program_guard(fluid.Program()):
ones = paddle.ones(shape=10, dtype="int64", device="opu")
with paddle.program_guard(paddle.Program()):
ones = paddle.ones(shape=10, dtype="int64")
self.assertRaises(ValueError, test_error1)
self.assertRaises(TypeError, test_error1)
def test_error2():
with fluid.program_guard(fluid.Program()):
ones = paddle.ones(shape=10, dtype="int64", device="opu")
with paddle.program_guard(paddle.Program()):
ones = paddle.ones(shape=10)
self.assertRaises(ValueError, test_error2)
self.assertRaises(TypeError, test_error2)
def test_error3():
with fluid.program_guard(fluid.Program()):
with paddle.program_guard(paddle.Program()):
ones = fluid.layers.ones(shape=10, dtype="int64")
self.assertRaises(TypeError, test_error3)
def test_error4():
with fluid.program_guard(fluid.Program()):
with paddle.program_guard(paddle.Program()):
ones = fluid.layers.ones(shape=[10], dtype="int8")
self.assertRaises(TypeError, test_error4)
......
......@@ -36,26 +36,43 @@ class TestZerosOpError(unittest.TestCase):
class ApiZerosTest(unittest.TestCase):
def test_out(self):
with paddle.program_guard(fluid.Program()):
with program_guard(Program()):
zeros = paddle.zeros(shape=[10], dtype="float64")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.Executor(place)
result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="float64")
self.assertEqual((result == expected_result).all(), True)
with paddle.program_guard(fluid.Program()):
with paddle.program_guard(Program()):
zeros = paddle.zeros(shape=[10], dtype="int64")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.Executor(place)
result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True)
with paddle.program_guard(fluid.Program()):
with program_guard(Program()):
zeros = paddle.zeros(shape=[10], dtype="int64")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.Executor(place)
result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True)
with program_guard(Program()):
out_np = np.zeros(shape=(1), dtype='float32')
out = paddle.zeros(shape=[1], dtype="float32")
place = paddle.CPUPlace()
exe = paddle.Executor(place)
result = exe.run(fetch_list=[out])
self.assertEqual((result == out_np).all(), True)
def test_fluid_out(self):
with program_guard(Program()):
zeros = fluid.layers.zeros(shape=[10], dtype="int64")
place = paddle.CPUPlace()
exe = paddle.Executor(place)
result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True)
......
......@@ -75,6 +75,9 @@ def full_like(x, fill_value, dtype=None, name=None):
Returns:
out(Variable): The Tensor variable storing the output.
Raises:
TypeError: The dtype must be one of bool, float16, float32, float64, int32, int64 and None.
Examples:
.. code-block:: python
......@@ -84,7 +87,8 @@ def full_like(x, fill_value, dtype=None, name=None):
paddle.enable_imperative() # Now we are in imperative mode
input = paddle.full(shape=[2, 3], fill_value=0.0, dtype='float32', name='input')
output = paddle.full_like(input, 2.0)
#output result : [array([[2., 2., 2.], [2., 2., 2.]], dtype=float32)]
# [[2. 2. 2.]
# [2. 2. 2.]]
"""
if dtype is None:
......@@ -99,7 +103,7 @@ def full_like(x, fill_value, dtype=None, name=None):
helper = LayerHelper("full_like", **locals())
check_dtype(dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'full_like/zeros_like')
'full_like')
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
......@@ -112,7 +116,7 @@ def full_like(x, fill_value, dtype=None, name=None):
return out
def ones(shape, dtype=None, out=None, device=None):
def ones(shape, dtype=None, name=None):
"""
:alias_main: paddle.ones
:alias: paddle.ones,paddle.tensor.ones,paddle.tensor.creation.ones
......@@ -120,38 +124,44 @@ def ones(shape, dtype=None, out=None, device=None):
The OP creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 1.
Args:
shape(tuple|list): Shape of output tensor.
dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor, it supports
bool, float16, float32, float64, int32 and int64.
out(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result.
device(str, optional): Which device to run the operator. The :attr:`device` must be
None,'cpu', 'gpu'. If :attr:`device` is None, it will be choose the device that the user set in
the paddle program. Default value is False.
shape(tuple|list|Variable): Shape of output tensor, the data type of shape is int32 or int64.
dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of output tensor, it supports
bool, float16, float32, float64, int32 and int64. Default: if None, the data type is 'float32'.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Variable: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1.
Raises:
TypeError: The dtype must be one of bool, float16, float32, float64, int32, int64 and None
and the data type of out Tensor must be the same as the dtype.
TypeError: The `shape` must be one of list, tuple and Variable.
Examples:
.. code-block:: python
import paddle
data = paddle.ones(shape=[3, 2], dtype='float32') # [[1., 1.], [1., 1.], [1., 1.]]
data = paddle.ones(shape=[2, 2], dtype='float32', device='cpu') # [[1., 1.], [1., 1.]]
"""
check_dtype(dtype, 'create data type',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'zeros')
paddle.enable_imperative()
if device is not None:
if device not in ['cpu', 'gpu']:
raise ValueError(
"The value of 'device' in zeros_op must be cpu or gpu, but received %s."
% (device))
with fluid.device_guard(device):
return fill_constant(value=1.0, shape=shape, dtype=dtype, out=out)
return fill_constant(value=1.0, shape=shape, dtype=dtype, out=out)
#default dtype for ones OP
data1 = paddle.ones(shape=[3, 2])
# [[1. 1.]
# [1. 1.]
# [1. 1.]]
data2 = paddle.ones(shape=[2, 2], dtype='int32')
# [[1 1]
# [1 1]]
#shape is a Variable
shape = paddle.fill_constant(shape=[2], dtype='int32', value=2)
data3 = paddle.ones(shape=shape, dtype='int32')
# [[1 1]
# [1 1]]
"""
if dtype is None:
dtype = 'float32'
return fill_constant(value=1.0, shape=shape, dtype=dtype, name=name)
def ones_like(input, dtype=None, device=None, name=None):
......@@ -366,7 +376,7 @@ def full(shape, fill_value, dtype=None, name=None):
Raises:
TypeError: The `dtype` must be one of None, bool, float16, float32, float64, int32 and int64.
TypeError: The `shape` must be one of Variable, list tuple.
TypeError: The `shape` must be one of Variable, list and tuple.
Examples:
.. code-block:: python
......@@ -374,23 +384,28 @@ def full(shape, fill_value, dtype=None, name=None):
import paddle
paddle.enable_imperative() # Now we are in imperative mode
data1 = paddle.full(shape=[2,1], fill_value=0, dtype='int64') # data1=[[0],[0]]
data1 = paddle.full(shape=[2,1], fill_value=0, dtype='int64')
#[[0]
# [0]]
# attr shape is a list which contains Variable Tensor.
positive_2 = paddle.fill_constant([1], "int32", 2)
data3 = paddle.full(shape=[1, positive_2], dtype='float32', fill_value=1.5) # data3=[1.5, 1.5]
data3 = paddle.full(shape=[1, positive_2], dtype='float32', fill_value=1.5)
# [[1.5 1.5]]
# attr shape is an Variable Tensor.
shape = paddle.fill_constant([2], "int32", 2) # shape=[2,2]
data4 = paddle.full(shape=shape, dtype='bool', fill_value=True) # data4=[[True,True],[True,True]]
# attr value is an Variable Tensor.
val = paddle.fill_constant([1], "float32", 2.0) # val=[2.0]
data5 = paddle.full(shape=[2,1], fill_value=val, dtype='float32') #data5=[[2.0],[2.0]]
shape = paddle.fill_constant([2], "int32", 2)
data4 = paddle.full(shape=shape, dtype='bool', fill_value=True)
# [[True True]
# [True True]]
# attr fill_value is an Variable Tensor.
val = paddle.fill_constant([1], "float32", 2.0)
data5 = paddle.full(shape=[2,1], fill_value=val, dtype='float32')
# [[2.0]
# [2.0]]
"""
helper = LayerHelper("full", **locals())
if dtype is None:
dtype = 'float32'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册