test_eigh_op.py 8.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import paddle
from op_test import OpTest
from gradient_checker import grad_check


24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
def valid_eigh_result(A, eigh_value, eigh_vector, uplo):
    assert A.ndim == 2 or A.ndim == 3

    if A.ndim == 2:
        valid_single_eigh_result(A, eigh_value, eigh_vector, uplo)
        return

    for batch_A, batch_w, batch_v in zip(A, eigh_value, eigh_vector):
        valid_single_eigh_result(batch_A, batch_w, batch_v, uplo)


def valid_single_eigh_result(A, eigh_value, eigh_vector, uplo):
    FP32_MAX_RELATIVE_ERR = 5e-5
    FP64_MAX_RELATIVE_ERR = 1e-14

    if A.dtype == np.single or A.dtype == np.csingle:
        rtol = FP32_MAX_RELATIVE_ERR
    else:
        rtol = FP64_MAX_RELATIVE_ERR

    M, N = A.shape

    triangular_func = np.tril if uplo == 'L' else np.triu

    if not np.iscomplexobj(A):
        # Reconstruct A by filling triangular part
        A = triangular_func(A) + triangular_func(A, -1).T
    else:
        # Reconstruct A to Hermitian matrix
        A = triangular_func(A) + np.matrix(triangular_func(A, -1)).H

    # Diagonal matrix of eigen value
    T = np.diag(eigh_value)

    # A = Q*T*Q'
59
    residual = A - (eigh_vector @ T @ np.linalg.inv(eigh_vector))
60 61 62 63 64 65 66

    # ||A - Q*T*Q'|| / (N*||A||) < rtol
    np.testing.assert_array_less(
        np.linalg.norm(residual, np.inf) / (N * np.linalg.norm(A, np.inf)),
        rtol)

    # ||I - Q*Q'|| / M < rtol
67
    residual = np.eye(M) - eigh_vector @ np.linalg.inv(eigh_vector)
68 69 70
    np.testing.assert_array_less(np.linalg.norm(residual, np.inf) / M, rtol)


71
class TestEighOp(OpTest):
72

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
    def setUp(self):
        paddle.enable_static()
        self.op_type = "eigh"
        self.init_input()
        self.init_config()
        np.random.seed(123)
        out_w, out_v = np.linalg.eigh(self.x_np, self.UPLO)
        self.inputs = {"X": self.x_np}
        self.attrs = {"UPLO": self.UPLO}
        self.outputs = {'Eigenvalues': out_w, "Eigenvectors": out_v}

    def init_config(self):
        self.UPLO = 'L'

    def init_input(self):
        self.x_shape = (10, 10)
        self.x_type = np.float64
        self.x_np = np.random.random(self.x_shape).astype(self.x_type)

    def test_check_output(self):
        self.check_output(no_check_set=['Eigenvectors'])

    def test_grad(self):
        self.check_grad(["X"], ["Eigenvalues"])


class TestEighUPLOCase(TestEighOp):
100

101 102 103 104 105
    def init_config(self):
        self.UPLO = 'U'


class TestEighGPUCase(unittest.TestCase):
106

107 108 109
    def setUp(self):
        self.x_shape = [32, 32]
        self.dtype = "float32"
110
        self.UPLO = "L"
111 112 113 114 115 116 117
        np.random.seed(123)
        self.x_np = np.random.random(self.x_shape).astype(self.dtype)

    def test_check_output_gpu(self):
        if paddle.is_compiled_with_cuda():
            paddle.disable_static(place=paddle.CUDAPlace(0))
            input_real_data = paddle.to_tensor(self.x_np)
118
            actual_w, actual_v = paddle.linalg.eigh(input_real_data, self.UPLO)
119 120
            valid_eigh_result(self.x_np, actual_w.numpy(), actual_v.numpy(),
                              self.UPLO)
121 122 123


class TestEighAPI(unittest.TestCase):
124

125
    def setUp(self):
C
crystal 已提交
126
        self.init_input_data()
127
        self.UPLO = 'L'
128 129
        self.rtol = 1e-5  # for test_eigh_grad
        self.atol = 1e-5  # for test_eigh_grad
130 131 132
        self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
            else paddle.CPUPlace()
        np.random.seed(123)
C
crystal 已提交
133

134
    def init_input_shape(self):
C
crystal 已提交
135
        self.x_shape = [5, 5]
136 137 138

    def init_input_data(self):
        self.init_input_shape()
C
crystal 已提交
139
        self.dtype = "float32"
140
        self.real_data = np.random.random(self.x_shape).astype(self.dtype)
C
crystal 已提交
141
        complex_data = np.random.random(self.x_shape).astype(
142 143 144 145
            self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype)
        self.trans_dims = list(range(len(self.x_shape) - 2)) + [
            len(self.x_shape) - 1, len(self.x_shape) - 2
        ]
C
crystal 已提交
146 147 148
        #build a random conjugate matrix
        self.complex_symm = np.divide(
            complex_data + np.conj(complex_data.transpose(self.trans_dims)), 2)
149 150 151 152 153

    def check_static_float_result(self):
        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()
        with paddle.static.program_guard(main_prog, startup_prog):
154 155 156
            input_x = paddle.static.data('input_x',
                                         shape=self.x_shape,
                                         dtype=self.dtype)
157 158
            output_w, output_v = paddle.linalg.eigh(input_x)
            exe = paddle.static.Executor(self.place)
159 160 161 162
            actual_w, actual_v = exe.run(main_prog,
                                         feed={"input_x": self.real_data},
                                         fetch_list=[output_w, output_v])
            valid_eigh_result(self.real_data, actual_w, actual_v, self.UPLO)
163 164 165 166 167 168

    def check_static_complex_result(self):
        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()
        with paddle.static.program_guard(main_prog, startup_prog):
            x_dtype = np.complex64 if self.dtype == "float32" else np.complex128
169 170 171
            input_x = paddle.static.data('input_x',
                                         shape=self.x_shape,
                                         dtype=x_dtype)
172 173
            output_w, output_v = paddle.linalg.eigh(input_x)
            exe = paddle.static.Executor(self.place)
174 175 176 177
            actual_w, actual_v = exe.run(main_prog,
                                         feed={"input_x": self.complex_symm},
                                         fetch_list=[output_w, output_v])
            valid_eigh_result(self.complex_symm, actual_w, actual_v, self.UPLO)
178 179 180 181 182 183 184

    def test_in_static_mode(self):
        paddle.enable_static()
        self.check_static_float_result()
        self.check_static_complex_result()

    def test_in_dynamic_mode(self):
185
        paddle.disable_static()
186 187
        input_real_data = paddle.to_tensor(self.real_data)
        actual_w, actual_v = paddle.linalg.eigh(input_real_data)
188 189
        valid_eigh_result(self.real_data, actual_w.numpy(), actual_v.numpy(),
                          self.UPLO)
190

C
crystal 已提交
191
        input_complex_data = paddle.to_tensor(self.complex_symm)
192
        actual_w, actual_v = paddle.linalg.eigh(input_complex_data)
193 194
        valid_eigh_result(self.complex_symm, actual_w.numpy(), actual_v.numpy(),
                          self.UPLO)
195 196

    def test_eigh_grad(self):
197
        paddle.disable_static()
C
crystal 已提交
198
        x = paddle.to_tensor(self.complex_symm, stop_gradient=False)
199 200
        w, v = paddle.linalg.eigh(x)
        (w.sum() + paddle.abs(v).sum()).backward()
201 202 203 204 205
        np.testing.assert_allclose(abs(x.grad.numpy()),
                                   abs(x.grad.numpy().conj().transpose(
                                       self.trans_dims)),
                                   rtol=self.rtol,
                                   atol=self.atol)
206 207 208


class TestEighBatchAPI(TestEighAPI):
209

210 211 212 213 214
    def init_input_shape(self):
        self.x_shape = [2, 5, 5]


class TestEighAPIError(unittest.TestCase):
215

216 217 218 219 220
    def test_error(self):
        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()
        with paddle.static.program_guard(main_prog, startup_prog):
            #input maxtrix must greater than 2 dimensions
221 222 223
            input_x = paddle.static.data(name='x_1',
                                         shape=[12],
                                         dtype='float32')
224 225 226
            self.assertRaises(ValueError, paddle.linalg.eigh, input_x)

            #input matrix must be square matrix
227 228 229
            input_x = paddle.static.data(name='x_2',
                                         shape=[12, 32],
                                         dtype='float32')
230 231 232
            self.assertRaises(ValueError, paddle.linalg.eigh, input_x)

            #uplo must be in 'L' or 'U'
233 234 235
            input_x = paddle.static.data(name='x_3',
                                         shape=[4, 4],
                                         dtype="float32")
236 237 238 239
            uplo = 'R'
            self.assertRaises(ValueError, paddle.linalg.eigh, input_x, uplo)

            #x_data cannot be integer
240 241 242
            input_x = paddle.static.data(name='x_4',
                                         shape=[4, 4],
                                         dtype="int32")
243 244 245 246 247
            self.assertRaises(TypeError, paddle.linalg.eigh, input_x)


if __name__ == "__main__":
    unittest.main()