未验证 提交 b463dff4 编写于 作者: Z zhiboniu 提交者: GitHub

new API inner&outer (#37706)

上级 42cf2bee
...@@ -245,6 +245,8 @@ from .tensor.math import diff # noqa: F401 ...@@ -245,6 +245,8 @@ from .tensor.math import diff # noqa: F401
from .tensor.math import angle # noqa: F401 from .tensor.math import angle # noqa: F401
from .tensor.math import fmax # noqa: F401 from .tensor.math import fmax # noqa: F401
from .tensor.math import fmin # noqa: F401 from .tensor.math import fmin # noqa: F401
from .tensor.math import inner # noqa: F401
from .tensor.math import outer # noqa: F401
from .tensor.random import bernoulli # noqa: F401 from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401 from .tensor.random import poisson # noqa: F401
...@@ -500,6 +502,8 @@ __all__ = [ # noqa ...@@ -500,6 +502,8 @@ __all__ = [ # noqa
'lgamma', 'lgamma',
'lerp', 'lerp',
'erfinv', 'erfinv',
'inner',
'outer',
'square', 'square',
'divide', 'divide',
'ceil', 'ceil',
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.static import Program, program_guard
class TestMultiplyApi(unittest.TestCase):
def _run_static_graph_case(self, x_data, y_data):
with program_guard(Program(), Program()):
paddle.enable_static()
x = paddle.static.data(
name='x', shape=x_data.shape, dtype=x_data.dtype)
y = paddle.static.data(
name='y', shape=y_data.shape, dtype=y_data.dtype)
res = paddle.inner(x, y)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
outs = exe.run(paddle.static.default_main_program(),
feed={'x': x_data,
'y': y_data},
fetch_list=[res])
res = outs[0]
return res
def _run_dynamic_graph_case(self, x_data, y_data):
paddle.disable_static()
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
res = paddle.inner(x, y)
return res.numpy()
def test_multiply(self):
np.random.seed(7)
# test static computation graph: 3-d array
x_data = np.random.rand(2, 10, 10).astype(np.float64)
y_data = np.random.rand(2, 5, 10).astype(np.float64)
res = self._run_static_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.inner(x_data, y_data)))
# test static computation graph: 2-d array
x_data = np.random.rand(200, 5).astype(np.float64)
y_data = np.random.rand(50, 5).astype(np.float64)
res = self._run_static_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.inner(x_data, y_data)))
# test static computation graph: 1-d array
x_data = np.random.rand(50).astype(np.float64)
y_data = np.random.rand(50).astype(np.float64)
res = self._run_static_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.inner(x_data, y_data)))
# test dynamic computation graph: 3-d array
x_data = np.random.rand(5, 10, 10).astype(np.float64)
y_data = np.random.rand(2, 10).astype(np.float64)
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.inner(x_data, y_data)))
# test dynamic computation graph: 2-d array
x_data = np.random.rand(20, 50).astype(np.float64)
y_data = np.random.rand(50).astype(np.float64)
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.inner(x_data, y_data)))
# test dynamic computation graph: Scalar
x_data = np.random.rand(20, 10).astype(np.float32)
y_data = np.random.rand(1).astype(np.float32).item()
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.inner(x_data, y_data)))
# test dynamic computation graph: 2-d array Complex
x_data = np.random.rand(20,
50).astype(np.float64) + 1J * np.random.rand(
20, 50).astype(np.float64)
y_data = np.random.rand(50).astype(np.float64) + 1J * np.random.rand(
50).astype(np.float64)
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.inner(x_data, y_data)))
# test dynamic computation graph: 3-d array Complex
x_data = np.random.rand(5, 10,
10).astype(np.float64) + 1J * np.random.rand(
5, 10, 10).astype(np.float64)
y_data = np.random.rand(2, 10).astype(np.float64) + 1J * np.random.rand(
2, 10).astype(np.float64)
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.inner(x_data, y_data)))
class TestMultiplyError(unittest.TestCase):
def test_errors(self):
# test static computation graph: dtype can not be int8
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[100], dtype=np.int8)
y = paddle.static.data(name='y', shape=[100], dtype=np.int8)
self.assertRaises(TypeError, paddle.inner, x, y)
# test static computation graph: inputs must be broadcastable
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[20, 50], dtype=np.float64)
y = paddle.static.data(name='y', shape=[20], dtype=np.float64)
self.assertRaises(ValueError, paddle.inner, x, y)
np.random.seed(7)
# test dynamic computation graph: dtype can not be int8
paddle.disable_static()
x_data = np.random.randn(200).astype(np.int8)
y_data = np.random.randn(200).astype(np.int8)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
self.assertRaises(RuntimeError, paddle.inner, x, y)
# test dynamic computation graph: inputs must be broadcastable
x_data = np.random.rand(20, 5)
y_data = np.random.rand(10, 2)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
self.assertRaises(ValueError, paddle.inner, x, y)
# test dynamic computation graph: dtype must be same
x_data = np.random.randn(200).astype(np.float32)
y_data = np.random.randn(200).astype(np.float64)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
self.assertRaises(ValueError, paddle.inner, x, y)
# test dynamic computation graph: dtype must be Tensor type
x_data = np.random.randn(200).astype(np.float64)
y_data = np.random.randn(200).astype(np.float64)
y = paddle.to_tensor(y_data)
self.assertRaises(ValueError, paddle.inner, x_data, y)
# test dynamic computation graph: dtype must be Tensor type
x_data = np.random.randn(200).astype(np.float64)
y_data = np.random.randn(200).astype(np.float64)
x = paddle.to_tensor(x_data)
self.assertRaises(ValueError, paddle.inner, x, y_data)
# test dynamic computation graph: dtype must be Tensor type
x_data = np.random.randn(200).astype(np.float32)
y_data = np.random.randn(200).astype(np.float32)
self.assertRaises(ValueError, paddle.inner, x_data, y_data)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.static import Program, program_guard
class TestMultiplyApi(unittest.TestCase):
def _run_static_graph_case(self, x_data, y_data):
with program_guard(Program(), Program()):
paddle.enable_static()
x = paddle.static.data(
name='x', shape=x_data.shape, dtype=x_data.dtype)
y = paddle.static.data(
name='y', shape=y_data.shape, dtype=y_data.dtype)
res = paddle.outer(x, y)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
outs = exe.run(paddle.static.default_main_program(),
feed={'x': x_data,
'y': y_data},
fetch_list=[res])
res = outs[0]
return res
def _run_dynamic_graph_case(self, x_data, y_data):
paddle.disable_static()
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
res = paddle.outer(x, y)
return res.numpy()
def test_multiply(self):
np.random.seed(7)
# test static computation graph: 3-d array
x_data = np.random.rand(2, 10, 10).astype(np.float64)
y_data = np.random.rand(2, 5, 10).astype(np.float64)
res = self._run_static_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.outer(x_data, y_data)))
# test static computation graph: 2-d array
x_data = np.random.rand(200, 5).astype(np.float64)
y_data = np.random.rand(50, 5).astype(np.float64)
res = self._run_static_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.outer(x_data, y_data)))
# test static computation graph: 1-d array
x_data = np.random.rand(50).astype(np.float64)
y_data = np.random.rand(50).astype(np.float64)
res = self._run_static_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.outer(x_data, y_data)))
# test dynamic computation graph: 3-d array
x_data = np.random.rand(5, 10, 10).astype(np.float64)
y_data = np.random.rand(2, 10).astype(np.float64)
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.outer(x_data, y_data)))
# test dynamic computation graph: 2-d array
x_data = np.random.rand(20, 50).astype(np.float64)
y_data = np.random.rand(50).astype(np.float64)
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.outer(x_data, y_data)))
# test dynamic computation graph: Scalar
x_data = np.random.rand(20, 10).astype(np.float32)
y_data = np.random.rand(1).astype(np.float32).item()
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.outer(x_data, y_data), rtol=1e4))
# test dynamic computation graph: 2-d array Complex
x_data = np.random.rand(20,
50).astype(np.float64) + 1J * np.random.rand(
20, 50).astype(np.float64)
y_data = np.random.rand(50).astype(np.float64) + 1J * np.random.rand(
50).astype(np.float64)
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.outer(x_data, y_data)))
# test dynamic computation graph: 3-d array Complex
x_data = np.random.rand(5, 10,
10).astype(np.float64) + 1J * np.random.rand(
5, 10, 10).astype(np.float64)
y_data = np.random.rand(2, 10).astype(np.float64) + 1J * np.random.rand(
2, 10).astype(np.float64)
res = self._run_dynamic_graph_case(x_data, y_data)
self.assertTrue(np.allclose(res, np.outer(x_data, y_data)))
class TestMultiplyError(unittest.TestCase):
def test_errors(self):
# test static computation graph: dtype can not be int8
paddle.enable_static()
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[100], dtype=np.int8)
y = paddle.static.data(name='y', shape=[100], dtype=np.int8)
self.assertRaises(TypeError, paddle.outer, x, y)
np.random.seed(7)
# test dynamic computation graph: dtype can not be int8
paddle.disable_static()
x_data = np.random.randn(200).astype(np.int8)
y_data = np.random.randn(200).astype(np.int8)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
self.assertRaises(RuntimeError, paddle.outer, x, y)
# test dynamic computation graph: dtype must be same
x_data = np.random.randn(200).astype(np.float32)
y_data = np.random.randn(200).astype(np.float64)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
self.assertRaises(ValueError, paddle.outer, x, y)
# test dynamic computation graph: dtype must be Tensor type
x_data = np.random.randn(200).astype(np.float64)
y_data = np.random.randn(200).astype(np.float64)
y = paddle.to_tensor(y_data)
self.assertRaises(ValueError, paddle.outer, x_data, y)
# test dynamic computation graph: dtype must be Tensor type
x_data = np.random.randn(200).astype(np.float32)
y_data = np.random.randn(200).astype(np.float32)
x = paddle.to_tensor(x_data)
self.assertRaises(ValueError, paddle.outer, x, y_data)
# test dynamic computation graph: dtype must be Tensor type
x_data = np.random.randn(200).astype(np.float32)
y_data = np.random.randn(200).astype(np.float32)
self.assertRaises(ValueError, paddle.outer, x_data, y_data)
if __name__ == '__main__':
unittest.main()
...@@ -215,6 +215,8 @@ from .math import diff # noqa: F401 ...@@ -215,6 +215,8 @@ from .math import diff # noqa: F401
from .math import angle # noqa: F401 from .math import angle # noqa: F401
from .math import fmax # noqa: F401 from .math import fmax # noqa: F401
from .math import fmin # noqa: F401 from .math import fmin # noqa: F401
from .math import inner # noqa: F401
from .math import outer # noqa: F401
from .random import multinomial # noqa: F401 from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401 from .random import standard_normal # noqa: F401
...@@ -323,6 +325,8 @@ tensor_method_func = [ #noqa ...@@ -323,6 +325,8 @@ tensor_method_func = [ #noqa
'fmax', 'fmax',
'fmin', 'fmin',
'mm', 'mm',
'inner',
'outer',
'divide', 'divide',
'floor_divide', 'floor_divide',
'remainder', 'remainder',
......
...@@ -1195,6 +1195,129 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None): ...@@ -1195,6 +1195,129 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None):
return out return out
def inner(x, y, name=None):
"""
Inner product of two input Tensor.
Ordinary inner product for 1-D Tensors, in higher dimensions a sum product over the last axes.
Args:
x (Tensor): An N-D Tensor or a Scalar Tensor. If its not a scalar Tensor, its last dimensions must match y's.
y (Tensor): An N-D Tensor or a Scalar Tensor. If its not a scalar Tensor, its last dimensions must match x's.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Tensor: The inner-product Tensor, the output shape is x.shape[:-1] + y.shape[:-1].
Examples:
.. code-block:: python
import paddle
x = paddle.arange(1, 7).reshape((2, 3)).astype('float32')
y = paddle.arange(1, 10).reshape((3, 3)).astype('float32')
out = paddle.inner(x, y)
print(out)
# ([[14, 32, 50],
# [32, 77, 122]])
"""
if x.size == 1 or y.size == 1:
return multiply(x, y)
else:
xshape = x.shape
yshape = y.shape
dstshape = list(xshape[:-1])+list(yshape[:-1])
if len(dstshape)==0:
dstshape = [1]
nx = x.reshape((-1, xshape[-1]))
ny = y.reshape((-1, yshape[-1]))
if in_dygraph_mode():
return _C_ops.matmul_v2(nx, ny.T).reshape(dstshape)
def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(val, name,
['float16', 'float32', 'float64'], 'inner')
x_shape = list(xshape)
y_shape = list(yshape)
# check the inner 2 dimensions
if x_shape[-1] != y_shape[-1]:
if not ((x_shape[-1] == -1) or (y_shape[-1] == -1)):
raise ValueError(
"After performing an optional transpose, Input X's last dim should be "
"equal to Y's last dim for multiplication "
"prerequisites. But received X's shape: %s, Y's shape: %s\n"
% (x_shape, y_shape))
__check_input(nx, ny)
helper = LayerHelper('inner', **locals())
out = helper.create_variable_for_type_inference(dtype=nx.dtype)
helper.append_op(
type='matmul_v2', inputs={'X': nx,
'Y': ny.T}, outputs={'Out': out})
return out.reshape(dstshape)
def outer(x, y, name=None):
"""
Outer product of two Tensors.
Input is flattened if not already 1-dimensional.
Args:
x (Tensor): An N-D Tensor or a Scalar Tensor.
y (Tensor): An N-D Tensor or a Scalar Tensor.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Tensor: The outer-product Tensor.
Examples:
.. code-block:: python
import paddle
x = paddle.arange(1, 4).astype('float32')
y = paddle.arange(1, 6).astype('float32')
out = paddle.outer(x, y)
print(out)
# ([[1, 2, 3, 4, 5],
# [2, 4, 6, 8, 10],
# [3, 6, 9, 12, 15]])
"""
nx = x.reshape((-1, 1))
ny = y.reshape((1, -1))
if in_dygraph_mode():
return _C_ops.matmul_v2(nx, ny)
def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(val, name,
['float16', 'float32', 'float64'], 'inner')
__check_input(nx, ny)
helper = LayerHelper('outer', **locals())
out = helper.create_variable_for_type_inference(dtype=nx.dtype)
helper.append_op(
type='matmul_v2', inputs={'X': nx,
'Y': ny}, outputs={'Out': out})
return out
def logsumexp(x, axis=None, keepdim=False, name=None): def logsumexp(x, axis=None, keepdim=False, name=None):
r""" r"""
This OP calculates the log of the sum of exponentials of ``x`` along ``axis`` . This OP calculates the log of the sum of exponentials of ``x`` along ``axis`` .
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册