From e49d0746ddf9741c871fca8ce53ddbe7a295e4a2 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 3 Feb 2021 14:13:07 +0800 Subject: [PATCH] [CustomOp] Support install as Package and Add load interface (#30798) * support setup.py to compile custom op * move file into paddle.utils.cpp_extension * support python setup.py install * refine code style * Enrich code and add unittest * Polish code and api doc * fix cpp_extension not include in package * fix relative import * fix os.makedirs exist_ok param compatibility PY2 * add compile flags in test_jit_load --- python/paddle/fluid/framework.py | 9 +- .../fluid/tests/custom_op/CMakeLists.txt | 2 + .../fluid/tests/custom_op/cpp_extension.py | 179 ------ .../fluid/tests/custom_op/extension_utils.py | 216 ------- .../fluid/tests/custom_op/setup_build.py | 33 ++ .../fluid/tests/custom_op/setup_install.py | 27 + .../custom_op/test_custom_op_with_setup.py | 5 +- .../fluid/tests/custom_op/test_jit_load.py | 42 ++ .../tests/custom_op/test_setup_install.py | 59 ++ .../tests/custom_op/{setup.py => utils.py} | 26 +- python/paddle/utils/__init__.py | 2 + python/paddle/utils/cpp_extension/__init__.py | 29 + .../utils/cpp_extension/cpp_extension.py | 339 +++++++++++ .../utils/cpp_extension/extension_utils.py | 543 ++++++++++++++++++ python/setup.py.in | 1 + 15 files changed, 1093 insertions(+), 419 deletions(-) delete mode 100644 python/paddle/fluid/tests/custom_op/cpp_extension.py delete mode 100644 python/paddle/fluid/tests/custom_op/extension_utils.py create mode 100644 python/paddle/fluid/tests/custom_op/setup_build.py create mode 100644 python/paddle/fluid/tests/custom_op/setup_install.py create mode 100644 python/paddle/fluid/tests/custom_op/test_jit_load.py create mode 100644 python/paddle/fluid/tests/custom_op/test_setup_install.py rename python/paddle/fluid/tests/custom_op/{setup.py => utils.py} (70%) create mode 100644 python/paddle/utils/cpp_extension/__init__.py create mode 100644 python/paddle/utils/cpp_extension/cpp_extension.py create mode 100644 python/paddle/utils/cpp_extension/extension_utils.py diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 7c492655968..e7a641b7aaf 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1991,9 +1991,13 @@ class OpProtoHolder(object): def update_op_proto(self): op_protos = get_all_op_protos() + custom_op_names = [] for proto in op_protos: if proto.type not in self.op_proto_map: self.op_proto_map[proto.type] = proto + custom_op_names.append(proto.type) + + return custom_op_names @staticmethod def generated_op_attr_names(): @@ -5702,6 +5706,9 @@ def load_op_library(lib_filename): Args: lib_filename (str): name of dynamic library. + + Returns: + list[str]: new registered custom op names. Examples: .. code-block:: python @@ -5711,7 +5718,7 @@ def load_op_library(lib_filename): """ core.load_op_library(lib_filename) - OpProtoHolder.instance().update_op_proto() + return OpProtoHolder.instance().update_op_proto() def switch_device(device): diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index 85d38c7548b..cc3c9c098c9 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -28,3 +28,5 @@ endforeach() # Compiling .so will cost some time, but running process is very fast. set_tests_properties(test_custom_op_with_setup PROPERTIES TIMEOUT 180) +set_tests_properties(test_jit_load PROPERTIES TIMEOUT 180) +set_tests_properties(test_setup_install PROPERTIES TIMEOUT 180) diff --git a/python/paddle/fluid/tests/custom_op/cpp_extension.py b/python/paddle/fluid/tests/custom_op/cpp_extension.py deleted file mode 100644 index e1243f00185..00000000000 --- a/python/paddle/fluid/tests/custom_op/cpp_extension.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import six -import sys -import copy -import setuptools -from setuptools.command.build_ext import build_ext - -from extension_utils import find_cuda_home, normalize_extension_kwargs, add_compile_flag -from extension_utils import is_cuda_file, prepare_unix_cflags, add_std_without_repeat, get_build_directory - -IS_WINDOWS = os.name == 'nt' -CUDA_HOME = find_cuda_home() - - -def CppExtension(name, sources, *args, **kwargs): - """ - Returns setuptools.CppExtension instance for setup.py to make it easy - to specify compile flags while build C++ custommed op kernel. - """ - kwargs = normalize_extension_kwargs(kwargs, use_cuda=False) - - return setuptools.Extension(name, sources, *args, **kwargs) - - -def CUDAExtension(name, sources, *args, **kwargs): - """ - Returns setuptools.CppExtension instance for setup.py to make it easy - to specify compile flags while build CUDA custommed op kernel. - """ - kwargs = normalize_extension_kwargs(kwargs, use_cuda=True) - - return setuptools.Extension(name, sources, *args, **kwargs) - - -class BuildExtension(build_ext, object): - """ - For setuptools.cmd_class. - """ - - @classmethod - def with_options(cls, **options): - ''' - Returns a BuildExtension subclass that support to specific use-defined options. - ''' - - class cls_with_options(cls): - def __init__(self, *args, **kwargs): - kwargs.update(options) - cls.__init__(self, *args, **kwargs) - - return cls_with_options - - def __init__(self, *args, **kwargs): - super(BuildExtension, self).__init__(*args, **kwargs) - self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", False) - - def initialize_options(self): - super(BuildExtension, self).initialize_options() - # update options here - # FIXME(Aurelius84): for unittest - self.build_lib = './' - - def finalize_options(self): - super(BuildExtension, self).finalize_options() - - def build_extensions(self): - self._check_abi() - for extension in self.extensions: - # check settings of compiler - if isinstance(extension.extra_compile_args, dict): - for compiler in ['cxx', 'nvcc']: - if compiler not in extension.extra_compile_args: - extension.extra_compile_args[compiler] = [] - # add determine compile flags - add_compile_flag(extension, '-std=c++11') - # add_compile_flag(extension, '-lpaddle_framework') - - # Consider .cu, .cu.cc as valid source extensions. - self.compiler.src_extensions += ['.cu', '.cu.cc'] - # Save the original _compile method for later. - if self.compiler.compiler_type == 'msvc' or IS_WINDOWS: - raise NotImplementedError("Not support on MSVC currently.") - else: - original_compile = self.compiler._compile - - def unix_custom_single_compiler(obj, src, ext, cc_args, extra_postargs, - pp_opts): - """ - Monkey patch machanism to replace inner compiler to custom complie process on Unix platform. - """ - # use abspath to ensure no warning - src = os.path.abspath(src) - cflags = copy.deepcopy(extra_postargs) - - try: - original_compiler = self.compiler.compiler_so - # ncvv compile CUDA source - if is_cuda_file(src): - assert CUDA_HOME is not None - nvcc_cmd = os.path.join(CUDA_HOME, 'bin', 'nvcc') - self.compiler.set_executable('compiler_so', nvcc_cmd) - # {'nvcc': {}, 'cxx: {}} - if isinstance(cflags, dict): - cflags = cflags['nvcc'] - else: - cflags = prepare_unix_cflags(cflags) - # cxx compile Cpp source - elif isinstance(cflags, dict): - cflags = cflags['cxx'] - - add_std_without_repeat( - cflags, self.compiler.compiler_type, use_std14=False) - original_compile(obj, src, ext, cc_args, cflags, pp_opts) - finally: - # restore original_compiler - self.compiler.compiler_so = original_compiler - - def object_filenames_with_cuda(origina_func): - """ - Decorated the function to add customized naming machanism. - """ - - def wrapper(source_filenames, strip_dir=0, output_dir=''): - try: - objects = origina_func(source_filenames, strip_dir, - output_dir) - for i, source in enumerate(source_filenames): - # modify xx.o -> xx.cu.o - if is_cuda_file(source): - old_obj = objects[i] - objects[i] = old_obj[:-1] + 'cu.o' - # ensure to use abspath - objects = [os.path.abspath(obj) for obj in objects] - finally: - self.compiler.object_filenames = origina_func - - return objects - - return wrapper - - # customized compile process - self.compiler._compile = unix_custom_single_compiler - self.compiler.object_filenames = object_filenames_with_cuda( - self.compiler.object_filenames) - - build_ext.build_extensions(self) - - def get_ext_filename(self, fullname): - # for example: custommed_extension.cpython-37m-x86_64-linux-gnu.so - ext_name = super(BuildExtension, self).get_ext_filename(fullname) - if self.no_python_abi_suffix and six.PY3: - split_str = '.' - name_items = ext_name.split(split_str) - assert len( - name_items - ) > 2, "Expected len(name_items) > 2, but received {}".format( - len(name_items)) - name_items.pop(-2) - # custommed_extension.so - ext_name = split_str.join(name_items) - - return ext_name - - def _check_abi(self): - pass diff --git a/python/paddle/fluid/tests/custom_op/extension_utils.py b/python/paddle/fluid/tests/custom_op/extension_utils.py deleted file mode 100644 index c2683140e8e..00000000000 --- a/python/paddle/fluid/tests/custom_op/extension_utils.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import six -import sys -import copy -import glob -import warnings -import subprocess - -import paddle - -IS_WINDOWS = os.name == 'nt' -# TODO(Aurelius84): Need check version of gcc and g++ is same. -# After CI path is fixed, we will modify into cc. -NVCC_COMPILE_FLAGS = [ - '-ccbin', 'gcc', '-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU', - '-DPADDLE_USE_DSO', '-Xcompiler', '-fPIC', '-w', '--expt-relaxed-constexpr', - '-O3', '-DNVCC' -] - - -def prepare_unix_cflags(cflags): - """ - Prepare all necessary compiled flags for nvcc compiling CUDA files. - """ - cflags = NVCC_COMPILE_FLAGS + cflags + get_cuda_arch_flags(cflags) - - return cflags - - -def add_std_without_repeat(cflags, compiler_type, use_std14=False): - """ - Append -std=c++11/14 in cflags if without specific it before. - """ - cpp_flag_prefix = '/std:' if compiler_type == 'msvc' else '-std=' - if not any(cpp_flag_prefix in flag for flag in cflags): - suffix = 'c++14' if use_std14 else 'c++11' - cpp_flag = cpp_flag_prefix + suffix - cflags.append(cpp_flag) - - -def get_cuda_arch_flags(cflags): - """ - For an arch, say "6.1", the added compile flag will be - ``-gencode=arch=compute_61,code=sm_61``. - For an added "+PTX", an additional - ``-gencode=arch=compute_xx,code=compute_xx`` is added. - """ - # TODO(Aurelius84): - return [] - - -def normalize_extension_kwargs(kwargs, use_cuda=False): - """ - Normalize include_dirs, library_dir and other attributes in kwargs. - """ - assert isinstance(kwargs, dict) - # append necessary include dir path of paddle - include_dirs = kwargs.get('include_dirs', []) - include_dirs.extend(find_paddle_includes(use_cuda)) - kwargs['include_dirs'] = include_dirs - - # append necessary lib path of paddle - library_dirs = kwargs.get('library_dirs', []) - library_dirs.extend(find_paddle_libraries(use_cuda)) - kwargs['library_dirs'] = library_dirs - - # add runtime library dirs - runtime_library_dirs = kwargs.get('runtime_library_dirs', []) - runtime_library_dirs.extend(find_paddle_libraries(use_cuda)) - kwargs['runtime_library_dirs'] = runtime_library_dirs - - # append compile flags - extra_compile_args = kwargs.get('extra_compile_args', []) - extra_compile_args.extend(['-g']) - kwargs['extra_compile_args'] = extra_compile_args - - # append link flags - extra_link_args = kwargs.get('extra_link_args', []) - extra_link_args.extend(['-lpaddle_framework', '-lcudart']) - kwargs['extra_link_args'] = extra_link_args - - kwargs['language'] = 'c++' - return kwargs - - -def find_paddle_includes(use_cuda=False): - """ - Return Paddle necessary include dir path. - """ - # pythonXX/site-packages/paddle/include - paddle_include_dir = paddle.sysconfig.get_include() - third_party_dir = os.path.join(paddle_include_dir, 'third_party') - - include_dirs = [paddle_include_dir, third_party_dir] - - return include_dirs - - -def find_cuda_includes(): - - cuda_home = find_cuda_home() - if cuda_home is None: - raise ValueError( - "Not found CUDA runtime, please use `export CUDA_HOME=XXX` to specific it." - ) - - return [os.path.join(cuda_home, 'lib64')] - - -def find_cuda_home(): - """ - Use heuristic method to find cuda path - """ - # step 1. find in $CUDA_HOME or $CUDA_PATH - cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') - - # step 2. find path by `which nvcc` - if cuda_home is None: - which_cmd = 'where' if IS_WINDOWS else 'which' - try: - with open(os.devnull, 'w') as devnull: - nvcc_path = subprocess.check_output( - [which_cmd, 'nvcc'], stderr=devnull) - if six.PY3: - nvcc_path = nvcc_path.decode() - nvcc_path = nvcc_path.rstrip('\r\n') - # for example: /usr/local/cuda/bin/nvcc - cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) - except: - if IS_WINDOWS: - # search from default NVIDIA GPU path - candidate_paths = glob.glob( - 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') - if len(candidate_paths) > 0: - cuda_home = candidate_paths[0] - else: - cuda_home = "/usr/local/cuda" - # step 3. check whether path is valid - if not os.path.exists(cuda_home) and paddle.is_compiled_with_cuda(): - cuda_home = None - warnings.warn( - "Not found CUDA runtime, please use `export CUDA_HOME= XXX` to specific it." - ) - - return cuda_home - - -def find_paddle_libraries(use_cuda=False): - """ - Return Paddle necessary library dir path. - """ - # pythonXX/site-packages/paddle/libs - paddle_lib_dirs = [paddle.sysconfig.get_lib()] - if use_cuda: - cuda_dirs = find_cuda_includes() - paddle_lib_dirs.extend(cuda_dirs) - return paddle_lib_dirs - - -def append_necessary_flags(extra_compile_args, use_cuda=False): - """ - Add necessary compile flags for gcc/nvcc compiler. - """ - necessary_flags = ['-std=c++11'] - - if use_cuda: - necessary_flags.extend(NVCC_COMPILE_FLAGS) - - -def add_compile_flag(extension, flag): - extra_compile_args = copy.deepcopy(extension.extra_compile_args) - if isinstance(extra_compile_args, dict): - for args in extra_compile_args.values(): - args.append(flag) - else: - extra_compile_args.append(flag) - - extension.extra_compile_args = extra_compile_args - - -def is_cuda_file(path): - - cuda_suffix = set(['.cu']) - items = os.path.splitext(path) - assert len(items) > 1 - return items[-1] in cuda_suffix - - -def get_build_directory(name): - """ - Return paddle extension root directory, default specific by `PADDLE_EXTENSION_DIR` - """ - root_extensions_directory = os.envsiron.get('PADDLE_EXTENSION_DIR') - if root_extensions_directory is None: - # TODO(Aurelius84): consider wind32/macOs - here = os.path.abspath(__file__) - root_extensions_directory = os.path.realpath(here) - warnings.warn( - "$PADDLE_EXTENSION_DIR is not set, using path: {} by default." - .format(root_extensions_directory)) - - return root_extensions_directory diff --git a/python/paddle/fluid/tests/custom_op/setup_build.py b/python/paddle/fluid/tests/custom_op/setup_build.py new file mode 100644 index 00000000000..01da3bba710 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/setup_build.py @@ -0,0 +1,33 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from utils import paddle_includes, extra_compile_args +from paddle.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension, setup + +file_dir = os.path.dirname(os.path.abspath(__file__)) + +setup( + name='relu2_op_shared', + ext_modules=[ + CUDAExtension( + name='librelu2_op_from_setup', + sources=['relu_op.cc', 'relu_op.cu'], + include_dirs=paddle_includes, + extra_compile_args=extra_compile_args) + ], + cmdclass={ + 'build_ext': BuildExtension.with_options( + no_python_abi_suffix=True, output_dir=file_dir) # for unittest + }) diff --git a/python/paddle/fluid/tests/custom_op/setup_install.py b/python/paddle/fluid/tests/custom_op/setup_install.py new file mode 100644 index 00000000000..286f3a7044c --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/setup_install.py @@ -0,0 +1,27 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from utils import paddle_includes, extra_compile_args +from paddle.utils.cpp_extension import CUDAExtension, setup + +setup( + name='custom_relu2', + ext_modules=[ + CUDAExtension( + name='custom_relu2', + sources=['relu_op.cc', 'relu_op.cu'], + include_dirs=paddle_includes, + extra_compile_args=extra_compile_args) + ]) diff --git a/python/paddle/fluid/tests/custom_op/test_custom_op_with_setup.py b/python/paddle/fluid/tests/custom_op/test_custom_op_with_setup.py index be9442cc71a..1e87161c846 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_op_with_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_op_with_setup.py @@ -14,8 +14,8 @@ import os import unittest - from test_custom_op import CustomOpTest, load_so +from paddle.utils.cpp_extension.extension_utils import run_cmd def compile_so(): @@ -24,7 +24,8 @@ def compile_so(): """ # build .so with setup.py file_dir = os.path.dirname(os.path.abspath(__file__)) - os.system('cd {} && python setup.py build'.format(file_dir)) + cmd = 'cd {} && python setup_build.py build'.format(file_dir) + run_cmd(cmd) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/custom_op/test_jit_load.py b/python/paddle/fluid/tests/custom_op/test_jit_load.py new file mode 100644 index 00000000000..47b45169cb8 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_jit_load.py @@ -0,0 +1,42 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import paddle +import numpy as np +from paddle.utils.cpp_extension import load +from utils import paddle_includes, extra_compile_args + +# Compile and load custom op Just-In-Time. +relu2 = load( + name='relu2', + sources=['relu_op.cc', 'relu_op.cu'], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cflags=extra_compile_args) # add for Coverage CI + + +class TestJITLoad(unittest.TestCase): + def test_api(self): + raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32') + x = paddle.to_tensor(raw_data, dtype='float32') + # use custom api + out = relu2(x) + self.assertTrue( + np.array_equal(out.numpy(), + np.array([[0, 1, 0], [1, 0, 0]]).astype('float32'))) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/test_setup_install.py b/python/paddle/fluid/tests/custom_op/test_setup_install.py new file mode 100644 index 00000000000..3ebf9b8b032 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_setup_install.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import site +import unittest +import paddle +import subprocess +import numpy as np +from paddle.utils.cpp_extension.extension_utils import run_cmd + + +class TestSetUpInstall(unittest.TestCase): + def setUp(self): + cur_dir = os.path.dirname(os.path.abspath(__file__)) + # compile, install the custom op egg into site-packages under background + cmd = 'cd {} && python setup_install.py install'.format(cur_dir) + run_cmd(cmd) + + # NOTE(Aurelius84): Normally, it's no need to add following codes for users. + # But we simulate to pip install in current process, so interpreter don't snap + # sys.path has been updated. So we update it manually. + + # See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3 + site_dir = site.getsitepackages()[0] + custom_egg_path = [ + x for x in os.listdir(site_dir) if 'custom_relu2' in x + ] + assert len(custom_egg_path) == 1 + sys.path.append(os.path.join(site_dir, custom_egg_path[0])) + + def test_api(self): + # usage: import the package directly + import custom_relu2 + + raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32') + x = paddle.to_tensor(raw_data, dtype='float32') + # use custom api + out = custom_relu2.relu2(x) + + self.assertTrue( + np.array_equal(out.numpy(), + np.array([[0, 1, 0], [1, 0, 0]]).astype('float32'))) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/custom_op/setup.py b/python/paddle/fluid/tests/custom_op/utils.py similarity index 70% rename from python/paddle/fluid/tests/custom_op/setup.py rename to python/paddle/fluid/tests/custom_op/utils.py index b61b745508d..f293c751942 100644 --- a/python/paddle/fluid/tests/custom_op/setup.py +++ b/python/paddle/fluid/tests/custom_op/utils.py @@ -1,24 +1,22 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import six from distutils.sysconfig import get_python_lib -from setuptools import setup -from cpp_extension import CppExtension, CUDAExtension, BuildExtension, IS_WINDOWS -from setuptools import Extension +from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS -file_dir = os.path.dirname(os.path.abspath(__file__)) site_packages_path = get_python_lib() # Note(Aurelius84): We use `add_test` in Cmake to config how to run unittest in CI. # `PYTHONPATH` will be set as `build/python/paddle` that will make no way to find @@ -33,17 +31,3 @@ paddle_includes = [ # and will lead to ABI problem on Coverage CI. We will handle it in next PR. extra_compile_args = ['-DPADDLE_WITH_MKLDNN' ] if six.PY2 and not IS_WINDOWS else [] - -setup( - name='relu_op_shared', - ext_modules=[ - CUDAExtension( - name='librelu2_op_from_setup', - sources=['relu_op.cc', 'relu_op.cu'], - include_dirs=paddle_includes, - extra_compile_args=extra_compile_args, - output_dir=file_dir) - ], - cmdclass={ - 'build_ext': BuildExtension.with_options(no_python_abi_suffix=True) - }) diff --git a/python/paddle/utils/__init__.py b/python/paddle/utils/__init__.py index faf0fd4984d..1db1b66426c 100644 --- a/python/paddle/utils/__init__.py +++ b/python/paddle/utils/__init__.py @@ -25,6 +25,8 @@ from ..fluid.framework import require_version from . import download +from . import cpp_extension + __all__ = ['dump_config', 'deprecated', 'download', 'run_check'] #TODO: define new api under this directory diff --git a/python/paddle/utils/cpp_extension/__init__.py b/python/paddle/utils/cpp_extension/__init__.py new file mode 100644 index 00000000000..04e32842b0e --- /dev/null +++ b/python/paddle/utils/cpp_extension/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .cpp_extension import CUDAExtension +from .cpp_extension import CppExtension +from .cpp_extension import BuildExtension +from .cpp_extension import load, setup + +from .extension_utils import parse_op_info +from .extension_utils import get_build_directory + +from . import cpp_extension +from . import extension_utils + +__all__ = [ + 'CppExtension', 'CUDAExtension', 'BuildExtension', 'load', 'setup', + 'get_build_directory' +] diff --git a/python/paddle/utils/cpp_extension/cpp_extension.py b/python/paddle/utils/cpp_extension/cpp_extension.py new file mode 100644 index 00000000000..8cd48100c99 --- /dev/null +++ b/python/paddle/utils/cpp_extension/cpp_extension.py @@ -0,0 +1,339 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import six +import sys +import textwrap +import copy + +import setuptools +from setuptools.command.easy_install import easy_install +from setuptools.command.build_ext import build_ext + +from .extension_utils import find_cuda_home, normalize_extension_kwargs, add_compile_flag, bootstrap_context +from .extension_utils import is_cuda_file, prepare_unix_cflags, add_std_without_repeat, get_build_directory +from .extension_utils import _import_module_from_library, CustomOpInfo, _write_setup_file, _jit_compile, parse_op_name_from + +IS_WINDOWS = os.name == 'nt' +CUDA_HOME = find_cuda_home() + + +def setup(**attr): + """ + Wrapper setuptools.setup function to valid `build_ext` command and + implement paddle api code injection by switching `write_stub` + function in bdist_egg with `custom_write_stub`. + """ + cmdclass = attr.get('cmdclass', {}) + assert isinstance(cmdclass, dict) + # if not specific cmdclass in setup, add it automaticaly. + if 'build_ext' not in cmdclass: + cmdclass['build_ext'] = BuildExtension.with_options( + no_python_abi_suffix=True) + attr['cmdclass'] = cmdclass + # elif not isinstance(cmdclass['build_ext'], BuildExtension): + # raise ValueError( + # "Require paddle.utils.cpp_extension.BuildExtension in setup(cmdclass={'build_ext: ...'}), but received {}". + # format(type(cmdclass['build_ext']))) + + # Add rename .so hook in easy_install + assert 'easy_install' not in cmdclass + cmdclass['easy_install'] = EasyInstallCommand + + # Always set zip_safe=False to make compatible in PY2 and PY3 + # See http://peak.telecommunity.com/DevCenter/setuptools#setting-the-zip-safe-flag + attr['zip_safe'] = False + + # switch `write_stub` to inject paddle api in .egg + with bootstrap_context(): + setuptools.setup(**attr) + + +def CppExtension(name, sources, *args, **kwargs): + """ + Returns setuptools.CppExtension instance for setup.py to make it easy + to specify compile flags while building C++ custommed op kernel. + + Args: + name(str): The extension name used as generated shared library name + sources(list[str]): The C++/CUDA source file names + args(list[options]): list of config options used to compile shared library + kwargs(dict[option]): dict of config options used to compile shared library + + Returns: + Extension: An instance of setuptools.Extension + """ + kwargs = normalize_extension_kwargs(kwargs, use_cuda=False) + + return setuptools.Extension(name, sources, *args, **kwargs) + + +def CUDAExtension(name, sources, *args, **kwargs): + """ + Returns setuptools.CppExtension instance for setup.py to make it easy + to specify compile flags while build CUDA custommed op kernel. + + Args: + name(str): The extension name used as generated shared library name + sources(list[str]): The C++/CUDA source file names + args(list[options]): list of config options used to compile shared library + kwargs(dict[option]): dict of config options used to compile shared library + + Returns: + Extension: An instance of setuptools.Extension + """ + kwargs = normalize_extension_kwargs(kwargs, use_cuda=True) + + return setuptools.Extension(name, sources, *args, **kwargs) + + +class BuildExtension(build_ext, object): + """ + Inherited from setuptools.command.build_ext to customize how to apply + compilation process with share library. + """ + + @classmethod + def with_options(cls, **options): + """ + Returns a BuildExtension subclass containing use-defined options. + """ + + class cls_with_options(cls): + def __init__(self, *args, **kwargs): + kwargs.update(options) + cls.__init__(self, *args, **kwargs) + + return cls_with_options + + def __init__(self, *args, **kwargs): + """ + Attributes is initialized with following oreder: + + 1. super(self).__init__() + 2. initialize_options(self) + 3. the reset of current __init__() + 4. finalize_options(self) + + So, it is recommended to set attribute value in `finalize_options`. + """ + super(BuildExtension, self).__init__(*args, **kwargs) + self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", True) + self.output_dir = kwargs.get("output_dir", None) + + def initialize_options(self): + super(BuildExtension, self).initialize_options() + + def finalize_options(self): + super(BuildExtension, self).finalize_options() + # NOTE(Aurelius84): Set location of compiled shared library. + # Carefully to modify this because `setup.py build/install` + # and `load` interface rely on this attribute. + if self.output_dir is not None: + self.build_lib = self.output_dir + + def build_extensions(self): + self._check_abi() + for extension in self.extensions: + # check settings of compiler + if isinstance(extension.extra_compile_args, dict): + for compiler in ['cxx', 'nvcc']: + if compiler not in extension.extra_compile_args: + extension.extra_compile_args[compiler] = [] + # add determine compile flags + add_compile_flag(extension, '-std=c++11') + + # Consider .cu, .cu.cc as valid source extensions. + self.compiler.src_extensions += ['.cu', '.cu.cc'] + # Save the original _compile method for later. + if self.compiler.compiler_type == 'msvc' or IS_WINDOWS: + raise NotImplementedError("Not support on MSVC currently.") + else: + original_compile = self.compiler._compile + + def unix_custom_single_compiler(obj, src, ext, cc_args, extra_postargs, + pp_opts): + """ + Monkey patch machanism to replace inner compiler to custom complie process on Unix platform. + """ + # use abspath to ensure no warning and don't remove deecopy because modify params + # with dict type is dangerous. + src = os.path.abspath(src) + cflags = copy.deepcopy(extra_postargs) + try: + original_compiler = self.compiler.compiler_so + # ncvv compile CUDA source + if is_cuda_file(src): + assert CUDA_HOME is not None + nvcc_cmd = os.path.join(CUDA_HOME, 'bin', 'nvcc') + self.compiler.set_executable('compiler_so', nvcc_cmd) + # {'nvcc': {}, 'cxx: {}} + if isinstance(cflags, dict): + cflags = cflags['nvcc'] + else: + cflags = prepare_unix_cflags(cflags) + # cxx compile Cpp source + elif isinstance(cflags, dict): + cflags = cflags['cxx'] + + add_std_without_repeat( + cflags, self.compiler.compiler_type, use_std14=False) + original_compile(obj, src, ext, cc_args, cflags, pp_opts) + finally: + # restore original_compiler + self.compiler.compiler_so = original_compiler + + def object_filenames_with_cuda(origina_func, build_directory): + """ + Decorated the function to add customized naming machanism. + Originally, both .cc/.cu will have .o object output that will + bring file override problem. Use .cu.o as CUDA object suffix. + """ + + def wrapper(source_filenames, strip_dir=0, output_dir=''): + try: + objects = origina_func(source_filenames, strip_dir, + output_dir) + for i, source in enumerate(source_filenames): + # modify xx.o -> xx.cu.o + if is_cuda_file(source): + old_obj = objects[i] + objects[i] = old_obj[:-1] + 'cu.o' + # if user set build_directory, output objects there. + if build_directory is not None: + objects = [ + os.path.join(build_directory, os.path.basename(obj)) + for obj in objects + ] + # ensure to use abspath + objects = [os.path.abspath(obj) for obj in objects] + finally: + self.compiler.object_filenames = origina_func + + return objects + + return wrapper + + # customized compile process + self.compiler._compile = unix_custom_single_compiler + self.compiler.object_filenames = object_filenames_with_cuda( + self.compiler.object_filenames, self.build_lib) + + self._record_op_info() + build_ext.build_extensions(self) + + def get_ext_filename(self, fullname): + # for example: custommed_extension.cpython-37m-x86_64-linux-gnu.so + ext_name = super(BuildExtension, self).get_ext_filename(fullname) + if self.no_python_abi_suffix and six.PY3: + split_str = '.' + name_items = ext_name.split(split_str) + assert len( + name_items + ) > 2, "Expected len(name_items) > 2, but received {}".format( + len(name_items)) + name_items.pop(-2) + # custommed_extension.so + ext_name = split_str.join(name_items) + + return ext_name + + def _check_abi(self): + # TODO(Aurelius84): Enhance abi check + pass + + def _record_op_info(self): + """ + Record custum op inforomation. + """ + # parse op name + sources = [] + for extension in self.extensions: + sources.extend(extension.sources) + + sources = [os.path.abspath(s) for s in sources] + op_name = parse_op_name_from(sources) + + # parse shared library abs path + outputs = self.get_outputs() + assert len(outputs) == 1 + + build_directory = os.path.abspath(outputs[0]) + so_name = os.path.basename(build_directory) + CustomOpInfo.instance().add(op_name, + so_name=so_name, + build_directory=build_directory) + + +class EasyInstallCommand(easy_install, object): + """ + Extend easy_intall Command to control the behavior of naming shared library + file. + + NOTE(Aurelius84): This is a hook subclass inherited Command used to rename shared + library file after extracting egg-info into site-packages. + """ + + def __init__(self, *args, **kwargs): + super(EasyInstallCommand, self).__init__(*args, **kwargs) + + # NOTE(Aurelius84): Add args and kwargs to make compatible with PY2/PY3 + def run(self, *args, **kwargs): + super(EasyInstallCommand, self).run(*args, **kwargs) + # NOTE: To avoid failing import .so file instead of + # python file because they have same name, we rename + # .so shared library to another name. + for egg_file in self.outputs: + filename, ext = os.path.splitext(egg_file) + if ext == '.so': + new_so_path = filename + "_pd_" + ext + if not os.path.exists(new_so_path): + os.rename(r'%s' % egg_file, r'%s' % new_so_path) + assert os.path.exists(new_so_path) + + +def load(name, + sources, + extra_cflags=None, + extra_cuda_cflags=None, + extra_ldflags=None, + extra_include_paths=None, + build_directory=None, + verbose=False): + + # TODO(Aurelius84): It just contains main logic codes, more details + # will be added later. + if build_directory is None: + build_directory = get_build_directory() + # ensure to use abs path + build_directory = os.path.abspath(build_directory) + file_path = os.path.join(build_directory, "setup.py") + + sources = [os.path.abspath(source) for source in sources] + + # TODO(Aurelius84): split cflags and cuda_flags + if extra_cflags is None: extra_cflags = [] + if extra_cuda_cflags is None: extra_cuda_cflags = [] + compile_flags = extra_cflags + extra_cuda_cflags + + # write setup.py file and compile it + _write_setup_file(name, sources, file_path, extra_include_paths, + compile_flags, extra_ldflags) + _jit_compile(file_path) + + # import as callable python api + custom_op_api = _import_module_from_library(name, build_directory) + + return custom_op_api diff --git a/python/paddle/utils/cpp_extension/extension_utils.py b/python/paddle/utils/cpp_extension/extension_utils.py new file mode 100644 index 00000000000..14aaddfd6b5 --- /dev/null +++ b/python/paddle/utils/cpp_extension/extension_utils.py @@ -0,0 +1,543 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import six +import sys +import copy +import glob +import collections +import textwrap +import platform +import warnings +import subprocess + +from contextlib import contextmanager +from setuptools.command import bdist_egg + +from .. import load_op_library +from ...fluid import core +from ...sysconfig import get_include, get_lib + +OS_NAME = platform.system() +IS_WINDOWS = OS_NAME == 'Windows' +NVCC_COMPILE_FLAGS = [ + '-ccbin', 'cc', '-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU', '-DPADDLE_USE_DSO', + '-Xcompiler', '-fPIC', '-w', '--expt-relaxed-constexpr', '-O3', '-DNVCC' +] + + +@contextmanager +def bootstrap_context(): + """ + Context to manage how to write `__bootstrap__` code in .egg + """ + origin_write_stub = bdist_egg.write_stub + bdist_egg.write_stub = custom_write_stub + yield + + bdist_egg.write_stub = origin_write_stub + + +def custom_write_stub(resource, pyfile): + """ + Customized write_stub function to allow us to inject generated python + api codes into egg python file. + """ + _stub_template = textwrap.dedent(""" + import os + import sys + import paddle + + def inject_ext_module(module_name, api_name): + if module_name in sys.modules: + return sys.modules[module_name] + + new_module = imp.new_module(module_name) + setattr(new_module, api_name, eval(api_name)) + + return new_module + + def __bootstrap__(): + cur_dir = os.path.dirname(os.path.abspath(__file__)) + so_path = os.path.join(cur_dir, "{resource}") + + assert os.path.exists(so_path) + + # load custom op shared library with abs path + new_custom_op = paddle.utils.load_op_library(so_path) + assert len(new_custom_op) == 1 + m = inject_ext_module(__name__, new_custom_op[0]) + + __bootstrap__() + + {custom_api} + """).lstrip() + + # Parse registerring op information + _, op_info = CustomOpInfo.instance().last() + so_path = op_info.build_directory + + new_custom_op = load_op_library(so_path) + assert len(new_custom_op) == 1 + + # NOTE: To avoid importing .so file instead of python file because they have same name, + # we rename .so shared library to another name, see EasyInstallCommand. + filename, ext = os.path.splitext(resource) + resource = filename + "_pd_" + ext + + with open(pyfile, 'w') as f: + f.write( + _stub_template.format( + resource=resource, + custom_api=_custom_api_content(new_custom_op[0]))) + + +OpInfo = collections.namedtuple('OpInfo', + ['so_name', 'build_directory', 'out_dtypes']) + + +class CustomOpInfo: + """ + A global Singleton map to record all compiled custom ops information. + """ + + @classmethod + def instance(cls): + if not hasattr(cls, '_instance'): + cls._instance = cls() + return cls._instance + + def __init__(self): + assert not hasattr( + self.__class__, + '_instance'), 'Please use `instance()` to get CustomOpInfo object!' + # NOTE(Aurelius84): Use OrderedDict to save more order information + self.op_info_map = collections.OrderedDict() + + def add(self, op_name, so_name, build_directory=None, out_dtypes=None): + self.op_info_map[op_name] = OpInfo(so_name, build_directory, out_dtypes) + + def last(self): + """ + Return the lastest insert custom op info. + """ + assert len(self.op_info_map) > 0 + return next(reversed(self.op_info_map.items())) + + +def prepare_unix_cflags(cflags): + """ + Prepare all necessary compiled flags for nvcc compiling CUDA files. + """ + cflags = NVCC_COMPILE_FLAGS + cflags + get_cuda_arch_flags(cflags) + + return cflags + + +def add_std_without_repeat(cflags, compiler_type, use_std14=False): + """ + Append -std=c++11/14 in cflags if without specific it before. + """ + cpp_flag_prefix = '/std:' if compiler_type == 'msvc' else '-std=' + if not any(cpp_flag_prefix in flag for flag in cflags): + suffix = 'c++14' if use_std14 else 'c++11' + cpp_flag = cpp_flag_prefix + suffix + cflags.append(cpp_flag) + + +def get_cuda_arch_flags(cflags): + """ + For an arch, say "6.1", the added compile flag will be + ``-gencode=arch=compute_61,code=sm_61``. + For an added "+PTX", an additional + ``-gencode=arch=compute_xx,code=compute_xx`` is added. + """ + # TODO(Aurelius84): + return [] + + +def normalize_extension_kwargs(kwargs, use_cuda=False): + """ + Normalize include_dirs, library_dir and other attributes in kwargs. + """ + assert isinstance(kwargs, dict) + # append necessary include dir path of paddle + include_dirs = kwargs.get('include_dirs', []) + include_dirs.extend(find_paddle_includes(use_cuda)) + kwargs['include_dirs'] = include_dirs + + # append necessary lib path of paddle + library_dirs = kwargs.get('library_dirs', []) + library_dirs.extend(find_paddle_libraries(use_cuda)) + kwargs['library_dirs'] = library_dirs + + # add runtime library dirs + runtime_library_dirs = kwargs.get('runtime_library_dirs', []) + runtime_library_dirs.extend(find_paddle_libraries(use_cuda)) + kwargs['runtime_library_dirs'] = runtime_library_dirs + + # append compile flags + extra_compile_args = kwargs.get('extra_compile_args', []) + extra_compile_args.extend(['-g']) + kwargs['extra_compile_args'] = extra_compile_args + + # append link flags + extra_link_args = kwargs.get('extra_link_args', []) + extra_link_args.extend(['-lpaddle_framework', '-lcudart']) + kwargs['extra_link_args'] = extra_link_args + + kwargs['language'] = 'c++' + return kwargs + + +def find_paddle_includes(use_cuda=False): + """ + Return Paddle necessary include dir path. + """ + # pythonXX/site-packages/paddle/include + paddle_include_dir = get_include() + third_party_dir = os.path.join(paddle_include_dir, 'third_party') + + include_dirs = [paddle_include_dir, third_party_dir] + + return include_dirs + + +def find_cuda_includes(): + + cuda_home = find_cuda_home() + if cuda_home is None: + raise ValueError( + "Not found CUDA runtime, please use `export CUDA_HOME=XXX` to specific it." + ) + + return [os.path.join(cuda_home, 'lib64')] + + +def find_cuda_home(): + """ + Use heuristic method to find cuda path + """ + # step 1. find in $CUDA_HOME or $CUDA_PATH + cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + + # step 2. find path by `which nvcc` + if cuda_home is None: + which_cmd = 'where' if IS_WINDOWS else 'which' + try: + with open(os.devnull, 'w') as devnull: + nvcc_path = subprocess.check_output( + [which_cmd, 'nvcc'], stderr=devnull) + if six.PY3: + nvcc_path = nvcc_path.decode() + nvcc_path = nvcc_path.rstrip('\r\n') + # for example: /usr/local/cuda/bin/nvcc + cuda_home = os.path.dirname(os.path.dirname(nvcc_path)) + except: + if IS_WINDOWS: + # search from default NVIDIA GPU path + candidate_paths = glob.glob( + 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') + if len(candidate_paths) > 0: + cuda_home = candidate_paths[0] + else: + cuda_home = "/usr/local/cuda" + # step 3. check whether path is valid + if not os.path.exists(cuda_home) and core.is_compiled_with_cuda(): + cuda_home = None + warnings.warn( + "Not found CUDA runtime, please use `export CUDA_HOME= XXX` to specific it." + ) + + return cuda_home + + +def find_paddle_libraries(use_cuda=False): + """ + Return Paddle necessary library dir path. + """ + # pythonXX/site-packages/paddle/libs + paddle_lib_dirs = [get_lib()] + if use_cuda: + cuda_dirs = find_cuda_includes() + paddle_lib_dirs.extend(cuda_dirs) + return paddle_lib_dirs + + +def append_necessary_flags(extra_compile_args, use_cuda=False): + """ + Add necessary compile flags for gcc/nvcc compiler. + """ + necessary_flags = ['-std=c++11'] + + if use_cuda: + necessary_flags.extend(NVCC_COMPILE_FLAGS) + + +def add_compile_flag(extension, flag): + extra_compile_args = copy.deepcopy(extension.extra_compile_args) + if isinstance(extra_compile_args, dict): + for args in extra_compile_args.values(): + args.append(flag) + else: + extra_compile_args.append(flag) + + extension.extra_compile_args = extra_compile_args + + +def is_cuda_file(path): + + cuda_suffix = set(['.cu']) + items = os.path.splitext(path) + assert len(items) > 1 + return items[-1] in cuda_suffix + + +def get_build_directory(): + """ + Return paddle extension root directory, default specific by `PADDLE_EXTENSION_DIR` + """ + root_extensions_directory = os.environ.get('PADDLE_EXTENSION_DIR') + if root_extensions_directory is None: + dir_name = "paddle_extensions" + if OS_NAME == 'Linux': + root_extensions_directory = os.path.join( + os.path.expanduser('~/.cache'), dir_name) + else: + # TODO(Aurelius84): consider wind32/macOs + raise NotImplementedError("Only support Linux now.") + + warnings.warn( + "$PADDLE_EXTENSION_DIR is not set, using path: {} by default.". + format(root_extensions_directory)) + + if not os.path.exists(root_extensions_directory): + os.makedirs(root_extensions_directory) + + return root_extensions_directory + + +def parse_op_info(op_name): + """ + Parse input names and outpus detail information from registered custom op + from OpInfoMap. + """ + from paddle.fluid.framework import OpProtoHolder + if op_name not in OpProtoHolder.instance().op_proto_map: + raise ValueError( + "Please load {} shared library file firstly by `paddle.utils.load_op_library(...)`". + format(op_name)) + op_proto = OpProtoHolder.instance().get_op_proto(op_name) + + in_names = [x.name for x in op_proto.inputs] + assert len(op_proto.outputs) == 1 + out_name = op_proto.outputs[0].name + + # TODO(Aurelius84): parse necessary out_dtype of custom op + out_infos = {out_name: ['float32']} + return in_names, out_infos + + +def _import_module_from_library(name, build_directory): + """ + Load .so shared library and import it as callable python module. + """ + ext_path = os.path.join(build_directory, name + '.so') + if not os.path.exists(ext_path): + raise FileNotFoundError("Extension path: {} does not exist.".format( + ext_path)) + + # load custom op_info and kernels from .so shared library + op_names = load_op_library(ext_path) + assert len(op_names) == 1 + + # generate Python api in ext_path + return _generate_python_module(op_names[0], build_directory) + + +def _generate_python_module(op_name, build_directory): + """ + Automatically generate python file to allow import or load into as module + """ + api_file = os.path.join(build_directory, op_name + '.py') + + # write into .py file + api_content = _custom_api_content(op_name) + with open(api_file, 'w') as f: + f.write(api_content) + + # load module + custom_api = _load_module_from_file(op_name, api_file) + return custom_api + + +def _custom_api_content(op_name): + params_str, ins_str = _get_api_inputs_str(op_name) + + API_TEMPLATE = textwrap.dedent(""" + from paddle.fluid.layer_helper import LayerHelper + from paddle.utils.cpp_extension import parse_op_info + + _, _out_infos = parse_op_info('{op_name}') + + def {op_name}({inputs}): + helper = LayerHelper("{op_name}", **locals()) + + # prepare inputs and output + ins = {ins} + outs = {{}} + for out_name in _out_infos: + outs[out_name] = [helper.create_variable(dtype=dtype) for dtype in _out_infos[out_name]] + + helper.append_op(type="{op_name}", inputs=ins, outputs=outs) + + res = list(outs.values())[0] + if len(res) == 1: + return res[0] + else: + return res + """).lstrip() + + # generate python api file + api_content = API_TEMPLATE.format( + op_name=op_name, inputs=params_str, ins=ins_str) + + return api_content + + +def _load_module_from_file(op_name, api_file_path): + """ + Load module from python file. + """ + if not os.path.exists(api_file_path): + raise FileNotFoundError("File : {} does not exist.".format( + api_file_path)) + + # Unique readable module name to place custom api. + ext_name = "_paddle_cpp_extension_" + if six.PY2: + import imp + module = imp.load_source(ext_name, api_file_path) + else: + from importlib import machinery + loader = machinery.SourceFileLoader(ext_name, api_file_path) + module = loader.load_module() + + assert hasattr(module, op_name) + return getattr(module, op_name) + + +def _get_api_inputs_str(op_name): + """ + Returns string of api parameters and inputs dict. + """ + in_names, _ = parse_op_info(op_name) + # e.g: x, y, z + params_str = ','.join([p.lower() for p in in_names]) + # e.g: {'X': x, 'Y': y, 'Z': z} + ins_str = "{%s}" % ','.join( + ["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names]) + return params_str, ins_str + + +def _write_setup_file(name, sources, file_path, include_dirs, compile_flags, + link_args): + """ + Automatically generate setup.py and write it into build directory. + """ + template = textwrap.dedent(""" + import os + from paddle.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension, setup + from paddle.utils.cpp_extension import get_build_directory + setup( + name='{name}', + ext_modules=[ + {prefix}Extension( + name='{name}', + sources={sources}, + include_dirs={include_dirs}, + extra_compile_args={extra_compile_args}, + extra_link_args={extra_link_args})], + cmdclass={{"build_ext" : BuildExtension.with_options( + output_dir=get_build_directory(), + no_python_abi_suffix=True) + }})""").lstrip() + + with_cuda = False + if any([is_cuda_file(source) for source in sources]): + with_cuda = True + + content = template.format( + name=name, + prefix='CUDA' if with_cuda else 'Cpp', + sources=list2str(sources), + include_dirs=list2str(include_dirs), + extra_compile_args=list2str(compile_flags), + extra_link_args=list2str(link_args)) + with open(file_path, 'w') as f: + f.write(content) + + +def list2str(args): + """ + Convert list[str] into string. For example: [x, y] -> "['x', 'y']" + """ + if args is None: return '[]' + assert isinstance(args, (list, tuple)) + args = ["'{}'".format(arg) for arg in args] + return '[' + ','.join(args) + ']' + + +def _jit_compile(file_path): + """ + Build shared library in subprocess + """ + ext_dir = os.path.dirname(file_path) + setup_file = os.path.basename(file_path) + compile_cmd = 'cd {} && python {} build'.format(ext_dir, setup_file) + run_cmd(compile_cmd) + + +def parse_op_name_from(sources): + """ + Parse registerring custom op name from sources. + """ + + def regex(content): + pattern = re.compile(r'REGISTER_OPERATOR\(([^,]+),') + + content = re.sub(r'\s|\t|\n', '', content) + op_name = pattern.findall(content) + op_name = set([re.sub('_grad', '', name) for name in op_name]) + + return op_name + + op_names = set() + for source in sources: + with open(source, 'r') as f: + content = f.read() + op_names |= regex(content) + + # TODO(Aurelius84): Support register more customs op at once + assert len(op_names) == 1 + return list(op_names)[0] + + +def run_cmd(command, wait=True): + """ + Execute command with subprocess. + """ + return subprocess.check_call(command, shell=True) diff --git a/python/setup.py.in b/python/setup.py.in index f8f941ff935..55fdbaff264 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -139,6 +139,7 @@ write_distributed_training_mode_py(filename='@PADDLE_BINARY_DIR@/python/paddle/f packages=['paddle', 'paddle.libs', 'paddle.utils', + 'paddle.utils.cpp_extension', 'paddle.dataset', 'paddle.reader', 'paddle.distributed', -- GitLab