From 9f9ed3ae32682dd763b4c3fe652c6d197a735fd2 Mon Sep 17 00:00:00 2001 From: huangjun12 <2399845970@qq.com> Date: Wed, 27 Oct 2021 11:03:16 +0800 Subject: [PATCH] add paddle.linalg.eigvalsh API (#35615) * add eigvalsh with is_test * add eigvalsh op * fix backward bug * forward and backward, float and complex, unittest * remove eigvalsh_helper.h * remove changes of cusolver.h * fix unittest * fix unittest bug * update code following eigh * fix test * update lapack * pull develop * update funcor * fix unittest bug * fix details * add tensor_method_func * fix notes --- cmake/operators.cmake | 1 + paddle/fluid/operators/eigvalsh_op.cc | 163 +++++++++++++++ paddle/fluid/operators/eigvalsh_op.cu | 36 ++++ paddle/fluid/operators/eigvalsh_op.h | 79 +++++++ python/paddle/__init__.py | 1 + .../fluid/tests/unittests/test_eigvalsh_op.py | 192 ++++++++++++++++++ .../white_list/no_check_set_white_list.py | 1 + python/paddle/linalg.py | 2 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/linalg.py | 69 ++++++- 10 files changed, 545 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/eigvalsh_op.cc create mode 100644 paddle/fluid/operators/eigvalsh_op.cu create mode 100644 paddle/fluid/operators/eigvalsh_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_eigvalsh_op.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 7830cf7b50..a537719cc7 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -185,6 +185,7 @@ function(op_library TARGET) list(REMOVE_ITEM hip_srcs "cholesky_op.cu") list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu") list(REMOVE_ITEM hip_srcs "svd_op.cu") + list(REMOVE_ITEM hip_srcs "eigvalsh_op.cu") list(REMOVE_ITEM hip_srcs "qr_op.cu") list(REMOVE_ITEM hip_srcs "eigh_op.cu") list(REMOVE_ITEM hip_srcs "multinomial_op.cu") diff --git a/paddle/fluid/operators/eigvalsh_op.cc b/paddle/fluid/operators/eigvalsh_op.cc new file mode 100644 index 0000000000..fd5893df0c --- /dev/null +++ b/paddle/fluid/operators/eigvalsh_op.cc @@ -0,0 +1,163 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/eigvalsh_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class EigvalshOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigvalsh"); + OP_INOUT_CHECK(ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues", + "Eigvalsh"); + + auto input_dim = ctx->GetInputDim("X"); + auto rank = input_dim.size(); + + PADDLE_ENFORCE_GE(rank, 2, + platform::errors::InvalidArgument( + "The Input(X) should have at least 2 dimensions." + "But received a %d dimension tensor.", + rank)); + PADDLE_ENFORCE_EQ( + input_dim[rank - 2], input_dim[rank - 1], + platform::errors::InvalidArgument( + "Eigvalsh op is designed for square matrix, consequently" + "inner-most 2 dimensions of Input(X) should be symmetric." + "But received X's shape[-2] = %d and shape[-1] = %d.", + input_dim[rank - 2], input_dim[rank - 1])); + + std::vector values_dim; + + for (auto i = 0; i < rank - 1; i++) { + values_dim.emplace_back(input_dim[i]); + } + + ctx->SetOutputDim("Eigenvalues", framework::make_ddim(values_dim)); + + if (ctx->HasOutput("Eigenvectors")) { + ctx->SetOutputDim("Eigenvectors", input_dim); + } + } +}; + +class EigvalshOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), Hermitian or real symmetric matrices." + "Its shape should be [*, N, N] where * is zero or" + "more batch dimensions. The data type is float32 ," + "float64, complex64, complex128."); + AddOutput("Eigenvalues", + "(Tensor), The eigenvalues in ascending order." + "The data type is float32 or float64."); + AddOutput( + "Eigenvectors", + "(Tensor), The column is the normalized eigenvector " + "corresponding to the eigenvalue. The data type is the same as ``X``." + "Eigenvectors are required to calculate gradient when backward."); + AddAttr( + "UPLO", + "(string, default 'L'), 'L' represents the lower triangular matrix," + "'U' represents the upper triangular matrix.") + .SetDefault("L"); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training.") + .SetDefault(false); + AddComment(R"DOC( +Eigvalsh Operator. + +Computes the eigenvalues of a complex Hermitian + (conjugate symmetric) or a real symmetric matrix. + +)DOC"); + } +}; + +class EigvalshGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", + "EigvalshGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvalues")), + "Input", "Eigenvalues@GRAD", "EigvalshGrad"); + auto dims = ctx->GetInputDim("Eigenvectors"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Eigenvectors"), + ctx.device_context()); + } +}; + +template +class EigvalshGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType(this->ForwardOpType() + "_grad"); + op->SetInput("Eigenvectors", this->Output("Eigenvectors")); + op->SetInput(framework::GradVarName("Eigenvalues"), + this->OutputGrad("Eigenvalues")); + op->SetAttrMap(this->Attrs()); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(eigvalsh, ops::EigvalshOp, ops::EigvalshOpMaker, + ops::EigvalshGradOpMaker, + ops::EigvalshGradOpMaker); +REGISTER_OPERATOR(eigvalsh_grad, ops::EigvalshGradOp); + +REGISTER_OP_CPU_KERNEL( + eigvalsh, + ops::EigvalshKernel, + ops::EigvalshKernel, + ops::EigvalshKernel>, + ops::EigvalshKernel>); + +REGISTER_OP_CPU_KERNEL( + eigvalsh_grad, + ops::EigvalshGradKernel, + ops::EigvalshGradKernel, + ops::EigvalshGradKernel>, + ops::EigvalshGradKernel>); diff --git a/paddle/fluid/operators/eigvalsh_op.cu b/paddle/fluid/operators/eigvalsh_op.cu new file mode 100644 index 0000000000..a623307857 --- /dev/null +++ b/paddle/fluid/operators/eigvalsh_op.cu @@ -0,0 +1,36 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/eigvalsh_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + eigvalsh, + ops::EigvalshKernel, + ops::EigvalshKernel, + ops::EigvalshKernel>, + ops::EigvalshKernel>); + +REGISTER_OP_CUDA_KERNEL( + eigvalsh_grad, + ops::EigvalshGradKernel, + ops::EigvalshGradKernel, + ops::EigvalshGradKernel>, + ops::EigvalshGradKernel>); diff --git a/paddle/fluid/operators/eigvalsh_op.h b/paddle/fluid/operators/eigvalsh_op.h new file mode 100644 index 0000000000..6c40ce107a --- /dev/null +++ b/paddle/fluid/operators/eigvalsh_op.h @@ -0,0 +1,79 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/eigen_values_vectors.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenVector = framework::EigenVector; + +template +class EigvalshKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto input = ctx.Input("X"); + auto output_w = ctx.Output("Eigenvalues"); + + std::string lower = ctx.Attr("UPLO"); + bool is_lower = (lower == "L"); + bool is_test = ctx.Attr("is_test"); + math::MatrixEighFunctor functor; + if (is_test) { + functor(ctx, *input, output_w, nullptr, is_lower, false); + } else { + auto output_v = ctx.Output("Eigenvectors"); + functor(ctx, *input, output_w, output_v, is_lower, true); + } + } +}; + +template +class EigvalshGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& x_grad = *ctx.Output(framework::GradVarName("X")); + auto& output_v = *ctx.Input("Eigenvectors"); + auto& output_w_grad = + *ctx.Input(framework::GradVarName("Eigenvalues")); + + auto dito = + math::DeviceIndependenceTensorOperations( + ctx); + auto tV = dito.Transpose(dito.Conj(output_v)); + + // compute elementwise multiply of output_v and output_w_grad + x_grad.mutable_data(output_v.dims(), ctx.GetPlace()); + auto output_v_vector = EigenVector::Flatten(output_v); + auto output_w_grad_vector = EigenVector::Flatten(output_w_grad); + auto result_vector = EigenVector::Flatten(x_grad); + auto& place = *ctx.template device_context().eigen_device(); + std::vector broadcast_factor; + broadcast_factor.push_back(output_v.dims().at(output_v.dims().size() - 1)); + result_vector.device(place) = + output_v_vector * output_w_grad_vector.broadcast(broadcast_factor); + + x_grad = dito.Matmul(x_grad, tV); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 29548a64f3..351b6ecb9f 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -101,6 +101,7 @@ from .tensor.linalg import histogram # noqa: F401 from .tensor.linalg import bincount # noqa: F401 from .tensor.linalg import mv # noqa: F401 from .tensor.logic import equal # noqa: F401 +from .tensor.linalg import eigvalsh # noqa: F401 from .tensor.logic import greater_equal # noqa: F401 from .tensor.logic import greater_than # noqa: F401 from .tensor.logic import is_empty # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_eigvalsh_op.py b/python/paddle/fluid/tests/unittests/test_eigvalsh_op.py new file mode 100644 index 0000000000..db02372267 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eigvalsh_op.py @@ -0,0 +1,192 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +from op_test import OpTest +from gradient_checker import grad_check + + +class TestEigvalshOp(OpTest): + def setUp(self): + paddle.enable_static() + self.op_type = "eigvalsh" + self.init_input() + self.init_config() + np.random.seed(123) + out_w, out_v = np.linalg.eigh(self.x_np, self.UPLO) + self.inputs = {"X": self.x_np} + self.attrs = {"UPLO": self.UPLO, "is_test": False} + self.outputs = {'Eigenvalues': out_w, 'Eigenvectors': out_v} + + def init_config(self): + self.UPLO = 'L' + + def init_input(self): + self.x_shape = (10, 10) + self.x_type = np.float64 + self.x_np = np.random.random(self.x_shape).astype(self.x_type) + + def test_check_output(self): + # Vectors in posetive or negative is equivalent + self.check_output(no_check_set=['Eigenvectors']) + + def test_grad(self): + self.check_grad(["X"], ["Eigenvalues"]) + + +class TestEigvalshUPLOCase(TestEigvalshOp): + def init_config(self): + self.UPLO = 'U' + + +class TestEigvalshGPUCase(unittest.TestCase): + def setUp(self): + self.x_shape = [32, 32] + self.dtype = "float32" + np.random.seed(123) + self.x_np = np.random.random(self.x_shape).astype(self.dtype) + self.rtol = 1e-5 + self.atol = 1e-5 + + def test_check_output_gpu(self): + if paddle.is_compiled_with_cuda(): + paddle.disable_static(place=paddle.CUDAPlace(0)) + input_real_data = paddle.to_tensor(self.x_np) + expected_w = np.linalg.eigvalsh(self.x_np) + actual_w = paddle.linalg.eigvalsh(input_real_data) + np.testing.assert_allclose( + actual_w, expected_w, rtol=self.rtol, atol=self.atol) + + +class TestEigvalshAPI(unittest.TestCase): + def setUp(self): + self.init_input_shape() + self.dtype = "float32" + self.UPLO = 'L' + self.rtol = 1e-6 + self.atol = 1e-6 + self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ + else paddle.CPUPlace() + np.random.seed(123) + self.real_data = np.random.random(self.x_shape).astype(self.dtype) + self.complex_data = np.random.random(self.x_shape).astype( + self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype) + self.trans_dims = list(range(len(self.x_shape) - 2)) + [ + len(self.x_shape) - 1, len(self.x_shape) - 2 + ] + + def init_input_shape(self): + self.x_shape = [5, 5] + + def compare_result(self, actual_w, expected_w): + np.testing.assert_allclose( + actual_w, expected_w, rtol=self.rtol, atol=self.atol) + + def check_static_float_result(self): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + input_x = paddle.static.data( + 'input_x', shape=self.x_shape, dtype=self.dtype) + output_w = paddle.linalg.eigvalsh(input_x) + exe = paddle.static.Executor(self.place) + expected_w = exe.run(main_prog, + feed={"input_x": self.real_data}, + fetch_list=[output_w]) + + actual_w = np.linalg.eigvalsh(self.real_data) + self.compare_result(actual_w, expected_w[0]) + + def check_static_complex_result(self): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + x_dtype = np.complex64 if self.dtype == "float32" else np.complex128 + input_x = paddle.static.data( + 'input_x', shape=self.x_shape, dtype=x_dtype) + output_w = paddle.linalg.eigvalsh(input_x) + exe = paddle.static.Executor(self.place) + expected_w = exe.run(main_prog, + feed={"input_x": self.complex_data}, + fetch_list=[output_w]) + actual_w = np.linalg.eigvalsh(self.complex_data) + self.compare_result(actual_w, expected_w[0]) + + def test_in_static_mode(self): + paddle.enable_static() + self.check_static_float_result() + self.check_static_complex_result() + + def test_in_dynamic_mode(self): + paddle.disable_static(self.place) + input_real_data = paddle.to_tensor(self.real_data) + expected_w = np.linalg.eigvalsh(self.real_data) + actual_w = paddle.linalg.eigvalsh(input_real_data) + self.compare_result(actual_w, expected_w) + + input_complex_data = paddle.to_tensor(self.complex_data) + expected_w = np.linalg.eigvalsh(self.complex_data) + actual_w = paddle.linalg.eigvalsh(input_complex_data) + self.compare_result(actual_w, expected_w) + + def test_eigvalsh_grad(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.complex_data, stop_gradient=False) + w = paddle.linalg.eigvalsh(x) + (w.sum()).backward() + np.testing.assert_allclose( + abs(x.grad.numpy()), + abs(x.grad.numpy().conj().transpose(self.trans_dims)), + rtol=self.rtol, + atol=self.atol) + + +class TestEigvalshBatchAPI(TestEigvalshAPI): + def init_input_shape(self): + self.x_shape = [2, 5, 5] + + +class TestEigvalshAPIError(unittest.TestCase): + def test_error(self): + main_prog = paddle.static.Program() + startup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, startup_prog): + #input maxtrix must greater than 2 dimensions + input_x = paddle.static.data( + name='x_1', shape=[12], dtype='float32') + self.assertRaises(ValueError, paddle.linalg.eigvalsh, input_x) + + #input matrix must be square matrix + input_x = paddle.static.data( + name='x_2', shape=[12, 32], dtype='float32') + self.assertRaises(ValueError, paddle.linalg.eigvalsh, input_x) + + #uplo must be in 'L' or 'U' + input_x = paddle.static.data( + name='x_3', shape=[4, 4], dtype="float32") + uplo = 'R' + self.assertRaises(ValueError, paddle.linalg.eigvalsh, input_x, uplo) + + #x_data cannot be integer + input_x = paddle.static.data( + name='x_4', shape=[4, 4], dtype="int32") + self.assertRaises(TypeError, paddle.linalg.eigvalsh, input_x) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index fd87e7584c..23bbc377ca 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -33,5 +33,6 @@ no_check_set_white_list = [ 'softmax_with_cross_entropy', 'svd', 'eigh', + 'eigvalsh', 'class_center_sample', ] diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 06b512150c..b58ccab6cb 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -23,6 +23,7 @@ from .tensor.linalg import eigvals # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_rank from .tensor.linalg import svd +from .tensor.linalg import eigvalsh from .tensor.linalg import qr from .tensor.linalg import eigh # noqa: F401 from .tensor.linalg import det @@ -44,6 +45,7 @@ __all__ = [ 'det', 'slogdet', 'eigh', + 'eigvalsh', 'pinv', 'solve' ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 04d0a3c745..69154378a7 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -52,6 +52,7 @@ from .linalg import qr # noqa: F401 from .linalg import eigvals # noqa: F401 from .linalg import multi_dot # noqa: F401 from .linalg import svd # noqa: F401 +from .linalg import eigvalsh # noqa: F401 from .linalg import eigh # noqa: F401 from .linalg import pinv # noqa: F401 from .linalg import solve # noqa: F401 @@ -240,6 +241,7 @@ tensor_method_func = [ #noqa 'matrix_power', 'qr', 'eigvals', + 'eigvalsh', 'abs', 'acos', 'all', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index aea56432fa..227769e98a 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -14,8 +14,8 @@ import numpy as np from ..fluid.layer_helper import LayerHelper +from ..fluid.framework import in_dygraph_mode, _varbase_creator, Variable, _dygraph_tracer from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype -from ..fluid.framework import in_dygraph_mode, _varbase_creator, Variable from ..fluid.layers import transpose, cast # noqa: F401 from ..fluid import layers @@ -2313,3 +2313,70 @@ def solve(x, y, name=None): type="solve", inputs={"X": x, "Y": y}, outputs={"Out": out}) return out + + +def eigvalsh(x, UPLO='L', name=None): + """ + Computes the eigenvalues of a + complex Hermitian (conjugate symmetric) or a real symmetric matrix. + + Args: + x (Tensor): A tensor with shape :math:`[_, M, M]` , The data type of the input Tensor x + should be one of float32, float64, complex64, complex128. + UPLO(str, optional): Lower triangular part of a (ā€˜Lā€™, default) or the upper triangular part (ā€˜Uā€™). + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The tensor eigenvalues in ascending order. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + x_data = np.array([[1, -2j], [2j, 5]]) + x = paddle.to_tensor(x_data) + out_value = paddle.eigvalsh(x, UPLO='L') + print(out_value) + #[0.17157288, 5.82842712] + """ + if in_dygraph_mode(): + is_test = x.stop_gradient + values, _ = _C_ops.eigvalsh(x, 'UPLO', UPLO, 'is_test', is_test) + return values + + def __check_input(x, UPLO): + x_shape = list(x.shape) + if len(x.shape) < 2: + raise ValueError( + "Input(input) only support >=2 tensor, but received " + "length of Input(input) is %s." % len(x.shape)) + if x_shape[-1] != x_shape[-2]: + raise ValueError( + "The input matrix must be batches of square matrices. But received x's dimention: {}". + format(x_shape)) + if UPLO is not 'L' and UPLO is not 'U': + raise ValueError( + "UPLO must be L or U. But received UPLO is: {}".format(UPLO)) + + __check_input(x, UPLO) + + helper = LayerHelper('eigvalsh', **locals()) + check_variable_and_dtype(x, 'dtype', + ['float32', 'float64', 'complex64', 'complex128'], + 'eigvalsh') + + out_value = helper.create_variable_for_type_inference(dtype=x.dtype) + out_vector = helper.create_variable_for_type_inference(dtype=x.dtype) + + is_test = x.stop_gradient + helper.append_op( + type='eigvalsh', + inputs={'X': x}, + outputs={'Eigenvalues': out_value, + 'Eigenvectors': out_vector}, + attrs={'UPLO': UPLO, + 'is_test': is_test}) + return out_value -- GitLab