未验证 提交 2092660c 编写于 作者: W wangchaochaohu 提交者: GitHub

Ones op for API 2.0: remove the device and out parameters (#25497)

上级 4a44ffdd
...@@ -650,9 +650,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -650,9 +650,10 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
Returns: Returns:
Variable: Tensor which is created according to shape and dtype. Variable: Tensor which is created according to shape and dtype.
Raise: Raises:
TypeError: The dtype must be one of bool, float16, float32, float64, int32 and int64 TypeError: The dtype must be one of bool, float16, float32, float64, int32 and int64
and the data type of out Tensor must be the same as the dtype. and the data type of out Tensor must be the same as the dtype.
TypeError: The shape must be one of list, tuple and Variable.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -665,7 +666,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -665,7 +666,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
# attr shape is a list which contains Variable Tensor. # attr shape is a list which contains Variable Tensor.
positive_2 = fluid.layers.fill_constant([1], "int32", 2) positive_2 = fluid.layers.fill_constant([1], "int32", 2)
data3 = fluid.layers.fill_constant(shape=[1, positive_2], dtype='float32', value=1.5) # data3=[1.5, 1.5] data3 = fluid.layers.fill_constant(shape=[1, positive_2], dtype='float32', value=1.5) # data3=[[1.5, 1.5]]
# attr shape is an Variable Tensor. # attr shape is an Variable Tensor.
shape = fluid.layers.fill_constant([2], "int32", 2) # shape=[2,2] shape = fluid.layers.fill_constant([2], "int32", 2) # shape=[2,2]
...@@ -1424,6 +1425,12 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -1424,6 +1425,12 @@ def linspace(start, stop, num, dtype=None, name=None):
the data shape of this tensor is :math:`[num]` . If the :attr:`num` is set 1, the output tensor just has \ the data shape of this tensor is :math:`[num]` . If the :attr:`num` is set 1, the output tensor just has \
the value with input :attr:`start`. the value with input :attr:`start`.
Raises:
TypeError: The dtype must be one of float32 and float64.
TypeError: The dtype of `start` and `stop` must be one of float32 and float64.
TypeError: The dtype of `num` must be one of int32 and int64.
Examples: Examples:
.. code-block:: python .. code-block:: python
......
...@@ -83,26 +83,6 @@ class TestFillConstantOp4(OpTest): ...@@ -83,26 +83,6 @@ class TestFillConstantOp4(OpTest):
self.check_output() self.check_output()
class TestFillConstantOp5(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
out_np = np.zeros(shape=(1), dtype='float32')
out = paddle.zeros(shape=[1], dtype="float32")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(fetch_list=[out])
self.assertEqual((result == out_np).all(), True)
with program_guard(Program()):
data = fluid.data(name="X", shape=[1], dtype="float32")
out = paddle.ones(shape=[1], out=data, dtype="float32")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
result = exe.run(feed={"X": np.array(
[0.1], dtype="float32")},
fetch_list=[data, out])
self.assertEqual(result[0], result[1])
class TestFillConstantOpWithSelectedRows(unittest.TestCase): class TestFillConstantOpWithSelectedRows(unittest.TestCase):
def check_with_place(self, place): def check_with_place(self, place):
scope = core.Scope() scope = core.Scope()
......
...@@ -26,27 +26,36 @@ import numpy as np ...@@ -26,27 +26,36 @@ import numpy as np
class ApiOnesTest(unittest.TestCase): class ApiOnesTest(unittest.TestCase):
def test_out(self): def test_paddle_ones(self):
with fluid.program_guard(fluid.Program()): with paddle.program_guard(paddle.Program()):
ones = paddle.ones(shape=[10], dtype="float64") ones = paddle.ones(shape=[10], dtype="float64")
place = fluid.CPUPlace() place = paddle.CPUPlace()
exe = fluid.Executor(place) exe = paddle.Executor(place)
result, = exe.run(fetch_list=[ones]) result, = exe.run(fetch_list=[ones])
expected_result = np.ones(10, dtype="float64") expected_result = np.ones(10, dtype="float64")
self.assertEqual((result == expected_result).all(), True) self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()): with paddle.program_guard(paddle.Program()):
ones = paddle.ones(shape=[10], dtype="float64")
place = paddle.CPUPlace()
exe = paddle.Executor(place)
result, = exe.run(fetch_list=[ones])
expected_result = np.ones(10, dtype="float64")
self.assertEqual((result == expected_result).all(), True)
with paddle.program_guard(paddle.Program()):
ones = paddle.ones(shape=[10], dtype="int64") ones = paddle.ones(shape=[10], dtype="int64")
place = fluid.CPUPlace() place = paddle.CPUPlace()
exe = fluid.Executor(place) exe = paddle.Executor(place)
result, = exe.run(fetch_list=[ones]) result, = exe.run(fetch_list=[ones])
expected_result = np.ones(10, dtype="int64") expected_result = np.ones(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True) self.assertEqual((result == expected_result).all(), True)
with fluid.program_guard(fluid.Program()): def test_fluid_ones(self):
ones = paddle.ones(shape=[10], dtype="int64", device="cpu") with paddle.program_guard(paddle.Program()):
place = fluid.CPUPlace() ones = fluid.layers.ones(shape=[10], dtype="int64")
exe = fluid.Executor(place) place = paddle.CPUPlace()
exe = paddle.Executor(place)
result, = exe.run(fetch_list=[ones]) result, = exe.run(fetch_list=[ones])
expected_result = np.ones(10, dtype="int64") expected_result = np.ones(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True) self.assertEqual((result == expected_result).all(), True)
...@@ -55,25 +64,25 @@ class ApiOnesTest(unittest.TestCase): ...@@ -55,25 +64,25 @@ class ApiOnesTest(unittest.TestCase):
class ApiOnesZerosError(unittest.TestCase): class ApiOnesZerosError(unittest.TestCase):
def test_errors(self): def test_errors(self):
def test_error1(): def test_error1():
with fluid.program_guard(fluid.Program()): with paddle.program_guard(paddle.Program()):
ones = paddle.ones(shape=10, dtype="int64", device="opu") ones = paddle.ones(shape=10, dtype="int64")
self.assertRaises(ValueError, test_error1) self.assertRaises(TypeError, test_error1)
def test_error2(): def test_error2():
with fluid.program_guard(fluid.Program()): with paddle.program_guard(paddle.Program()):
ones = paddle.ones(shape=10, dtype="int64", device="opu") ones = paddle.ones(shape=10)
self.assertRaises(ValueError, test_error2) self.assertRaises(TypeError, test_error2)
def test_error3(): def test_error3():
with fluid.program_guard(fluid.Program()): with paddle.program_guard(paddle.Program()):
ones = fluid.layers.ones(shape=10, dtype="int64") ones = fluid.layers.ones(shape=10, dtype="int64")
self.assertRaises(TypeError, test_error3) self.assertRaises(TypeError, test_error3)
def test_error4(): def test_error4():
with fluid.program_guard(fluid.Program()): with paddle.program_guard(paddle.Program()):
ones = fluid.layers.ones(shape=[10], dtype="int8") ones = fluid.layers.ones(shape=[10], dtype="int8")
self.assertRaises(TypeError, test_error4) self.assertRaises(TypeError, test_error4)
......
...@@ -36,26 +36,43 @@ class TestZerosOpError(unittest.TestCase): ...@@ -36,26 +36,43 @@ class TestZerosOpError(unittest.TestCase):
class ApiZerosTest(unittest.TestCase): class ApiZerosTest(unittest.TestCase):
def test_out(self): def test_out(self):
with paddle.program_guard(fluid.Program()): with program_guard(Program()):
zeros = paddle.zeros(shape=[10], dtype="float64") zeros = paddle.zeros(shape=[10], dtype="float64")
place = fluid.CPUPlace() place = paddle.CPUPlace()
exe = fluid.Executor(place) exe = paddle.Executor(place)
result, = exe.run(fetch_list=[zeros]) result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="float64") expected_result = np.zeros(10, dtype="float64")
self.assertEqual((result == expected_result).all(), True) self.assertEqual((result == expected_result).all(), True)
with paddle.program_guard(fluid.Program()): with paddle.program_guard(Program()):
zeros = paddle.zeros(shape=[10], dtype="int64") zeros = paddle.zeros(shape=[10], dtype="int64")
place = fluid.CPUPlace() place = paddle.CPUPlace()
exe = fluid.Executor(place) exe = paddle.Executor(place)
result, = exe.run(fetch_list=[zeros]) result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="int64") expected_result = np.zeros(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True) self.assertEqual((result == expected_result).all(), True)
with paddle.program_guard(fluid.Program()): with program_guard(Program()):
zeros = paddle.zeros(shape=[10], dtype="int64") zeros = paddle.zeros(shape=[10], dtype="int64")
place = fluid.CPUPlace() place = paddle.CPUPlace()
exe = fluid.Executor(place) exe = paddle.Executor(place)
result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True)
with program_guard(Program()):
out_np = np.zeros(shape=(1), dtype='float32')
out = paddle.zeros(shape=[1], dtype="float32")
place = paddle.CPUPlace()
exe = paddle.Executor(place)
result = exe.run(fetch_list=[out])
self.assertEqual((result == out_np).all(), True)
def test_fluid_out(self):
with program_guard(Program()):
zeros = fluid.layers.zeros(shape=[10], dtype="int64")
place = paddle.CPUPlace()
exe = paddle.Executor(place)
result, = exe.run(fetch_list=[zeros]) result, = exe.run(fetch_list=[zeros])
expected_result = np.zeros(10, dtype="int64") expected_result = np.zeros(10, dtype="int64")
self.assertEqual((result == expected_result).all(), True) self.assertEqual((result == expected_result).all(), True)
......
...@@ -75,6 +75,9 @@ def full_like(x, fill_value, dtype=None, name=None): ...@@ -75,6 +75,9 @@ def full_like(x, fill_value, dtype=None, name=None):
Returns: Returns:
out(Variable): The Tensor variable storing the output. out(Variable): The Tensor variable storing the output.
Raises:
TypeError: The dtype must be one of bool, float16, float32, float64, int32, int64 and None.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -84,7 +87,8 @@ def full_like(x, fill_value, dtype=None, name=None): ...@@ -84,7 +87,8 @@ def full_like(x, fill_value, dtype=None, name=None):
paddle.enable_imperative() # Now we are in imperative mode paddle.enable_imperative() # Now we are in imperative mode
input = paddle.full(shape=[2, 3], fill_value=0.0, dtype='float32', name='input') input = paddle.full(shape=[2, 3], fill_value=0.0, dtype='float32', name='input')
output = paddle.full_like(input, 2.0) output = paddle.full_like(input, 2.0)
#output result : [array([[2., 2., 2.], [2., 2., 2.]], dtype=float32)] # [[2. 2. 2.]
# [2. 2. 2.]]
""" """
if dtype is None: if dtype is None:
...@@ -99,7 +103,7 @@ def full_like(x, fill_value, dtype=None, name=None): ...@@ -99,7 +103,7 @@ def full_like(x, fill_value, dtype=None, name=None):
helper = LayerHelper("full_like", **locals()) helper = LayerHelper("full_like", **locals())
check_dtype(dtype, 'dtype', check_dtype(dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'full_like/zeros_like') 'full_like')
out = helper.create_variable_for_type_inference(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op( helper.append_op(
...@@ -112,7 +116,7 @@ def full_like(x, fill_value, dtype=None, name=None): ...@@ -112,7 +116,7 @@ def full_like(x, fill_value, dtype=None, name=None):
return out return out
def ones(shape, dtype=None, out=None, device=None): def ones(shape, dtype=None, name=None):
""" """
:alias_main: paddle.ones :alias_main: paddle.ones
:alias: paddle.ones,paddle.tensor.ones,paddle.tensor.creation.ones :alias: paddle.ones,paddle.tensor.ones,paddle.tensor.creation.ones
...@@ -120,38 +124,44 @@ def ones(shape, dtype=None, out=None, device=None): ...@@ -120,38 +124,44 @@ def ones(shape, dtype=None, out=None, device=None):
The OP creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 1. The OP creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 1.
Args: Args:
shape(tuple|list): Shape of output tensor. shape(tuple|list|Variable): Shape of output tensor, the data type of shape is int32 or int64.
dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor, it supports dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of output tensor, it supports
bool, float16, float32, float64, int32 and int64. bool, float16, float32, float64, int32 and int64. Default: if None, the data type is 'float32'.
out(Variable, optional): Optional output which can be any created name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result.
device(str, optional): Which device to run the operator. The :attr:`device` must be
None,'cpu', 'gpu'. If :attr:`device` is None, it will be choose the device that the user set in
the paddle program. Default value is False.
Returns: Returns:
Variable: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1. Variable: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1.
Raises:
TypeError: The dtype must be one of bool, float16, float32, float64, int32, int64 and None
and the data type of out Tensor must be the same as the dtype.
TypeError: The `shape` must be one of list, tuple and Variable.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
data = paddle.ones(shape=[3, 2], dtype='float32') # [[1., 1.], [1., 1.], [1., 1.]] paddle.enable_imperative()
data = paddle.ones(shape=[2, 2], dtype='float32', device='cpu') # [[1., 1.], [1., 1.]]
#default dtype for ones OP
data1 = paddle.ones(shape=[3, 2])
# [[1. 1.]
# [1. 1.]
# [1. 1.]]
data2 = paddle.ones(shape=[2, 2], dtype='int32')
# [[1 1]
# [1 1]]
#shape is a Variable
shape = paddle.fill_constant(shape=[2], dtype='int32', value=2)
data3 = paddle.ones(shape=shape, dtype='int32')
# [[1 1]
# [1 1]]
""" """
check_dtype(dtype, 'create data type', if dtype is None:
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], dtype = 'float32'
'zeros') return fill_constant(value=1.0, shape=shape, dtype=dtype, name=name)
if device is not None:
if device not in ['cpu', 'gpu']:
raise ValueError(
"The value of 'device' in zeros_op must be cpu or gpu, but received %s."
% (device))
with fluid.device_guard(device):
return fill_constant(value=1.0, shape=shape, dtype=dtype, out=out)
return fill_constant(value=1.0, shape=shape, dtype=dtype, out=out)
def ones_like(input, dtype=None, device=None, name=None): def ones_like(input, dtype=None, device=None, name=None):
...@@ -366,7 +376,7 @@ def full(shape, fill_value, dtype=None, name=None): ...@@ -366,7 +376,7 @@ def full(shape, fill_value, dtype=None, name=None):
Raises: Raises:
TypeError: The `dtype` must be one of None, bool, float16, float32, float64, int32 and int64. TypeError: The `dtype` must be one of None, bool, float16, float32, float64, int32 and int64.
TypeError: The `shape` must be one of Variable, list tuple. TypeError: The `shape` must be one of Variable, list and tuple.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -374,23 +384,28 @@ def full(shape, fill_value, dtype=None, name=None): ...@@ -374,23 +384,28 @@ def full(shape, fill_value, dtype=None, name=None):
import paddle import paddle
paddle.enable_imperative() # Now we are in imperative mode paddle.enable_imperative() # Now we are in imperative mode
data1 = paddle.full(shape=[2,1], fill_value=0, dtype='int64') # data1=[[0],[0]] data1 = paddle.full(shape=[2,1], fill_value=0, dtype='int64')
#[[0]
# [0]]
# attr shape is a list which contains Variable Tensor. # attr shape is a list which contains Variable Tensor.
positive_2 = paddle.fill_constant([1], "int32", 2) positive_2 = paddle.fill_constant([1], "int32", 2)
data3 = paddle.full(shape=[1, positive_2], dtype='float32', fill_value=1.5) # data3=[1.5, 1.5] data3 = paddle.full(shape=[1, positive_2], dtype='float32', fill_value=1.5)
# [[1.5 1.5]]
# attr shape is an Variable Tensor. # attr shape is an Variable Tensor.
shape = paddle.fill_constant([2], "int32", 2) # shape=[2,2] shape = paddle.fill_constant([2], "int32", 2)
data4 = paddle.full(shape=shape, dtype='bool', fill_value=True) # data4=[[True,True],[True,True]] data4 = paddle.full(shape=shape, dtype='bool', fill_value=True)
# [[True True]
# [True True]]
# attr value is an Variable Tensor. # attr fill_value is an Variable Tensor.
val = paddle.fill_constant([1], "float32", 2.0) # val=[2.0] val = paddle.fill_constant([1], "float32", 2.0)
data5 = paddle.full(shape=[2,1], fill_value=val, dtype='float32') #data5=[[2.0],[2.0]] data5 = paddle.full(shape=[2,1], fill_value=val, dtype='float32')
# [[2.0]
# [2.0]]
""" """
helper = LayerHelper("full", **locals())
if dtype is None: if dtype is None:
dtype = 'float32' dtype = 'float32'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册