cpp_extension.py 36.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
# isort: skip_file

17 18
import os
import copy
19
import re
20 21 22 23

import setuptools
from setuptools.command.easy_install import easy_install
from setuptools.command.build_ext import build_ext
24
from distutils.command.build import build
25

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
from .extension_utils import (
    add_compile_flag,
    find_cuda_home,
    find_rocm_home,
    normalize_extension_kwargs,
)
from .extension_utils import (
    is_cuda_file,
    prepare_unix_cudaflags,
    prepare_win_cudaflags,
)
from .extension_utils import (
    _import_module_from_library,
    _write_setup_file,
    _jit_compile,
)
from .extension_utils import (
    check_abi_compatibility,
    log_v,
    CustomOpInfo,
    parse_op_name_from,
)
48
from .extension_utils import _reset_so_rpath, clean_object_if_change_cflags
49 50 51 52 53 54 55 56 57 58 59
from .extension_utils import (
    bootstrap_context,
    get_build_directory,
    add_std_without_repeat,
)

from .extension_utils import (
    IS_WINDOWS,
    OS_NAME,
    MSVC_COMPILE_FLAGS,
)
60
from .extension_utils import CLANG_COMPILE_FLAGS, CLANG_LINK_FLAGS
61

62 63
from ...fluid import core

64 65 66
# Note(zhouwei): On windows, it will export function 'PyInit_[name]' by default,
# The solution is: 1.User add function PyInit_[name] 2. set not to export
# refer to https://stackoverflow.com/questions/34689210/error-exporting-symbol-when-building-python-c-extension-in-windows
67
if IS_WINDOWS:
68 69
    from distutils.command.build_ext import build_ext as _du_build_ext
    from unittest.mock import Mock
70

71
    _du_build_ext.get_export_symbols = Mock(return_value=None)
72

73
CUDA_HOME = find_cuda_home()
74 75
if core.is_compiled_with_rocm():
    ROCM_HOME = find_rocm_home()
76
    CUDA_HOME = ROCM_HOME
77 78 79 80


def setup(**attr):
    """
81
    The interface is used to config the process of compiling customized operators,
82
    mainly includes how to compile shared library, automatically generate python API
83 84 85 86
    and install it into site-package. It supports using customized operators directly with
    ``import`` statement.

    It encapsulates the python built-in ``setuptools.setup`` function and keeps arguments
H
HongyuJia 已提交
87
    and usage same as the native interface. Meanwhile, it hides Paddle inner framework
88
    concepts, such as necessary compiling flags, included paths of head files, and linking
89 90
    flags. It also will automatically search and valid local environment and versions of
    ``cc(Linux)`` , ``cl.exe(Windows)`` and ``nvcc`` , then compiles customized operators
91
    supporting CPU or GPU device according to the specified Extension type.
92

93
    Moreover, `ABI compatibility <https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html>`_
94
    will be checked to ensure that compiler version from ``cc(Linux)`` , ``cl.exe(Windows)``
95 96
    on local machine is compatible with pre-installed Paddle whl in python site-packages.

97 98 99 100 101
    For Linux, GCC version will be checked . For example if Paddle with CUDA 10.1 is built with GCC 8.2,
    then the version of user's local machine should satisfy GCC >= 8.2.
    For Windows, Visual Studio version will be checked, and it should be greater than or equal to that of
    PaddlePaddle (Visual Studio 2017).
    If the above conditions are not met, the corresponding warning will be printed, and a fatal error may
102
    occur because of ABI compatibility.
103

104
    Note:
105

H
HongyuJia 已提交
106 107
        1. Currently we support Linux, MacOS and Windows platform.
        2. On Linux platform, we recommend to use GCC 8.2 as soft linking candidate of ``/usr/bin/cc`` .
108
           Then, Use ``which cc`` to ensure location of ``cc`` and using ``cc --version`` to ensure linking
109
           GCC version.
110
        3. On Windows platform, we recommend to install `` Visual Studio`` (>=2017).
111 112 113 114 115 116


    Compared with Just-In-Time ``load`` interface, it only compiles once by executing
    ``python setup.py install`` . Then customized operators API will be available everywhere
    after importing it.

117
    A simple example of ``setup.py`` as followed:
118 119 120

    .. code-block:: text

121
        # setup.py
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154

        # Case 1: Compiling customized operators supporting CPU and GPU devices
        from paddle.utils.cpp_extension import CUDAExtension, setup

        setup(
            name='custom_op',  # name of package used by "import"
            ext_modules=CUDAExtension(
                sources=['relu_op.cc', 'relu_op.cu', 'tanh_op.cc', 'tanh_op.cu']  # Support for compilation of multiple OPs
            )
        )

        # Case 2: Compiling customized operators supporting only CPU device
        from paddle.utils.cpp_extension import CppExtension, setup

        setup(
            name='custom_op',  # name of package used by "import"
            ext_modules=CppExtension(
                sources=['relu_op.cc', 'tanh_op.cc']  # Support for compilation of multiple OPs
            )
        )


    Applying compilation and installation by executing ``python setup.py install`` under source files directory.
    Then we can use the layer api as followed:

    .. code-block:: text

        import paddle
        from custom_op import relu, tanh

        x = paddle.randn([4, 10], dtype='float32')
        relu_out = relu(x)
        tanh_out = tanh(x)
155

156 157 158

    Args:
        name(str): Specify the name of shared library file and installed python package.
159
        ext_modules(Extension): Specify the Extension instance including customized operator source files, compiling flags et.al.
160 161
                                If only compile operator supporting CPU device, please use ``CppExtension`` ; If compile operator
                                supporting CPU and GPU devices, please use ``CUDAExtension`` .
162
        include_dirs(list[str], optional): Specify the extra include directories to search head files. The interface will automatically add
163 164 165
                                 ``site-package/paddle/include`` . Please add the corresponding directory path if including third-party
                                 head files. Default is None.
        extra_compile_args(list[str] | dict, optional): Specify the extra compiling flags such as ``-O3`` . If set ``list[str]`` , all these flags
H
HongyuJia 已提交
166
                                will be applied for ``cc`` and ``nvcc`` compiler. It supports specify flags only applied ``cc`` or ``nvcc``
167 168 169
                                compiler using dict type with ``{'cxx': [...], 'nvcc': [...]}`` . Default is None.
        **attr(dict, optional): Specify other arguments same as ``setuptools.setup`` .

170
    Returns:
171
        None
172

173 174 175
    """
    cmdclass = attr.get('cmdclass', {})
    assert isinstance(cmdclass, dict)
176
    # if not specific cmdclass in setup, add it automatically.
177 178
    if 'build_ext' not in cmdclass:
        cmdclass['build_ext'] = BuildExtension.with_options(
179 180
            no_python_abi_suffix=True
        )
181 182
        attr['cmdclass'] = cmdclass

183 184 185 186 187 188 189 190 191
    error_msg = """
    Required to specific `name` argument in paddle.utils.cpp_extension.setup.
    It's used as `import XXX` when you want install and import your custom operators.\n
    For Example:
        # setup.py file
        from paddle.utils.cpp_extension import CUDAExtension, setup
        setup(name='custom_module',
              ext_modules=CUDAExtension(
              sources=['relu_op.cc', 'relu_op.cu'])
192

193
        # After running `python setup.py install`
194
        from custom_module import relu
195 196 197 198 199
    """
    # name argument is required
    if 'name' not in attr:
        raise ValueError(error_msg)

200 201 202
    assert not attr['name'].endswith(
        'module'
    ), "Please don't use 'module' as suffix in `name` argument, "
203 204
    "it will be stripped in setuptools.bdist_egg and cause import error."

205 206 207
    ext_modules = attr.get('ext_modules', [])
    if not isinstance(ext_modules, list):
        ext_modules = [ext_modules]
208 209 210 211 212
    assert (
        len(ext_modules) == 1
    ), "Required only one Extension, but received {}. If you want to compile multi operators, you can include all necessary source files in one Extension.".format(
        len(ext_modules)
    )
H
HongyuJia 已提交
213
    # replace Extension.name with attr['name] to keep consistent with Package name.
214 215 216 217 218
    for ext_module in ext_modules:
        ext_module.name = attr['name']

    attr['ext_modules'] = ext_modules

219 220 221 222
    # Add rename .so hook in easy_install
    assert 'easy_install' not in cmdclass
    cmdclass['easy_install'] = EasyInstallCommand

223 224 225 226 227 228 229
    # Note(Aurelius84): Add rename build_base directory hook in build command.
    # To avoid using same build directory that will lead to remove the directory
    # by mistake while parallelling execute setup.py, for example on CI.
    assert 'build' not in cmdclass
    build_base = os.path.join('build', attr['name'])
    cmdclass['build'] = BuildCommand.with_options(build_base=build_base)

230 231 232 233 234 235 236 237 238
    # Always set zip_safe=False to make compatible in PY2 and PY3
    # See http://peak.telecommunity.com/DevCenter/setuptools#setting-the-zip-safe-flag
    attr['zip_safe'] = False

    # switch `write_stub` to inject paddle api in .egg
    with bootstrap_context():
        setuptools.setup(**attr)


239
def CppExtension(sources, *args, **kwargs):
240
    """
241 242 243 244
    The interface is used to config source files of customized operators and complies
    Op Kernel only supporting CPU device. Please use ``CUDAExtension`` if you want to
    compile Op Kernel that supports both CPU and GPU devices.

245
    It further encapsulates python built-in ``setuptools.Extension`` .The arguments and
246 247 248 249 250 251 252
    usage are same as the native interface, except for no need to explicitly specify
    ``name`` .

    **A simple example:**

    .. code-block:: text

253
        # setup.py
254 255 256 257 258 259 260 261 262 263

        # Compiling customized operators supporting only CPU device
        from paddle.utils.cpp_extension import CppExtension, setup

        setup(
            name='custom_op',
            ext_modules=CppExtension(sources=['relu_op.cc'])
        )


264
    Note:
H
HongyuJia 已提交
265
        It is mainly used in ``setup`` and the name of built shared library keeps same
266 267
        as ``name`` argument specified in ``setup`` interface.

268 269

    Args:
270 271 272
        sources(list[str]): Specify the C++/CUDA source files of customized operators.
        *args(list[options], optional): Specify other arguments same as ``setuptools.Extension`` .
        **kwargs(dict[option], optional): Specify other arguments same as ``setuptools.Extension`` .
273

274 275
    Returns:
        setuptools.Extension: An instance of ``setuptools.Extension``
276 277
    """
    kwargs = normalize_extension_kwargs(kwargs, use_cuda=False)
278
    # Note(Aurelius84): While using `setup` and `jit`, the Extension `name` will
H
HongyuJia 已提交
279
    # be replaced as `setup.name` to keep consistent with package. Because we allow
280 281 282 283 284
    # users can not specific name in Extension.
    # See `paddle.utils.cpp_extension.setup` for details.
    name = kwargs.get('name', None)
    if name is None:
        name = _generate_extension_name(sources)
285 286 287 288

    return setuptools.Extension(name, sources, *args, **kwargs)


289
def CUDAExtension(sources, *args, **kwargs):
290
    """
291 292 293 294
    The interface is used to config source files of customized operators and complies
    Op Kernel supporting both CPU and GPU devices. Please use ``CppExtension`` if you want to
    compile Op Kernel that supports only CPU device.

295
    It further encapsulates python built-in ``setuptools.Extension`` .The arguments and
296 297 298 299 300 301 302
    usage are same as the native interface, except for no need to explicitly specify
    ``name`` .

    **A simple example:**

    .. code-block:: text

303
        # setup.py
304 305 306 307 308 309 310 311 312 313 314 315

        # Compiling customized operators supporting CPU and GPU devices
        from paddle.utils.cpp_extension import CUDAExtension, setup

        setup(
            name='custom_op',
            ext_modules=CUDAExtension(
                sources=['relu_op.cc', 'relu_op.cu']
            )
        )


316
    Note:
H
HongyuJia 已提交
317
        It is mainly used in ``setup`` and the name of built shared library keeps same
318 319
        as ``name`` argument specified in ``setup`` interface.

320 321

    Args:
322 323 324
        sources(list[str]): Specify the C++/CUDA source files of customized operators.
        *args(list[options], optional): Specify other arguments same as ``setuptools.Extension`` .
        **kwargs(dict[option], optional): Specify other arguments same as ``setuptools.Extension`` .
325

326
    Returns:
327
        setuptools.Extension: An instance of setuptools.Extension.
328 329
    """
    kwargs = normalize_extension_kwargs(kwargs, use_cuda=True)
330
    # Note(Aurelius84): While using `setup` and `jit`, the Extension `name` will
H
HongyuJia 已提交
331
    # be replaced as `setup.name` to keep consistent with package. Because we allow
332 333 334 335 336
    # users can not specific name in Extension.
    # See `paddle.utils.cpp_extension.setup` for details.
    name = kwargs.get('name', None)
    if name is None:
        name = _generate_extension_name(sources)
337 338 339 340

    return setuptools.Extension(name, sources, *args, **kwargs)


341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356
def _generate_extension_name(sources):
    """
    Generate extension name by source files.
    """
    assert len(sources) > 0, "source files is empty"
    file_prefix = []
    for source in sources:
        source = os.path.basename(source)
        filename, _ = os.path.splitext(source)
        # Use list to generate same order.
        if filename not in file_prefix:
            file_prefix.append(filename)

    return '_'.join(file_prefix)


357
class BuildExtension(build_ext):
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
    """
    Inherited from setuptools.command.build_ext to customize how to apply
    compilation process with share library.
    """

    @classmethod
    def with_options(cls, **options):
        """
        Returns a BuildExtension subclass containing use-defined options.
        """

        class cls_with_options(cls):
            def __init__(self, *args, **kwargs):
                kwargs.update(options)
                cls.__init__(self, *args, **kwargs)

        return cls_with_options

    def __init__(self, *args, **kwargs):
        """
H
HongyuJia 已提交
378
        Attributes is initialized with following order:
379

380
            1. super().__init__()
381 382 383
            2. initialize_options(self)
            3. the reset of current __init__()
            4. finalize_options(self)
384

385 386
        So, it is recommended to set attribute value in `finalize_options`.
        """
387
        super().__init__(*args, **kwargs)
388 389
        self.no_python_abi_suffix = kwargs.get("no_python_abi_suffix", True)
        self.output_dir = kwargs.get("output_dir", None)
390 391
        # whether containing cuda source file in Extensions
        self.contain_cuda_file = False
392 393

    def initialize_options(self):
394
        super().initialize_options()
395 396

    def finalize_options(self):
397
        super().finalize_options()
398 399 400 401 402 403 404
        # NOTE(Aurelius84): Set location of compiled shared library.
        # Carefully to modify this because `setup.py build/install`
        # and `load` interface rely on this attribute.
        if self.output_dir is not None:
            self.build_lib = self.output_dir

    def build_extensions(self):
405 406 407
        if OS_NAME.startswith("darwin"):
            self._valid_clang_compiler()

408 409
        self._check_abi()

410 411
        # Note(Aurelius84): If already compiling source before, we should check whether
        # cflags have changed and delete the built shared library to re-compile the source
412
        # even though source file content keep unchanged.
413
        so_name = self.get_ext_fullpath(self.extensions[0].name)
414 415 416
        clean_object_if_change_cflags(
            os.path.abspath(so_name), self.extensions[0]
        )
417

418 419 420
        # Consider .cu, .cu.cc as valid source extensions.
        self.compiler.src_extensions += ['.cu', '.cu.cc']
        # Save the original _compile method for later.
421 422 423 424
        if self.compiler.compiler_type == 'msvc':
            self.compiler._cpp_extensions += ['.cu', '.cuh']
            original_compile = self.compiler.compile
            original_spawn = self.compiler.spawn
425 426 427
        else:
            original_compile = self.compiler._compile

428 429 430
        def unix_custom_single_compiler(
            obj, src, ext, cc_args, extra_postargs, pp_opts
        ):
431
            """
H
HongyuJia 已提交
432
            Monkey patch mechanism to replace inner compiler to custom compile process on Unix platform.
433
            """
H
HongyuJia 已提交
434
            # use abspath to ensure no warning and don't remove deepcopy because modify params
435 436 437 438 439
            # with dict type is dangerous.
            src = os.path.abspath(src)
            cflags = copy.deepcopy(extra_postargs)
            try:
                original_compiler = self.compiler.compiler_so
440
                # nvcc or hipcc compile CUDA source
441
                if is_cuda_file(src):
442
                    if core.is_compiled_with_rocm():
443 444 445
                        assert (
                            ROCM_HOME is not None
                        ), "Not found ROCM runtime, \
446 447
                            please use `export ROCM_PATH= XXX` to specify it."

448 449 450 451 452 453
                        hipcc_cmd = os.path.join(ROCM_HOME, 'bin', 'hipcc')
                        self.compiler.set_executable('compiler_so', hipcc_cmd)
                        # {'nvcc': {}, 'cxx: {}}
                        if isinstance(cflags, dict):
                            cflags = cflags['hipcc']
                    else:
454 455 456
                        assert (
                            CUDA_HOME is not None
                        ), "Not found CUDA runtime, \
457 458
                            please use `export CUDA_HOME= XXX` to specify it."

459 460 461 462 463
                        nvcc_cmd = os.path.join(CUDA_HOME, 'bin', 'nvcc')
                        self.compiler.set_executable('compiler_so', nvcc_cmd)
                        # {'nvcc': {}, 'cxx: {}}
                        if isinstance(cflags, dict):
                            cflags = cflags['nvcc']
464 465

                    cflags = prepare_unix_cudaflags(cflags)
466 467 468 469
                # cxx compile Cpp source
                elif isinstance(cflags, dict):
                    cflags = cflags['cxx']

470 471 472 473 474
                # Note(qili93): HIP require some additional flags for CMAKE_C_FLAGS
                if core.is_compiled_with_rocm():
                    cflags.append('-D__HIP_PLATFORM_HCC__')
                    cflags.append('-D__HIP_NO_HALF_CONVERSIONS__=1')
                    cflags.append(
475 476
                        '-DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP'
                    )
477

478 479
                # NOTE(Aurelius84): Since Paddle 2.0, we require gcc version > 5.x,
                # so we add this flag to ensure the symbol names from user compiled
480
                # shared library have same ABI suffix with libpaddle.so.
481
                # See https://stackoverflow.com/questions/34571583/understanding-gcc-5s-glibcxx-use-cxx11-abi-or-the-new-abi
482
                add_compile_flag(cflags, ['-D_GLIBCXX_USE_CXX11_ABI=1'])
483 484
                # Append this macor only when jointly compiling .cc with .cu
                if not is_cuda_file(src) and self.contain_cuda_file:
485 486 487 488
                    if core.is_compiled_with_rocm():
                        cflags.append('-DPADDLE_WITH_HIP')
                    else:
                        cflags.append('-DPADDLE_WITH_CUDA')
489

490 491 492
                add_std_without_repeat(
                    cflags, self.compiler.compiler_type, use_std14=True
                )
493 494 495
                original_compile(obj, src, ext, cc_args, cflags, pp_opts)
            finally:
                # restore original_compiler
496
                self.compiler.set_executable('compiler_so', original_compiler)
497

498 499 500 501 502 503 504 505 506 507
        def win_custom_single_compiler(
            sources,
            output_dir=None,
            macros=None,
            include_dirs=None,
            debug=0,
            extra_preargs=None,
            extra_postargs=None,
            depends=None,
        ):
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523

            self.cflags = copy.deepcopy(extra_postargs)
            extra_postargs = None

            def win_custom_spawn(cmd):
                # Using regex to modify compile options
                compile_options = self.compiler.compile_options
                for i in range(len(cmd)):
                    if re.search('/MD', cmd[i]) is not None:
                        cmd[i] = '/MT'
                    if re.search('/W[1-4]', cmd[i]) is not None:
                        cmd[i] = '/W0'

                # Using regex to match src, obj and include files
                src_regex = re.compile('/T(p|c)(.*)')
                src_list = [
524 525
                    m.group(2)
                    for m in (src_regex.match(elem) for elem in cmd)
526 527 528 529 530
                    if m
                ]

                obj_regex = re.compile('/Fo(.*)')
                obj_list = [
531 532
                    m.group(1)
                    for m in (obj_regex.match(elem) for elem in cmd)
533 534 535 536 537 538
                    if m
                ]

                include_regex = re.compile(r'((\-|\/)I.*)')
                include_list = [
                    m.group(1)
539 540
                    for m in (include_regex.match(elem) for elem in cmd)
                    if m
541 542 543 544 545 546
                ]

                assert len(src_list) == 1 and len(obj_list) == 1
                src = src_list[0]
                obj = obj_list[0]
                if is_cuda_file(src):
547 548 549
                    assert (
                        CUDA_HOME is not None
                    ), "Not found CUDA runtime, \
550 551
                        please use `export CUDA_HOME= XXX` to specify it."

552 553 554 555 556 557 558 559
                    nvcc_cmd = os.path.join(CUDA_HOME, 'bin', 'nvcc')
                    if isinstance(self.cflags, dict):
                        cflags = self.cflags['nvcc']
                    elif isinstance(self.cflags, list):
                        cflags = self.cflags
                    else:
                        cflags = []

560
                    cflags = prepare_win_cudaflags(cflags) + ['--use-local-env']
561 562
                    for flag in MSVC_COMPILE_FLAGS:
                        cflags = ['-Xcompiler', flag] + cflags
563 564 565
                    cmd = (
                        [nvcc_cmd, '-c', src, '-o', obj] + include_list + cflags
                    )
566 567 568 569 570 571
                elif isinstance(self.cflags, dict):
                    cflags = MSVC_COMPILE_FLAGS + self.cflags['cxx']
                    cmd += cflags
                elif isinstance(self.cflags, list):
                    cflags = MSVC_COMPILE_FLAGS + self.cflags
                    cmd += cflags
572 573 574
                # Append this macor only when jointly compiling .cc with .cu
                if not is_cuda_file(src) and self.contain_cuda_file:
                    cmd.append('-DPADDLE_WITH_CUDA')
575 576 577 578 579

                return original_spawn(cmd)

            try:
                self.compiler.spawn = win_custom_spawn
580 581 582 583 584 585 586 587 588 589
                return original_compile(
                    sources,
                    output_dir,
                    macros,
                    include_dirs,
                    debug,
                    extra_preargs,
                    extra_postargs,
                    depends,
                )
590 591 592
            finally:
                self.compiler.spawn = original_spawn

593 594
        def object_filenames_with_cuda(origina_func, build_directory):
            """
H
HongyuJia 已提交
595
            Decorated the function to add customized naming mechanism.
596 597 598 599 600 601
            Originally, both .cc/.cu will have .o object output that will
            bring file override problem. Use .cu.o as CUDA object suffix.
            """

            def wrapper(source_filenames, strip_dir=0, output_dir=''):
                try:
602 603 604
                    objects = origina_func(
                        source_filenames, strip_dir, output_dir
                    )
605
                    for i, source in enumerate(source_filenames):
606
                        # modify xx.o -> xx.cu.o/xx.cu.obj
607 608
                        if is_cuda_file(source):
                            old_obj = objects[i]
609 610 611 612
                            if self.compiler.compiler_type == 'msvc':
                                objects[i] = old_obj[:-3] + 'cu.obj'
                            else:
                                objects[i] = old_obj[:-1] + 'cu.o'
613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
                    # if user set build_directory, output objects there.
                    if build_directory is not None:
                        objects = [
                            os.path.join(build_directory, os.path.basename(obj))
                            for obj in objects
                        ]
                    # ensure to use abspath
                    objects = [os.path.abspath(obj) for obj in objects]
                finally:
                    self.compiler.object_filenames = origina_func

                return objects

            return wrapper

        # customized compile process
629 630 631 632 633
        if self.compiler.compiler_type == 'msvc':
            self.compiler.compile = win_custom_single_compiler
        else:
            self.compiler._compile = unix_custom_single_compiler

634
        self.compiler.object_filenames = object_filenames_with_cuda(
635 636
            self.compiler.object_filenames, self.build_lib
        )
637
        self._record_op_info()
638 639

        print("Compiling user custom op, it will cost a few seconds.....")
640 641
        build_ext.build_extensions(self)

642 643 644 645
        # Reset runtime library path on MacOS platform
        so_path = self.get_ext_fullpath(self.extensions[0]._full_name)
        _reset_so_rpath(so_path)

646
    def get_ext_filename(self, fullname):
H
HongyuJia 已提交
647
        # for example: customized_extension.cpython-37m-x86_64-linux-gnu.so
648
        ext_name = super().get_ext_filename(fullname)
649 650
        split_str = '.'
        name_items = ext_name.split(split_str)
651
        if self.no_python_abi_suffix:
652 653 654 655 656
            assert (
                len(name_items) > 2
            ), "Expected len(name_items) > 2, but received {}".format(
                len(name_items)
            )
657 658 659
            name_items.pop(-2)
            ext_name = split_str.join(name_items)

H
HongyuJia 已提交
660
        # customized_extension.dylib
661 662 663
        if OS_NAME.startswith('darwin'):
            name_items[-1] = 'dylib'
            ext_name = split_str.join(name_items)
664 665
        return ext_name

666 667 668 669 670 671
    def _valid_clang_compiler(self):
        """
        Make sure to use Clang as compiler on Mac platform
        """
        compiler_infos = ['clang'] + CLANG_COMPILE_FLAGS
        linker_infos = ['clang'] + CLANG_LINK_FLAGS
672 673 674 675 676 677 678
        self.compiler.set_executables(
            compiler=compiler_infos,
            compiler_so=compiler_infos,
            compiler_cxx=['clang'],
            linker_exe=['clang'],
            linker_so=linker_infos,
        )
679

680
    def _check_abi(self):
681 682 683 684 685 686 687 688 689 690 691
        """
        Check ABI Compatibility.
        """
        if hasattr(self.compiler, 'compiler_cxx'):
            compiler = self.compiler.compiler_cxx[0]
        elif IS_WINDOWS:
            compiler = os.environ.get('CXX', 'cl')
        else:
            compiler = os.environ.get('CXX', 'c++')

        check_abi_compatibility(compiler)
692
        # Warn user if VC env is activated but `DISTUTILS_USE_SDK` is not set.
693 694 695 696 697
        if (
            IS_WINDOWS
            and 'VSCMD_ARG_TGT_ARCH' in os.environ
            and 'DISTUTILS_USE_SDK' not in os.environ
        ):
698 699 700
            msg = (
                'It seems that the VC environment is activated but DISTUTILS_USE_SDK is not set.'
                'This may lead to multiple activations of the VC env.'
701 702
                'Please run `set DISTUTILS_USE_SDK=1` and try again.'
            )
703
            raise UserWarning(msg)
704 705 706

    def _record_op_info(self):
        """
707
        Record custom op information.
708 709 710 711
        """
        # parse shared library abs path
        outputs = self.get_outputs()
        assert len(outputs) == 1
712 713 714 715 716 717
        # multi operators built into same one .so file
        so_path = os.path.abspath(outputs[0])
        so_name = os.path.basename(so_path)

        for i, extension in enumerate(self.extensions):
            sources = [os.path.abspath(s) for s in extension.sources]
718 719
            if not self.contain_cuda_file:
                self.contain_cuda_file = any([is_cuda_file(s) for s in sources])
720 721 722
            op_names = parse_op_name_from(sources)

            for op_name in op_names:
723 724 725
                CustomOpInfo.instance().add(
                    op_name, so_name=so_name, so_path=so_path
                )
726 727


728
class EasyInstallCommand(easy_install):
729
    """
H
HongyuJia 已提交
730
    Extend easy_install Command to control the behavior of naming shared library
731 732 733 734 735 736 737
    file.

    NOTE(Aurelius84): This is a hook subclass inherited Command used to rename shared
                    library file after extracting egg-info into site-packages.
    """

    def __init__(self, *args, **kwargs):
738
        super().__init__(*args, **kwargs)
739 740 741

    # NOTE(Aurelius84): Add args and kwargs to make compatible with PY2/PY3
    def run(self, *args, **kwargs):
742
        super().run(*args, **kwargs)
743 744 745 746 747
        # NOTE: To avoid failing import .so file instead of
        # python file because they have same name, we rename
        # .so shared library to another name.
        for egg_file in self.outputs:
            filename, ext = os.path.splitext(egg_file)
748 749 750
            will_rename = False
            if OS_NAME.startswith('linux') and ext == '.so':
                will_rename = True
751 752
            elif OS_NAME.startswith('darwin') and ext == '.dylib':
                will_rename = True
753 754 755 756
            elif IS_WINDOWS and ext == '.pyd':
                will_rename = True

            if will_rename:
757 758 759 760 761 762
                new_so_path = filename + "_pd_" + ext
                if not os.path.exists(new_so_path):
                    os.rename(r'%s' % egg_file, r'%s' % new_so_path)
                assert os.path.exists(new_so_path)


763
class BuildCommand(build):
764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787
    """
    Extend build Command to control the behavior of specifying `build_base` root directory.

    NOTE(Aurelius84): This is a hook subclass inherited Command used to specify customized
                      build_base directory.
    """

    @classmethod
    def with_options(cls, **options):
        """
        Returns a BuildCommand subclass containing use-defined options.
        """

        class cls_with_options(cls):
            def __init__(self, *args, **kwargs):
                kwargs.update(options)
                cls.__init__(self, *args, **kwargs)

        return cls_with_options

    def __init__(self, *args, **kwargs):
        # Note: shall put before super()
        self._specified_build_base = kwargs.get('build_base', None)

788
        super().__init__(*args, **kwargs)
789 790 791 792 793 794

    def initialize_options(self):
        """
        build_base is root directory for all sub-command, such as
        build_lib, build_temp. See `distutils.command.build` for details.
        """
795
        super().initialize_options()
796 797 798 799
        if self._specified_build_base is not None:
            self.build_base = self._specified_build_base


800 801 802 803 804 805 806 807 808 809
def load(
    name,
    sources,
    extra_cxx_cflags=None,
    extra_cuda_cflags=None,
    extra_ldflags=None,
    extra_include_paths=None,
    build_directory=None,
    verbose=False,
):
810 811 812
    """
    An Interface to automatically compile C++/CUDA source files Just-In-Time
    and return callable python function as other Paddle layers API. It will
813 814 815
    append user defined custom operators in background while building models.

    It will perform compiling, linking, Python API generation and module loading
816 817 818
    processes under a individual subprocess. It does not require CMake or Ninja
    environment. On Linux platform, it requires GCC compiler whose version is
    greater than 5.4 and it should be soft linked to ``/usr/bin/cc`` . On Windows
819
    platform, it requires Visual Studio whose version is greater than 2017.
820
    On MacOS, clang++ is requited. In addition, if compiling Operators supporting
821
    GPU device, please make sure ``nvcc`` compiler is installed in local environment.
822 823

    Moreover, `ABI compatibility <https://gcc.gnu.org/onlinedocs/libstdc++/manual/abi.html>`_
824
    will be checked to ensure that compiler version from ``cc(Linux)`` , ``cl.exe(Windows)``
825
    on local machine is compatible with pre-installed Paddle whl in python site-packages.
826

827 828 829 830 831
    For Linux, GCC version will be checked . For example if Paddle with CUDA 10.1 is built with GCC 8.2,
    then the version of user's local machine should satisfy GCC >= 8.2.
    For Windows, Visual Studio version will be checked, and it should be greater than or equal to that of
    PaddlePaddle (Visual Studio 2017).
    If the above conditions are not met, the corresponding warning will be printed, and a fatal error may
832
    occur because of ABI compatibility.
833 834 835 836 837

    Compared with ``setup`` interface, it doesn't need extra ``setup.py`` and excute
    ``python setup.py install`` command. The interface contains all compiling and installing
    process underground.

838
    Note:
839

H
HongyuJia 已提交
840 841
        1. Currently we support Linux, MacOS and Windows platform.
        2. On Linux platform, we recommend to use GCC 8.2 as soft linking candidate of ``/usr/bin/cc`` .
842
           Then, Use ``which cc`` to ensure location of ``cc`` and using ``cc --version`` to ensure linking
843
           GCC version.
844
        3. On Windows platform, we recommend to install `` Visual Studio`` (>=2017).
845 846 847 848 849


    **A simple example:**

    .. code-block:: text
850

851 852 853 854 855
        import paddle
        from paddle.utils.cpp_extension import load

        custom_op_module = load(
            name="op_shared_libary_name",                # name of shared library
856
            sources=['relu_op.cc', 'relu_op.cu'],        # source files of customized op
857 858 859
            extra_cxx_cflags=['-g', '-w'],               # optional, specify extra flags to compile .cc/.cpp file
            extra_cuda_cflags=['-O2'],                   # optional, specify extra flags to compile .cu file
            verbose=True                                 # optional, specify to output log information
860 861 862 863
        )

        x = paddle.randn([4, 10], dtype='float32')
        out = custom_op_module.relu(x)
864 865 866


    Args:
867 868 869 870
        name(str): Specify the name of generated shared library file name, not including ``.so`` and ``.dll`` suffix.
        sources(list[str]): Specify source files name of customized operators.  Supporting ``.cc`` , ``.cpp`` for CPP file
                            and ``.cu`` for CUDA file.
        extra_cxx_cflags(list[str], optional): Specify additional flags used to compile CPP files. By default
871
                               all basic and framework related flags have been included.
872
        extra_cuda_cflags(list[str], optional): Specify additional flags used to compile CUDA files. By default
873
                               all basic and framework related flags have been included.
874
                               See `Cuda Compiler Driver NVCC <https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html>`_
875 876 877 878 879 880 881 882 883 884 885
                               for details. Default is None.
        extra_ldflags(list[str], optional): Specify additional flags used to link shared library. See
                                `GCC Link Options <https://gcc.gnu.org/onlinedocs/gcc/Link-Options.html>`_ for details.
                                Default is None.
        extra_include_paths(list[str], optional): Specify additional include path used to search header files. By default
                                all basic headers are included implicitly from ``site-package/paddle/include`` .
                                Default is None.
        build_directory(str, optional): Specify root directory path to put shared library file. If set None,
                            it will use ``PADDLE_EXTENSION_DIR`` from os.environ. Use
                            ``paddle.utils.cpp_extension.get_build_directory()`` to see the location. Default is None.
        verbose(bool, optional): whether to verbose compiled log information. Default is False
886 887

    Returns:
888
        Module: A callable python module contains all CustomOp Layer APIs.
889 890

    """
891 892

    if build_directory is None:
893 894
        build_directory = get_build_directory(verbose)

895 896
    # ensure to use abs path
    build_directory = os.path.abspath(build_directory)
897

898
    log_v("build_directory: {}".format(build_directory), verbose)
899

900
    file_path = os.path.join(build_directory, "{}_setup.py".format(name))
901 902
    sources = [os.path.abspath(source) for source in sources]

903 904 905 906
    if extra_cxx_cflags is None:
        extra_cxx_cflags = []
    if extra_cuda_cflags is None:
        extra_cuda_cflags = []
907 908 909
    assert isinstance(
        extra_cxx_cflags, list
    ), "Required type(extra_cxx_cflags) == list[str], but received {}".format(
910 911
        extra_cxx_cflags
    )
912 913 914
    assert isinstance(
        extra_cuda_cflags, list
    ), "Required type(extra_cuda_cflags) == list[str], but received {}".format(
915 916
        extra_cuda_cflags
    )
917

918 919
    log_v(
        "additional extra_cxx_cflags: [{}], extra_cuda_cflags: [{}]".format(
920 921 922 923
            ' '.join(extra_cxx_cflags), ' '.join(extra_cuda_cflags)
        ),
        verbose,
    )
924

925
    # write setup.py file and compile it
926
    build_base_dir = os.path.join(build_directory, name)
927

928 929 930 931 932 933 934 935 936 937 938
    _write_setup_file(
        name,
        sources,
        file_path,
        build_base_dir,
        extra_include_paths,
        extra_cxx_cflags,
        extra_cuda_cflags,
        extra_ldflags,
        verbose,
    )
939
    _jit_compile(file_path, verbose)
940 941

    # import as callable python api
942
    custom_op_api = _import_module_from_library(name, build_base_dir, verbose)
943 944

    return custom_op_api