test_eigh_op.py 8.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
import paddle
from op_test import OpTest
from gradient_checker import grad_check


22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
def valid_eigh_result(A, eigh_value, eigh_vector, uplo):
    assert A.ndim == 2 or A.ndim == 3

    if A.ndim == 2:
        valid_single_eigh_result(A, eigh_value, eigh_vector, uplo)
        return

    for batch_A, batch_w, batch_v in zip(A, eigh_value, eigh_vector):
        valid_single_eigh_result(batch_A, batch_w, batch_v, uplo)


def valid_single_eigh_result(A, eigh_value, eigh_vector, uplo):
    FP32_MAX_RELATIVE_ERR = 5e-5
    FP64_MAX_RELATIVE_ERR = 1e-14

    if A.dtype == np.single or A.dtype == np.csingle:
        rtol = FP32_MAX_RELATIVE_ERR
    else:
        rtol = FP64_MAX_RELATIVE_ERR

    M, N = A.shape

    triangular_func = np.tril if uplo == 'L' else np.triu

    if not np.iscomplexobj(A):
        # Reconstruct A by filling triangular part
        A = triangular_func(A) + triangular_func(A, -1).T
    else:
        # Reconstruct A to Hermitian matrix
        A = triangular_func(A) + np.matrix(triangular_func(A, -1)).H

    # Diagonal matrix of eigen value
    T = np.diag(eigh_value)

    # A = Q*T*Q'
57
    residual = A - (eigh_vector @ T @ np.linalg.inv(eigh_vector))
58 59 60 61 62 63 64

    # ||A - Q*T*Q'|| / (N*||A||) < rtol
    np.testing.assert_array_less(
        np.linalg.norm(residual, np.inf) / (N * np.linalg.norm(A, np.inf)),
        rtol)

    # ||I - Q*Q'|| / M < rtol
65
    residual = np.eye(M) - eigh_vector @ np.linalg.inv(eigh_vector)
66 67 68
    np.testing.assert_array_less(np.linalg.norm(residual, np.inf) / M, rtol)


69
class TestEighOp(OpTest):
70

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
    def setUp(self):
        paddle.enable_static()
        self.op_type = "eigh"
        self.init_input()
        self.init_config()
        np.random.seed(123)
        out_w, out_v = np.linalg.eigh(self.x_np, self.UPLO)
        self.inputs = {"X": self.x_np}
        self.attrs = {"UPLO": self.UPLO}
        self.outputs = {'Eigenvalues': out_w, "Eigenvectors": out_v}

    def init_config(self):
        self.UPLO = 'L'

    def init_input(self):
        self.x_shape = (10, 10)
        self.x_type = np.float64
        self.x_np = np.random.random(self.x_shape).astype(self.x_type)

    def test_check_output(self):
        self.check_output(no_check_set=['Eigenvectors'])

    def test_grad(self):
        self.check_grad(["X"], ["Eigenvalues"])


class TestEighUPLOCase(TestEighOp):
98

99 100 101 102 103
    def init_config(self):
        self.UPLO = 'U'


class TestEighGPUCase(unittest.TestCase):
104

105 106 107
    def setUp(self):
        self.x_shape = [32, 32]
        self.dtype = "float32"
108
        self.UPLO = "L"
109 110 111 112 113 114 115
        np.random.seed(123)
        self.x_np = np.random.random(self.x_shape).astype(self.dtype)

    def test_check_output_gpu(self):
        if paddle.is_compiled_with_cuda():
            paddle.disable_static(place=paddle.CUDAPlace(0))
            input_real_data = paddle.to_tensor(self.x_np)
116
            actual_w, actual_v = paddle.linalg.eigh(input_real_data, self.UPLO)
117 118
            valid_eigh_result(self.x_np, actual_w.numpy(), actual_v.numpy(),
                              self.UPLO)
119 120 121


class TestEighAPI(unittest.TestCase):
122

123
    def setUp(self):
C
crystal 已提交
124
        self.init_input_data()
125
        self.UPLO = 'L'
126 127
        self.rtol = 1e-5  # for test_eigh_grad
        self.atol = 1e-5  # for test_eigh_grad
128 129 130
        self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
            else paddle.CPUPlace()
        np.random.seed(123)
C
crystal 已提交
131

132
    def init_input_shape(self):
C
crystal 已提交
133
        self.x_shape = [5, 5]
134 135 136

    def init_input_data(self):
        self.init_input_shape()
C
crystal 已提交
137
        self.dtype = "float32"
138
        self.real_data = np.random.random(self.x_shape).astype(self.dtype)
C
crystal 已提交
139
        complex_data = np.random.random(self.x_shape).astype(
140 141 142 143
            self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype)
        self.trans_dims = list(range(len(self.x_shape) - 2)) + [
            len(self.x_shape) - 1, len(self.x_shape) - 2
        ]
C
crystal 已提交
144 145 146
        #build a random conjugate matrix
        self.complex_symm = np.divide(
            complex_data + np.conj(complex_data.transpose(self.trans_dims)), 2)
147 148 149 150 151

    def check_static_float_result(self):
        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()
        with paddle.static.program_guard(main_prog, startup_prog):
152 153 154
            input_x = paddle.static.data('input_x',
                                         shape=self.x_shape,
                                         dtype=self.dtype)
155 156
            output_w, output_v = paddle.linalg.eigh(input_x)
            exe = paddle.static.Executor(self.place)
157 158 159 160
            actual_w, actual_v = exe.run(main_prog,
                                         feed={"input_x": self.real_data},
                                         fetch_list=[output_w, output_v])
            valid_eigh_result(self.real_data, actual_w, actual_v, self.UPLO)
161 162 163 164 165 166

    def check_static_complex_result(self):
        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()
        with paddle.static.program_guard(main_prog, startup_prog):
            x_dtype = np.complex64 if self.dtype == "float32" else np.complex128
167 168 169
            input_x = paddle.static.data('input_x',
                                         shape=self.x_shape,
                                         dtype=x_dtype)
170 171
            output_w, output_v = paddle.linalg.eigh(input_x)
            exe = paddle.static.Executor(self.place)
172 173 174 175
            actual_w, actual_v = exe.run(main_prog,
                                         feed={"input_x": self.complex_symm},
                                         fetch_list=[output_w, output_v])
            valid_eigh_result(self.complex_symm, actual_w, actual_v, self.UPLO)
176 177 178 179 180 181 182

    def test_in_static_mode(self):
        paddle.enable_static()
        self.check_static_float_result()
        self.check_static_complex_result()

    def test_in_dynamic_mode(self):
183
        paddle.disable_static()
184 185
        input_real_data = paddle.to_tensor(self.real_data)
        actual_w, actual_v = paddle.linalg.eigh(input_real_data)
186 187
        valid_eigh_result(self.real_data, actual_w.numpy(), actual_v.numpy(),
                          self.UPLO)
188

C
crystal 已提交
189
        input_complex_data = paddle.to_tensor(self.complex_symm)
190
        actual_w, actual_v = paddle.linalg.eigh(input_complex_data)
191 192
        valid_eigh_result(self.complex_symm, actual_w.numpy(), actual_v.numpy(),
                          self.UPLO)
193 194

    def test_eigh_grad(self):
195
        paddle.disable_static()
C
crystal 已提交
196
        x = paddle.to_tensor(self.complex_symm, stop_gradient=False)
197 198
        w, v = paddle.linalg.eigh(x)
        (w.sum() + paddle.abs(v).sum()).backward()
199 200 201 202 203
        np.testing.assert_allclose(abs(x.grad.numpy()),
                                   abs(x.grad.numpy().conj().transpose(
                                       self.trans_dims)),
                                   rtol=self.rtol,
                                   atol=self.atol)
204 205 206


class TestEighBatchAPI(TestEighAPI):
207

208 209 210 211 212
    def init_input_shape(self):
        self.x_shape = [2, 5, 5]


class TestEighAPIError(unittest.TestCase):
213

214 215 216 217 218
    def test_error(self):
        main_prog = paddle.static.Program()
        startup_prog = paddle.static.Program()
        with paddle.static.program_guard(main_prog, startup_prog):
            #input maxtrix must greater than 2 dimensions
219 220 221
            input_x = paddle.static.data(name='x_1',
                                         shape=[12],
                                         dtype='float32')
222 223 224
            self.assertRaises(ValueError, paddle.linalg.eigh, input_x)

            #input matrix must be square matrix
225 226 227
            input_x = paddle.static.data(name='x_2',
                                         shape=[12, 32],
                                         dtype='float32')
228 229 230
            self.assertRaises(ValueError, paddle.linalg.eigh, input_x)

            #uplo must be in 'L' or 'U'
231 232 233
            input_x = paddle.static.data(name='x_3',
                                         shape=[4, 4],
                                         dtype="float32")
234 235 236 237
            uplo = 'R'
            self.assertRaises(ValueError, paddle.linalg.eigh, input_x, uplo)

            #x_data cannot be integer
238 239 240
            input_x = paddle.static.data(name='x_4',
                                         shape=[4, 4],
                                         dtype="int32")
241 242 243 244 245
            self.assertRaises(TypeError, paddle.linalg.eigh, input_x)


if __name__ == "__main__":
    unittest.main()