diff --git a/python/paddle/utils/cpp_extension/cpp_extension.py b/python/paddle/utils/cpp_extension/cpp_extension.py index 57bcea658b53c400234afd22d4d5acc77f7f43ce..5d132217bba91f84924cbff9f2bd951d381326cf 100644 --- a/python/paddle/utils/cpp_extension/cpp_extension.py +++ b/python/paddle/utils/cpp_extension/cpp_extension.py @@ -22,11 +22,14 @@ from setuptools.command.easy_install import easy_install from setuptools.command.build_ext import build_ext from distutils.command.build import build -from .extension_utils import find_cuda_home, normalize_extension_kwargs, add_compile_flag, bootstrap_context -from .extension_utils import is_cuda_file, prepare_unix_cudaflags, prepare_win_cudaflags, add_std_without_repeat, get_build_directory -from .extension_utils import _import_module_from_library, CustomOpInfo, _write_setup_file, _jit_compile, parse_op_name_from -from .extension_utils import check_abi_compatibility, log_v, IS_WINDOWS, OS_NAME -from .extension_utils import use_new_custom_op_load_method, MSVC_COMPILE_FLAGS +from .extension_utils import find_cuda_home, normalize_extension_kwargs, add_compile_flag +from .extension_utils import is_cuda_file, prepare_unix_cudaflags, prepare_win_cudaflags +from .extension_utils import _import_module_from_library, _write_setup_file, _jit_compile +from .extension_utils import check_abi_compatibility, log_v, CustomOpInfo, parse_op_name_from +from .extension_utils import use_new_custom_op_load_method, clean_object_if_change_cflags +from .extension_utils import bootstrap_context, get_build_directory, add_std_without_repeat + +from .extension_utils import IS_WINDOWS, OS_NAME, MSVC_COMPILE_FLAGS, MSVC_COMPILE_FLAGS # Note(zhouwei): On windows, it will export function 'PyInit_[name]' by default, # The solution is: 1.User add function PyInit_[name] 2. set not to export @@ -357,6 +360,13 @@ class BuildExtension(build_ext, object): def build_extensions(self): self._check_abi() + # Note(Aurelius84): If already compiling source before, we should check whether + # cflags have changed and delete the built shared library to re-compile the source + # even though source file content keep unchanaged. + so_name = self.get_ext_fullpath(self.extensions[0].name) + clean_object_if_change_cflags( + os.path.abspath(so_name), self.extensions[0]) + # Consider .cu, .cu.cc as valid source extensions. self.compiler.src_extensions += ['.cu', '.cu.cc'] # Save the original _compile method for later. diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 896293246a275a40d735b1b6390aebf007848c51..712342b41e57e68ebb127d9a75dc97c19e43669b 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -16,7 +16,9 @@ import os import re import six import sys +import json import glob +import hashlib import logging import collections import textwrap @@ -219,6 +221,106 @@ class CustomOpInfo: return next(reversed(self.op_info_map.items())) +VersionFields = collections.namedtuple('VersionFields', [ + 'sources', + 'extra_compile_args', + 'extra_link_args', + 'library_dirs', + 'runtime_library_dirs', + 'include_dirs', + 'define_macros', + 'undef_macros', +]) + + +class VersionManager: + def __init__(self, version_field): + self.version_field = version_field + self.version = self.hasher(version_field) + + def hasher(self, version_field): + from paddle.fluid.layers.utils import flatten + + md5 = hashlib.md5() + for field in version_field._fields: + elem = getattr(version_field, field) + if not elem: continue + if isinstance(elem, (list, tuple, dict)): + flat_elem = flatten(elem) + md5 = combine_hash(md5, tuple(flat_elem)) + else: + raise RuntimeError( + "Support types with list, tuple and dict, but received {} with {}.". + format(type(elem), elem)) + + return md5.hexdigest() + + @property + def details(self): + return self.version_field._asdict() + + +def combine_hash(md5, value): + """ + Return new hash value. + DO NOT use `hash()` beacuse it doesn't generate stable value between different process. + See https://stackoverflow.com/questions/27522626/hash-function-in-python-3-3-returns-different-results-between-sessions + """ + md5.update(repr(value).encode()) + return md5 + + +def clean_object_if_change_cflags(so_path, extension): + """ + If already compiling source before, we should check whether cflags + have changed and delete the built object to re-compile the source + even though source file content keeps unchanaged. + """ + + def serialize(path, version_info): + assert isinstance(version_info, dict) + with open(path, 'w') as f: + f.write(json.dumps(version_info, indent=4, sort_keys=True)) + + def deserialize(path): + assert os.path.exists(path) + with open(path, 'r') as f: + content = f.read() + return json.loads(content) + + # version file + VERSION_FILE = "version.txt" + base_dir = os.path.dirname(so_path) + so_name = os.path.basename(so_path) + version_file = os.path.join(base_dir, VERSION_FILE) + + # version info + args = [getattr(extension, field, None) for field in VersionFields._fields] + version_field = VersionFields._make(args) + versioner = VersionManager(version_field) + + if os.path.exists(so_path) and os.path.exists(version_file): + old_version_info = deserialize(version_file) + so_version = old_version_info.get(so_name, None) + # delete shared library file if versison is changed to re-compile it. + if so_version is not None and so_version != versioner.version: + log_v( + "Re-Compiling {}, because specified cflags have been changed. New signature {} has been saved into {}.". + format(so_name, versioner.version, version_file)) + os.remove(so_path) + # upate new version information + new_version_info = versioner.details + new_version_info[so_name] = versioner.version + serialize(version_file, new_version_info) + else: + # If compile at first time, save compiling detail information for debug. + if not os.path.exists(base_dir): + os.makedirs(base_dir) + details = versioner.details + details[so_name] = versioner.version + serialize(version_file, details) + + def prepare_unix_cudaflags(cflags): """ Prepare all necessary compiled flags for nvcc compiling CUDA files.