未验证 提交 85f5d264 编写于 作者: Z zhiboniu 提交者: GitHub

add new API: paddle.cov (#38392)

上级 706d2c08
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
import six
import paddle
def numpy_cov(np_arr, rowvar=True, ddof=1, fweights=None, aweights=None):
return np.cov(np_arr,
rowvar=rowvar,
ddof=int(ddof),
fweights=fweights,
aweights=aweights)
class Cov_Test(unittest.TestCase):
def setUp(self):
self.shape = [20, 10]
self.weightshape = [10]
def test_tensor_cov_default(self):
typelist = ['float64']
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for idx, p in enumerate(places):
if idx == 0:
paddle.set_device('cpu')
else:
paddle.set_device('gpu')
for dtype in typelist:
np_arr = np.random.rand(*self.shape).astype(dtype)
tensor = paddle.to_tensor(np_arr, place=p)
cov = paddle.linalg.cov(tensor,
rowvar=True,
ddof=True,
fweights=None,
aweights=None)
np_cov = numpy_cov(
np_arr, rowvar=True, ddof=1, fweights=None, aweights=None)
self.assertTrue(np.allclose(np_cov, cov.numpy()))
def test_tensor_cov_rowvar(self):
typelist = ['float64']
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for idx, p in enumerate(places):
if idx == 0:
paddle.set_device('cpu')
else:
paddle.set_device('gpu')
for dtype in typelist:
np_arr = np.random.rand(*self.shape).astype(dtype)
tensor = paddle.to_tensor(np_arr, place=p)
cov = paddle.linalg.cov(tensor,
rowvar=False,
ddof=True,
fweights=None,
aweights=None)
np_cov = numpy_cov(
np_arr, rowvar=False, ddof=1, fweights=None, aweights=None)
self.assertTrue(np.allclose(np_cov, cov.numpy()))
def test_tensor_cov_ddof(self):
typelist = ['float64']
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for idx, p in enumerate(places):
if idx == 0:
paddle.set_device('cpu')
else:
paddle.set_device('gpu')
for dtype in typelist:
np_arr = np.random.rand(*self.shape).astype(dtype)
tensor = paddle.to_tensor(np_arr, place=p)
cov = paddle.linalg.cov(tensor,
rowvar=True,
ddof=False,
fweights=None,
aweights=None)
np_cov = numpy_cov(
np_arr, rowvar=True, ddof=0, fweights=None, aweights=None)
self.assertTrue(np.allclose(np_cov, cov.numpy()))
def test_tensor_cov_fweights(self):
typelist = ['float64']
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for idx, p in enumerate(places):
if idx == 0:
paddle.set_device('cpu')
else:
paddle.set_device('gpu')
for dtype in typelist:
np_arr = np.random.rand(*self.shape).astype(dtype)
np_fw = np.random.randint(
10, size=self.weightshape).astype('int32')
tensor = paddle.to_tensor(np_arr, place=p)
fweights = paddle.to_tensor(np_fw, place=p)
cov = paddle.linalg.cov(tensor,
rowvar=True,
ddof=True,
fweights=fweights,
aweights=None)
np_cov = numpy_cov(
np_arr, rowvar=True, ddof=1, fweights=np_fw, aweights=None)
self.assertTrue(np.allclose(np_cov, cov.numpy()))
def test_tensor_cov_aweights(self):
typelist = ['float64']
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for idx, p in enumerate(places):
if idx == 0:
paddle.set_device('cpu')
else:
paddle.set_device('gpu')
for dtype in typelist:
np_arr = np.random.rand(*self.shape).astype(dtype)
np_aw = np.random.randint(
10, size=self.weightshape).astype('int32')
tensor = paddle.to_tensor(np_arr, place=p)
aweights = paddle.to_tensor(np_aw, place=p)
cov = paddle.linalg.cov(tensor,
rowvar=True,
ddof=True,
fweights=None,
aweights=aweights)
np_cov = numpy_cov(
np_arr, rowvar=True, ddof=1, fweights=None, aweights=np_aw)
self.assertTrue(np.allclose(np_cov, cov.numpy()))
def test_tensor_cov_weights(self):
typelist = ['float64']
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for idx, p in enumerate(places):
if idx == 0:
paddle.set_device('cpu')
else:
paddle.set_device('gpu')
for dtype in typelist:
np_arr = np.random.rand(*self.shape).astype(dtype)
np_fw = np.random.randint(
10, size=self.weightshape).astype('int64')
np_aw = np.random.rand(*self.weightshape).astype('float64')
tensor = paddle.to_tensor(np_arr, place=p)
fweights = paddle.to_tensor(np_fw, place=p)
aweights = paddle.to_tensor(np_aw, place=p)
cov = paddle.linalg.cov(tensor,
rowvar=True,
ddof=True,
fweights=fweights,
aweights=aweights)
np_cov = numpy_cov(
np_arr, rowvar=True, ddof=1, fweights=np_fw, aweights=np_aw)
self.assertTrue(np.allclose(np_cov, cov.numpy()))
class Cov_Test2(Cov_Test):
def setUp(self):
self.shape = [10]
self.weightshape = [10]
# Input(x) only support N-D (1<=N<=2) tensor
class Cov_Test3(unittest.TestCase):
def setUp(self):
self.shape = [2, 5, 10]
self.fweightshape = [10]
self.aweightshape = [10]
self.fw_s = 1.
self.aw_s = 1.
def test_errors(self):
def test_err():
np_arr = np.random.rand(*self.shape).astype('float64')
np_fw = self.fw_s * np.random.rand(
*self.fweightshape).astype('int32')
np_aw = self.aw_s * np.random.rand(
*self.aweightshape).astype('float64')
tensor = paddle.to_tensor(np_arr)
fweights = paddle.to_tensor(np_fw)
aweights = paddle.to_tensor(np_aw)
cov = paddle.linalg.cov(tensor,
rowvar=True,
ddof=True,
fweights=fweights,
aweights=aweights)
self.assertRaises(ValueError, test_err)
#Input(fweights) only support N-D (N<=1) tensor
class Cov_Test4(Cov_Test3):
def setUp(self):
self.shape = [5, 10]
self.fweightshape = [2, 10]
self.aweightshape = [10]
self.fw_s = 1.
self.aw_s = 1.
#The number of Input(fweights) should equal to x's dim[1]
class Cov_Test5(Cov_Test3):
def setUp(self):
self.shape = [5, 10]
self.fweightshape = [5]
self.aweightshape = [10]
self.fw_s = 1.
self.aw_s = 1.
#The value of Input(fweights) cannot be negtive
class Cov_Test6(Cov_Test3):
def setUp(self):
self.shape = [5, 10]
self.fweightshape = [10]
self.aweightshape = [10]
self.fw_s = -1.
self.aw_s = 1.
#Input(aweights) only support N-D (N<=1) tensor
class Cov_Test7(Cov_Test3):
def setUp(self):
self.shape = [5, 10]
self.fweightshape = [10]
self.aweightshape = [2, 10]
self.fw_s = 1.
self.aw_s = 1.
#The number of Input(aweights) should equal to x's dim[1]
class Cov_Test8(Cov_Test3):
def setUp(self):
self.shape = [5, 10]
self.fweightshape = [10]
self.aweightshape = [5]
self.fw_s = 1.
self.aw_s = 1.
#The value of Input(aweights) cannot be negtive
class Cov_Test9(Cov_Test3):
def setUp(self):
self.shape = [5, 10]
self.fweightshape = [10]
self.aweightshape = [10]
self.fw_s = 1.
self.aw_s = -1.
if __name__ == '__main__':
unittest.main()
......@@ -15,6 +15,7 @@
from .tensor.linalg import cholesky # noqa: F401
from .tensor.linalg import norm # noqa: F401
from .tensor.linalg import eig # noqa: F401
from .tensor.linalg import cov # noqa: F401
from .tensor.linalg import cond # noqa: F401
from .tensor.linalg import matrix_power # noqa: F401
from .tensor.linalg import solve # noqa: F401
......@@ -36,6 +37,7 @@ __all__ = [
'cholesky', #noqa
'norm',
'cond',
'cov',
'inv',
'eig',
'eigvals',
......
......@@ -39,6 +39,7 @@ from .creation import empty_like # noqa: F401
from .creation import complex # noqa: F401
from .linalg import matmul # noqa: F401
from .linalg import dot # noqa: F401
from .linalg import cov # noqa: F401
from .linalg import norm # noqa: F401
from .linalg import cond # noqa: F401
from .linalg import transpose # noqa: F401
......@@ -263,6 +264,7 @@ from .einsum import einsum # noqa: F401
tensor_method_func = [ #noqa
'matmul',
'dot',
'cov',
'norm',
'cond',
'transpose',
......
......@@ -920,6 +920,119 @@ def dot(x, y, name=None):
return out
def cov(x, rowvar=True, ddof=True, fweights=None, aweights=None, name=None):
"""
Estimate the covariance matrix of the input variables, given data and weights.
A covariance matrix is a square matrix, indicate the covariance of each pair variables in the input matrix.
For example, for an N-dimensional samples X=[x1,x2,…xN]T, then the covariance matrix
element Cij is the covariance of xi and xj. The element Cii is the variance of xi itself.
Parameters:
x(Tensor): A N-D(N<=2) Tensor containing multiple variables and observations. By default, each row of x represents a variable. Also see rowvar below.
rowvar(Bool, optional): If rowvar is True (default), then each row represents a variable, with observations in the columns. Default: True
ddof(Bool, optional): If ddof=True will return the unbiased estimate, and ddof=False will return the simple average. Default: True
fweights(Tensor, optional): 1-D Tensor of integer frequency weights; The number of times each observation vector should be repeated. Default: None
aweights(Tensor, optional): 1-D Tensor of observation vector weights. How important of the observation vector, larger data means this element is more important. Default: None
name(str, optional): Name of the output. Default is None. It's used to print debug info for developers. Details: :ref:`api_guide_Name`
Returns:
Tensor: The covariance matrix Tensor of the variables.
Examples:
.. code-block:: python
import paddle
xt = paddle.rand((3,4))
paddle.linalg.cov(xt)
'''
Tensor(shape=[3, 3], dtype=float64, place=CUDAPlace(0), stop_gradient=True,
[[0.07918842, 0.06127326, 0.01493049],
[0.06127326, 0.06166256, 0.00302668],
[0.01493049, 0.00302668, 0.01632146]])
'''
"""
op_type = 'cov'
if len(x.shape) > 2 or len(x.shape) < 1:
raise ValueError(
"Input(x) only support N-D (1<=N<=2) tensor in cov, but received "
"length of Input(input) is %s." % len(x.shape))
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'cov')
nx = x
if len(x.shape) == 1:
nx = x.reshape((1, -1))
if not rowvar and nx.shape[0] != 1:
nx = nx.t()
w = None
observation_num = nx.shape[1]
if fweights is not None:
w = fweights.astype(nx.dtype)
if len(w.shape) > 1:
raise ValueError(
"Input(fweights) only support N-D (N<=1) tensor in cov, but received "
"shape of Input(input) is %s." % len(fweights.shape))
if fweights.shape[0] != observation_num:
raise ValueError(
"The number of Input(fweights) should equal to x's dim[1]: {}, but received "
"size of Input(fweights) is {}.".format(observation_num,
fweights.shape[0]))
if fweights.min() < 0:
raise ValueError(
"The value of Input(fweights) cannot be negtive, but received "
"min of Input(fweights) is {}.".format(fweights.min()))
if not paddle.all(fweights == paddle.round(fweights.astype('float64'))):
raise ValueError("Input(fweights) must be integer ")
if aweights is not None:
aw = aweights.astype(nx.dtype)
if len(aw.shape) > 1:
raise ValueError(
"Input(aweights) only support N-D (N<=1) tensor in cov, but received "
"length of Input(input) is %s." % len(aweights.shape))
check_variable_and_dtype(aweights, 'dtype', ['float32', 'float64'],
'cov')
if aweights.shape[0] != observation_num:
raise ValueError(
"The number of Input(aweights) should equal to x's dim[1]: {}, but received "
"size of Input(aweights) is {}.".format(observation_num,
aweights.shape[0]))
if aweights.min() < 0:
raise ValueError(
"The value of Input(aweights) cannot be negtive, but received "
"min of Input(aweights) is {}.".format(aweights.min()))
if w is not None:
w = w * aw
else:
w = aw
w_sum = paddle.to_tensor(observation_num, dtype=nx.dtype)
if fweights is not None or aweights is not None:
w_sum = w.sum()
if w_sum.item() == 0:
raise ValueError("The sum of weights is zero, can't be normalized.")
if w is not None:
nx_w = nx * w
avg = (nx_w).sum(axis=1) / w_sum
else:
avg = nx.sum(axis=1) / w_sum
nx_w = nx
if w is not None and aweights is not None and ddof == True:
norm_factor = w_sum - (w * aweights).sum() / w_sum
else:
norm_factor = w_sum - ddof
if norm_factor <= 0:
norm_factor = paddle.to_tensor(0, dtype=nx.dtype)
nx = nx - avg.unsqueeze(1)
xxt = paddle.mm(nx, nx_w.t().conj())
cov = paddle.divide(xxt, norm_factor).squeeze()
return cov
def t(input, name=None):
"""
Transpose <=2-D tensor.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册