test_eigh_op.py 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17 18 19
import numpy as np
from op_test import OpTest

20 21
import paddle

22

23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
def valid_eigh_result(A, eigh_value, eigh_vector, uplo):
    assert A.ndim == 2 or A.ndim == 3

    if A.ndim == 2:
        valid_single_eigh_result(A, eigh_value, eigh_vector, uplo)
        return

    for batch_A, batch_w, batch_v in zip(A, eigh_value, eigh_vector):
        valid_single_eigh_result(batch_A, batch_w, batch_v, uplo)


def valid_single_eigh_result(A, eigh_value, eigh_vector, uplo):
    FP32_MAX_RELATIVE_ERR = 5e-5
    FP64_MAX_RELATIVE_ERR = 1e-14

    if A.dtype == np.single or A.dtype == np.csingle:
        rtol = FP32_MAX_RELATIVE_ERR
    else:
        rtol = FP64_MAX_RELATIVE_ERR

    M, N = A.shape

    triangular_func = np.tril if uplo == 'L' else np.triu

    if not np.iscomplexobj(A):
        # Reconstruct A by filling triangular part
        A = triangular_func(A) + triangular_func(A, -1).T
    else:
        # Reconstruct A to Hermitian matrix
        A = triangular_func(A) + np.matrix(triangular_func(A, -1)).H

    # Diagonal matrix of eigen value
    T = np.diag(eigh_value)

    # A = Q*T*Q'
58
    residual = A - (eigh_vector @ T @ np.linalg.inv(eigh_vector))
59 60 61

    # ||A - Q*T*Q'|| / (N*||A||) < rtol
    np.testing.assert_array_less(
62 63
        np.linalg.norm(residual, np.inf) / (N * np.linalg.norm(A, np.inf)), rtol
    )
64 65

    # ||I - Q*Q'|| / M < rtol
66
    residual = np.eye(M) - eigh_vector @ np.linalg.inv(eigh_vector)
67 68 69
    np.testing.assert_array_less(np.linalg.norm(residual, np.inf) / M, rtol)


70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
class TestEighOp(OpTest):
    def setUp(self):
        paddle.enable_static()
        self.op_type = "eigh"
        self.init_input()
        self.init_config()
        np.random.seed(123)
        out_w, out_v = np.linalg.eigh(self.x_np, self.UPLO)
        self.inputs = {"X": self.x_np}
        self.attrs = {"UPLO": self.UPLO}
        self.outputs = {'Eigenvalues': out_w, "Eigenvectors": out_v}

    def init_config(self):
        self.UPLO = 'L'

    def init_input(self):
        self.x_shape = (10, 10)
        self.x_type = np.float64
        self.x_np = np.random.random(self.x_shape).astype(self.x_type)

    def test_check_output(self):
        self.check_output(no_check_set=['Eigenvectors'])

    def test_grad(self):
        self.check_grad(["X"], ["Eigenvalues"])


class TestEighUPLOCase(TestEighOp):
    def init_config(self):
        self.UPLO = 'U'


class TestEighGPUCase(unittest.TestCase):
    def setUp(self):
        self.x_shape = [32, 32]
        self.dtype = "float32"
106
        self.UPLO = "L"
107 108 109 110 111 112 113
        np.random.seed(123)
        self.x_np = np.random.random(self.x_shape).astype(self.dtype)

    def test_check_output_gpu(self):
        if paddle.is_compiled_with_cuda():
            paddle.disable_static(place=paddle.CUDAPlace(0))
            input_real_data = paddle.to_tensor(self.x_np)
114
            actual_w, actual_v = paddle.linalg.eigh(input_real_data, self.UPLO)
115 116 117
            valid_eigh_result(
                self.x_np, actual_w.numpy(), actual_v.numpy(), self.UPLO
            )
118 119 120 121


class TestEighAPI(unittest.TestCase):
    def setUp(self):
C
crystal 已提交
122
        self.init_input_data()
123
        self.UPLO = 'L'
124 125
        self.rtol = 1e-5  # for test_eigh_grad
        self.atol = 1e-5  # for test_eigh_grad
126 127 128
        self.place = (
            paddle.CUDAPlace(0)
            if paddle.is_compiled_with_cuda()
129
            else paddle.CPUPlace()
130
        )
131
        np.random.seed(123)
C
crystal 已提交
132

133
    def init_input_shape(self):
C
crystal 已提交
134
        self.x_shape = [5, 5]
135 136 137

    def init_input_data(self):
        self.init_input_shape()
C
crystal 已提交
138
        self.dtype = "float32"
139
        self.real_data = np.random.random(self.x_shape).astype(self.dtype)
C
crystal 已提交
140
        complex_data = np.random.random(self.x_shape).astype(
141 142
            self.dtype
        ) + 1j * np.random.random(self.x_shape).astype(self.dtype)
143
        self.trans_dims = list(range(len(self.x_shape) - 2)) + [
144 145
            len(self.x_shape) - 1,
            len(self.x_shape) - 2,
146
        ]
147
        # build a random conjugate matrix
C
crystal 已提交
148
        self.complex_symm = np.divide(
149 150
            complex_data + np.conj(complex_data.transpose(self.trans_dims)), 2
        )
151 152 153 154 155

    def check_static_float_result(self):
        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()
        with paddle.static.program_guard(main_prog, startup_prog):
156 157 158
            input_x = paddle.static.data(
                'input_x', shape=self.x_shape, dtype=self.dtype
            )
159 160
            output_w, output_v = paddle.linalg.eigh(input_x)
            exe = paddle.static.Executor(self.place)
161 162 163 164 165
            actual_w, actual_v = exe.run(
                main_prog,
                feed={"input_x": self.real_data},
                fetch_list=[output_w, output_v],
            )
166
            valid_eigh_result(self.real_data, actual_w, actual_v, self.UPLO)
167 168 169 170 171 172

    def check_static_complex_result(self):
        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()
        with paddle.static.program_guard(main_prog, startup_prog):
            x_dtype = np.complex64 if self.dtype == "float32" else np.complex128
173 174 175
            input_x = paddle.static.data(
                'input_x', shape=self.x_shape, dtype=x_dtype
            )
176 177
            output_w, output_v = paddle.linalg.eigh(input_x)
            exe = paddle.static.Executor(self.place)
178 179 180 181 182
            actual_w, actual_v = exe.run(
                main_prog,
                feed={"input_x": self.complex_symm},
                fetch_list=[output_w, output_v],
            )
183
            valid_eigh_result(self.complex_symm, actual_w, actual_v, self.UPLO)
184 185 186 187 188 189 190

    def test_in_static_mode(self):
        paddle.enable_static()
        self.check_static_float_result()
        self.check_static_complex_result()

    def test_in_dynamic_mode(self):
191
        paddle.disable_static()
192 193
        input_real_data = paddle.to_tensor(self.real_data)
        actual_w, actual_v = paddle.linalg.eigh(input_real_data)
194 195 196
        valid_eigh_result(
            self.real_data, actual_w.numpy(), actual_v.numpy(), self.UPLO
        )
197

C
crystal 已提交
198
        input_complex_data = paddle.to_tensor(self.complex_symm)
199
        actual_w, actual_v = paddle.linalg.eigh(input_complex_data)
200 201 202
        valid_eigh_result(
            self.complex_symm, actual_w.numpy(), actual_v.numpy(), self.UPLO
        )
203 204

    def test_eigh_grad(self):
205
        paddle.disable_static()
C
crystal 已提交
206
        x = paddle.to_tensor(self.complex_symm, stop_gradient=False)
207 208
        w, v = paddle.linalg.eigh(x)
        (w.sum() + paddle.abs(v).sum()).backward()
209 210 211 212 213 214
        np.testing.assert_allclose(
            abs(x.grad.numpy()),
            abs(x.grad.numpy().conj().transpose(self.trans_dims)),
            rtol=self.rtol,
            atol=self.atol,
        )
215 216 217 218 219 220 221 222 223 224 225 226


class TestEighBatchAPI(TestEighAPI):
    def init_input_shape(self):
        self.x_shape = [2, 5, 5]


class TestEighAPIError(unittest.TestCase):
    def test_error(self):
        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()
        with paddle.static.program_guard(main_prog, startup_prog):
227 228 229 230
            # input maxtrix must greater than 2 dimensions
            input_x = paddle.static.data(
                name='x_1', shape=[12], dtype='float32'
            )
231 232
            self.assertRaises(ValueError, paddle.linalg.eigh, input_x)

233 234 235 236
            # input matrix must be square matrix
            input_x = paddle.static.data(
                name='x_2', shape=[12, 32], dtype='float32'
            )
237 238
            self.assertRaises(ValueError, paddle.linalg.eigh, input_x)

239 240 241 242
            # uplo must be in 'L' or 'U'
            input_x = paddle.static.data(
                name='x_3', shape=[4, 4], dtype="float32"
            )
243 244 245
            uplo = 'R'
            self.assertRaises(ValueError, paddle.linalg.eigh, input_x, uplo)

246 247 248 249
            # x_data cannot be integer
            input_x = paddle.static.data(
                name='x_4', shape=[4, 4], dtype="int32"
            )
250 251 252 253 254
            self.assertRaises(TypeError, paddle.linalg.eigh, input_x)


if __name__ == "__main__":
    unittest.main()