test_check_abi.py 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import paddle
import os
import warnings

import paddle.utils.cpp_extension.extension_utils as utils


class TestABIBase(unittest.TestCase):
    def test_environ(self):
25 26 27 28 29
        compiler_list = ['gcc', 'cl']
        for compiler in compiler_list:
            for flag in ['1', 'True', 'true']:
                os.environ['PADDLE_SKIP_CHECK_ABI'] = flag
                self.assertTrue(utils.check_abi_compatibility(compiler))
30 31 32 33 34 35 36

    def del_environ(self):
        key = 'PADDLE_SKIP_CHECK_ABI'
        if key in os.environ:
            del os.environ[key]


37
class TestCheckCompiler(TestABIBase):
38 39 40
    def test_expected_compiler(self):
        if utils.OS_NAME.startswith('linux'):
            gt = ['gcc', 'g++', 'gnu-c++', 'gnu-cc']
41 42 43 44 45 46
        elif utils.IS_WINDOWS:
            gt = ['cl']
        elif utils.OS_NAME.startswith('darwin'):
            gt = ['clang', 'clang++']

        self.assertListEqual(utils._expected_compiler_current_platform(), gt)
47

48
    def test_compiler_version(self):
49 50 51
        # clear environ
        self.del_environ()
        if utils.OS_NAME.startswith('linux'):
52 53 54 55 56 57 58 59
            compiler = 'g++'
        elif utils.IS_WINDOWS:
            compiler = 'cl'

        # Linux: all CI gcc version > 5.4.0
        # Windows: all CI MSVC version > 19.00.24215
        # Mac: clang has no version limitation, always return true
        self.assertTrue(utils.check_abi_compatibility(compiler, verbose=True))
60 61 62 63

    def test_wrong_compiler_warning(self):
        # clear environ
        self.del_environ()
64 65 66 67 68 69 70 71 72
        compiler = 'python'  # fake wrong compiler
        with warnings.catch_warnings(record=True) as error:
            flag = utils.check_abi_compatibility(compiler, verbose=True)
            # check return False
            self.assertFalse(flag)
            # check Compiler Compatibility WARNING
            self.assertTrue(len(error) == 1)
            self.assertTrue(
                "Compiler Compatibility WARNING" in str(error[0].message))
73 74 75 76 77 78

    def test_exception(self):
        # clear environ
        self.del_environ()
        compiler = 'python'  # fake command
        if utils.OS_NAME.startswith('linux'):
79

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
            def fake():
                return [compiler]

            # mock a fake function
            raw_func = utils._expected_compiler_current_platform
            utils._expected_compiler_current_platform = fake
            with warnings.catch_warnings(record=True) as error:
                flag = utils.check_abi_compatibility(compiler, verbose=True)
                # check return False
                self.assertFalse(flag)
                # check ABI Compatibility WARNING
                self.assertTrue(len(error) == 1)
                self.assertTrue("Failed to check compiler version for" in
                                str(error[0].message))

            # restore
            utils._expected_compiler_current_platform = raw_func


class TestJITCompilerException(unittest.TestCase):
    def test_exception(self):
        with self.assertRaisesRegexp(RuntimeError,
                                     "Failed to check Python interpreter"):
            file_path = os.path.abspath(__file__)
            utils._jit_compile(file_path, interpreter='fake_cmd', verbose=True)


class TestRunCMDException(unittest.TestCase):
    def test_exception(self):
        for verbose in [True, False]:
            with self.assertRaisesRegexp(RuntimeError, "Failed to run command"):
                cmd = "fake cmd"
                utils.run_cmd(cmd, verbose)


if __name__ == '__main__':
    unittest.main()