未验证 提交 4dbe16c4 编写于 作者: A Aurelius84 提交者: GitHub

[CustomOp] Refine name argument in setup (#31049)

* refine setup name usage

* fix unittest failed
上级 f2dc29a9
......@@ -23,10 +23,9 @@ use_new_custom_op_load_method(False)
file_dir = os.path.dirname(os.path.abspath(__file__))
setup(
name='relu2_op_shared',
name='librelu2_op_from_setup',
ext_modules=[
CUDAExtension(
name='librelu2_op_from_setup',
sources=['relu_op3.cc', 'relu_op3.cu', 'relu_op.cc',
'relu_op.cu'], # test for multi ops
include_dirs=paddle_includes,
......
......@@ -22,11 +22,8 @@ use_new_custom_op_load_method(False)
setup(
name='custom_relu2',
ext_modules=[
CUDAExtension(
name='custom_relu2',
ext_modules=CUDAExtension( # test for not specific name here.
sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc',
'relu_op3.cu'], # test for multi ops
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args)
])
extra_compile_args=extra_compile_args))
......@@ -19,12 +19,9 @@ from paddle.utils.cpp_extension import CUDAExtension, setup
setup(
name='simple_setup_relu2',
ext_modules=[
CUDAExtension(
name='simple_setup_relu2',
ext_modules=CUDAExtension( # test for not specific name here.
sources=[
'relu_op_simple.cc', 'relu_op_simple.cu', 'relu_op3_simple.cc'
], # test for multi ops
include_dirs=paddle_includes,
extra_compile_args=extra_compile_args)
])
extra_compile_args=extra_compile_args))
......@@ -25,10 +25,9 @@ from setuptools.command.build_ext import build_ext
from .extension_utils import find_cuda_home, normalize_extension_kwargs, add_compile_flag, bootstrap_context
from .extension_utils import is_cuda_file, prepare_unix_cflags, add_std_without_repeat, get_build_directory
from .extension_utils import _import_module_from_library, CustomOpInfo, _write_setup_file, _jit_compile, parse_op_name_from
from .extension_utils import check_abi_compatibility, log_v
from .extension_utils import check_abi_compatibility, log_v, IS_WINDOWS
from .extension_utils import use_new_custom_op_load_method
IS_WINDOWS = os.name == 'nt'
CUDA_HOME = find_cuda_home()
......@@ -37,6 +36,21 @@ def setup(**attr):
Wrapper setuptools.setup function to valid `build_ext` command and
implement paddle api code injection by switching `write_stub`
function in bdist_egg with `custom_write_stub`.
Its usage is almost same as `setuptools.setup` except for `ext_modules`
arguments. For compiling multi custom operators, all necessary source files
can be include into just one Extension (CppExtension/CUDAExtension).
Moreover, only one `name` argument is required in `setup` and no need to spcific
`name` in Extension.
Example:
>> from paddle.utils.cpp_extension import CUDAExtension, setup
>> setup(name='custom_module',
ext_modules=CUDAExtension(
sources=['relu_op.cc', 'relu_op.cu'],
include_dirs=[], # specific user-defined include dirs
extra_compile_args=[]) # specific user-defined compil arguments.
"""
cmdclass = attr.get('cmdclass', {})
assert isinstance(cmdclass, dict)
......@@ -46,6 +60,36 @@ def setup(**attr):
no_python_abi_suffix=True)
attr['cmdclass'] = cmdclass
error_msg = """
Required to specific `name` argument in paddle.utils.cpp_extension.setup.
It's used as `import XXX` when you want install and import your custom operators.\n
For Example:
# setup.py file
from paddle.utils.cpp_extension import CUDAExtension, setup
setup(name='custom_module',
ext_modules=CUDAExtension(
sources=['relu_op.cc', 'relu_op.cu'])
# After running `python setup.py install`
from custom_module import relue
"""
# name argument is required
if 'name' not in attr:
raise ValueError(error_msg)
ext_modules = attr.get('ext_modules', [])
if not isinstance(ext_modules, list):
ext_modules = [ext_modules]
assert len(
ext_modules
) == 1, "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extenion.".format(
len(ext_modules))
# replace Extension.name with attr['name] to keep consistant with Package name.
for ext_module in ext_modules:
ext_module.name = attr['name']
attr['ext_modules'] = ext_modules
# Add rename .so hook in easy_install
assert 'easy_install' not in cmdclass
cmdclass['easy_install'] = EasyInstallCommand
......@@ -59,13 +103,12 @@ def setup(**attr):
setuptools.setup(**attr)
def CppExtension(name, sources, *args, **kwargs):
def CppExtension(sources, *args, **kwargs):
"""
Returns setuptools.CppExtension instance for setup.py to make it easy
to specify compile flags while building C++ custommed op kernel.
Args:
name(str): The extension name used as generated shared library name
sources(list[str]): The C++/CUDA source file names
args(list[options]): list of config options used to compile shared library
kwargs(dict[option]): dict of config options used to compile shared library
......@@ -74,17 +117,23 @@ def CppExtension(name, sources, *args, **kwargs):
Extension: An instance of setuptools.Extension
"""
kwargs = normalize_extension_kwargs(kwargs, use_cuda=False)
# Note(Aurelius84): While using `setup` and `jit`, the Extension `name` will
# be replaced as `setup.name` to keep consistant with package. Because we allow
# users can not specific name in Extension.
# See `paddle.utils.cpp_extension.setup` for details.
name = kwargs.get('name', None)
if name is None:
name = _generate_extension_name(sources)
return setuptools.Extension(name, sources, *args, **kwargs)
def CUDAExtension(name, sources, *args, **kwargs):
def CUDAExtension(sources, *args, **kwargs):
"""
Returns setuptools.CppExtension instance for setup.py to make it easy
to specify compile flags while build CUDA custommed op kernel.
Args:
name(str): The extension name used as generated shared library name
sources(list[str]): The C++/CUDA source file names
args(list[options]): list of config options used to compile shared library
kwargs(dict[option]): dict of config options used to compile shared library
......@@ -93,10 +142,33 @@ def CUDAExtension(name, sources, *args, **kwargs):
Extension: An instance of setuptools.Extension
"""
kwargs = normalize_extension_kwargs(kwargs, use_cuda=True)
# Note(Aurelius84): While using `setup` and `jit`, the Extension `name` will
# be replaced as `setup.name` to keep consistant with package. Because we allow
# users can not specific name in Extension.
# See `paddle.utils.cpp_extension.setup` for details.
name = kwargs.get('name', None)
if name is None:
name = _generate_extension_name(sources)
return setuptools.Extension(name, sources, *args, **kwargs)
def _generate_extension_name(sources):
"""
Generate extension name by source files.
"""
assert len(sources) > 0, "source files is empty"
file_prefix = []
for source in sources:
source = os.path.basename(source)
filename, _ = os.path.splitext(source)
# Use list to generate same order.
if filename not in file_prefix:
file_prefix.append(filename)
return '_'.join(file_prefix)
class BuildExtension(build_ext, object):
"""
Inherited from setuptools.command.build_ext to customize how to apply
......@@ -285,7 +357,7 @@ class BuildExtension(build_ext, object):
for op_name in op_names:
CustomOpInfo.instance().add(op_name,
so_name=so_name,
build_directory=so_path)
so_path=so_path)
class EasyInstallCommand(easy_install, object):
......
......@@ -109,7 +109,6 @@ def load_op_meta_info_and_register_op(lib_filename):
if USING_NEW_CUSTOM_OP_LOAD_METHOD:
core.load_op_meta_info_and_register_op(lib_filename)
else:
print("old branch")
core.load_op_library(lib_filename)
return OpProtoHolder.instance().update_op_proto()
......@@ -152,7 +151,7 @@ def custom_write_stub(resource, pyfile):
# Parse registerring op information
_, op_info = CustomOpInfo.instance().last()
so_path = op_info.build_directory
so_path = op_info.so_path
new_custom_ops = load_op_meta_info_and_register_op(so_path)
assert len(
......@@ -175,8 +174,7 @@ def custom_write_stub(resource, pyfile):
resource=resource, custom_api='\n\n'.join(api_content)))
OpInfo = collections.namedtuple('OpInfo',
['so_name', 'build_directory', 'out_dtypes'])
OpInfo = collections.namedtuple('OpInfo', ['so_name', 'so_path'])
class CustomOpInfo:
......@@ -197,8 +195,8 @@ class CustomOpInfo:
# NOTE(Aurelius84): Use OrderedDict to save more order information
self.op_info_map = collections.OrderedDict()
def add(self, op_name, so_name, build_directory=None, out_dtypes=None):
self.op_info_map[op_name] = OpInfo(so_name, build_directory, out_dtypes)
def add(self, op_name, so_name, so_path=None):
self.op_info_map[op_name] = OpInfo(so_name, so_path)
def last(self):
"""
......@@ -266,7 +264,10 @@ def normalize_extension_kwargs(kwargs, use_cuda=False):
# append link flags
extra_link_args = kwargs.get('extra_link_args', [])
extra_link_args.extend(['-lpaddle_framework', '-lcudart'])
extra_link_args.append('-lpaddle_framework')
if use_cuda:
extra_link_args.append('-lcudart')
kwargs['extra_link_args'] = extra_link_args
kwargs['language'] = 'c++'
......@@ -533,7 +534,6 @@ def _write_setup_file(name,
name='{name}',
ext_modules=[
{prefix}Extension(
name='{name}',
sources={sources},
include_dirs={include_dirs},
extra_compile_args={extra_compile_args},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册