未验证 提交 4ebb4764 编写于 作者: N NetPunk 提交者: GitHub

【Hackathon 4 No.9】Add pca_lowrank API to Paddle (#53743)

上级 7309f8ab
......@@ -29,6 +29,7 @@ from .tensor.linalg import matrix_power # noqa: F401
from .tensor.linalg import matrix_rank # noqa: F401
from .tensor.linalg import multi_dot # noqa: F401
from .tensor.linalg import norm # noqa: F401
from .tensor.linalg import pca_lowrank # noqa: F401
from .tensor.linalg import pinv # noqa: F401
from .tensor.linalg import qr # noqa: F401
from .tensor.linalg import slogdet # noqa: F401
......@@ -50,6 +51,7 @@ __all__ = [
'matrix_rank',
'svd',
'qr',
'pca_lowrank',
'lu',
'lu_unpack',
'matrix_power',
......
......@@ -28,6 +28,7 @@ from .unary import square
from .unary import log1p
from .unary import abs
from .unary import pow
from .unary import pca_lowrank
from .unary import cast
from .unary import neg
from .unary import coalesce
......@@ -69,6 +70,7 @@ __all__ = [
'log1p',
'abs',
'pow',
'pca_lowrank',
'cast',
'neg',
'deg2rad',
......
......@@ -14,6 +14,7 @@
import numpy as np
import paddle
from paddle import _C_ops, in_dynamic_mode
from paddle.common_ops_import import Variable
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
......@@ -920,3 +921,205 @@ def slice(x, axes, starts, ends, name=None):
type=op_type, inputs={'x': x}, outputs={'out': out}, attrs=attrs
)
return out
def pca_lowrank(x, q=None, center=True, niter=2, name=None):
r"""
Performs linear Principal Component Analysis (PCA) on a sparse matrix.
Let :math:`X` be the input matrix or a batch of input matrices, the output should satisfies:
.. math::
X = U * diag(S) * V^{T}
Args:
x (Tensor): The input tensor. Its shape should be `[N, M]`,
N and M can be arbitraty positive number.
The data type of x should be float32 or float64.
q (int, optional): a slightly overestimated rank of :math:`X`.
Default value is :math:`q=min(6,N,M)`.
center (bool, optional): if True, center the input tensor.
Default value is True.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
- Tensor U, is N x q matrix.
- Tensor S, is a vector with length q.
- Tensor V, is M x q matrix.
tuple (U, S, V): which is the nearly optimal approximation of a singular value decomposition of a centered matrix :math:`X`.
Examples:
.. code-block:: python
import paddle
format = "coo"
dense_x = paddle.randn((5, 5), dtype='float64')
if format == "coo":
sparse_x = dense_x.to_sparse_coo(len(dense_x.shape))
else:
sparse_x = dense_x.to_sparse_csr()
print("sparse.pca_lowrank API only support CUDA 11.x")
U, S, V = None, None, None
# use code blow when your device CUDA version >= 11.0
# U, S, V = paddle.sparse.pca_lowrank(sparse_x)
print(U)
# Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [[ 0.02206024, 0.53170082, -0.22392168, -0.48450657, 0.65720625],
# [ 0.02206024, 0.53170082, -0.22392168, -0.32690402, -0.74819812],
# [ 0.02206024, 0.53170082, -0.22392168, 0.81141059, 0.09099187],
# [ 0.15045792, 0.37840027, 0.91333217, -0.00000000, 0.00000000],
# [ 0.98787775, -0.09325209, -0.12410317, -0.00000000, -0.00000000]])
print(S)
# Tensor(shape=[5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [2.28621761, 0.93618564, 0.53234942, 0.00000000, 0.00000000])
print(V)
# Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [[ 0.26828910, -0.57116436, -0.26548201, 0.67342660, -0.27894114],
# [-0.19592125, -0.31629129, 0.02001645, -0.50484498, -0.77865626],
# [-0.82913017, -0.09391036, 0.37975388, 0.39938099, -0.00241046],
# [-0.41163516, 0.27490410, -0.86666276, 0.03382656, -0.05230341],
# [ 0.18092947, 0.69952818, 0.18385126, 0.36190987, -0.55959343]])
"""
def get_floating_dtype(x):
dtype = x.dtype
if dtype in (paddle.float16, paddle.float32, paddle.float64):
return dtype
return paddle.float32
def conjugate(x):
if x.is_complex():
return x.conj()
return x
def transpose(x):
shape = x.shape
perm = list(range(0, len(shape)))
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
if x.is_sparse():
return paddle.sparse.transpose(x, perm)
return paddle.transpose(x, perm)
def transjugate(x):
return conjugate(transpose(x))
def get_approximate_basis(x, q, niter=2, M=None):
niter = 2 if niter is None else niter
m, n = x.shape[-2:]
qr = paddle.linalg.qr
R = paddle.randn((n, q), dtype=x.dtype)
A_t = transpose(x)
A_H = conjugate(A_t)
if M is None:
Q = qr(paddle.sparse.matmul(x, R))[0]
for i in range(niter):
Q = qr(paddle.sparse.matmul(A_H, Q))[0]
Q = qr(paddle.sparse.matmul(x, Q))[0]
else:
M_H = transjugate(M)
Q = qr(paddle.sparse.matmul(x, R) - paddle.matmul(M, R))[0]
for i in range(niter):
Q = qr(paddle.sparse.matmul(A_H, Q) - paddle.matmul(M_H, Q))[0]
Q = qr(paddle.sparse.matmul(x, Q) - paddle.matmul(M, Q))[0]
return Q
def svd_lowrank(x, q=6, niter=2, M=None):
q = 6 if q is None else q
m, n = x.shape[-2:]
if M is None:
M_t = None
else:
M_t = transpose(M)
A_t = transpose(x)
if m < n or n > q:
Q = get_approximate_basis(A_t, q, niter=niter, M=M_t)
Q_c = conjugate(Q)
if M is None:
B_t = paddle.sparse.matmul(x, Q_c)
else:
B_t = paddle.sparse.matmul(x, Q_c) - paddle.matmul(M, Q_c)
assert B_t.shape[-2] == m, (B_t.shape, m)
assert B_t.shape[-1] == q, (B_t.shape, q)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = transjugate(Vh)
V = Q.matmul(V)
else:
Q = get_approximate_basis(x, q, niter=niter, M=M)
Q_c = conjugate(Q)
if M is None:
B = paddle.sparse.matmul(A_t, Q_c)
else:
B = paddle.sparse.matmul(A_t, Q_c) - paddle.matmul(M_t, Q_c)
B_t = transpose(B)
assert B_t.shape[-2] == q, (B_t.shape, q)
assert B_t.shape[-1] == n, (B_t.shape, n)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = transjugate(Vh)
U = Q.matmul(U)
return U, S, V
if not paddle.is_tensor(x):
raise ValueError(f'Input must be tensor, but got {type(x)}')
if not x.is_sparse():
raise ValueError('Input must be sparse, but got dense')
cuda_version = paddle.version.cuda()
if (
cuda_version is None
or cuda_version == 'False'
or int(cuda_version.split('.')[0]) < 11
):
raise ValueError('sparse.pca_lowrank API only support CUDA 11.x')
(m, n) = x.shape[-2:]
if q is None:
q = min(6, m, n)
elif not (q >= 0 and q <= min(m, n)):
raise ValueError(
'q(={}) must be non-negative integer'
' and not greater than min(m, n)={}'.format(q, min(m, n))
)
if not (niter >= 0):
raise ValueError(f'niter(={niter}) must be non-negative integer')
dtype = get_floating_dtype(x)
if not center:
return svd_lowrank(x, q, niter=niter, M=None)
if len(x.shape) != 2:
raise ValueError('input is expected to be 2-dimensional tensor')
# TODO: complement sparse_csr_tensor test
# when sparse.sum with axis(-2) is implemented
s_sum = paddle.sparse.sum(x, axis=-2)
s_val = s_sum.values() / m
c = paddle.sparse.sparse_coo_tensor(
s_sum.indices(), s_val, dtype=s_sum.dtype, place=s_sum.place
)
column_indices = c.indices()[0]
indices = paddle.zeros((2, len(column_indices)), dtype=column_indices.dtype)
indices[0] = column_indices
C_t = paddle.sparse.sparse_coo_tensor(
indices, c.values(), (n, 1), dtype=dtype, place=x.place
)
ones_m1_t = paddle.ones(x.shape[:-2] + [1, m], dtype=dtype)
M = transpose(paddle.matmul(C_t.to_dense(), ones_m1_t))
return svd_lowrank(x, q, niter=niter, M=M)
......@@ -46,6 +46,7 @@ from .linalg import dot # noqa: F401
from .linalg import cov # noqa: F401
from .linalg import corrcoef # noqa: F401
from .linalg import norm # noqa: F401
from .linalg import pca_lowrank # noqa: F401
from .linalg import cond # noqa: F401
from .linalg import transpose # noqa: F401
from .linalg import lstsq # noqa: F401
......@@ -333,6 +334,7 @@ tensor_method_func = [ # noqa
'mv',
'matrix_power',
'qr',
'pca_lowrank',
'eigvals',
'eigvalsh',
'abs',
......
......@@ -1963,6 +1963,159 @@ def svd(x, full_matrices=False, name=None):
return u, s, vh
def pca_lowrank(x, q=None, center=True, niter=2, name=None):
r"""
Performs linear Principal Component Analysis (PCA) on a low-rank matrix or batches of such matrices.
Let :math:`X` be the input matrix or a batch of input matrices, the output should satisfies:
.. math::
X = U * diag(S) * V^{T}
Args:
x (Tensor): The input tensor. Its shape should be `[..., N, M]`,
where `...` is zero or more batch dimensions. N and M can be arbitraty
positive number. The data type of x should be float32 or float64.
q (int, optional): a slightly overestimated rank of :math:`X`.
Default value is :math:`q=min(6,N,M)`.
center (bool, optional): if True, center the input tensor.
Default value is True.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
- Tensor U, is N x q matrix.
- Tensor S, is a vector with length q.
- Tensor V, is M x q matrix.
tuple (U, S, V): which is the nearly optimal approximation of a singular value decomposition of a centered matrix :math:`X`.
Examples:
.. code-block:: python
import paddle
x = paddle.randn((5, 5), dtype='float64')
U, S, V = paddle.linalg.pca_lowrank(x)
print(U)
# Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [[ 0.41057070, 0.40364287, 0.59099574, -0.34529432, 0.44721360],
# [-0.30243321, 0.55670611, -0.15025419, 0.61321785, 0.44721360],
# [ 0.57427340, -0.15936327, -0.66414981, -0.06097905, 0.44721360],
# [-0.63897516, -0.09968973, -0.17298615, -0.59316819, 0.44721360],
# [-0.04343573, -0.70129598, 0.39639442, 0.38622370, 0.44721360]])
print(S)
# Tensor(shape=[5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [3.33724265, 2.57573259, 1.69479048, 0.68069312, 0.00000000])
print(V)
# Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [[ 0.09800724, -0.32627008, -0.23593953, 0.81840445, 0.39810690],
# [-0.60100303, 0.63741176, -0.01953663, 0.09023999, 0.47326173],
# [ 0.25073864, -0.21305240, -0.32662950, -0.54786156, 0.69634740],
# [ 0.33057205, 0.48282641, -0.75998527, 0.06744040, -0.27472705],
# [ 0.67604895, 0.45688227, 0.50959437, 0.13179682, 0.23908071]])
"""
def conjugate(x):
if x.is_complex():
return x.conj()
return x
def transpose(x):
shape = x.shape
perm = list(range(0, len(shape)))
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
return paddle.transpose(x, perm)
def transjugate(x):
return conjugate(transpose(x))
def get_approximate_basis(x, q, niter=2, M=None):
niter = 2 if niter is None else niter
m, n = x.shape[-2:]
qr = paddle.linalg.qr
R = paddle.randn((n, q), dtype=x.dtype)
A_t = transpose(x)
A_H = conjugate(A_t)
if M is None:
Q = qr(paddle.matmul(x, R))[0]
for i in range(niter):
Q = qr(paddle.matmul(A_H, Q))[0]
Q = qr(paddle.matmul(x, Q))[0]
else:
M_H = transjugate(M)
Q = qr(paddle.matmul(x, R) - paddle.matmul(M, R))[0]
for i in range(niter):
Q = qr(paddle.matmul(A_H, Q) - paddle.matmul(M_H, Q))[0]
Q = qr(paddle.matmul(x, Q) - paddle.matmul(M, Q))[0]
return Q
def svd_lowrank(x, q=6, niter=2, M=None):
q = 6 if q is None else q
m, n = x.shape[-2:]
if M is None:
M_t = None
else:
M_t = transpose(M)
A_t = transpose(x)
if m < n or n > q:
Q = get_approximate_basis(A_t, q, niter=niter, M=M_t)
Q_c = conjugate(Q)
if M is None:
B_t = paddle.matmul(x, Q_c)
else:
B_t = paddle.matmul(x, Q_c) - paddle.matmul(M, Q_c)
assert B_t.shape[-2] == m, (B_t.shape, m)
assert B_t.shape[-1] == q, (B_t.shape, q)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = transjugate(Vh)
V = Q.matmul(V)
else:
Q = get_approximate_basis(x, q, niter=niter, M=M)
Q_c = conjugate(Q)
if M is None:
B = paddle.matmul(A_t, Q_c)
else:
B = paddle.matmul(A_t, Q_c) - paddle.matmul(M_t, Q_c)
B_t = transpose(B)
assert B_t.shape[-2] == q, (B_t.shape, q)
assert B_t.shape[-1] == n, (B_t.shape, n)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = transjugate(Vh)
U = Q.matmul(U)
return U, S, V
if not paddle.is_tensor(x):
raise ValueError(f'Input must be tensor, but got {type(x)}')
(m, n) = x.shape[-2:]
if q is None:
q = min(6, m, n)
elif not (q >= 0 and q <= min(m, n)):
raise ValueError(
'q(={}) must be non-negative integer'
' and not greater than min(m, n)={}'.format(q, min(m, n))
)
if not (niter >= 0):
raise ValueError(f'niter(={niter}) must be non-negative integer')
if not center:
return svd_lowrank(x, q, niter=niter, M=None)
C = x.mean(axis=-2, keepdim=True)
return svd_lowrank(x - C, q, niter=niter, M=None)
def matrix_power(x, n, name=None):
r"""
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class TestPcaLowrankAPI(unittest.TestCase):
def transpose(self, x):
shape = x.shape
perm = list(range(0, len(shape)))
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
return paddle.transpose(x, perm)
def random_matrix(self, rows, columns, *batch_dims, **kwargs):
dtype = kwargs.get('dtype', paddle.float64)
x = paddle.randn(batch_dims + (rows, columns), dtype=dtype)
if x.numel() == 0:
return x
u, _, vh = paddle.linalg.svd(x, full_matrices=False)
k = min(rows, columns)
s = paddle.linspace(1 / (k + 1), 1, k, dtype=dtype)
return (u * s.unsqueeze(-2)) @ vh
def random_lowrank_matrix(self, rank, rows, columns, *batch_dims, **kwargs):
B = self.random_matrix(rows, rank, *batch_dims, **kwargs)
C = self.random_matrix(rank, columns, *batch_dims, **kwargs)
return B.matmul(C)
def run_subtest(
self, guess_rank, actual_rank, matrix_size, batches, pca, **options
):
if isinstance(matrix_size, int):
rows = columns = matrix_size
else:
rows, columns = matrix_size
a_input = self.random_lowrank_matrix(
actual_rank, rows, columns, *batches
)
a = a_input
u, s, v = pca(a_input, q=guess_rank, **options)
self.assertEqual(s.shape[-1], guess_rank)
self.assertEqual(u.shape[-2], rows)
self.assertEqual(u.shape[-1], guess_rank)
self.assertEqual(v.shape[-1], guess_rank)
self.assertEqual(v.shape[-2], columns)
A1 = u.matmul(paddle.nn.functional.diag_embed(s)).matmul(
self.transpose(v)
)
ones_m1 = paddle.ones(batches + (rows, 1), dtype=a.dtype)
c = a.sum(axis=-2) / rows
c = c.reshape(batches + (1, columns))
A2 = a - ones_m1.matmul(c)
np.testing.assert_allclose(A1.numpy(), A2.numpy(), atol=1e-5)
detect_rank = (s.abs() > 1e-5).sum(axis=-1)
left = actual_rank * paddle.ones(batches, dtype=paddle.int64)
if not left.shape:
np.testing.assert_allclose(int(left), int(detect_rank))
else:
np.testing.assert_allclose(left.numpy(), detect_rank.numpy())
S = paddle.linalg.svd(A2, full_matrices=False)[1]
left = s[..., :actual_rank]
right = S[..., :actual_rank]
np.testing.assert_allclose(left.numpy(), right.numpy())
def test_forward(self):
pca_lowrank = paddle.linalg.pca_lowrank
all_batches = [(), (1,), (3,), (2, 3)]
for actual_rank, size in [
(2, (17, 4)),
(2, (100, 4)),
(6, (100, 40)),
]:
for batches in all_batches:
for guess_rank in [
actual_rank,
actual_rank + 2,
actual_rank + 6,
]:
if guess_rank <= min(*size):
self.run_subtest(
guess_rank, actual_rank, size, batches, pca_lowrank
)
self.run_subtest(
guess_rank,
actual_rank,
size[::-1],
batches,
pca_lowrank,
)
x = np.random.randn(5, 5).astype('float64')
x = paddle.to_tensor(x)
q = None
U, S, V = pca_lowrank(x, q, center=False)
def test_errors(self):
pca_lowrank = paddle.linalg.pca_lowrank
x = np.random.randn(5, 5).astype('float64')
x = paddle.to_tensor(x)
def test_x_not_tensor():
U, S, V = pca_lowrank(x.numpy())
self.assertRaises(ValueError, test_x_not_tensor)
def test_q_range():
q = -1
U, S, V = pca_lowrank(x, q)
self.assertRaises(ValueError, test_q_range)
def test_niter_range():
n = -1
U, S, V = pca_lowrank(x, niter=n)
self.assertRaises(ValueError, test_niter_range)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import re
import unittest
import numpy as np
import paddle
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
class TestSparsePcaLowrankAPI(unittest.TestCase):
def transpose(self, x):
shape = x.shape
perm = list(range(0, len(shape)))
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
return paddle.transpose(x, perm)
def random_sparse_matrix(self, rows, columns, density=0.01, **kwargs):
dtype = kwargs.get('dtype', paddle.float64)
nonzero_elements = max(
min(rows, columns), int(rows * columns * density)
)
row_indices = [i % rows for i in range(nonzero_elements)]
column_indices = [i % columns for i in range(nonzero_elements)]
random.shuffle(column_indices)
indices = [row_indices, column_indices]
values = paddle.randn((nonzero_elements,), dtype=dtype)
values *= paddle.to_tensor(
[-float(i - j) ** 2 for i, j in zip(*indices)], dtype=dtype
).exp()
indices_tensor = paddle.to_tensor(indices)
x = paddle.sparse.sparse_coo_tensor(
indices_tensor, values, (rows, columns)
)
return paddle.sparse.coalesce(x)
def run_subtest(self, guess_rank, matrix_size, batches, pca, **options):
density = options.pop('density', 0.5)
if isinstance(matrix_size, int):
rows = columns = matrix_size
else:
rows, columns = matrix_size
a_input = self.random_sparse_matrix(rows, columns, density)
a = a_input.to_dense()
u, s, v = pca(a_input, q=guess_rank, **options)
self.assertEqual(s.shape[-1], guess_rank)
self.assertEqual(u.shape[-2], rows)
self.assertEqual(u.shape[-1], guess_rank)
self.assertEqual(v.shape[-1], guess_rank)
self.assertEqual(v.shape[-2], columns)
A1 = u.matmul(paddle.nn.functional.diag_embed(s)).matmul(
self.transpose(v)
)
ones_m1 = paddle.ones(batches + (rows, 1), dtype=a.dtype)
c = a.sum(axis=-2) / rows
c = c.reshape(batches + (1, columns))
A2 = a - ones_m1.matmul(c)
np.testing.assert_allclose(A1.numpy(), A2.numpy(), atol=1e-5)
@unittest.skipIf(
not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000,
"only support cuda>=11.0",
)
def test_sparse(self):
pca_lowrank = paddle.sparse.pca_lowrank
for guess_rank, size in [
(4, (17, 4)),
(4, (4, 17)),
(16, (17, 17)),
(21, (100, 40)),
]:
for density in [0.005, 0.01]:
self.run_subtest(
guess_rank, size, (), pca_lowrank, density=density
)
def test_errors(self):
pca_lowrank = paddle.sparse.pca_lowrank
x = np.random.randn(5, 5).astype('float64')
dense_x = paddle.to_tensor(x)
sparse_x = dense_x.to_sparse_coo(len(x.shape))
def test_x_not_tensor():
U, S, V = pca_lowrank(x)
self.assertRaises(ValueError, test_x_not_tensor)
def test_x_not_sparse():
U, S, V = pca_lowrank(sparse_x.to_dense())
self.assertRaises(ValueError, test_x_not_sparse)
def test_q_range():
q = -1
U, S, V = pca_lowrank(sparse_x, q)
self.assertRaises(ValueError, test_q_range)
def test_niter_range():
n = -1
U, S, V = pca_lowrank(sparse_x, niter=n)
self.assertRaises(ValueError, test_niter_range)
def test_x_wrong_shape():
x = np.random.randn(5, 5, 5).astype('float64')
dense_x = paddle.to_tensor(x)
sparse_x = dense_x.to_sparse_coo(len(x.shape))
U, S, V = pca_lowrank(sparse_x)
self.assertRaises(ValueError, test_x_wrong_shape)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册