未验证 提交 cc89120a 编写于 作者: Z Zhou Wei 提交者: GitHub

[Custom OP]add MSVC compile check on Windows (#31265)

上级 af9066e8
...@@ -23,15 +23,14 @@ set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120) ...@@ -23,15 +23,14 @@ set_tests_properties(test_multi_out_jit PROPERTIES TIMEOUT 120)
py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py) py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py)
set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120) set_tests_properties(test_custom_attrs_jit PROPERTIES TIMEOUT 120)
py_test(test_check_abi SRCS test_check_abi.py)
cc_test(test_check_error SRCS test_check_error.cc DEPS gtest) cc_test(test_check_error SRCS test_check_error.cc DEPS gtest)
if(NOT LINUX) if(NOT LINUX)
return() return()
endif() endif()
# TODO(zhouwei): support test_check_abi and abi check on Windows
py_test(test_check_abi SRCS test_check_abi.py)
# Old custom OP only support Linux, only run on Linux # Old custom OP only support Linux, only run on Linux
py_test(test_custom_op SRCS test_custom_op.py) py_test(test_custom_op SRCS test_custom_op.py)
py_test(test_jit_load SRCS test_jit_load.py) py_test(test_jit_load SRCS test_jit_load.py)
......
...@@ -22,10 +22,11 @@ import paddle.utils.cpp_extension.extension_utils as utils ...@@ -22,10 +22,11 @@ import paddle.utils.cpp_extension.extension_utils as utils
class TestABIBase(unittest.TestCase): class TestABIBase(unittest.TestCase):
def test_environ(self): def test_environ(self):
compiler = 'gcc' compiler_list = ['gcc', 'cl']
for flag in ['1', 'True', 'true']: for compiler in compiler_list:
os.environ['PADDLE_SKIP_CHECK_ABI'] = flag for flag in ['1', 'True', 'true']:
self.assertTrue(utils.check_abi_compatibility(compiler)) os.environ['PADDLE_SKIP_CHECK_ABI'] = flag
self.assertTrue(utils.check_abi_compatibility(compiler))
def del_environ(self): def del_environ(self):
key = 'PADDLE_SKIP_CHECK_ABI' key = 'PADDLE_SKIP_CHECK_ABI'
...@@ -33,43 +34,49 @@ class TestABIBase(unittest.TestCase): ...@@ -33,43 +34,49 @@ class TestABIBase(unittest.TestCase):
del os.environ[key] del os.environ[key]
class TestCheckLinux(TestABIBase): class TestCheckCompiler(TestABIBase):
def test_expected_compiler(self): def test_expected_compiler(self):
if utils.OS_NAME.startswith('linux'): if utils.OS_NAME.startswith('linux'):
gt = ['gcc', 'g++', 'gnu-c++', 'gnu-cc'] gt = ['gcc', 'g++', 'gnu-c++', 'gnu-cc']
self.assertListEqual(utils._expected_compiler_current_platform(), elif utils.IS_WINDOWS:
gt) gt = ['cl']
elif utils.OS_NAME.startswith('darwin'):
gt = ['clang', 'clang++']
self.assertListEqual(utils._expected_compiler_current_platform(), gt)
def test_gcc_version(self): def test_compiler_version(self):
# clear environ # clear environ
self.del_environ() self.del_environ()
compiler = 'g++'
if utils.OS_NAME.startswith('linux'): if utils.OS_NAME.startswith('linux'):
# all CI gcc version > 5.4.0 compiler = 'g++'
self.assertTrue( elif utils.IS_WINDOWS:
utils.check_abi_compatibility( compiler = 'cl'
compiler, verbose=True))
# Linux: all CI gcc version > 5.4.0
# Windows: all CI MSVC version > 19.00.24215
# Mac: clang has no version limitation, always return true
self.assertTrue(utils.check_abi_compatibility(compiler, verbose=True))
def test_wrong_compiler_warning(self): def test_wrong_compiler_warning(self):
# clear environ # clear environ
self.del_environ() self.del_environ()
compiler = 'nvcc' # fake wrong compiler compiler = 'nvcc' # fake wrong compiler
if utils.OS_NAME.startswith('linux'): with warnings.catch_warnings(record=True) as error:
with warnings.catch_warnings(record=True) as error: flag = utils.check_abi_compatibility(compiler, verbose=True)
flag = utils.check_abi_compatibility(compiler, verbose=True) # check return False
# check return False self.assertFalse(flag)
self.assertFalse(flag) # check Compiler Compatibility WARNING
# check Compiler Compatibility WARNING self.assertTrue(len(error) == 1)
self.assertTrue(len(error) == 1) self.assertTrue(
self.assertTrue( "Compiler Compatibility WARNING" in str(error[0].message))
"Compiler Compatibility WARNING" in str(error[0].message))
def test_exception(self): def test_exception(self):
# clear environ # clear environ
self.del_environ() self.del_environ()
compiler = 'python' # fake command compiler = 'python' # fake command
if utils.OS_NAME.startswith('linux'): if utils.OS_NAME.startswith('linux'):
# to skip _expected_compiler_current_platform
def fake(): def fake():
return [compiler] return [compiler]
...@@ -89,32 +96,6 @@ class TestCheckLinux(TestABIBase): ...@@ -89,32 +96,6 @@ class TestCheckLinux(TestABIBase):
utils._expected_compiler_current_platform = raw_func utils._expected_compiler_current_platform = raw_func
class TestCheckMacOs(TestABIBase):
def test_expected_compiler(self):
if utils.OS_NAME.startswith('darwin'):
gt = ['clang', 'clang++']
self.assertListEqual(utils._expected_compiler_current_platform(),
gt)
def test_gcc_version(self):
# clear environ
self.del_environ()
if utils.OS_NAME.startswith('darwin'):
# clang has no version limitation.
self.assertTrue(utils.check_abi_compatibility())
class TestCheckWindows(TestABIBase):
def test_gcc_version(self):
# clear environ
self.del_environ()
if utils.IS_WINDOWS:
# we skip windows now
self.assertTrue(utils.check_abi_compatibility())
class TestJITCompilerException(unittest.TestCase): class TestJITCompilerException(unittest.TestCase):
def test_exception(self): def test_exception(self):
with self.assertRaisesRegexp(RuntimeError, with self.assertRaisesRegexp(RuntimeError,
......
...@@ -51,6 +51,7 @@ MSVC_LINK_FLAGS = ['/MACHINE:X64', 'paddle_custom_op.lib'] ...@@ -51,6 +51,7 @@ MSVC_LINK_FLAGS = ['/MACHINE:X64', 'paddle_custom_op.lib']
COMMON_NVCC_FLAGS = ['-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU', '-O3'] COMMON_NVCC_FLAGS = ['-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU', '-O3']
GCC_MINI_VERSION = (5, 4, 0) GCC_MINI_VERSION = (5, 4, 0)
MSVC_MINI_VERSION = (19, 0, 24215)
# Give warning if using wrong compiler # Give warning if using wrong compiler
WRONG_COMPILER_WARNING = ''' WRONG_COMPILER_WARNING = '''
************************************* *************************************
...@@ -64,7 +65,7 @@ built Paddle for this platform, which is {paddle_compiler} on {platform}. Please ...@@ -64,7 +65,7 @@ built Paddle for this platform, which is {paddle_compiler} on {platform}. Please
use {paddle_compiler} to compile your custom op. Or you may compile Paddle from use {paddle_compiler} to compile your custom op. Or you may compile Paddle from
source using {user_compiler}, and then also use it compile your custom op. source using {user_compiler}, and then also use it compile your custom op.
See https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/2.0/install/compile/linux-compile.html See https://www.paddlepaddle.org.cn/documentation/docs/zh/install/compile/fromsource.html
for help with compiling Paddle from source. for help with compiling Paddle from source.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
...@@ -877,13 +878,12 @@ def check_abi_compatibility(compiler, verbose=False): ...@@ -877,13 +878,12 @@ def check_abi_compatibility(compiler, verbose=False):
Check whether GCC version on user local machine is compatible with Paddle in Check whether GCC version on user local machine is compatible with Paddle in
site-packages. site-packages.
""" """
# TODO(Aurelius84): After we support windows, remove IS_WINDOWS in following code. if os.environ.get('PADDLE_SKIP_CHECK_ABI') in ['True', 'true', '1']:
if os.environ.get('PADDLE_SKIP_CHECK_ABI') in ['True', 'true', '1'
] or IS_WINDOWS:
return True return True
which = 'where' if IS_WINDOWS else 'which'
cmd_out = subprocess.check_output( cmd_out = subprocess.check_output(
['which', compiler], stderr=subprocess.STDOUT) [which, compiler], stderr=subprocess.STDOUT)
compiler_path = os.path.realpath(cmd_out.decode() compiler_path = os.path.realpath(cmd_out.decode()
if six.PY3 else cmd_out).strip() if six.PY3 else cmd_out).strip()
# step 1. if not found any suitable compiler, raise error # step 1. if not found any suitable compiler, raise error
...@@ -896,32 +896,41 @@ def check_abi_compatibility(compiler, verbose=False): ...@@ -896,32 +896,41 @@ def check_abi_compatibility(compiler, verbose=False):
platform=OS_NAME)) platform=OS_NAME))
return False return False
version = (0, 0, 0)
# clang++ have no ABI compatibility problem # clang++ have no ABI compatibility problem
if OS_NAME.startswith('darwin'): if OS_NAME.startswith('darwin'):
return True return True
try: try:
if OS_NAME.startswith('linux'): if OS_NAME.startswith('linux'):
mini_required_version = GCC_MINI_VERSION
version_info = subprocess.check_output( version_info = subprocess.check_output(
[compiler, '-dumpfullversion', '-dumpversion']) [compiler, '-dumpfullversion', '-dumpversion'])
if six.PY3: if six.PY3:
version_info = version_info.decode() version_info = version_info.decode()
version = version_info.strip().split('.') version = version_info.strip().split('.')
assert len(version) == 3
# check version compatibility
if tuple(map(int, version)) >= GCC_MINI_VERSION:
return True
else:
warnings.warn(
ABI_INCOMPATIBILITY_WARNING.format(
user_compiler=compiler, version=version_info.strip()))
elif IS_WINDOWS: elif IS_WINDOWS:
# TODO(zhouwei): support check abi compatibility on windows mini_required_version = MSVC_MINI_VERSION
warnings.warn("We don't support Windows now.") compiler_info = subprocess.check_output(
compiler, stderr=subprocess.STDOUT)
if six.PY3:
compiler_info = compiler_info.decode()
match = re.search(r'(\d+)\.(\d+)\.(\d+)', compiler_info.strip())
if match is not None:
version = match.groups()
except Exception: except Exception:
# check compiler version failed
_, error, _ = sys.exc_info() _, error, _ = sys.exc_info()
warnings.warn('Failed to check compiler version for {}: {}'.format( warnings.warn('Failed to check compiler version for {}: {}'.format(
compiler, error)) compiler, error))
return False
# check version compatibility
assert len(version) == 3
if tuple(map(int, version)) >= mini_required_version:
return True
warnings.warn(
ABI_INCOMPATIBILITY_WARNING.format(
user_compiler=compiler, version=version.strip()))
return False return False
...@@ -929,8 +938,12 @@ def _expected_compiler_current_platform(): ...@@ -929,8 +938,12 @@ def _expected_compiler_current_platform():
""" """
Returns supported compiler string on current platform Returns supported compiler string on current platform
""" """
expect_compilers = ['clang', 'clang++'] if OS_NAME.startswith( if OS_NAME.startswith('darwin'):
'darwin') else ['gcc', 'g++', 'gnu-c++', 'gnu-cc'] expect_compilers = ['clang', 'clang++']
elif OS_NAME.startswith('linux'):
expect_compilers = ['gcc', 'g++', 'gnu-c++', 'gnu-cc']
elif IS_WINDOWS:
expect_compilers = ['cl']
return expect_compilers return expect_compilers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册