未验证 提交 2df74aa6 编写于 作者: H Haohongxiang 提交者: GitHub

Support new API linalg.cond in paddle (#35140)

* Support new API linalg.cond in paddle

* check code style

* check code style

* modify codes

* add docs_eng of linalg.cond

* add svd_norm for linalg.cond

* modify docs_en of cond

* add support for empty input in dynamic mode

* modify set_time of unittest

* update

* modify unittest of cond

* update

* remove cond in paddle.__all__

* pull latest codes

* merge latest codes

* update
上级 07d0b834
......@@ -93,6 +93,7 @@ from .tensor.linalg import dot # noqa: F401
from .tensor.linalg import norm # noqa: F401
from .tensor.linalg import transpose # noqa: F401
from .tensor.linalg import dist # noqa: F401
from .tensor.linalg import cond # noqa: F401
from .tensor.linalg import t # noqa: F401
from .tensor.linalg import cross # noqa: F401
from .tensor.linalg import cholesky # noqa: F401
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.static as static
p_list_n_n = ("fro", "nuc", 1, -1, np.inf, -np.inf)
p_list_m_n = (None, 2, -2)
def test_static_assert_true(self, x_list, p_list):
for p in p_list:
for x in x_list:
with static.program_guard(static.Program(), static.Program()):
input_data = static.data("X", shape=x.shape, dtype=x.dtype)
output = paddle.cond(input_data, p)
exe = static.Executor()
result = exe.run(feed={"X": x}, fetch_list=[output])
expected_output = np.linalg.cond(x, p)
self.assertTrue(np.allclose(result, expected_output))
def test_dygraph_assert_true(self, x_list, p_list):
for p in p_list:
for x in x_list:
input_tensor = paddle.to_tensor(x)
output = paddle.cond(input_tensor, p)
expected_output = np.linalg.cond(x, p)
self.assertTrue(np.allclose(output, expected_output))
def gen_input():
# generate square matrix or batches of square matrices
input_1 = np.random.rand(5, 5).astype('float32')
input_2 = np.random.rand(3, 6, 6).astype('float64')
input_3 = np.random.rand(2, 4, 3, 3).astype('float32')
# generate non-square matrix or batches of non-square matrices
input_4 = np.random.rand(9, 7).astype('float64')
input_5 = np.random.rand(4, 2, 10).astype('float32')
input_6 = np.random.rand(3, 5, 4, 1).astype('float32')
list_n_n = (input_1, input_2, input_3)
list_m_n = (input_4, input_5, input_6)
return list_n_n, list_m_n
def gen_empty_input():
# generate square matrix or batches of square matrices which are empty tensor
input_1 = np.random.rand(0, 7, 7).astype('float32')
input_2 = np.random.rand(0, 9, 9).astype('float32')
input_3 = np.random.rand(0, 4, 5, 5).astype('float64')
# generate non-square matrix or batches of non-square matrices which are empty tensor
input_4 = np.random.rand(0, 7, 11).astype('float32')
input_5 = np.random.rand(0, 10, 8).astype('float64')
input_6 = np.random.rand(5, 0, 4, 3).astype('float32')
list_n_n = (input_1, input_2, input_3)
list_m_n = (input_4, input_5, input_6)
return list_n_n, list_m_n
class API_TestStaticCond(unittest.TestCase):
def test_out(self):
paddle.enable_static()
# test calling results of 'cond' in static mode
x_list_n_n, x_list_m_n = gen_input()
test_static_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n)
test_static_assert_true(self, x_list_m_n, p_list_m_n)
class API_TestDygraphCond(unittest.TestCase):
def test_out(self):
paddle.disable_static()
# test calling results of 'cond' in dynamic mode
x_list_n_n, x_list_m_n = gen_input()
test_dygraph_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n)
test_dygraph_assert_true(self, x_list_m_n, p_list_m_n)
class TestCondAPIError(unittest.TestCase):
def test_dygraph_api_error(self):
paddle.disable_static()
# test raising errors when 'cond' is called in dygraph mode
p_list_error = ('fro_', '_nuc', -0.7, 0, 1.5, 3)
x_list_n_n, x_list_m_n = gen_input()
for p in p_list_error:
for x in (x_list_n_n + x_list_m_n):
x_tensor = paddle.to_tensor(x)
self.assertRaises(ValueError, paddle.cond, x_tensor, p)
for p in p_list_n_n:
for x in x_list_m_n:
x_tensor = paddle.to_tensor(x)
self.assertRaises(ValueError, paddle.cond, x_tensor, p)
def test_static_api_error(self):
paddle.enable_static()
# test raising errors when 'cond' is called in static mode
p_list_error = ('f ro', 'fre', 'NUC', -1.6, 0, 5)
x_list_n_n, x_list_m_n = gen_input()
for p in p_list_error:
for x in (x_list_n_n + x_list_m_n):
with static.program_guard(static.Program(), static.Program()):
x_data = static.data("X", shape=x.shape, dtype=x.dtype)
self.assertRaises(ValueError, paddle.cond, x_data, p)
for p in p_list_n_n:
for x in x_list_m_n:
with static.program_guard(static.Program(), static.Program()):
x_data = static.data("X", shape=x.shape, dtype=x.dtype)
self.assertRaises(ValueError, paddle.cond, x_data, p)
# it's not supported when input is an empty tensor in static mode
def test_static_empty_input_error(self):
paddle.enable_static()
x_list_n_n, x_list_m_n = gen_empty_input()
for p in (p_list_n_n + p_list_m_n):
for x in x_list_n_n:
with static.program_guard(static.Program(), static.Program()):
x_data = static.data("X", shape=x.shape, dtype=x.dtype)
self.assertRaises(ValueError, paddle.cond, x_data, p)
for p in (p_list_n_n + p_list_m_n):
for x in x_list_n_n:
with static.program_guard(static.Program(), static.Program()):
x_data = static.data("X", shape=x.shape, dtype=x.dtype)
self.assertRaises(ValueError, paddle.cond, x_data, p)
class TestCondEmptyTensorInput(unittest.TestCase):
def test_dygraph_empty_tensor_input(self):
paddle.disable_static()
# test calling results of 'cond' when input is an empty tensor in dynamic mode
x_list_n_n, x_list_m_n = gen_empty_input()
test_dygraph_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n)
test_dygraph_assert_true(self, x_list_m_n, p_list_m_n)
if __name__ == "__main__":
paddle.enable_static()
# paddle.device.set_device("cpu")
unittest.main()
......@@ -14,6 +14,7 @@
from .tensor.linalg import cholesky # noqa: F401
from .tensor.linalg import norm # noqa: F401
from .tensor.linalg import cond # noqa: F401
from .tensor.linalg import matrix_power # noqa: F401
from .tensor import inverse as inv # noqa: F401
from .tensor.linalg import multi_dot # noqa: F401
......@@ -24,6 +25,7 @@ from .tensor.linalg import eigh # noqa: F401
__all__ = [
'cholesky', #noqa
'norm',
'cond',
'inv',
'multi_dot',
'matrix_rank',
......
......@@ -36,6 +36,7 @@ from .creation import empty_like # noqa: F401
from .linalg import matmul # noqa: F401
from .linalg import dot # noqa: F401
from .linalg import norm # noqa: F401
from .linalg import cond # noqa: F401
from .linalg import transpose # noqa: F401
from .linalg import dist # noqa: F401
from .linalg import t # noqa: F401
......@@ -219,6 +220,7 @@ tensor_method_func = [ #noqa
'matmul',
'dot',
'norm',
'cond',
'transpose',
'dist',
't',
......
......@@ -543,6 +543,323 @@ def dist(x, y, p=2):
return out
def cond(x, p=None, name=None):
"""
Computes the condition number of a matrix or batches of matrices with respect to a matrix norm ``p``.
Args:
x (Tensor): The input tensor could be tensor of shape ``(*, m, n)`` where ``*`` is zero or more batch dimensions
for ``p`` in ``(2, -2)``, or of shape ``(*, n, n)`` where every matrix is invertible for any supported ``p``.
And the input data type could be ``float32`` or ``float64``.
p (float|string, optional): Order of the norm. Supported values are `fro`, `nuc`, `1`, `-1`, `2`, `-2`,
`inf`, `-inf`. Default value is `None`, meaning that the order of the norm is `2`.
name (str, optional): The default value is `None`. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: computing results of condition number, its data type is the same as input Tensor ``x``.
Examples:
.. code-block:: python
import paddle
import numpy as np
x = paddle.to_tensor([[1., 0, -1], [0, 1, 0], [1, 0, 1]])
# compute conditional number when p is None
out = paddle.linalg.cond(x)
# out.numpy() [1.4142135]
# compute conditional number when order of the norm is 'fro'
out_fro = paddle.linalg.cond(x, p='fro')
# out_fro.numpy() [3.1622777]
# compute conditional number when order of the norm is 'nuc'
out_nuc = paddle.linalg.cond(x, p='nuc')
# out_nuc.numpy() [9.2426405]
# compute conditional number when order of the norm is 1
out_1 = paddle.linalg.cond(x, p=1)
# out_1.numpy() [2.]
# compute conditional number when order of the norm is -1
out_minus_1 = paddle.linalg.cond(x, p=-1)
# out_minus_1.numpy() [1.]
# compute conditional number when order of the norm is 2
out_2 = paddle.linalg.cond(x, p=2)
# out_2.numpy() [1.4142135]
# compute conditional number when order of the norm is -1
out_minus_2 = paddle.linalg.cond(x, p=-2)
# out_minus_2.numpy() [0.70710677]
# compute conditional number when order of the norm is inf
out_inf = paddle.linalg.cond(x, p=np.inf)
# out_inf.numpy() [2.]
# compute conditional number when order of the norm is -inf
out_minus_inf = paddle.linalg.cond(x, p=-np.inf)
# out_minus_inf.numpy() [1.]
a = paddle.to_tensor(np.random.randn(2, 4, 4).astype('float32'))
# a.numpy()
# [[[ 0.14063153 -0.996288 0.7996131 -0.02571543]
# [-0.16303636 1.5534962 -0.49919784 -0.04402903]
# [-1.1341571 -0.6022629 0.5445269 0.29154757]
# [-0.16816919 -0.30972657 1.7521842 -0.5402487 ]]
# [[-0.58081484 0.12402827 0.7229862 -0.55046535]
# [-0.15178485 -1.1604939 0.75810957 0.30971205]
# [-0.9669573 1.0940945 -0.27363303 -0.35416734]
# [-1.216529 2.0018666 -0.7773689 -0.17556527]]]
a_cond_fro = paddle.linalg.cond(a, p='fro')
# a_cond_fro.numpy() [31.572273 28.120834]
b = paddle.to_tensor(np.random.randn(2, 3, 4).astype('float64'))
# b.numpy()
# [[[ 1.61707487 0.46829144 0.38130416 0.82546736]
# [-1.72710298 0.08866375 -0.62518804 0.16128892]
# [-0.02822879 -1.67764516 0.11141444 0.3220113 ]]
# [[ 0.22524372 0.62474921 -0.85503233 -1.03960523]
# [-0.76620689 0.56673047 0.85064753 -0.45158196]
# [ 1.47595418 2.23646462 1.5701758 0.10497519]]]
b_cond_2 = paddle.linalg.cond(b, p=2)
# b_cond_2.numpy() [3.30064451 2.51976252]
"""
def mat_norm(input, porder=1., axis=None):
"""
NOTE:
Calculate the matrix norm of a square matrix or batches of square matrices,
when porder is in (1, -1, inf, -inf)
"""
reduce_all = True if axis is None or axis == [] else False
axis = axis if axis != None and axis != [] else [0]
keepdim = False
if in_dygraph_mode():
abs_out = _C_ops.abs(input)
sum_out = _C_ops.reduce_sum(abs_out, 'dim', axis, 'keepdim',
keepdim, 'reduce_all', reduce_all)
if porder == 1 or porder == np.inf:
return _C_ops.reduce_max(sum_out, 'dim', [-1], 'keepdim',
keepdim, 'reduce_all', reduce_all)
if porder == -1 or porder == -np.inf:
return _C_ops.reduce_min(sum_out, 'dim', [-1], 'keepdim',
keepdim, 'reduce_all', reduce_all)
block = LayerHelper('norm', **locals())
abs_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
sum_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='abs', inputs={'X': input}, outputs={'Out': abs_out})
block.append_op(
type='reduce_sum',
inputs={'X': abs_out},
outputs={'Out': sum_out},
attrs={'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all})
if porder == 1 or porder == np.inf:
block.append_op(
type='reduce_max',
inputs={'X': sum_out},
outputs={'Out': out},
attrs={
'dim': [-1],
'keep_dim': keepdim,
'reduce_all': reduce_all
})
if porder == -1 or porder == -np.inf:
block.append_op(
type='reduce_min',
inputs={'X': sum_out},
outputs={'Out': out},
attrs={
'dim': [-1],
'keep_dim': keepdim,
'reduce_all': reduce_all
})
return out
def fro_norm(input, porder=2, axis=[-1]):
"""
NOTE:
Calculate the frobenius norm of a square matrix or batches of square matrices.
"""
reduce_all = True if axis is None or axis == [] else False
keepdim = False
if in_dygraph_mode():
pow_out = _C_ops.pow(input, 'factor', porder)
sum_out_1 = _C_ops.reduce_sum(pow_out, 'dim', axis, 'keepdim',
keepdim, 'reduce_all', reduce_all)
sum_out_2 = _C_ops.reduce_sum(sum_out_1, 'dim', axis, 'keepdim',
keepdim, 'reduce_all', reduce_all)
return _C_ops.pow(sum_out_2, 'factor', float(1. / porder))
block = LayerHelper('norm', **locals())
pow_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
sum_out_1 = block.create_variable_for_type_inference(
dtype=block.input_dtype())
sum_out_2 = block.create_variable_for_type_inference(
dtype=block.input_dtype())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='pow',
inputs={'X': input},
outputs={'Out': pow_out},
attrs={'factor': porder})
block.append_op(
type='reduce_sum',
inputs={'X': pow_out},
outputs={'Out': sum_out_1},
attrs={'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all})
block.append_op(
type='reduce_sum',
inputs={'X': sum_out_1},
outputs={'Out': sum_out_2},
attrs={'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all})
block.append_op(
type='pow',
inputs={'X': sum_out_2},
outputs={'Out': out},
attrs={'factor': float(1. / porder)})
return out
def svd_norm(input, porder, axis=[-1]):
"""
NOTE:
Calculate the matrix norm, which is related to singular values, of a matrix
or batches of matrices, including nuclear norm, 2-norm and (-2)-norm.
"""
reduce_all = True if axis is None or axis == [] else False
keepdim = False
u, s, vh = svd(input, full_matrices=False)
if in_dygraph_mode():
if porder == "nuc":
return _C_ops.reduce_sum(s, 'dim', axis, 'keepdim', keepdim,
'reduce_all', reduce_all)
max_out = _C_ops.reduce_max(s, 'dim', axis, 'keepdim', keepdim,
'reduce_all', reduce_all)
min_out = _C_ops.reduce_min(s, 'dim', axis, 'keepdim', keepdim,
'reduce_all', reduce_all)
if porder == 2:
return _C_ops.elementwise_div(max_out, min_out, 'aixs', axis,
'use_mkldnn', False)
if porder == -2:
return _C_ops.elementwise_div(min_out, max_out, 'aixs', axis,
'use_mkldnn', False)
block = LayerHelper('norm', **locals())
out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
if porder == "nuc":
block.append_op(
type='reduce_sum',
inputs={'X': s},
outputs={'Out': out},
attrs={
'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all
})
return out
max_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
min_out = block.create_variable_for_type_inference(
dtype=block.input_dtype())
block.append_op(
type='reduce_max',
inputs={'X': s},
outputs={'Out': max_out},
attrs={'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all})
block.append_op(
type='reduce_min',
inputs={'X': s},
outputs={'Out': min_out},
attrs={'dim': axis,
'keep_dim': keepdim,
'reduce_all': reduce_all})
if porder == 2:
block.append_op(
type='elementwise_div',
inputs={'X': max_out,
'Y': min_out},
outputs={'Out': out},
attrs={'aixs': axis,
'use_mkldnn': False})
return out
if porder == -2:
block.append_op(
type='elementwise_div',
inputs={'X': min_out,
'Y': max_out},
outputs={'Out': out},
attrs={'aixs': axis,
'use_mkldnn': False})
return out
def empty_tensor(input, shape):
if in_dygraph_mode():
return input.reshape(shape)
raise ValueError("only support x is nonempty tensor in static mode")
x_shape = list(x.shape)
if not len(x_shape) >= 2:
raise ValueError("input should be a matrix or batches of matrices, " +
"but the dimention of received input is {}".format(
len(x_shape)))
if p == None:
p = 2
x_size = 0 if (0 in x_shape) else 1
if p in ("fro", "nuc", 1, -1, np.inf, -np.inf):
if x_shape[len(x_shape) - 1] == x_shape[len(x_shape) - 2]:
if x_size == 0:
return empty_tensor(x, x_shape[:-2])
x_inv = x.inverse()
if p == "fro":
return fro_norm(x) * fro_norm(x_inv)
if p == "nuc":
return svd_norm(x, p) * svd_norm(x_inv, p)
if p in (1, -1):
return mat_norm(
x, porder=p, axis=[-2]) * mat_norm(
x_inv, porder=p, axis=[-2])
if p in (np.inf, -np.inf):
return mat_norm(
x, porder=p, axis=[-1]) * mat_norm(
x_inv, porder=p, axis=[-1])
else:
raise ValueError("only support p is {} when input is a ".format(p) +
"square matrix or batches of square matrices")
elif p in (2, -2):
if x_size == 0:
return empty_tensor(x, x_shape[:-2])
return svd_norm(x, porder=p)
else:
raise ValueError(
"unsupported {} for p, only supporting ('fro', 'nuc', ".format(
p) + "1, -1, 2, -2, inf, -inf) or none")
def dot(x, y, name=None):
"""
This operator calculates inner product for vectors.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册