diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py index 1503c666d2bb1886d40fdd75af4712560f549351..3113ffea786a0d07698e725aba315d7fd513ac21 100644 --- a/python/paddle/utils/cpp_extension/extension_utils.py +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -441,13 +441,46 @@ def _reset_so_rpath(so_path): run_cmd(cmd) +def _get_include_dirs_when_compiling(compile_dir): + """ + Get all include directories when compiling the PaddlePaddle + source code. + """ + include_dirs_file = 'includes.txt' + path = os.path.abspath(compile_dir) + include_dirs_file = os.path.join(path, include_dirs_file) + if not os.path.isfile(include_dirs_file): + return [] + with open(include_dirs_file, 'r') as f: + include_dirs = [line.strip() for line in f.readlines() if line.strip()] + + extra_dirs = ['paddle/fluid/platform'] + all_include_dirs = list(include_dirs) + for extra_dir in extra_dirs: + for include_dir in include_dirs: + d = os.path.join(include_dir, extra_dir) + if os.path.isdir(d): + all_include_dirs.append(d) + all_include_dirs.append(path) + all_include_dirs.sort() + return all_include_dirs + + def normalize_extension_kwargs(kwargs, use_cuda=False): """ Normalize include_dirs, library_dir and other attributes in kwargs. """ assert isinstance(kwargs, dict) + include_dirs = [] + # NOTE: the "_compile_dir" argument is not public to users. It is only + # reserved for internal usage. We do not guarantee that this argument + # is always valid in the future release versions. + compile_dir = kwargs.get("_compile_dir", None) + if compile_dir: + include_dirs = _get_include_dirs_when_compiling(compile_dir) + # append necessary include dir path of paddle - include_dirs = kwargs.get('include_dirs', []) + include_dirs = kwargs.get('include_dirs', include_dirs) include_dirs.extend(find_paddle_includes(use_cuda)) kwargs['include_dirs'] = include_dirs