test_linalg_cond.py 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import paddle
import paddle.static as static
21
from paddle.fluid.framework import _test_eager_guard
22 23 24 25 26 27 28 29 30 31

p_list_n_n = ("fro", "nuc", 1, -1, np.inf, -np.inf)
p_list_m_n = (None, 2, -2)


def test_static_assert_true(self, x_list, p_list):
    for p in p_list:
        for x in x_list:
            with static.program_guard(static.Program(), static.Program()):
                input_data = static.data("X", shape=x.shape, dtype=x.dtype)
32
                output = paddle.linalg.cond(input_data, p)
33 34 35
                exe = static.Executor()
                result = exe.run(feed={"X": x}, fetch_list=[output])
                expected_output = np.linalg.cond(x, p)
H
Haohongxiang 已提交
36 37
                np.testing.assert_allclose(
                    result[0], expected_output, rtol=5e-5)
38 39 40 41 42 43


def test_dygraph_assert_true(self, x_list, p_list):
    for p in p_list:
        for x in x_list:
            input_tensor = paddle.to_tensor(x)
44
            output = paddle.linalg.cond(input_tensor, p)
45
            expected_output = np.linalg.cond(x, p)
H
Haohongxiang 已提交
46 47
            np.testing.assert_allclose(
                output.numpy(), expected_output, rtol=5e-5)
48 49 50


def gen_input():
51
    np.random.seed(2021)
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    # generate square matrix or batches of square matrices
    input_1 = np.random.rand(5, 5).astype('float32')
    input_2 = np.random.rand(3, 6, 6).astype('float64')
    input_3 = np.random.rand(2, 4, 3, 3).astype('float32')

    # generate non-square matrix or batches of non-square matrices
    input_4 = np.random.rand(9, 7).astype('float64')
    input_5 = np.random.rand(4, 2, 10).astype('float32')
    input_6 = np.random.rand(3, 5, 4, 1).astype('float32')

    list_n_n = (input_1, input_2, input_3)
    list_m_n = (input_4, input_5, input_6)
    return list_n_n, list_m_n


def gen_empty_input():
    # generate square matrix or batches of square matrices which are empty tensor
    input_1 = np.random.rand(0, 7, 7).astype('float32')
    input_2 = np.random.rand(0, 9, 9).astype('float32')
    input_3 = np.random.rand(0, 4, 5, 5).astype('float64')

    # generate non-square matrix or batches of non-square matrices which are empty tensor
    input_4 = np.random.rand(0, 7, 11).astype('float32')
    input_5 = np.random.rand(0, 10, 8).astype('float64')
    input_6 = np.random.rand(5, 0, 4, 3).astype('float32')

    list_n_n = (input_1, input_2, input_3)
    list_m_n = (input_4, input_5, input_6)
    return list_n_n, list_m_n


class API_TestStaticCond(unittest.TestCase):
    def test_out(self):
        paddle.enable_static()
        # test calling results of 'cond' in static mode
        x_list_n_n, x_list_m_n = gen_input()
        test_static_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n)
        test_static_assert_true(self, x_list_m_n, p_list_m_n)


class API_TestDygraphCond(unittest.TestCase):
93
    def func_out(self):
94 95 96 97 98 99
        paddle.disable_static()
        # test calling results of 'cond' in dynamic mode
        x_list_n_n, x_list_m_n = gen_input()
        test_dygraph_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n)
        test_dygraph_assert_true(self, x_list_m_n, p_list_m_n)

100 101 102 103 104
    def test_out(self):
        with _test_eager_guard():
            self.func_out()
        self.func_out()

105 106

class TestCondAPIError(unittest.TestCase):
107
    def func_dygraph_api_error(self):
108 109 110 111 112 113 114
        paddle.disable_static()
        # test raising errors when 'cond' is called in dygraph mode
        p_list_error = ('fro_', '_nuc', -0.7, 0, 1.5, 3)
        x_list_n_n, x_list_m_n = gen_input()
        for p in p_list_error:
            for x in (x_list_n_n + x_list_m_n):
                x_tensor = paddle.to_tensor(x)
115
                self.assertRaises(ValueError, paddle.linalg.cond, x_tensor, p)
116 117 118 119

        for p in p_list_n_n:
            for x in x_list_m_n:
                x_tensor = paddle.to_tensor(x)
120
                self.assertRaises(ValueError, paddle.linalg.cond, x_tensor, p)
121

122 123 124 125 126
    def test_dygraph_api_error(self):
        with _test_eager_guard():
            self.func_dygraph_api_error()
        self.func_dygraph_api_error()

127 128 129 130 131 132 133 134 135
    def test_static_api_error(self):
        paddle.enable_static()
        # test raising errors when 'cond' is called in static mode
        p_list_error = ('f ro', 'fre', 'NUC', -1.6, 0, 5)
        x_list_n_n, x_list_m_n = gen_input()
        for p in p_list_error:
            for x in (x_list_n_n + x_list_m_n):
                with static.program_guard(static.Program(), static.Program()):
                    x_data = static.data("X", shape=x.shape, dtype=x.dtype)
136
                    self.assertRaises(ValueError, paddle.linalg.cond, x_data, p)
137 138 139 140 141

        for p in p_list_n_n:
            for x in x_list_m_n:
                with static.program_guard(static.Program(), static.Program()):
                    x_data = static.data("X", shape=x.shape, dtype=x.dtype)
142
                    self.assertRaises(ValueError, paddle.linalg.cond, x_data, p)
143 144 145 146 147 148 149 150 151 152

    # it's not supported when input is an empty tensor in static mode
    def test_static_empty_input_error(self):
        paddle.enable_static()

        x_list_n_n, x_list_m_n = gen_empty_input()
        for p in (p_list_n_n + p_list_m_n):
            for x in x_list_n_n:
                with static.program_guard(static.Program(), static.Program()):
                    x_data = static.data("X", shape=x.shape, dtype=x.dtype)
153
                    self.assertRaises(ValueError, paddle.linalg.cond, x_data, p)
154 155 156 157 158

        for p in (p_list_n_n + p_list_m_n):
            for x in x_list_n_n:
                with static.program_guard(static.Program(), static.Program()):
                    x_data = static.data("X", shape=x.shape, dtype=x.dtype)
159
                    self.assertRaises(ValueError, paddle.linalg.cond, x_data, p)
160 161 162


class TestCondEmptyTensorInput(unittest.TestCase):
163
    def func_dygraph_empty_tensor_input(self):
164 165 166 167 168 169
        paddle.disable_static()
        # test calling results of 'cond' when input is an empty tensor in dynamic mode
        x_list_n_n, x_list_m_n = gen_empty_input()
        test_dygraph_assert_true(self, x_list_n_n, p_list_n_n + p_list_m_n)
        test_dygraph_assert_true(self, x_list_m_n, p_list_m_n)

170 171 172 173 174
    def test_dygraph_empty_tensor_input(self):
        with _test_eager_guard():
            self.func_dygraph_empty_tensor_input()
        self.func_dygraph_empty_tensor_input()

175 176 177 178

if __name__ == "__main__":
    paddle.enable_static()
    unittest.main()