未验证 提交 59b00e8c 编写于 作者: A Aurelius84 提交者: GitHub

[CustomOP]Support Incremental compilation and Add Version management (#31228)

* Support Incremental compilation and Add Version management

* replace hash with hashlib
上级 126633c5
...@@ -22,11 +22,14 @@ from setuptools.command.easy_install import easy_install ...@@ -22,11 +22,14 @@ from setuptools.command.easy_install import easy_install
from setuptools.command.build_ext import build_ext from setuptools.command.build_ext import build_ext
from distutils.command.build import build from distutils.command.build import build
from .extension_utils import find_cuda_home, normalize_extension_kwargs, add_compile_flag, bootstrap_context from .extension_utils import find_cuda_home, normalize_extension_kwargs, add_compile_flag
from .extension_utils import is_cuda_file, prepare_unix_cudaflags, prepare_win_cudaflags, add_std_without_repeat, get_build_directory from .extension_utils import is_cuda_file, prepare_unix_cudaflags, prepare_win_cudaflags
from .extension_utils import _import_module_from_library, CustomOpInfo, _write_setup_file, _jit_compile, parse_op_name_from from .extension_utils import _import_module_from_library, _write_setup_file, _jit_compile
from .extension_utils import check_abi_compatibility, log_v, IS_WINDOWS, OS_NAME from .extension_utils import check_abi_compatibility, log_v, CustomOpInfo, parse_op_name_from
from .extension_utils import use_new_custom_op_load_method, MSVC_COMPILE_FLAGS from .extension_utils import use_new_custom_op_load_method, clean_object_if_change_cflags
from .extension_utils import bootstrap_context, get_build_directory, add_std_without_repeat
from .extension_utils import IS_WINDOWS, OS_NAME, MSVC_COMPILE_FLAGS, MSVC_COMPILE_FLAGS
# Note(zhouwei): On windows, it will export function 'PyInit_[name]' by default, # Note(zhouwei): On windows, it will export function 'PyInit_[name]' by default,
# The solution is: 1.User add function PyInit_[name] 2. set not to export # The solution is: 1.User add function PyInit_[name] 2. set not to export
...@@ -357,6 +360,13 @@ class BuildExtension(build_ext, object): ...@@ -357,6 +360,13 @@ class BuildExtension(build_ext, object):
def build_extensions(self): def build_extensions(self):
self._check_abi() self._check_abi()
# Note(Aurelius84): If already compiling source before, we should check whether
# cflags have changed and delete the built shared library to re-compile the source
# even though source file content keep unchanaged.
so_name = self.get_ext_fullpath(self.extensions[0].name)
clean_object_if_change_cflags(
os.path.abspath(so_name), self.extensions[0])
# Consider .cu, .cu.cc as valid source extensions. # Consider .cu, .cu.cc as valid source extensions.
self.compiler.src_extensions += ['.cu', '.cu.cc'] self.compiler.src_extensions += ['.cu', '.cu.cc']
# Save the original _compile method for later. # Save the original _compile method for later.
......
...@@ -16,7 +16,9 @@ import os ...@@ -16,7 +16,9 @@ import os
import re import re
import six import six
import sys import sys
import json
import glob import glob
import hashlib
import logging import logging
import collections import collections
import textwrap import textwrap
...@@ -219,6 +221,106 @@ class CustomOpInfo: ...@@ -219,6 +221,106 @@ class CustomOpInfo:
return next(reversed(self.op_info_map.items())) return next(reversed(self.op_info_map.items()))
VersionFields = collections.namedtuple('VersionFields', [
'sources',
'extra_compile_args',
'extra_link_args',
'library_dirs',
'runtime_library_dirs',
'include_dirs',
'define_macros',
'undef_macros',
])
class VersionManager:
def __init__(self, version_field):
self.version_field = version_field
self.version = self.hasher(version_field)
def hasher(self, version_field):
from paddle.fluid.layers.utils import flatten
md5 = hashlib.md5()
for field in version_field._fields:
elem = getattr(version_field, field)
if not elem: continue
if isinstance(elem, (list, tuple, dict)):
flat_elem = flatten(elem)
md5 = combine_hash(md5, tuple(flat_elem))
else:
raise RuntimeError(
"Support types with list, tuple and dict, but received {} with {}.".
format(type(elem), elem))
return md5.hexdigest()
@property
def details(self):
return self.version_field._asdict()
def combine_hash(md5, value):
"""
Return new hash value.
DO NOT use `hash()` beacuse it doesn't generate stable value between different process.
See https://stackoverflow.com/questions/27522626/hash-function-in-python-3-3-returns-different-results-between-sessions
"""
md5.update(repr(value).encode())
return md5
def clean_object_if_change_cflags(so_path, extension):
"""
If already compiling source before, we should check whether cflags
have changed and delete the built object to re-compile the source
even though source file content keeps unchanaged.
"""
def serialize(path, version_info):
assert isinstance(version_info, dict)
with open(path, 'w') as f:
f.write(json.dumps(version_info, indent=4, sort_keys=True))
def deserialize(path):
assert os.path.exists(path)
with open(path, 'r') as f:
content = f.read()
return json.loads(content)
# version file
VERSION_FILE = "version.txt"
base_dir = os.path.dirname(so_path)
so_name = os.path.basename(so_path)
version_file = os.path.join(base_dir, VERSION_FILE)
# version info
args = [getattr(extension, field, None) for field in VersionFields._fields]
version_field = VersionFields._make(args)
versioner = VersionManager(version_field)
if os.path.exists(so_path) and os.path.exists(version_file):
old_version_info = deserialize(version_file)
so_version = old_version_info.get(so_name, None)
# delete shared library file if versison is changed to re-compile it.
if so_version is not None and so_version != versioner.version:
log_v(
"Re-Compiling {}, because specified cflags have been changed. New signature {} has been saved into {}.".
format(so_name, versioner.version, version_file))
os.remove(so_path)
# upate new version information
new_version_info = versioner.details
new_version_info[so_name] = versioner.version
serialize(version_file, new_version_info)
else:
# If compile at first time, save compiling detail information for debug.
if not os.path.exists(base_dir):
os.makedirs(base_dir)
details = versioner.details
details[so_name] = versioner.version
serialize(version_file, details)
def prepare_unix_cudaflags(cflags): def prepare_unix_cudaflags(cflags):
""" """
Prepare all necessary compiled flags for nvcc compiling CUDA files. Prepare all necessary compiled flags for nvcc compiling CUDA files.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册