未验证 提交 08e3d9c0 编写于 作者: W wawltor 提交者: GitHub

Add the matmul, elementwise_euqal, elementwise_sum ops to API2.0 (#23437)

* Add the matmul, elementwise_euqal, elementwise_sum ops to API2.0
* Fix the import meesage in common_ops_import
* Update the test cast for mm
上级 d223a249
...@@ -78,7 +78,7 @@ from .tensor.logic import equal #DEFINE_ALIAS ...@@ -78,7 +78,7 @@ from .tensor.logic import equal #DEFINE_ALIAS
# from .tensor.logic import reduce_all #DEFINE_ALIAS # from .tensor.logic import reduce_all #DEFINE_ALIAS
# from .tensor.logic import reduce_any #DEFINE_ALIAS # from .tensor.logic import reduce_any #DEFINE_ALIAS
from .tensor.logic import allclose #DEFINE_ALIAS from .tensor.logic import allclose #DEFINE_ALIAS
# from .tensor.logic import elementwise_equal #DEFINE_ALIAS from .tensor.logic import elementwise_equal #DEFINE_ALIAS
# from .tensor.logic import isnan #DEFINE_ALIAS # from .tensor.logic import isnan #DEFINE_ALIAS
# from .tensor..tensor import Tensor #DEFINE_ALIAS # from .tensor..tensor import Tensor #DEFINE_ALIAS
# from .tensor..tensor import LoDTensor #DEFINE_ALIAS # from .tensor..tensor import LoDTensor #DEFINE_ALIAS
...@@ -129,10 +129,10 @@ from .tensor.math import sqrt #DEFINE_ALIAS ...@@ -129,10 +129,10 @@ from .tensor.math import sqrt #DEFINE_ALIAS
from .tensor.math import sum #DEFINE_ALIAS from .tensor.math import sum #DEFINE_ALIAS
# from .tensor.math import sums #DEFINE_ALIAS # from .tensor.math import sums #DEFINE_ALIAS
from .tensor.math import tanh #DEFINE_ALIAS from .tensor.math import tanh #DEFINE_ALIAS
# from .tensor.math import elementwise_sum #DEFINE_ALIAS from .tensor.math import elementwise_sum #DEFINE_ALIAS
# from .tensor.math import max #DEFINE_ALIAS # from .tensor.math import max #DEFINE_ALIAS
# from .tensor.math import min #DEFINE_ALIAS # from .tensor.math import min #DEFINE_ALIAS
# from .tensor.math import mm #DEFINE_ALIAS from .tensor.math import mm #DEFINE_ALIAS
from .tensor.math import div #DEFINE_ALIAS from .tensor.math import div #DEFINE_ALIAS
from .tensor.math import add #DEFINE_ALIAS from .tensor.math import add #DEFINE_ALIAS
# from .tensor.math import atan #DEFINE_ALIAS # from .tensor.math import atan #DEFINE_ALIAS
......
...@@ -34,6 +34,7 @@ from .. import unique_name ...@@ -34,6 +34,7 @@ from .. import unique_name
from functools import reduce from functools import reduce
from .. import core from .. import core
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
import paddle
__all__ = [ __all__ = [
'fc', 'fc',
...@@ -10155,16 +10156,7 @@ def sum(x): ...@@ -10155,16 +10156,7 @@ def sum(x):
# and '__int64' on Windows. They both represent 64-bit integer variables. # and '__int64' on Windows. They both represent 64-bit integer variables.
""" """
helper = LayerHelper('sum', **locals()) return paddle.elementwise_sum(x)
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('x'))
helper.append_op(
type='sum',
inputs={'X': x},
outputs={'Out': out},
attrs={'use_mkldnn': False})
return out
@templatedoc() @templatedoc()
......
...@@ -17,6 +17,8 @@ from __future__ import print_function ...@@ -17,6 +17,8 @@ from __future__ import print_function
import op_test import op_test
import unittest import unittest
import numpy import numpy
import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
...@@ -58,5 +60,26 @@ class TestCompareOpError(unittest.TestCase): ...@@ -58,5 +60,26 @@ class TestCompareOpError(unittest.TestCase):
self.assertRaises(TypeError, fluid.layers.greater_equal, x, y) self.assertRaises(TypeError, fluid.layers.greater_equal, x, y)
class API_TestElementwise_Equal(unittest.TestCase):
def test_api(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
label = fluid.layers.assign(np.array([3, 3], dtype="int32"))
limit = fluid.layers.assign(np.array([3, 2], dtype="int32"))
out = paddle.elementwise_equal(x=label, y=limit)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
res, = exe.run(fetch_list=[out])
self.assertEqual((res == np.array([True, False])).all(), True)
with fluid.program_guard(fluid.Program(), fluid.Program()):
label = fluid.layers.assign(np.array([3, 3], dtype="int32"))
limit = fluid.layers.assign(np.array([3, 3], dtype="int32"))
out = paddle.elementwise_equal(x=label, y=limit)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
res, = exe.run(fetch_list=[out])
self.assertEqual((res == np.array([True, True])).all(), True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
...@@ -241,5 +242,67 @@ for dim in [4]: ...@@ -241,5 +242,67 @@ for dim in [4]:
'transpose_Y': transpose_Y, 'transpose_Y': transpose_Y,
}) })
class API_TestMm(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[3, 2], dtype="float32")
y = fluid.data(name='y', shape=[2, 3], dtype='float32')
res = fluid.data(name="output", shape=[3, 3], dtype="float32")
y_1 = paddle.mm(x, y, out=res)
exe = fluid.Executor(fluid.CPUPlace())
data1 = np.random.rand(3, 2).astype('float32')
data2 = np.random.rand(2, 3).astype('float32')
np_res, np_y_1 = exe.run(feed={'x': data1,
'y': data2},
fetch_list=[res, y_1])
self.assertEqual((np_res == np_y_1).all(), True)
with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[2], dtype="float32")
y = fluid.data(name='y', shape=[2], dtype='float32')
res = fluid.data(name="output", shape=[1], dtype="float32")
result = paddle.mm(x, y)
exe = fluid.Executor(fluid.CPUPlace())
data1 = np.random.rand(2).astype('float32')
data2 = np.random.rand(2).astype('float32')
np_res = exe.run(feed={'x': data1, 'y': data2}, fetch_list=[result])
expected_result = np.matmul(
data1.reshape(1, 2), data2.reshape(2, 1))
self.assertEqual((np_res == expected_result).all(), True)
class API_TestMmError(unittest.TestCase):
def test_errors(self):
def test_error1():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.data(name="data1", shape=[10, 2], dtype="float32")
data2 = fluid.data(name="data2", shape=[3, 10], dtype="float32")
paddle.mm(data1, data2)
self.assertRaises(ValueError, test_error1)
def test_error2():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.data(
name="data1", shape=[-1, 10, 2], dtype="float32")
data2 = fluid.data(
name="data2", shape=[-1, 2, 10], dtype="float32")
paddle.mm(data1, data2)
test_error2()
def test_error3():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.data(
name="data1", shape=[10, 10, 2], dtype="float32")
data2 = fluid.data(
name="data2", shape=[3, 2, 10], dtype="float32")
paddle.mm(data1, data2)
self.assertRaises(ValueError, test_error3)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -17,6 +17,8 @@ from __future__ import print_function ...@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
...@@ -223,6 +225,22 @@ def create_test_sum_fp16_class(parent): ...@@ -223,6 +225,22 @@ def create_test_sum_fp16_class(parent):
globals()[cls_name] = TestSumFp16Case globals()[cls_name] = TestSumFp16Case
class API_Test_Elementwise_Sum(unittest.TestCase):
def test_api(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input0 = fluid.layers.fill_constant(
shape=[2, 3], dtype='int64', value=5)
input1 = fluid.layers.fill_constant(
shape=[2, 3], dtype='int64', value=3)
expected_result = np.empty((2, 3))
expected_result.fill(8)
sum_value = paddle.elementwise_sum([input0, input1])
exe = fluid.Executor(fluid.CPUPlace())
result = exe.run(fetch_list=[sum_value])
self.assertEqual((result == expected_result).all(), True)
create_test_sum_fp16_class(TestSelectedRowsSumOp) create_test_sum_fp16_class(TestSelectedRowsSumOp)
create_test_sum_fp16_class(TestLoDTensorAndSelectedRowsOp) create_test_sum_fp16_class(TestLoDTensorAndSelectedRowsOp)
......
...@@ -54,7 +54,7 @@ from .logic import equal #DEFINE_ALIAS ...@@ -54,7 +54,7 @@ from .logic import equal #DEFINE_ALIAS
# from .logic import reduce_all #DEFINE_ALIAS # from .logic import reduce_all #DEFINE_ALIAS
# from .logic import reduce_any #DEFINE_ALIAS # from .logic import reduce_any #DEFINE_ALIAS
from .logic import allclose #DEFINE_ALIAS from .logic import allclose #DEFINE_ALIAS
# from .logic import elementwise_equal #DEFINE_ALIAS from .logic import elementwise_equal #DEFINE_ALIAS
# from .logic import isnan #DEFINE_ALIAS # from .logic import isnan #DEFINE_ALIAS
# from . import Tensor #DEFINE_ALIAS # from . import Tensor #DEFINE_ALIAS
# from . import LoDTensor #DEFINE_ALIAS # from . import LoDTensor #DEFINE_ALIAS
...@@ -105,10 +105,10 @@ from .math import sqrt #DEFINE_ALIAS ...@@ -105,10 +105,10 @@ from .math import sqrt #DEFINE_ALIAS
from .math import sum #DEFINE_ALIAS from .math import sum #DEFINE_ALIAS
# from .math import sums #DEFINE_ALIAS # from .math import sums #DEFINE_ALIAS
from .math import tanh #DEFINE_ALIAS from .math import tanh #DEFINE_ALIAS
# from .math import elementwise_sum #DEFINE_ALIAS from .math import elementwise_sum #DEFINE_ALIAS
# from .math import max #DEFINE_ALIAS # from .math import max #DEFINE_ALIAS
# from .math import min #DEFINE_ALIAS # from .math import min #DEFINE_ALIAS
# from .math import mm #DEFINE_ALIAS from .math import mm #DEFINE_ALIAS
from .math import div #DEFINE_ALIAS from .math import div #DEFINE_ALIAS
from .math import add #DEFINE_ALIAS from .math import add #DEFINE_ALIAS
# from .math import atan #DEFINE_ALIAS # from .math import atan #DEFINE_ALIAS
......
...@@ -33,7 +33,7 @@ __all__ = [ ...@@ -33,7 +33,7 @@ __all__ = [
# 'reduce_all', # 'reduce_all',
# 'reduce_any', # 'reduce_any',
'allclose', 'allclose',
# 'elementwise_equal', 'elementwise_equal',
# 'isnan' # 'isnan'
] ]
...@@ -186,3 +186,40 @@ def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): ...@@ -186,3 +186,40 @@ def allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False, name=None):
type='allclose', inputs=inputs, outputs=outputs, attrs=attrs) type='allclose', inputs=inputs, outputs=outputs, attrs=attrs)
return out return out
def elementwise_equal(x, y, name=None):
"""
This layer returns the truth value of :math:`x == y` elementwise.
Args:
x(Variable): Tensor, data type is float32, float64, int32, int64.
y(Variable): Tensor, data type is float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable: output Tensor, it's shape is the same as the input's Tensor,
and the data type is bool. The result of this op is stop_gradient.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
label = fluid.layers.assign(np.array([3, 3], dtype="int32"))
limit = fluid.layers.assign(np.array([3, 2], dtype="int32"))
out1 = paddle.elementwise_equal(x=label, y=limit) #out1=[True, False]
"""
helper = LayerHelper("elementwise_equal", **locals())
out = helper.create_variable_for_type_inference(dtype='bool')
out.stop_gradient = True
helper.append_op(
type='equal',
inputs={'X': [x],
'Y': [y]},
outputs={'Out': [out]},
attrs={'force_cpu': False})
return out
...@@ -63,10 +63,10 @@ __all__ = [ ...@@ -63,10 +63,10 @@ __all__ = [
'sum', 'sum',
# 'sums', # 'sums',
'tanh', 'tanh',
# 'elementwise_sum', 'elementwise_sum',
# 'max', # 'max',
# 'min', # 'min',
# 'mm', 'mm',
'div', 'div',
'add', 'add',
# 'atan', # 'atan',
...@@ -747,3 +747,186 @@ def sum(input, dim=None, dtype=None, keep_dim=False, name=None): ...@@ -747,3 +747,186 @@ def sum(input, dim=None, dtype=None, keep_dim=False, name=None):
outputs={'Out': out}, outputs={'Out': out},
attrs=attrs) attrs=attrs)
return out return out
@templatedoc(op_type="sum")
def elementwise_sum(inputs, name=None):
"""
${comment}
Case 1:
::
Input:
Input. Shape = [2, 3]
Input = [[1, 2, 3],
[4, 5, 6]]
Output:
The output. Shape = [2, 3]
Output = [[1, 2, 3],
[4, 5, 6]]
Case 2:
::
Input:
First input:
Input1. Shape = [2, 3]
Input1 = [[1, 2, 3],
[4, 5, 6]]
The second input:
Input2. Shape = [2, 3]
Input2 = [[7, 8, 9],
[10, 11, 12]]
Output:
The output. Shape = [2, 3]
Output = [[8, 10, 12],
[14, 16, 18]]
Args:
inputs (Variable|list(Variable)): A Varaible list. The shape and data type of the list elementsshould be consistent.
Variable can be multi-dimensional Tensoror LoDTensor, and data types can be: float32, float64, int32, int64.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Variable: the sum of input :math:`inputs` . its shape and data types are consistent with :math:`inputs` .
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
input0 = fluid.layers.fill_constant(shape=[2, 3], dtype='int64', value=5)
input1 = fluid.layers.fill_constant(shape=[2, 3], dtype='int64', value=3)
sum = paddle.elementwise_sum([input0, input1])
# You can print out 'sum' via executor.
out = fluid.layers.Print(sum, message="the sum of input0 and input1: ")
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_main_program())
# The printed result is:
# 1570701754 the sum of input0 and input1: The place is:CPUPlace
# Tensor[elementwise_sum_0.tmp_0]
# shape: [2,3,]
# dtype: l
# data: 8,8,8,8,8,8,
# the sum of input0 and input1 is 2-D Tensor with shape [2,3].
# dtype is the corresponding C++ data type, which may vary in different environments.
# Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t,
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# and '__int64' on Windows. They both represent 64-bit integer variables.
"""
helper = LayerHelper('elementwise_sum', **locals())
out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('inputs'))
helper.append_op(
type='sum',
inputs={'X': inputs},
outputs={'Out': out},
attrs={'use_mkldnn': False})
return out
def mm(input, mat2, out=None, name=None):
"""
Applies matrix multiplication to two tensors.
Currently, the input tensors' rank can be any, but when the rank of any
inputs is bigger than 3, this two inputs' rank should be equal.
Also note that if the raw tensor :math:`x` or :math:`mat2` is rank-1 and
nontransposed, the prepended or appended dimension :math:`1` will be
removed after matrix multiplication.
Args:
x (Variable): The input variable which is a Tensor or LoDTensor.
mat2 (Variable): The input variable which is a Tensor or LoDTensor.
out(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Variable: The product Tensor (or LoDTensor) variable.
Examples:
.. code-block:: python
# Examples to clarify shapes of the inputs and output
# x: [B, ..., M, K], mat2: [B, ..., K, N]
# fluid.layers.matmul(x, mat2) # out: [B, ..., M, N]
# x: [B, M, K], mat2: [B, K, N]
# fluid.layers.matmul(x, mat2) # out: [B, M, N]
# x: [B, M, K], mat2: [K, N]
# fluid.layers.matmul(x, mat2) # out: [B, M, N]
# x: [M, K], mat2: [K, N]
# fluid.layers.matmul(x, mat2) # out: [M, N]
# x: [B, M, K], mat2: [K]
# fluid.layers.matmul(x, mat2) # out: [B, M]
# x: [K], mat2: [K]
# fluid.layers.matmul(x, mat2) # out: [1]
import paddle
import paddle.fluid as fluid
x = fluid.data(name='x', shape=[2, 3], dtype='float32')
mat2 = fluid.data(name='mat2', shape=[3, 2], dtype='float32')
out = paddle.mm(x, mat2) # out shape is [2, 2]
"""
if in_dygraph_mode():
return core.ops.matmul(input, mat2)
def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(val, name,
['float16', 'float32', 'float64'], 'mm')
x_shape = list(x.shape)
y_shape = list(y.shape)
if len(x_shape) == 1:
x_shape = [1] + x_shape
if len(y_shape) == 1:
y_shape = y_shape + [1]
# check the inner 2 dimensions
if x_shape[-1] != y_shape[-2]:
if not ((x_shape[-1] == -1) or (y_shape[-2] == -1)):
raise ValueError(
"After performing an optional transpose, Input X's width should be "
"equal to Y's width for multiplication "
"prerequisites. But received X's shape: %s, Y's shape: %s\n"
% (x_shape, y_shape))
if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]):
# don't check neg shape
if dim_x < 0 or y_shape[i] < 0:
continue
if dim_x != y_shape[i]:
raise ValueError(
"When the matrix is larger than 2 dimensions, the higher "
"dimensional values of the two matrices need to be equal. "
"But received x_shape[%d] != y_shape[%d]. X's shape: %s, "
"Y's shape: %s.\n" % (i, i, x_shape, y_shape))
__check_input(input, mat2)
helper = LayerHelper('mm', **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='matmul', inputs={'X': input,
'Y': mat2}, outputs={'Out': out})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册