未验证 提交 e49d0746 编写于 作者: A Aurelius84 提交者: GitHub

[CustomOp] Support install as Package and Add load interface (#30798)

* support setup.py to compile custom op

* move file into paddle.utils.cpp_extension

* support python setup.py install

* refine code style

* Enrich code and add unittest

* Polish code and api doc

* fix cpp_extension not include in package

* fix relative import

* fix os.makedirs exist_ok param compatibility PY2

* add compile flags in test_jit_load
上级 2cb55eff
......@@ -1991,9 +1991,13 @@ class OpProtoHolder(object):
def update_op_proto(self):
op_protos = get_all_op_protos()
custom_op_names = []
for proto in op_protos:
if proto.type not in self.op_proto_map:
self.op_proto_map[proto.type] = proto
custom_op_names.append(proto.type)
return custom_op_names
@staticmethod
def generated_op_attr_names():
......@@ -5702,6 +5706,9 @@ def load_op_library(lib_filename):
Args:
lib_filename (str): name of dynamic library.
Returns:
list[str]: new registered custom op names.
Examples:
.. code-block:: python
......@@ -5711,7 +5718,7 @@ def load_op_library(lib_filename):
"""
core.load_op_library(lib_filename)
OpProtoHolder.instance().update_op_proto()
return OpProtoHolder.instance().update_op_proto()
def switch_device(device):
......
......@@ -28,3 +28,5 @@ endforeach()
# Compiling .so will cost some time, but running process is very fast.
set_tests_properties(test_custom_op_with_setup PROPERTIES TIMEOUT 180)
set_tests_properties(test_jit_load PROPERTIES TIMEOUT 180)
set_tests_properties(test_setup_install PROPERTIES TIMEOUT 180)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from utils import paddle_includes, extra_compile_args
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension, setup
file_dir = os.path.dirname(os.path.abspath(__file__))
setup(
name='relu2_op_shared',
ext_modules=[
CUDAExtension(
name='librelu2_op_from_setup',
sources=['relu_op.cc', 'relu_op.cu'],
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args)
],
cmdclass={
'build_ext': BuildExtension.with_options(
no_python_abi_suffix=True, output_dir=file_dir) # for unittest
})
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from utils import paddle_includes, extra_compile_args
from paddle.utils.cpp_extension import CUDAExtension, setup
setup(
name='custom_relu2',
ext_modules=[
CUDAExtension(
name='custom_relu2',
sources=['relu_op.cc', 'relu_op.cu'],
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args)
])
......@@ -14,8 +14,8 @@
import os
import unittest
from test_custom_op import CustomOpTest, load_so
from paddle.utils.cpp_extension.extension_utils import run_cmd
def compile_so():
......@@ -24,7 +24,8 @@ def compile_so():
"""
# build .so with setup.py
file_dir = os.path.dirname(os.path.abspath(__file__))
os.system('cd {} && python setup.py build'.format(file_dir))
cmd = 'cd {} && python setup_build.py build'.format(file_dir)
run_cmd(cmd)
if __name__ == '__main__':
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import paddle
import numpy as np
from paddle.utils.cpp_extension import load
from utils import paddle_includes, extra_compile_args
# Compile and load custom op Just-In-Time.
relu2 = load(
name='relu2',
sources=['relu_op.cc', 'relu_op.cu'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args) # add for Coverage CI
class TestJITLoad(unittest.TestCase):
def test_api(self):
raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32')
x = paddle.to_tensor(raw_data, dtype='float32')
# use custom api
out = relu2(x)
self.assertTrue(
np.array_equal(out.numpy(),
np.array([[0, 1, 0], [1, 0, 0]]).astype('float32')))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import site
import unittest
import paddle
import subprocess
import numpy as np
from paddle.utils.cpp_extension.extension_utils import run_cmd
class TestSetUpInstall(unittest.TestCase):
def setUp(self):
cur_dir = os.path.dirname(os.path.abspath(__file__))
# compile, install the custom op egg into site-packages under background
cmd = 'cd {} && python setup_install.py install'.format(cur_dir)
run_cmd(cmd)
# NOTE(Aurelius84): Normally, it's no need to add following codes for users.
# But we simulate to pip install in current process, so interpreter don't snap
# sys.path has been updated. So we update it manually.
# See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3
site_dir = site.getsitepackages()[0]
custom_egg_path = [
x for x in os.listdir(site_dir) if 'custom_relu2' in x
]
assert len(custom_egg_path) == 1
sys.path.append(os.path.join(site_dir, custom_egg_path[0]))
def test_api(self):
# usage: import the package directly
import custom_relu2
raw_data = np.array([[-1, 1, 0], [1, -1, -1]]).astype('float32')
x = paddle.to_tensor(raw_data, dtype='float32')
# use custom api
out = custom_relu2.relu2(x)
self.assertTrue(
np.array_equal(out.numpy(),
np.array([[0, 1, 0], [1, 0, 0]]).astype('float32')))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
from distutils.sysconfig import get_python_lib
from setuptools import setup
from cpp_extension import CppExtension, CUDAExtension, BuildExtension, IS_WINDOWS
from setuptools import Extension
from paddle.utils.cpp_extension.extension_utils import IS_WINDOWS
file_dir = os.path.dirname(os.path.abspath(__file__))
site_packages_path = get_python_lib()
# Note(Aurelius84): We use `add_test` in Cmake to config how to run unittest in CI.
# `PYTHONPATH` will be set as `build/python/paddle` that will make no way to find
......@@ -33,17 +31,3 @@ paddle_includes = [
# and will lead to ABI problem on Coverage CI. We will handle it in next PR.
extra_compile_args = ['-DPADDLE_WITH_MKLDNN'
] if six.PY2 and not IS_WINDOWS else []
setup(
name='relu_op_shared',
ext_modules=[
CUDAExtension(
name='librelu2_op_from_setup',
sources=['relu_op.cc', 'relu_op.cu'],
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args,
output_dir=file_dir)
],
cmdclass={
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
})
......@@ -25,6 +25,8 @@ from ..fluid.framework import require_version
from . import download
from . import cpp_extension
__all__ = ['dump_config', 'deprecated', 'download', 'run_check']
#TODO: define new api under this directory
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cpp_extension import CUDAExtension
from .cpp_extension import CppExtension
from .cpp_extension import BuildExtension
from .cpp_extension import load, setup
from .extension_utils import parse_op_info
from .extension_utils import get_build_directory
from . import cpp_extension
from . import extension_utils
__all__ = [
'CppExtension', 'CUDAExtension', 'BuildExtension', 'load', 'setup',
'get_build_directory'
]
......@@ -15,21 +15,65 @@
import os
import six
import sys
import textwrap
import copy
import setuptools
from setuptools.command.easy_install import easy_install
from setuptools.command.build_ext import build_ext
from extension_utils import find_cuda_home, normalize_extension_kwargs, add_compile_flag
from extension_utils import is_cuda_file, prepare_unix_cflags, add_std_without_repeat, get_build_directory
from .extension_utils import find_cuda_home, normalize_extension_kwargs, add_compile_flag, bootstrap_context
from .extension_utils import is_cuda_file, prepare_unix_cflags, add_std_without_repeat, get_build_directory
from .extension_utils import _import_module_from_library, CustomOpInfo, _write_setup_file, _jit_compile, parse_op_name_from
IS_WINDOWS = os.name == 'nt'
CUDA_HOME = find_cuda_home()
def setup(**attr):
"""
Wrapper setuptools.setup function to valid `build_ext` command and
implement paddle api code injection by switching `write_stub`
function in bdist_egg with `custom_write_stub`.
"""
cmdclass = attr.get('cmdclass', {})
assert isinstance(cmdclass, dict)
# if not specific cmdclass in setup, add it automaticaly.
if 'build_ext' not in cmdclass:
cmdclass['build_ext'] = BuildExtension.with_options(
no_python_abi_suffix=True)
attr['cmdclass'] = cmdclass
# elif not isinstance(cmdclass['build_ext'], BuildExtension):
# raise ValueError(
# "Require paddle.utils.cpp_extension.BuildExtension in setup(cmdclass={'build_ext: ...'}), but received {}".
# format(type(cmdclass['build_ext'])))
# Add rename .so hook in easy_install
assert 'easy_install' not in cmdclass
cmdclass['easy_install'] = EasyInstallCommand
# Always set zip_safe=False to make compatible in PY2 and PY3
# See http://peak.telecommunity.com/DevCenter/setuptools#setting-the-zip-safe-flag
attr['zip_safe'] = False
# switch `write_stub` to inject paddle api in .egg
with bootstrap_context():
setuptools.setup(**attr)
def CppExtension(name, sources, *args, **kwargs):
"""
Returns setuptools.CppExtension instance for setup.py to make it easy
to specify compile flags while build C++ custommed op kernel.
to specify compile flags while building C++ custommed op kernel.
Args:
name(str): The extension name used as generated shared library name
sources(list[str]): The C++/CUDA source file names
args(list[options]): list of config options used to compile shared library
kwargs(dict[option]): dict of config options used to compile shared library
Returns:
Extension: An instance of setuptools.Extension
"""
kwargs = normalize_extension_kwargs(kwargs, use_cuda=False)
......@@ -40,6 +84,15 @@ def CUDAExtension(name, sources, *args, **kwargs):
"""
Returns setuptools.CppExtension instance for setup.py to make it easy
to specify compile flags while build CUDA custommed op kernel.
Args:
name(str): The extension name used as generated shared library name
sources(list[str]): The C++/CUDA source file names
args(list[options]): list of config options used to compile shared library
kwargs(dict[option]): dict of config options used to compile shared library
Returns:
Extension: An instance of setuptools.Extension
"""
kwargs = normalize_extension_kwargs(kwargs, use_cuda=True)
......@@ -48,14 +101,15 @@ def CUDAExtension(name, sources, *args, **kwargs):
class BuildExtension(build_ext, object):
"""
For setuptools.cmd_class.
Inherited from setuptools.command.build_ext to customize how to apply
compilation process with share library.
"""
@classmethod
def with_options(cls, **options):
'''
Returns a BuildExtension subclass that support to specific use-defined options.
'''
"""
Returns a BuildExtension subclass containing use-defined options.
"""
class cls_with_options(cls):
def __init__(self, *args, **kwargs):
......@@ -65,17 +119,30 @@ class BuildExtension(build_ext, object):
return cls_with_options
def __init__(self, *args, **kwargs):
"""
Attributes is initialized with following oreder:
1. super(self).__init__()
2. initialize_options(self)
3. the reset of current __init__()
4. finalize_options(self)
So, it is recommended to set attribute value in `finalize_options`.
"""
super(BuildExtension, self).__init__(*args, **kwargs)
self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", False)
self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", True)
self.output_dir = kwargs.get("output_dir", None)
def initialize_options(self):
super(BuildExtension, self).initialize_options()
# update options here
# FIXME(Aurelius84): for unittest
self.build_lib = './'
def finalize_options(self):
super(BuildExtension, self).finalize_options()
# NOTE(Aurelius84): Set location of compiled shared library.
# Carefully to modify this because `setup.py build/install`
# and `load` interface rely on this attribute.
if self.output_dir is not None:
self.build_lib = self.output_dir
def build_extensions(self):
self._check_abi()
......@@ -87,7 +154,6 @@ class BuildExtension(build_ext, object):
extension.extra_compile_args[compiler] = []
# add determine compile flags
add_compile_flag(extension, '-std=c++11')
# add_compile_flag(extension, '-lpaddle_framework')
# Consider .cu, .cu.cc as valid source extensions.
self.compiler.src_extensions += ['.cu', '.cu.cc']
......@@ -102,10 +168,10 @@ class BuildExtension(build_ext, object):
"""
Monkey patch machanism to replace inner compiler to custom complie process on Unix platform.
"""
# use abspath to ensure no warning
# use abspath to ensure no warning and don't remove deecopy because modify params
# with dict type is dangerous.
src = os.path.abspath(src)
cflags = copy.deepcopy(extra_postargs)
try:
original_compiler = self.compiler.compiler_so
# ncvv compile CUDA source
......@@ -129,9 +195,11 @@ class BuildExtension(build_ext, object):
# restore original_compiler
self.compiler.compiler_so = original_compiler
def object_filenames_with_cuda(origina_func):
def object_filenames_with_cuda(origina_func, build_directory):
"""
Decorated the function to add customized naming machanism.
Originally, both .cc/.cu will have .o object output that will
bring file override problem. Use .cu.o as CUDA object suffix.
"""
def wrapper(source_filenames, strip_dir=0, output_dir=''):
......@@ -143,6 +211,12 @@ class BuildExtension(build_ext, object):
if is_cuda_file(source):
old_obj = objects[i]
objects[i] = old_obj[:-1] + 'cu.o'
# if user set build_directory, output objects there.
if build_directory is not None:
objects = [
os.path.join(build_directory, os.path.basename(obj))
for obj in objects
]
# ensure to use abspath
objects = [os.path.abspath(obj) for obj in objects]
finally:
......@@ -155,8 +229,9 @@ class BuildExtension(build_ext, object):
# customized compile process
self.compiler._compile = unix_custom_single_compiler
self.compiler.object_filenames = object_filenames_with_cuda(
self.compiler.object_filenames)
self.compiler.object_filenames, self.build_lib)
self._record_op_info()
build_ext.build_extensions(self)
def get_ext_filename(self, fullname):
......@@ -176,4 +251,89 @@ class BuildExtension(build_ext, object):
return ext_name
def _check_abi(self):
# TODO(Aurelius84): Enhance abi check
pass
def _record_op_info(self):
"""
Record custum op inforomation.
"""
# parse op name
sources = []
for extension in self.extensions:
sources.extend(extension.sources)
sources = [os.path.abspath(s) for s in sources]
op_name = parse_op_name_from(sources)
# parse shared library abs path
outputs = self.get_outputs()
assert len(outputs) == 1
build_directory = os.path.abspath(outputs[0])
so_name = os.path.basename(build_directory)
CustomOpInfo.instance().add(op_name,
so_name=so_name,
build_directory=build_directory)
class EasyInstallCommand(easy_install, object):
"""
Extend easy_intall Command to control the behavior of naming shared library
file.
NOTE(Aurelius84): This is a hook subclass inherited Command used to rename shared
library file after extracting egg-info into site-packages.
"""
def __init__(self, *args, **kwargs):
super(EasyInstallCommand, self).__init__(*args, **kwargs)
# NOTE(Aurelius84): Add args and kwargs to make compatible with PY2/PY3
def run(self, *args, **kwargs):
super(EasyInstallCommand, self).run(*args, **kwargs)
# NOTE: To avoid failing import .so file instead of
# python file because they have same name, we rename
# .so shared library to another name.
for egg_file in self.outputs:
filename, ext = os.path.splitext(egg_file)
if ext == '.so':
new_so_path = filename + "_pd_" + ext
if not os.path.exists(new_so_path):
os.rename(r'%s' % egg_file, r'%s' % new_so_path)
assert os.path.exists(new_so_path)
def load(name,
sources,
extra_cflags=None,
extra_cuda_cflags=None,
extra_ldflags=None,
extra_include_paths=None,
build_directory=None,
verbose=False):
# TODO(Aurelius84): It just contains main logic codes, more details
# will be added later.
if build_directory is None:
build_directory = get_build_directory()
# ensure to use abs path
build_directory = os.path.abspath(build_directory)
file_path = os.path.join(build_directory, "setup.py")
sources = [os.path.abspath(source) for source in sources]
# TODO(Aurelius84): split cflags and cuda_flags
if extra_cflags is None: extra_cflags = []
if extra_cuda_cflags is None: extra_cuda_cflags = []
compile_flags = extra_cflags + extra_cuda_cflags
# write setup.py file and compile it
_write_setup_file(name, sources, file_path, extra_include_paths,
compile_flags, extra_ldflags)
_jit_compile(file_path)
# import as callable python api
custom_op_api = _import_module_from_library(name, build_directory)
return custom_op_api
......@@ -13,25 +13,131 @@
# limitations under the License.
import os
import re
import six
import sys
import copy
import glob
import collections
import textwrap
import platform
import warnings
import subprocess
import paddle
from contextlib import contextmanager
from setuptools.command import bdist_egg
IS_WINDOWS = os.name == 'nt'
# TODO(Aurelius84): Need check version of gcc and g++ is same.
# After CI path is fixed, we will modify into cc.
from .. import load_op_library
from ...fluid import core
from ...sysconfig import get_include, get_lib
OS_NAME = platform.system()
IS_WINDOWS = OS_NAME == 'Windows'
NVCC_COMPILE_FLAGS = [
'-ccbin', 'gcc', '-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU',
'-DPADDLE_USE_DSO', '-Xcompiler', '-fPIC', '-w', '--expt-relaxed-constexpr',
'-O3', '-DNVCC'
'-ccbin', 'cc', '-DPADDLE_WITH_CUDA', '-DEIGEN_USE_GPU', '-DPADDLE_USE_DSO',
'-Xcompiler', '-fPIC', '-w', '--expt-relaxed-constexpr', '-O3', '-DNVCC'
]
@contextmanager
def bootstrap_context():
"""
Context to manage how to write `__bootstrap__` code in .egg
"""
origin_write_stub = bdist_egg.write_stub
bdist_egg.write_stub = custom_write_stub
yield
bdist_egg.write_stub = origin_write_stub
def custom_write_stub(resource, pyfile):
"""
Customized write_stub function to allow us to inject generated python
api codes into egg python file.
"""
_stub_template = textwrap.dedent("""
import os
import sys
import paddle
def inject_ext_module(module_name, api_name):
if module_name in sys.modules:
return sys.modules[module_name]
new_module = imp.new_module(module_name)
setattr(new_module, api_name, eval(api_name))
return new_module
def __bootstrap__():
cur_dir = os.path.dirname(os.path.abspath(__file__))
so_path = os.path.join(cur_dir, "{resource}")
assert os.path.exists(so_path)
# load custom op shared library with abs path
new_custom_op = paddle.utils.load_op_library(so_path)
assert len(new_custom_op) == 1
m = inject_ext_module(__name__, new_custom_op[0])
__bootstrap__()
{custom_api}
""").lstrip()
# Parse registerring op information
_, op_info = CustomOpInfo.instance().last()
so_path = op_info.build_directory
new_custom_op = load_op_library(so_path)
assert len(new_custom_op) == 1
# NOTE: To avoid importing .so file instead of python file because they have same name,
# we rename .so shared library to another name, see EasyInstallCommand.
filename, ext = os.path.splitext(resource)
resource = filename + "_pd_" + ext
with open(pyfile, 'w') as f:
f.write(
_stub_template.format(
resource=resource,
custom_api=_custom_api_content(new_custom_op[0])))
OpInfo = collections.namedtuple('OpInfo',
['so_name', 'build_directory', 'out_dtypes'])
class CustomOpInfo:
"""
A global Singleton map to record all compiled custom ops information.
"""
@classmethod
def instance(cls):
if not hasattr(cls, '_instance'):
cls._instance = cls()
return cls._instance
def __init__(self):
assert not hasattr(
self.__class__,
'_instance'), 'Please use `instance()` to get CustomOpInfo object!'
# NOTE(Aurelius84): Use OrderedDict to save more order information
self.op_info_map = collections.OrderedDict()
def add(self, op_name, so_name, build_directory=None, out_dtypes=None):
self.op_info_map[op_name] = OpInfo(so_name, build_directory, out_dtypes)
def last(self):
"""
Return the lastest insert custom op info.
"""
assert len(self.op_info_map) > 0
return next(reversed(self.op_info_map.items()))
def prepare_unix_cflags(cflags):
"""
Prepare all necessary compiled flags for nvcc compiling CUDA files.
......@@ -102,7 +208,7 @@ def find_paddle_includes(use_cuda=False):
Return Paddle necessary include dir path.
"""
# pythonXX/site-packages/paddle/include
paddle_include_dir = paddle.sysconfig.get_include()
paddle_include_dir = get_include()
third_party_dir = os.path.join(paddle_include_dir, 'third_party')
include_dirs = [paddle_include_dir, third_party_dir]
......@@ -150,7 +256,7 @@ def find_cuda_home():
else:
cuda_home = "/usr/local/cuda"
# step 3. check whether path is valid
if not os.path.exists(cuda_home) and paddle.is_compiled_with_cuda():
if not os.path.exists(cuda_home) and core.is_compiled_with_cuda():
cuda_home = None
warnings.warn(
"Not found CUDA runtime, please use `export CUDA_HOME= XXX` to specific it."
......@@ -164,7 +270,7 @@ def find_paddle_libraries(use_cuda=False):
Return Paddle necessary library dir path.
"""
# pythonXX/site-packages/paddle/libs
paddle_lib_dirs = [paddle.sysconfig.get_lib()]
paddle_lib_dirs = [get_lib()]
if use_cuda:
cuda_dirs = find_cuda_includes()
paddle_lib_dirs.extend(cuda_dirs)
......@@ -200,17 +306,238 @@ def is_cuda_file(path):
return items[-1] in cuda_suffix
def get_build_directory(name):
def get_build_directory():
"""
Return paddle extension root directory, default specific by `PADDLE_EXTENSION_DIR`
"""
root_extensions_directory = os.envsiron.get('PADDLE_EXTENSION_DIR')
root_extensions_directory = os.environ.get('PADDLE_EXTENSION_DIR')
if root_extensions_directory is None:
# TODO(Aurelius84): consider wind32/macOs
here = os.path.abspath(__file__)
root_extensions_directory = os.path.realpath(here)
dir_name = "paddle_extensions"
if OS_NAME == 'Linux':
root_extensions_directory = os.path.join(
os.path.expanduser('~/.cache'), dir_name)
else:
# TODO(Aurelius84): consider wind32/macOs
raise NotImplementedError("Only support Linux now.")
warnings.warn(
"$PADDLE_EXTENSION_DIR is not set, using path: {} by default."
.format(root_extensions_directory))
"$PADDLE_EXTENSION_DIR is not set, using path: {} by default.".
format(root_extensions_directory))
if not os.path.exists(root_extensions_directory):
os.makedirs(root_extensions_directory)
return root_extensions_directory
def parse_op_info(op_name):
"""
Parse input names and outpus detail information from registered custom op
from OpInfoMap.
"""
from paddle.fluid.framework import OpProtoHolder
if op_name not in OpProtoHolder.instance().op_proto_map:
raise ValueError(
"Please load {} shared library file firstly by `paddle.utils.load_op_library(...)`".
format(op_name))
op_proto = OpProtoHolder.instance().get_op_proto(op_name)
in_names = [x.name for x in op_proto.inputs]
assert len(op_proto.outputs) == 1
out_name = op_proto.outputs[0].name
# TODO(Aurelius84): parse necessary out_dtype of custom op
out_infos = {out_name: ['float32']}
return in_names, out_infos
def _import_module_from_library(name, build_directory):
"""
Load .so shared library and import it as callable python module.
"""
ext_path = os.path.join(build_directory, name + '.so')
if not os.path.exists(ext_path):
raise FileNotFoundError("Extension path: {} does not exist.".format(
ext_path))
# load custom op_info and kernels from .so shared library
op_names = load_op_library(ext_path)
assert len(op_names) == 1
# generate Python api in ext_path
return _generate_python_module(op_names[0], build_directory)
def _generate_python_module(op_name, build_directory):
"""
Automatically generate python file to allow import or load into as module
"""
api_file = os.path.join(build_directory, op_name + '.py')
# write into .py file
api_content = _custom_api_content(op_name)
with open(api_file, 'w') as f:
f.write(api_content)
# load module
custom_api = _load_module_from_file(op_name, api_file)
return custom_api
def _custom_api_content(op_name):
params_str, ins_str = _get_api_inputs_str(op_name)
API_TEMPLATE = textwrap.dedent("""
from paddle.fluid.layer_helper import LayerHelper
from paddle.utils.cpp_extension import parse_op_info
_, _out_infos = parse_op_info('{op_name}')
def {op_name}({inputs}):
helper = LayerHelper("{op_name}", **locals())
# prepare inputs and output
ins = {ins}
outs = {{}}
for out_name in _out_infos:
outs[out_name] = [helper.create_variable(dtype=dtype) for dtype in _out_infos[out_name]]
helper.append_op(type="{op_name}", inputs=ins, outputs=outs)
res = list(outs.values())[0]
if len(res) == 1:
return res[0]
else:
return res
""").lstrip()
# generate python api file
api_content = API_TEMPLATE.format(
op_name=op_name, inputs=params_str, ins=ins_str)
return api_content
def _load_module_from_file(op_name, api_file_path):
"""
Load module from python file.
"""
if not os.path.exists(api_file_path):
raise FileNotFoundError("File : {} does not exist.".format(
api_file_path))
# Unique readable module name to place custom api.
ext_name = "_paddle_cpp_extension_"
if six.PY2:
import imp
module = imp.load_source(ext_name, api_file_path)
else:
from importlib import machinery
loader = machinery.SourceFileLoader(ext_name, api_file_path)
module = loader.load_module()
assert hasattr(module, op_name)
return getattr(module, op_name)
def _get_api_inputs_str(op_name):
"""
Returns string of api parameters and inputs dict.
"""
in_names, _ = parse_op_info(op_name)
# e.g: x, y, z
params_str = ','.join([p.lower() for p in in_names])
# e.g: {'X': x, 'Y': y, 'Z': z}
ins_str = "{%s}" % ','.join(
["'{}' : {}".format(in_name, in_name.lower()) for in_name in in_names])
return params_str, ins_str
def _write_setup_file(name, sources, file_path, include_dirs, compile_flags,
link_args):
"""
Automatically generate setup.py and write it into build directory.
"""
template = textwrap.dedent("""
import os
from paddle.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension, setup
from paddle.utils.cpp_extension import get_build_directory
setup(
name='{name}',
ext_modules=[
{prefix}Extension(
name='{name}',
sources={sources},
include_dirs={include_dirs},
extra_compile_args={extra_compile_args},
extra_link_args={extra_link_args})],
cmdclass={{"build_ext" : BuildExtension.with_options(
output_dir=get_build_directory(),
no_python_abi_suffix=True)
}})""").lstrip()
with_cuda = False
if any([is_cuda_file(source) for source in sources]):
with_cuda = True
content = template.format(
name=name,
prefix='CUDA' if with_cuda else 'Cpp',
sources=list2str(sources),
include_dirs=list2str(include_dirs),
extra_compile_args=list2str(compile_flags),
extra_link_args=list2str(link_args))
with open(file_path, 'w') as f:
f.write(content)
def list2str(args):
"""
Convert list[str] into string. For example: [x, y] -> "['x', 'y']"
"""
if args is None: return '[]'
assert isinstance(args, (list, tuple))
args = ["'{}'".format(arg) for arg in args]
return '[' + ','.join(args) + ']'
def _jit_compile(file_path):
"""
Build shared library in subprocess
"""
ext_dir = os.path.dirname(file_path)
setup_file = os.path.basename(file_path)
compile_cmd = 'cd {} && python {} build'.format(ext_dir, setup_file)
run_cmd(compile_cmd)
def parse_op_name_from(sources):
"""
Parse registerring custom op name from sources.
"""
def regex(content):
pattern = re.compile(r'REGISTER_OPERATOR\(([^,]+),')
content = re.sub(r'\s|\t|\n', '', content)
op_name = pattern.findall(content)
op_name = set([re.sub('_grad', '', name) for name in op_name])
return op_name
op_names = set()
for source in sources:
with open(source, 'r') as f:
content = f.read()
op_names |= regex(content)
# TODO(Aurelius84): Support register more customs op at once
assert len(op_names) == 1
return list(op_names)[0]
def run_cmd(command, wait=True):
"""
Execute command with subprocess.
"""
return subprocess.check_call(command, shell=True)
......@@ -139,6 +139,7 @@ write_distributed_training_mode_py(filename='@PADDLE_BINARY_DIR@/python/paddle/f
packages=['paddle',
'paddle.libs',
'paddle.utils',
'paddle.utils.cpp_extension',
'paddle.dataset',
'paddle.reader',
'paddle.distributed',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册