未验证 提交 406f4a75 编写于 作者: A Aurelius84 提交者: GitHub

[CustomOp] Support to specific extra_cflags and exctra_cuda_flags independently (#31059)

* split cxx/nvcc compile flags

* enhance input argument check

* rename extra_cflags into extrac_cxx_flags

* add name checking in setup

* fix test_dispatch failed

* fix word typo and rm usless import statement

* refine import statement

* fix unittest failed

* fix cuda flags error
上级 572cc8bd
......@@ -40,7 +40,8 @@ custom_module = load(
'custom_relu_op.cc', 'custom_relu_op.cu', 'custom_relu_op_dup.cc'
],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args, # add for Coverage CI
extra_cxx_cflags=extra_compile_args, # add for Coverage CI
extra_cuda_cflags=extra_compile_args, # add for Coverage CI
verbose=True)
......
......@@ -31,7 +31,8 @@ dispatch_op = load(
name='dispatch_op',
sources=['dispatch_test_op.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args, # add for Coverage CI
extra_cxx_cflags=extra_compile_args,
extra_cuda_cflags=extra_compile_args, # add for Coverage CI
verbose=True)
......
......@@ -29,7 +29,8 @@ custom_module = load(
sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc', 'relu_op3.cu'],
interpreter='python', # add for unittest
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args, # add for Coverage CI
extra_cxx_cflags=extra_compile_args, # add for Coverage CI,
extra_cuda_cflags=extra_compile_args, # add for split cpp/cuda flags
verbose=True # add for unittest
)
......
......@@ -35,7 +35,8 @@ multi_out_module = load(
name='multi_out_jit',
sources=['multi_out_test_op.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args, # add for Coverage CI
extra_cxx_cflags=extra_compile_args, # add for Coverage CI
extra_cuda_cflags=extra_compile_args, # add for Coverage CI
verbose=True)
......
......@@ -14,8 +14,6 @@
import os
import six
import sys
import textwrap
import copy
import re
......@@ -50,7 +48,7 @@ def setup(**attr):
Its usage is almost same as `setuptools.setup` except for `ext_modules`
arguments. For compiling multi custom operators, all necessary source files
can be include into just one Extension (CppExtension/CUDAExtension).
Moreover, only one `name` argument is required in `setup` and no need to spcific
Moreover, only one `name` argument is required in `setup` and no need to specify
`name` in Extension.
Example:
......@@ -60,11 +58,11 @@ def setup(**attr):
ext_modules=CUDAExtension(
sources=['relu_op.cc', 'relu_op.cu'],
include_dirs=[], # specific user-defined include dirs
extra_compile_args=[]) # specific user-defined compil arguments.
extra_compile_args=[]) # specific user-defined compiler arguments.
"""
cmdclass = attr.get('cmdclass', {})
assert isinstance(cmdclass, dict)
# if not specific cmdclass in setup, add it automaticaly.
# if not specific cmdclass in setup, add it automatically.
if 'build_ext' not in cmdclass:
cmdclass['build_ext'] = BuildExtension.with_options(
no_python_abi_suffix=True)
......@@ -81,18 +79,22 @@ def setup(**attr):
sources=['relu_op.cc', 'relu_op.cu'])
# After running `python setup.py install`
from custom_module import relue
from custom_module import relu
"""
# name argument is required
if 'name' not in attr:
raise ValueError(error_msg)
assert not attr['name'].endswith('module'), \
"Please don't use 'module' as suffix in `name` argument, "
"it will be stripped in setuptools.bdist_egg and cause import error."
ext_modules = attr.get('ext_modules', [])
if not isinstance(ext_modules, list):
ext_modules = [ext_modules]
assert len(
ext_modules
) == 1, "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extenion.".format(
) == 1, "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extension.".format(
len(ext_modules))
# replace Extension.name with attr['name] to keep consistant with Package name.
for ext_module in ext_modules:
......@@ -233,12 +235,6 @@ class BuildExtension(build_ext, object):
def build_extensions(self):
self._check_abi()
for extension in self.extensions:
# check settings of compiler
if isinstance(extension.extra_compile_args, dict):
for compiler in ['cxx', 'nvcc']:
if compiler not in extension.extra_compile_args:
extension.extra_compile_args[compiler] = []
# Consider .cu, .cu.cc as valid source extensions.
self.compiler.src_extensions += ['.cu', '.cu.cc']
......@@ -248,8 +244,6 @@ class BuildExtension(build_ext, object):
original_compile = self.compiler.compile
original_spawn = self.compiler.spawn
else:
# add determine compile flags
add_compile_flag(extension, '-std=c++11')
original_compile = self.compiler._compile
def unix_custom_single_compiler(obj, src, ext, cc_args, extra_postargs,
......@@ -271,7 +265,7 @@ class BuildExtension(build_ext, object):
# {'nvcc': {}, 'cxx: {}}
if isinstance(cflags, dict):
cflags = cflags['nvcc']
else:
cflags = prepare_unix_cudaflags(cflags)
# cxx compile Cpp source
elif isinstance(cflags, dict):
......@@ -434,7 +428,7 @@ class BuildExtension(build_ext, object):
compiler = os.environ.get('CXX', 'c++')
check_abi_compatibility(compiler)
# Warn user if VC env is activated but `DISTUILS_USE_SDK` is not set.
# Warn user if VC env is activated but `DISTUTILS_USE_SDK` is not set.
if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os.environ and 'DISTUTILS_USE_SDK' not in os.environ:
msg = (
'It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.'
......@@ -444,7 +438,7 @@ class BuildExtension(build_ext, object):
def _record_op_info(self):
"""
Record custum op inforomation.
Record custom op information.
"""
# parse shared library abs path
outputs = self.get_outputs()
......@@ -535,7 +529,7 @@ class BuildCommand(build, object):
def load(name,
sources,
extra_cflags=None,
extra_cxx_cflags=None,
extra_cuda_cflags=None,
extra_ldflags=None,
extra_include_paths=None,
......@@ -558,14 +552,14 @@ def load(name,
Args:
name(str): generated shared library file name.
sources(list[str]): custom op source files name with .cc/.cu suffix.
extra_cflag(list[str]): additional flags used to compile CPP files. By default
extra_cxx_cflags(list[str]): additional flags used to compile CPP files. By default
all basic and framework related flags have been included.
If your pre-insall Paddle supported MKLDNN, please add
'-DPADDLE_WITH_MKLDNN'. Default None.
extra_cuda_cflags(list[str]): additonal flags used to compile CUDA files. See
extra_cuda_cflags(list[str]): additional flags used to compile CUDA files. See
https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html
for details. Default None.
extra_ldflags(list[str]): additonal flags used to link shared library. See
extra_ldflags(list[str]): additional flags used to link shared library. See
https://gcc.gnu.org/onlinedocs/gcc/Link-Options.html for details.
Default None.
extra_include_paths(list[str]): additional include path used to search header files.
......@@ -578,7 +572,7 @@ def load(name,
verbose(bool): whether to verbose compiled log information
Returns:
custom api: A callable python function with same signature as CustomOp Kernel defination.
custom api: A callable python function with same signature as CustomOp Kernel definition.
Example:
......@@ -603,18 +597,25 @@ def load(name,
file_path = os.path.join(build_directory, "{}_setup.py".format(name))
sources = [os.path.abspath(source) for source in sources]
# TODO(Aurelius84): split cflags and cuda_flags
if extra_cflags is None: extra_cflags = []
if extra_cxx_cflags is None: extra_cxx_cflags = []
if extra_cuda_cflags is None: extra_cuda_cflags = []
compile_flags = extra_cflags + extra_cuda_cflags
log_v("additonal compile_flags: [{}]".format(' '.join(compile_flags)),
verbose)
assert isinstance(
extra_cxx_cflags, list
), "Required type(extra_cxx_cflags) == list[str], but received {}".format(
extra_cxx_cflags)
assert isinstance(
extra_cuda_cflags, list
), "Required type(extra_cuda_cflags) == list[str], but received {}".format(
extra_cuda_cflags)
log_v("additional extra_cxx_cflags: [{}], extra_cuda_cflags: [{}]".format(
' '.join(extra_cxx_cflags), ' '.join(extra_cuda_cflags)), verbose)
# write setup.py file and compile it
build_base_dir = os.path.join(build_directory, name)
_write_setup_file(name, sources, file_path, build_base_dir,
extra_include_paths, compile_flags, extra_ldflags,
verbose)
extra_include_paths, extra_cxx_cflags, extra_cuda_cflags,
extra_ldflags, verbose)
_jit_compile(file_path, interpreter, verbose)
# import as callable python api
......
......@@ -16,7 +16,6 @@ import os
import re
import six
import sys
import copy
import glob
import logging
import collections
......@@ -271,6 +270,13 @@ def normalize_extension_kwargs(kwargs, use_cuda=False):
library_dirs.extend(find_paddle_libraries(use_cuda))
kwargs['library_dirs'] = library_dirs
# append compile flags and check settings of compiler
extra_compile_args = kwargs.get('extra_compile_args', [])
if isinstance(extra_compile_args, dict):
for compiler in ['cxx', 'nvcc']:
if compiler not in extra_compile_args:
extra_compile_args[compiler] = []
if IS_WINDOWS:
# TODO(zhouwei): may append compile flags in future
pass
......@@ -282,9 +288,7 @@ def normalize_extension_kwargs(kwargs, use_cuda=False):
kwargs['extra_link_args'] = extra_link_args
else:
# append compile flags
extra_compile_args = kwargs.get('extra_compile_args', [])
extra_compile_args.extend(['-g', '-w']) # diable warnings
kwargs['extra_compile_args'] = extra_compile_args
add_compile_flag(extra_compile_args, ['-g', '-w']) # disable warnings
# append link flags
extra_link_args = kwargs.get('extra_link_args', [])
......@@ -302,6 +306,8 @@ def normalize_extension_kwargs(kwargs, use_cuda=False):
runtime_library_dirs.extend(find_paddle_libraries(use_cuda))
kwargs['runtime_library_dirs'] = runtime_library_dirs
kwargs['extra_compile_args'] = extra_compile_args
kwargs['language'] = 'c++'
return kwargs
......@@ -407,15 +413,13 @@ def find_paddle_libraries(use_cuda=False):
return paddle_lib_dirs
def add_compile_flag(extension, flag):
extra_compile_args = copy.deepcopy(extension.extra_compile_args)
def add_compile_flag(extra_compile_args, flags):
assert isinstance(flags, list)
if isinstance(extra_compile_args, dict):
for args in extra_compile_args.values():
args.append(flag)
args.extend(flags)
else:
extra_compile_args.append(flag)
extension.extra_compile_args = extra_compile_args
extra_compile_args.extend(flags)
def is_cuda_file(path):
......@@ -520,7 +524,7 @@ def _custom_api_content(op_name):
def {op_name}({inputs}):
helper = LayerHelper("{op_name}", **locals())
# prepare inputs and output
# prepare inputs and outputs
ins = {ins}
outs = {{}}
out_names = {out_names}
......@@ -585,7 +589,8 @@ def _write_setup_file(name,
file_path,
build_dir,
include_dirs,
compile_flags,
extra_cxx_cflags,
extra_cuda_cflags,
link_args,
verbose=False):
"""
......@@ -605,7 +610,7 @@ def _write_setup_file(name,
{prefix}Extension(
sources={sources},
include_dirs={include_dirs},
extra_compile_args={extra_compile_args},
extra_compile_args={{'cxx':{extra_cxx_cflags}, 'nvcc':{extra_cuda_cflags}}},
extra_link_args={extra_link_args})],
cmdclass={{"build_ext" : BuildExtension.with_options(
output_dir=r'{build_dir}',
......@@ -622,7 +627,8 @@ def _write_setup_file(name,
prefix='CUDA' if with_cuda else 'Cpp',
sources=list2str(sources),
include_dirs=list2str(include_dirs),
extra_compile_args=list2str(compile_flags),
extra_cxx_cflags=list2str(extra_cxx_cflags),
extra_cuda_cflags=list2str(extra_cuda_cflags),
extra_link_args=list2str(link_args),
build_dir=build_dir,
use_new_method=use_new_custom_op_load_method())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册