未验证 提交 406f4a75 编写于 作者: A Aurelius84 提交者: GitHub

[CustomOp] Support to specific extra_cflags and exctra_cuda_flags independently (#31059)

* split cxx/nvcc compile flags

* enhance input argument check

* rename extra_cflags into extrac_cxx_flags

* add name checking in setup

* fix test_dispatch failed

* fix word typo and rm usless import statement

* refine import statement

* fix unittest failed

* fix cuda flags error
上级 572cc8bd
...@@ -40,7 +40,8 @@ custom_module = load( ...@@ -40,7 +40,8 @@ custom_module = load(
'custom_relu_op.cc', 'custom_relu_op.cu', 'custom_relu_op_dup.cc' 'custom_relu_op.cc', 'custom_relu_op.cu', 'custom_relu_op_dup.cc'
], ],
extra_include_paths=paddle_includes, # add for Coverage CI extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args, # add for Coverage CI extra_cxx_cflags=extra_compile_args, # add for Coverage CI
extra_cuda_cflags=extra_compile_args, # add for Coverage CI
verbose=True) verbose=True)
......
...@@ -31,7 +31,8 @@ dispatch_op = load( ...@@ -31,7 +31,8 @@ dispatch_op = load(
name='dispatch_op', name='dispatch_op',
sources=['dispatch_test_op.cc'], sources=['dispatch_test_op.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args, # add for Coverage CI extra_cxx_cflags=extra_compile_args,
extra_cuda_cflags=extra_compile_args, # add for Coverage CI
verbose=True) verbose=True)
......
...@@ -29,7 +29,8 @@ custom_module = load( ...@@ -29,7 +29,8 @@ custom_module = load(
sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc', 'relu_op3.cu'], sources=['relu_op.cc', 'relu_op.cu', 'relu_op3.cc', 'relu_op3.cu'],
interpreter='python', # add for unittest interpreter='python', # add for unittest
extra_include_paths=paddle_includes, # add for Coverage CI extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args, # add for Coverage CI extra_cxx_cflags=extra_compile_args, # add for Coverage CI,
extra_cuda_cflags=extra_compile_args, # add for split cpp/cuda flags
verbose=True # add for unittest verbose=True # add for unittest
) )
......
...@@ -35,7 +35,8 @@ multi_out_module = load( ...@@ -35,7 +35,8 @@ multi_out_module = load(
name='multi_out_jit', name='multi_out_jit',
sources=['multi_out_test_op.cc'], sources=['multi_out_test_op.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args, # add for Coverage CI extra_cxx_cflags=extra_compile_args, # add for Coverage CI
extra_cuda_cflags=extra_compile_args, # add for Coverage CI
verbose=True) verbose=True)
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
import os import os
import six import six
import sys
import textwrap
import copy import copy
import re import re
...@@ -50,7 +48,7 @@ def setup(**attr): ...@@ -50,7 +48,7 @@ def setup(**attr):
Its usage is almost same as `setuptools.setup` except for `ext_modules` Its usage is almost same as `setuptools.setup` except for `ext_modules`
arguments. For compiling multi custom operators, all necessary source files arguments. For compiling multi custom operators, all necessary source files
can be include into just one Extension (CppExtension/CUDAExtension). can be include into just one Extension (CppExtension/CUDAExtension).
Moreover, only one `name` argument is required in `setup` and no need to spcific Moreover, only one `name` argument is required in `setup` and no need to specify
`name` in Extension. `name` in Extension.
Example: Example:
...@@ -60,11 +58,11 @@ def setup(**attr): ...@@ -60,11 +58,11 @@ def setup(**attr):
ext_modules=CUDAExtension( ext_modules=CUDAExtension(
sources=['relu_op.cc', 'relu_op.cu'], sources=['relu_op.cc', 'relu_op.cu'],
include_dirs=[], # specific user-defined include dirs include_dirs=[], # specific user-defined include dirs
extra_compile_args=[]) # specific user-defined compil arguments. extra_compile_args=[]) # specific user-defined compiler arguments.
""" """
cmdclass = attr.get('cmdclass', {}) cmdclass = attr.get('cmdclass', {})
assert isinstance(cmdclass, dict) assert isinstance(cmdclass, dict)
# if not specific cmdclass in setup, add it automaticaly. # if not specific cmdclass in setup, add it automatically.
if 'build_ext' not in cmdclass: if 'build_ext' not in cmdclass:
cmdclass['build_ext'] = BuildExtension.with_options( cmdclass['build_ext'] = BuildExtension.with_options(
no_python_abi_suffix=True) no_python_abi_suffix=True)
...@@ -81,18 +79,22 @@ def setup(**attr): ...@@ -81,18 +79,22 @@ def setup(**attr):
sources=['relu_op.cc', 'relu_op.cu']) sources=['relu_op.cc', 'relu_op.cu'])
# After running `python setup.py install` # After running `python setup.py install`
from custom_module import relue from custom_module import relu
""" """
# name argument is required # name argument is required
if 'name' not in attr: if 'name' not in attr:
raise ValueError(error_msg) raise ValueError(error_msg)
assert not attr['name'].endswith('module'), \
"Please don't use 'module' as suffix in `name` argument, "
"it will be stripped in setuptools.bdist_egg and cause import error."
ext_modules = attr.get('ext_modules', []) ext_modules = attr.get('ext_modules', [])
if not isinstance(ext_modules, list): if not isinstance(ext_modules, list):
ext_modules = [ext_modules] ext_modules = [ext_modules]
assert len( assert len(
ext_modules ext_modules
) == 1, "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extenion.".format( ) == 1, "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extension.".format(
len(ext_modules)) len(ext_modules))
# replace Extension.name with attr['name] to keep consistant with Package name. # replace Extension.name with attr['name] to keep consistant with Package name.
for ext_module in ext_modules: for ext_module in ext_modules:
...@@ -233,12 +235,6 @@ class BuildExtension(build_ext, object): ...@@ -233,12 +235,6 @@ class BuildExtension(build_ext, object):
def build_extensions(self): def build_extensions(self):
self._check_abi() self._check_abi()
for extension in self.extensions:
# check settings of compiler
if isinstance(extension.extra_compile_args, dict):
for compiler in ['cxx', 'nvcc']:
if compiler not in extension.extra_compile_args:
extension.extra_compile_args[compiler] = []
# Consider .cu, .cu.cc as valid source extensions. # Consider .cu, .cu.cc as valid source extensions.
self.compiler.src_extensions += ['.cu', '.cu.cc'] self.compiler.src_extensions += ['.cu', '.cu.cc']
...@@ -248,8 +244,6 @@ class BuildExtension(build_ext, object): ...@@ -248,8 +244,6 @@ class BuildExtension(build_ext, object):
original_compile = self.compiler.compile original_compile = self.compiler.compile
original_spawn = self.compiler.spawn original_spawn = self.compiler.spawn
else: else:
# add determine compile flags
add_compile_flag(extension, '-std=c++11')
original_compile = self.compiler._compile original_compile = self.compiler._compile
def unix_custom_single_compiler(obj, src, ext, cc_args, extra_postargs, def unix_custom_single_compiler(obj, src, ext, cc_args, extra_postargs,
...@@ -271,8 +265,8 @@ class BuildExtension(build_ext, object): ...@@ -271,8 +265,8 @@ class BuildExtension(build_ext, object):
# {'nvcc': {}, 'cxx: {}} # {'nvcc': {}, 'cxx: {}}
if isinstance(cflags, dict): if isinstance(cflags, dict):
cflags = cflags['nvcc'] cflags = cflags['nvcc']
else:
cflags = prepare_unix_cudaflags(cflags) cflags = prepare_unix_cudaflags(cflags)
# cxx compile Cpp source # cxx compile Cpp source
elif isinstance(cflags, dict): elif isinstance(cflags, dict):
cflags = cflags['cxx'] cflags = cflags['cxx']
...@@ -434,7 +428,7 @@ class BuildExtension(build_ext, object): ...@@ -434,7 +428,7 @@ class BuildExtension(build_ext, object):
compiler = os.environ.get('CXX', 'c++') compiler = os.environ.get('CXX', 'c++')
check_abi_compatibility(compiler) check_abi_compatibility(compiler)
# Warn user if VC env is activated but `DISTUILS_USE_SDK` is not set. # Warn user if VC env is activated but `DISTUTILS_USE_SDK` is not set.
if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os.environ and 'DISTUTILS_USE_SDK' not in os.environ: if IS_WINDOWS and 'VSCMD_ARG_TGT_ARCH' in os.environ and 'DISTUTILS_USE_SDK' not in os.environ:
msg = ( msg = (
'It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.' 'It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.'
...@@ -444,7 +438,7 @@ class BuildExtension(build_ext, object): ...@@ -444,7 +438,7 @@ class BuildExtension(build_ext, object):
def _record_op_info(self): def _record_op_info(self):
""" """
Record custum op inforomation. Record custom op information.
""" """
# parse shared library abs path # parse shared library abs path
outputs = self.get_outputs() outputs = self.get_outputs()
...@@ -535,7 +529,7 @@ class BuildCommand(build, object): ...@@ -535,7 +529,7 @@ class BuildCommand(build, object):
def load(name, def load(name,
sources, sources,
extra_cflags=None, extra_cxx_cflags=None,
extra_cuda_cflags=None, extra_cuda_cflags=None,
extra_ldflags=None, extra_ldflags=None,
extra_include_paths=None, extra_include_paths=None,
...@@ -558,14 +552,14 @@ def load(name, ...@@ -558,14 +552,14 @@ def load(name,
Args: Args:
name(str): generated shared library file name. name(str): generated shared library file name.
sources(list[str]): custom op source files name with .cc/.cu suffix. sources(list[str]): custom op source files name with .cc/.cu suffix.
extra_cflag(list[str]): additional flags used to compile CPP files. By default extra_cxx_cflags(list[str]): additional flags used to compile CPP files. By default
all basic and framework related flags have been included. all basic and framework related flags have been included.
If your pre-insall Paddle supported MKLDNN, please add If your pre-insall Paddle supported MKLDNN, please add
'-DPADDLE_WITH_MKLDNN'. Default None. '-DPADDLE_WITH_MKLDNN'. Default None.
extra_cuda_cflags(list[str]): additonal flags used to compile CUDA files. See extra_cuda_cflags(list[str]): additional flags used to compile CUDA files. See
https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html
for details. Default None. for details. Default None.
extra_ldflags(list[str]): additonal flags used to link shared library. See extra_ldflags(list[str]): additional flags used to link shared library. See
https://gcc.gnu.org/onlinedocs/gcc/Link-Options.html for details. https://gcc.gnu.org/onlinedocs/gcc/Link-Options.html for details.
Default None. Default None.
extra_include_paths(list[str]): additional include path used to search header files. extra_include_paths(list[str]): additional include path used to search header files.
...@@ -578,7 +572,7 @@ def load(name, ...@@ -578,7 +572,7 @@ def load(name,
verbose(bool): whether to verbose compiled log information verbose(bool): whether to verbose compiled log information
Returns: Returns:
custom api: A callable python function with same signature as CustomOp Kernel defination. custom api: A callable python function with same signature as CustomOp Kernel definition.
Example: Example:
...@@ -603,18 +597,25 @@ def load(name, ...@@ -603,18 +597,25 @@ def load(name,
file_path = os.path.join(build_directory, "{}_setup.py".format(name)) file_path = os.path.join(build_directory, "{}_setup.py".format(name))
sources = [os.path.abspath(source) for source in sources] sources = [os.path.abspath(source) for source in sources]
# TODO(Aurelius84): split cflags and cuda_flags if extra_cxx_cflags is None: extra_cxx_cflags = []
if extra_cflags is None: extra_cflags = []
if extra_cuda_cflags is None: extra_cuda_cflags = [] if extra_cuda_cflags is None: extra_cuda_cflags = []
compile_flags = extra_cflags + extra_cuda_cflags assert isinstance(
log_v("additonal compile_flags: [{}]".format(' '.join(compile_flags)), extra_cxx_cflags, list
verbose) ), "Required type(extra_cxx_cflags) == list[str], but received {}".format(
extra_cxx_cflags)
assert isinstance(
extra_cuda_cflags, list
), "Required type(extra_cuda_cflags) == list[str], but received {}".format(
extra_cuda_cflags)
log_v("additional extra_cxx_cflags: [{}], extra_cuda_cflags: [{}]".format(
' '.join(extra_cxx_cflags), ' '.join(extra_cuda_cflags)), verbose)
# write setup.py file and compile it # write setup.py file and compile it
build_base_dir = os.path.join(build_directory, name) build_base_dir = os.path.join(build_directory, name)
_write_setup_file(name, sources, file_path, build_base_dir, _write_setup_file(name, sources, file_path, build_base_dir,
extra_include_paths, compile_flags, extra_ldflags, extra_include_paths, extra_cxx_cflags, extra_cuda_cflags,
verbose) extra_ldflags, verbose)
_jit_compile(file_path, interpreter, verbose) _jit_compile(file_path, interpreter, verbose)
# import as callable python api # import as callable python api
......
...@@ -16,7 +16,6 @@ import os ...@@ -16,7 +16,6 @@ import os
import re import re
import six import six
import sys import sys
import copy
import glob import glob
import logging import logging
import collections import collections
...@@ -271,6 +270,13 @@ def normalize_extension_kwargs(kwargs, use_cuda=False): ...@@ -271,6 +270,13 @@ def normalize_extension_kwargs(kwargs, use_cuda=False):
library_dirs.extend(find_paddle_libraries(use_cuda)) library_dirs.extend(find_paddle_libraries(use_cuda))
kwargs['library_dirs'] = library_dirs kwargs['library_dirs'] = library_dirs
# append compile flags and check settings of compiler
extra_compile_args = kwargs.get('extra_compile_args', [])
if isinstance(extra_compile_args, dict):
for compiler in ['cxx', 'nvcc']:
if compiler not in extra_compile_args:
extra_compile_args[compiler] = []
if IS_WINDOWS: if IS_WINDOWS:
# TODO(zhouwei): may append compile flags in future # TODO(zhouwei): may append compile flags in future
pass pass
...@@ -282,9 +288,7 @@ def normalize_extension_kwargs(kwargs, use_cuda=False): ...@@ -282,9 +288,7 @@ def normalize_extension_kwargs(kwargs, use_cuda=False):
kwargs['extra_link_args'] = extra_link_args kwargs['extra_link_args'] = extra_link_args
else: else:
# append compile flags # append compile flags
extra_compile_args = kwargs.get('extra_compile_args', []) add_compile_flag(extra_compile_args, ['-g', '-w']) # disable warnings
extra_compile_args.extend(['-g', '-w']) # diable warnings
kwargs['extra_compile_args'] = extra_compile_args
# append link flags # append link flags
extra_link_args = kwargs.get('extra_link_args', []) extra_link_args = kwargs.get('extra_link_args', [])
...@@ -302,6 +306,8 @@ def normalize_extension_kwargs(kwargs, use_cuda=False): ...@@ -302,6 +306,8 @@ def normalize_extension_kwargs(kwargs, use_cuda=False):
runtime_library_dirs.extend(find_paddle_libraries(use_cuda)) runtime_library_dirs.extend(find_paddle_libraries(use_cuda))
kwargs['runtime_library_dirs'] = runtime_library_dirs kwargs['runtime_library_dirs'] = runtime_library_dirs
kwargs['extra_compile_args'] = extra_compile_args
kwargs['language'] = 'c++' kwargs['language'] = 'c++'
return kwargs return kwargs
...@@ -407,15 +413,13 @@ def find_paddle_libraries(use_cuda=False): ...@@ -407,15 +413,13 @@ def find_paddle_libraries(use_cuda=False):
return paddle_lib_dirs return paddle_lib_dirs
def add_compile_flag(extension, flag): def add_compile_flag(extra_compile_args, flags):
extra_compile_args = copy.deepcopy(extension.extra_compile_args) assert isinstance(flags, list)
if isinstance(extra_compile_args, dict): if isinstance(extra_compile_args, dict):
for args in extra_compile_args.values(): for args in extra_compile_args.values():
args.append(flag) args.extend(flags)
else: else:
extra_compile_args.append(flag) extra_compile_args.extend(flags)
extension.extra_compile_args = extra_compile_args
def is_cuda_file(path): def is_cuda_file(path):
...@@ -520,7 +524,7 @@ def _custom_api_content(op_name): ...@@ -520,7 +524,7 @@ def _custom_api_content(op_name):
def {op_name}({inputs}): def {op_name}({inputs}):
helper = LayerHelper("{op_name}", **locals()) helper = LayerHelper("{op_name}", **locals())
# prepare inputs and output # prepare inputs and outputs
ins = {ins} ins = {ins}
outs = {{}} outs = {{}}
out_names = {out_names} out_names = {out_names}
...@@ -585,7 +589,8 @@ def _write_setup_file(name, ...@@ -585,7 +589,8 @@ def _write_setup_file(name,
file_path, file_path,
build_dir, build_dir,
include_dirs, include_dirs,
compile_flags, extra_cxx_cflags,
extra_cuda_cflags,
link_args, link_args,
verbose=False): verbose=False):
""" """
...@@ -605,7 +610,7 @@ def _write_setup_file(name, ...@@ -605,7 +610,7 @@ def _write_setup_file(name,
{prefix}Extension( {prefix}Extension(
sources={sources}, sources={sources},
include_dirs={include_dirs}, include_dirs={include_dirs},
extra_compile_args={extra_compile_args}, extra_compile_args={{'cxx':{extra_cxx_cflags}, 'nvcc':{extra_cuda_cflags}}},
extra_link_args={extra_link_args})], extra_link_args={extra_link_args})],
cmdclass={{"build_ext" : BuildExtension.with_options( cmdclass={{"build_ext" : BuildExtension.with_options(
output_dir=r'{build_dir}', output_dir=r'{build_dir}',
...@@ -622,7 +627,8 @@ def _write_setup_file(name, ...@@ -622,7 +627,8 @@ def _write_setup_file(name,
prefix='CUDA' if with_cuda else 'Cpp', prefix='CUDA' if with_cuda else 'Cpp',
sources=list2str(sources), sources=list2str(sources),
include_dirs=list2str(include_dirs), include_dirs=list2str(include_dirs),
extra_compile_args=list2str(compile_flags), extra_cxx_cflags=list2str(extra_cxx_cflags),
extra_cuda_cflags=list2str(extra_cuda_cflags),
extra_link_args=list2str(link_args), extra_link_args=list2str(link_args),
build_dir=build_dir, build_dir=build_dir,
use_new_method=use_new_custom_op_load_method()) use_new_method=use_new_custom_op_load_method())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册