未验证 提交 4c675a45 编写于 作者: L LutaoChu 提交者: GitHub

Add 2 new ops: paddle.tensor.div() and paddle.tensor.add() to API2.0(#23352)

* add new op paddle.tensor.div(x, y, out=None, name=None) 
* add gpu and dygraph unittests.
* Performance optimization: scale op is not called when alpha=1. 
* op error message optimization.
上级 036121b7
......@@ -133,8 +133,8 @@ from .tensor.math import tanh #DEFINE_ALIAS
# from .tensor.math import max #DEFINE_ALIAS
# from .tensor.math import min #DEFINE_ALIAS
# from .tensor.math import mm #DEFINE_ALIAS
# from .tensor.math import div #DEFINE_ALIAS
# from .tensor.math import add #DEFINE_ALIAS
from .tensor.math import div #DEFINE_ALIAS
from .tensor.math import add #DEFINE_ALIAS
# from .tensor.math import atan #DEFINE_ALIAS
# from .tensor.math import logsumexp #DEFINE_ALIAS
# from .tensor.math import inverse #DEFINE_ALIAS
......
......@@ -15,6 +15,7 @@
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci
import paddle.fluid as fluid
......@@ -380,5 +381,104 @@ class TestElementwiseAddOpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.elementwise_add, x2, y2)
class TestAddOp(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[3], dtype="float32")
y = fluid.data(name='y', shape=[3], dtype='float32')
res = fluid.data(name="output", shape=[3], dtype="float32")
y_1 = paddle.add(x, y, out=res)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
data1 = np.array([2, 3, 4], dtype='float32')
data2 = np.array([1, 5, 2], dtype='float32')
np_res, np_y_1 = exe.run(feed={'x': data1,
'y': data2},
fetch_list=[res, y_1])
self.assertEqual((np_res == np_y_1).all(), True)
def test_out_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[3], dtype="float32")
y = fluid.data(name='y', shape=[3], dtype='float32')
res = fluid.data(name="output", shape=[3], dtype="float32")
y_1 = paddle.add(x, y, out=res)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
data1 = np.array([2, 3, 4], dtype='float32')
data2 = np.array([1, 5, 2], dtype='float32')
np_res, np_y_1 = exe.run(feed={'x': data1,
'y': data2},
fetch_list=[res, y_1])
self.assertEqual((np_res == np_y_1).all(), True)
def test_name(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[2, 3], dtype="float32")
y = fluid.data(name='y', shape=[2, 3], dtype='float32')
y_1 = paddle.add(x, y, name='add_res')
self.assertEqual(('add_res' in y_1.name), True)
def test_alpha(self):
with fluid.program_guard(fluid.Program()):
def gen_data():
return {
"x": np.array([2, 3, 4]).astype('float32'),
"y": np.array([1, 5, 2]).astype('float32')
}
x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z = paddle.add(x, y, alpha=10)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(), fetch_list=[z.name])
z_expected = np.array([12., 53., 24.])
self.assertEqual((z_value == z_expected).all(), True)
def test_alpha_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
with fluid.program_guard(fluid.Program()):
def gen_data():
return {
"x": np.array([2, 3, 4]).astype('float32'),
"y": np.array([1, 5, 2]).astype('float32')
}
x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z = paddle.add(x, y, alpha=-0.5)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(), fetch_list=[z.name])
z_expected = np.array([1.5, 0.5, 3.])
self.assertEqual((z_value == z_expected).all(), True)
def test_dygraph(self):
with fluid.dygraph.guard():
np_x = np.array([2, 3, 4]).astype('float64')
np_y = np.array([1, 5, 2]).astype('float64')
x = fluid.dygraph.to_variable(np_x)
y = fluid.dygraph.to_variable(np_y)
z = paddle.add(x, y, alpha=-0.5)
np_z = z.numpy()
z_expected = np.array([1.5, 0.5, 3.])
self.assertEqual((np_z == z_expected).all(), True)
if __name__ == '__main__':
unittest.main()
......@@ -15,6 +15,8 @@
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci
......@@ -225,5 +227,64 @@ class TestElementwiseDivOpFp16(ElementwiseDivOp):
['X'], 'Out', max_relative_error=1, no_grad_set=set('Y'))
class TestDivOp(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[3], dtype="float32")
y = fluid.data(name='y', shape=[3], dtype='float32')
res = fluid.data(name="output", shape=[3], dtype="float32")
y_1 = paddle.div(x, y, out=res)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
data1 = np.array([2, 3, 4], dtype='float32')
data2 = np.array([1, 5, 2], dtype='float32')
np_res, np_y_1 = exe.run(feed={'x': data1,
'y': data2},
fetch_list=[res, y_1])
self.assertEqual((np_res == np_y_1).all(), True)
def test_out_gpu(self):
if not fluid.core.is_compiled_with_cuda():
return
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[3], dtype="float32")
y = fluid.data(name='y', shape=[3], dtype='float32')
res = fluid.data(name="output", shape=[3], dtype="float32")
y_1 = paddle.div(x, y, out=res)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
data1 = np.array([2, 3, 4], dtype='float32')
data2 = np.array([1, 5, 2], dtype='float32')
np_res, np_y_1 = exe.run(feed={'x': data1,
'y': data2},
fetch_list=[res, y_1])
self.assertEqual((np_res == np_y_1).all(), True)
def test_name(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[2, 3], dtype="float32")
y = fluid.data(name='y', shape=[2, 3], dtype='float32')
y_1 = paddle.div(x, y, name='div_res')
self.assertEqual(('div_res' in y_1.name), True)
def test_dygraph(self):
with fluid.dygraph.guard():
np_x = np.array([2, 3, 4]).astype('float64')
np_y = np.array([1, 5, 2]).astype('float64')
x = fluid.dygraph.to_variable(np_x)
y = fluid.dygraph.to_variable(np_y)
z = paddle.div(x, y)
np_z = z.numpy()
z_expected = np.array([2., 0.6, 2.])
self.assertEqual((np_z == z_expected).all(), True)
if __name__ == '__main__':
unittest.main()
......@@ -109,8 +109,8 @@ from .math import tanh #DEFINE_ALIAS
# from .math import max #DEFINE_ALIAS
# from .math import min #DEFINE_ALIAS
# from .math import mm #DEFINE_ALIAS
# from .math import div #DEFINE_ALIAS
# from .math import add #DEFINE_ALIAS
from .math import div #DEFINE_ALIAS
from .math import add #DEFINE_ALIAS
# from .math import atan #DEFINE_ALIAS
# from .math import logsumexp #DEFINE_ALIAS
# from .math import inverse #DEFINE_ALIAS
......
......@@ -11,14 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
math functions
"""
from __future__ import print_function
import warnings
from ..fluid.framework import OpProtoHolder, core, in_dygraph_mode
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype
from paddle.common_ops_import import *
from ..fluid.framework import core
from ..fluid.layers.layer_function_generator import _generate_doc_string_
# TODO: define math functions
......@@ -67,8 +66,8 @@ __all__ = [
# 'max',
# 'min',
# 'mm',
# 'div',
# 'add',
'div',
'add',
# 'atan',
# 'logsumexp',
# 'inverse',
......@@ -154,3 +153,345 @@ __ops__noattr__ = [
for _OP in set(__ops__noattr__):
globals()[_OP] = generate_op_noattr(_OP)
@dygraph_only
def _elementwise_op_in_dygraph(x,
y,
axis=-1,
act=None,
use_mkldnn=False,
op_name=None):
op = getattr(core.ops, op_name)
out = op(x, y, 'axis', axis, 'use_mkldnn', use_mkldnn)
return dygraph_utils._append_activation_in_dygraph(
out, act, use_mkldnn=use_mkldnn)
def _elementwise_op(helper):
op_type = helper.layer_type
original_op_type = helper.kwargs.get('original_op_type', op_type)
x = helper.kwargs.get('x', None)
y = helper.kwargs.get('y', None)
assert x is not None, 'x cannot be None in {}'.format(original_op_type)
assert y is not None, 'y cannot be None in {}'.format(original_op_type)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'],
original_op_type)
check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64', 'int32', 'int64'],
original_op_type)
axis = helper.kwargs.get('axis', -1)
use_mkldnn = helper.kwargs.get('use_mkldnn', False)
name = helper.kwargs.get('name', None)
out = helper.kwargs.get('out', None)
if out is None:
if name is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type=op_type,
inputs={'X': x,
'Y': y},
outputs={'Out': out},
attrs={'axis': axis,
'use_mkldnn': use_mkldnn})
return helper.append_activation(out)
def add(x, y, alpha=1, out=None, name=None):
"""
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
def gen_data():
return {
"x": np.array([2, 3, 4]).astype('float32'),
"y": np.array([1, 5, 2]).astype('float32')
}
x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z1 = paddle.add(x, y)
z2 = paddle.add(x, y, alpha=10)
# z = x + y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z1.name, z2.name])
print(z_value[0]) # [3., 8., 6.]
print(z_value[1]) # [12. 53. 24.]
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
def gen_data():
return {
"x": np.ones((2, 3, 4, 5)).astype('float32'),
"y": np.zeros((4, 5)).astype('float32')
}
x = fluid.data(name="x", shape=[2, 3, 4, 5], dtype='float32')
y = fluid.data(name="y", shape=[4, 5], dtype='float32')
z = paddle.add(x, y, name='z')
# z = x + y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value[0])
print(z_value[0].shape) # z.shape=[2,3,4,5]
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
def gen_data():
return {
"x": np.random.randint(1, 5, size=[2, 3, 4, 5]).astype('float32'),
"y": np.random.randint(1, 5, size=[5]).astype('float32')
}
x = fluid.data(name="x", shape=[2,3,4,5], dtype='float32')
y = fluid.data(name="y", shape=[5], dtype='float32')
z = paddle.add(x, y)
# z = x / y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value[0])
print(z_value[0].shape) # z.shape=[2,3,4,5]
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
x = fluid.data(name="x", shape=[3], dtype="float32")
y = fluid.data(name='y', shape=[3], dtype='float32')
output = fluid.data(name="output", shape=[3], dtype="float32")
z = paddle.add(x, y, out=output)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
data1 = np.array([2, 3, 4], dtype='float32')
data2 = np.array([1, 5, 2], dtype='float32')
z_value = exe.run(feed={'x': data1,
'y': data2},
fetch_list=[z])
print(z_value[0]) # [3. 8. 6.]
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
with fluid.dygraph.guard():
np_x = np.array([2, 3, 4]).astype('float64')
np_y = np.array([1, 5, 2]).astype('float64')
x = fluid.dygraph.to_variable(np_x)
y = fluid.dygraph.to_variable(np_y)
z = paddle.add(x, y, alpha=-0.5)
np_z = z.numpy()
print(np_z) # [1.5, 0.5, 3. ]
"""
op_type = 'elementwise_add'
axis = -1
act = None
if alpha != 1:
y = scale(y, scale=alpha)
if in_dygraph_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type)
original_op_type = 'add'
if name and out:
warnings.warn(
"Both name and out parameters have been set in paddle.tensor.%s, only out will take effect to specify the result storage. "
"You can discard either one to solve this warning." %
original_op_type,
category=UserWarning,
stacklevel=2)
return _elementwise_op(LayerHelper(op_type, **locals()))
def div(x, y, out=None, name=None):
"""
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
def gen_data():
return {
"x": np.array([2, 3, 4]).astype('float32'),
"y": np.array([1, 5, 2]).astype('float32')
}
x = fluid.data(name="x", shape=[3], dtype='float32')
y = fluid.data(name="y", shape=[3], dtype='float32')
z = paddle.div(x, y)
# z = x / y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value) # [2., 0.6, 2.]
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
def gen_data():
return {
"x": np.ones((2, 3, 4, 5)).astype('float32'),
"y": np.zeros((4, 5)).astype('float32')
}
x = fluid.data(name="x", shape=[2, 3, 4, 5], dtype='float32')
y = fluid.data(name="y", shape=[4, 5], dtype='float32')
z = paddle.div(x, y, name='z')
# z = x / y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value[0])
print(z_value[0].shape) # z.shape=[2,3,4,5]
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
def gen_data():
return {
"x": np.random.randint(1, 5, size=[2, 3, 4, 5]).astype('float32'),
"y": np.random.randint(1, 5, size=[5]).astype('float32')
}
x = fluid.data(name="x", shape=[2,3,4,5], dtype='float32')
y = fluid.data(name="y", shape=[5], dtype='float32')
output = fluid.data(name="output", shape=[2,3,4,5], dtype="float32")
z = paddle.div(x, y, out=output)
# z = x / y
place = fluid.CPUPlace()
exe = fluid.Executor(place)
z_value = exe.run(feed=gen_data(),
fetch_list=[z.name])
print(z_value[0])
print(z_value[0].shape) # z.shape=[2,3,4,5]
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
with fluid.dygraph.guard(fluid.CPUPlace()):
np_x = np.array([2, 3, 4]).astype('float64')
np_y = np.array([1, 5, 2]).astype('float64')
x = fluid.dygraph.to_variable(np_x)
y = fluid.dygraph.to_variable(np_y)
z = paddle.div(x, y)
np_z = z.numpy()
print(np_z) # [2., 0.6, 2.]
"""
op_type = 'elementwise_div'
axis = -1
act = None
if in_dygraph_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type)
original_op_type = 'div'
if name and out:
warnings.warn(
"Both name and out parameters have been set in paddle.tensor.%s, only out will take effect to specify the result storage. "
"You can discard either one to solve this warning." %
original_op_type,
category=UserWarning,
stacklevel=2)
return _elementwise_op(LayerHelper(op_type, **locals()))
for func in [
add,
div,
]:
proto_dict = {'add': 'elementwise_add', 'div': 'elementwise_div'}
op_proto = OpProtoHolder.instance().get_op_proto(proto_dict[func.__name__])
if func.__name__ in ['add']:
additional_args_lines = [
"alpha (int|float, optional): The alpha factor of the input. Default is 1. If alpha is not 1, the equation becomes Out = X + alpha * Y.",
"out (Variable, optinal): The Variable that stores results of the operation. Default is None. If out is None, \
a new Variable will be created to store the results."
,
"name (string, optional): Name of the output. \
Default is None. It's used to print debug info for developers. Details: \
:ref:`api_guide_Name` "
]
else:
additional_args_lines = [
"out (Variable, optinal): The Variable that stores results of the operation. If out is None, \
a new Variable will be created to store the results."
,
"name (string, optional): Name of the output. \
Default is None. It's used to print debug info for developers. Details: \
:ref:`api_guide_Name` "
]
func.__doc__ = _generate_doc_string_(
op_proto,
additional_args_lines=additional_args_lines,
skip_attrs_set={"x_data_format", "y_data_format", "axis"
}) + """\n""" + str(func.__doc__)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册