test_eigh_op.py 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import paddle
from op_test import OpTest
from gradient_checker import grad_check
22 23
from paddle.fluid.framework import _enable_legacy_dygraph
_enable_legacy_dygraph()
24 25


26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
def valid_eigh_result(A, eigh_value, eigh_vector, uplo):
    assert A.ndim == 2 or A.ndim == 3

    if A.ndim == 2:
        valid_single_eigh_result(A, eigh_value, eigh_vector, uplo)
        return

    for batch_A, batch_w, batch_v in zip(A, eigh_value, eigh_vector):
        valid_single_eigh_result(batch_A, batch_w, batch_v, uplo)


def valid_single_eigh_result(A, eigh_value, eigh_vector, uplo):
    FP32_MAX_RELATIVE_ERR = 5e-5
    FP64_MAX_RELATIVE_ERR = 1e-14

    if A.dtype == np.single or A.dtype == np.csingle:
        rtol = FP32_MAX_RELATIVE_ERR
    else:
        rtol = FP64_MAX_RELATIVE_ERR

    M, N = A.shape

    triangular_func = np.tril if uplo == 'L' else np.triu

    if not np.iscomplexobj(A):
        # Reconstruct A by filling triangular part
        A = triangular_func(A) + triangular_func(A, -1).T
    else:
        # Reconstruct A to Hermitian matrix
        A = triangular_func(A) + np.matrix(triangular_func(A, -1)).H

    # Diagonal matrix of eigen value
    T = np.diag(eigh_value)

    # A = Q*T*Q'
    residual = A - (eigh_vector @T @np.linalg.inv(eigh_vector))

    # ||A - Q*T*Q'|| / (N*||A||) < rtol
    np.testing.assert_array_less(
        np.linalg.norm(residual, np.inf) / (N * np.linalg.norm(A, np.inf)),
        rtol)

    # ||I - Q*Q'|| / M < rtol
    residual = np.eye(M) - eigh_vector @np.linalg.inv(eigh_vector)
    np.testing.assert_array_less(np.linalg.norm(residual, np.inf) / M, rtol)


73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
class TestEighOp(OpTest):
    def setUp(self):
        paddle.enable_static()
        self.op_type = "eigh"
        self.init_input()
        self.init_config()
        np.random.seed(123)
        out_w, out_v = np.linalg.eigh(self.x_np, self.UPLO)
        self.inputs = {"X": self.x_np}
        self.attrs = {"UPLO": self.UPLO}
        self.outputs = {'Eigenvalues': out_w, "Eigenvectors": out_v}

    def init_config(self):
        self.UPLO = 'L'

    def init_input(self):
        self.x_shape = (10, 10)
        self.x_type = np.float64
        self.x_np = np.random.random(self.x_shape).astype(self.x_type)

    def test_check_output(self):
        self.check_output(no_check_set=['Eigenvectors'])

    def test_grad(self):
        self.check_grad(["X"], ["Eigenvalues"])


class TestEighUPLOCase(TestEighOp):
    def init_config(self):
        self.UPLO = 'U'


class TestEighGPUCase(unittest.TestCase):
    def setUp(self):
        self.x_shape = [32, 32]
        self.dtype = "float32"
109
        self.UPLO = "L"
110 111 112 113 114 115 116
        np.random.seed(123)
        self.x_np = np.random.random(self.x_shape).astype(self.dtype)

    def test_check_output_gpu(self):
        if paddle.is_compiled_with_cuda():
            paddle.disable_static(place=paddle.CUDAPlace(0))
            input_real_data = paddle.to_tensor(self.x_np)
117 118 119
            actual_w, actual_v = paddle.linalg.eigh(input_real_data, self.UPLO)
            valid_eigh_result(self.x_np,
                              actual_w.numpy(), actual_v.numpy(), self.UPLO)
120 121 122 123


class TestEighAPI(unittest.TestCase):
    def setUp(self):
C
crystal 已提交
124
        self.init_input_data()
125
        self.UPLO = 'L'
126 127
        self.rtol = 1e-5  # for test_eigh_grad
        self.atol = 1e-5  # for test_eigh_grad
128 129 130
        self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
            else paddle.CPUPlace()
        np.random.seed(123)
C
crystal 已提交
131

132
    def init_input_shape(self):
C
crystal 已提交
133
        self.x_shape = [5, 5]
134 135 136

    def init_input_data(self):
        self.init_input_shape()
C
crystal 已提交
137
        self.dtype = "float32"
138
        self.real_data = np.random.random(self.x_shape).astype(self.dtype)
C
crystal 已提交
139
        complex_data = np.random.random(self.x_shape).astype(
140 141 142 143
            self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype)
        self.trans_dims = list(range(len(self.x_shape) - 2)) + [
            len(self.x_shape) - 1, len(self.x_shape) - 2
        ]
C
crystal 已提交
144 145 146
        #build a random conjugate matrix
        self.complex_symm = np.divide(
            complex_data + np.conj(complex_data.transpose(self.trans_dims)), 2)
147 148 149 150 151 152 153 154 155

    def check_static_float_result(self):
        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()
        with paddle.static.program_guard(main_prog, startup_prog):
            input_x = paddle.static.data(
                'input_x', shape=self.x_shape, dtype=self.dtype)
            output_w, output_v = paddle.linalg.eigh(input_x)
            exe = paddle.static.Executor(self.place)
156 157 158 159
            actual_w, actual_v = exe.run(main_prog,
                                         feed={"input_x": self.real_data},
                                         fetch_list=[output_w, output_v])
            valid_eigh_result(self.real_data, actual_w, actual_v, self.UPLO)
160 161 162 163 164 165 166 167 168 169

    def check_static_complex_result(self):
        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()
        with paddle.static.program_guard(main_prog, startup_prog):
            x_dtype = np.complex64 if self.dtype == "float32" else np.complex128
            input_x = paddle.static.data(
                'input_x', shape=self.x_shape, dtype=x_dtype)
            output_w, output_v = paddle.linalg.eigh(input_x)
            exe = paddle.static.Executor(self.place)
170 171 172 173
            actual_w, actual_v = exe.run(main_prog,
                                         feed={"input_x": self.complex_symm},
                                         fetch_list=[output_w, output_v])
            valid_eigh_result(self.complex_symm, actual_w, actual_v, self.UPLO)
174 175 176 177 178 179 180

    def test_in_static_mode(self):
        paddle.enable_static()
        self.check_static_float_result()
        self.check_static_complex_result()

    def test_in_dynamic_mode(self):
181
        paddle.disable_static()
182 183
        input_real_data = paddle.to_tensor(self.real_data)
        actual_w, actual_v = paddle.linalg.eigh(input_real_data)
184 185
        valid_eigh_result(self.real_data,
                          actual_w.numpy(), actual_v.numpy(), self.UPLO)
186

C
crystal 已提交
187
        input_complex_data = paddle.to_tensor(self.complex_symm)
188
        actual_w, actual_v = paddle.linalg.eigh(input_complex_data)
189 190
        valid_eigh_result(self.complex_symm,
                          actual_w.numpy(), actual_v.numpy(), self.UPLO)
191 192

    def test_eigh_grad(self):
193
        paddle.disable_static()
C
crystal 已提交
194
        x = paddle.to_tensor(self.complex_symm, stop_gradient=False)
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
        w, v = paddle.linalg.eigh(x)
        (w.sum() + paddle.abs(v).sum()).backward()
        np.testing.assert_allclose(
            abs(x.grad.numpy()),
            abs(x.grad.numpy().conj().transpose(self.trans_dims)),
            rtol=self.rtol,
            atol=self.atol)


class TestEighBatchAPI(TestEighAPI):
    def init_input_shape(self):
        self.x_shape = [2, 5, 5]


class TestEighAPIError(unittest.TestCase):
    def test_error(self):
        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()
        with paddle.static.program_guard(main_prog, startup_prog):
            #input maxtrix must greater than 2 dimensions
            input_x = paddle.static.data(
                name='x_1', shape=[12], dtype='float32')
            self.assertRaises(ValueError, paddle.linalg.eigh, input_x)

            #input matrix must be square matrix
            input_x = paddle.static.data(
                name='x_2', shape=[12, 32], dtype='float32')
            self.assertRaises(ValueError, paddle.linalg.eigh, input_x)

            #uplo must be in 'L' or 'U'
            input_x = paddle.static.data(
                name='x_3', shape=[4, 4], dtype="float32")
            uplo = 'R'
            self.assertRaises(ValueError, paddle.linalg.eigh, input_x, uplo)

            #x_data cannot be integer
            input_x = paddle.static.data(
                name='x_4', shape=[4, 4], dtype="int32")
            self.assertRaises(TypeError, paddle.linalg.eigh, input_x)


if __name__ == "__main__":
    unittest.main()