未验证 提交 702fc894 编写于 作者: H HongyuJia 提交者: GitHub

[Custom Op] Support extra_library_paths of load (#51249)

* [Custom Op] Support extra_library_paths of load

* change API order

* change unittest parameter order
上级 6bfb8152
...@@ -17,7 +17,13 @@ import unittest ...@@ -17,7 +17,13 @@ import unittest
import numpy as np import numpy as np
from test_custom_relu_op_setup import custom_relu_dynamic, custom_relu_static from test_custom_relu_op_setup import custom_relu_dynamic, custom_relu_static
from utils import IS_MAC, extra_cc_args, extra_nvcc_args, paddle_includes from utils import (
IS_MAC,
extra_cc_args,
extra_nvcc_args,
paddle_includes,
paddle_libraries,
)
import paddle import paddle
from paddle.utils.cpp_extension import get_build_directory, load from paddle.utils.cpp_extension import get_build_directory, load
...@@ -44,6 +50,7 @@ custom_module = load( ...@@ -44,6 +50,7 @@ custom_module = load(
name='custom_relu_module_jit', name='custom_relu_module_jit',
sources=sources, sources=sources,
extra_include_paths=paddle_includes, # add for Coverage CI extra_include_paths=paddle_includes, # add for Coverage CI
extra_library_paths=paddle_libraries,
extra_cxx_cflags=extra_cc_args, # test for cc flags extra_cxx_cflags=extra_cc_args, # test for cc flags
extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags
verbose=True, verbose=True,
......
...@@ -25,6 +25,7 @@ IS_MAC = sys.platform.startswith('darwin') ...@@ -25,6 +25,7 @@ IS_MAC = sys.platform.startswith('darwin')
# paddle include directory. Because the following path is generated after installing # paddle include directory. Because the following path is generated after installing
# PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI. # PaddlePaddle whl. So here we specific `include_dirs` to avoid errors in CI.
paddle_includes = [] paddle_includes = []
paddle_libraries = []
for site_packages_path in getsitepackages(): for site_packages_path in getsitepackages():
paddle_includes.append( paddle_includes.append(
os.path.join(site_packages_path, 'paddle', 'include') os.path.join(site_packages_path, 'paddle', 'include')
...@@ -32,6 +33,7 @@ for site_packages_path in getsitepackages(): ...@@ -32,6 +33,7 @@ for site_packages_path in getsitepackages():
paddle_includes.append( paddle_includes.append(
os.path.join(site_packages_path, 'paddle', 'include', 'third_party') os.path.join(site_packages_path, 'paddle', 'include', 'third_party')
) )
paddle_libraries.append(os.path.join(site_packages_path, 'paddle', 'libs'))
# Test for extra compile args # Test for extra compile args
extra_cc_args = ['-w', '-g'] if not IS_WINDOWS else ['/w'] extra_cc_args = ['-w', '-g'] if not IS_WINDOWS else ['/w']
......
...@@ -804,6 +804,7 @@ def load( ...@@ -804,6 +804,7 @@ def load(
extra_cuda_cflags=None, extra_cuda_cflags=None,
extra_ldflags=None, extra_ldflags=None,
extra_include_paths=None, extra_include_paths=None,
extra_library_paths=None,
build_directory=None, build_directory=None,
verbose=False, verbose=False,
): ):
...@@ -879,10 +880,13 @@ def load( ...@@ -879,10 +880,13 @@ def load(
extra_include_paths(list[str], optional): Specify additional include path used to search header files. By default extra_include_paths(list[str], optional): Specify additional include path used to search header files. By default
all basic headers are included implicitly from ``site-package/paddle/include`` . all basic headers are included implicitly from ``site-package/paddle/include`` .
Default is None. Default is None.
extra_library_paths(list[str], optional): Specify additional library path used to search library files. By default
all basic libraries are included implicitly from ``site-packages/paddle/libs`` .
Default is None.
build_directory(str, optional): Specify root directory path to put shared library file. If set None, build_directory(str, optional): Specify root directory path to put shared library file. If set None,
it will use ``PADDLE_EXTENSION_DIR`` from os.environ. Use it will use ``PADDLE_EXTENSION_DIR`` from os.environ. Use
``paddle.utils.cpp_extension.get_build_directory()`` to see the location. Default is None. ``paddle.utils.cpp_extension.get_build_directory()`` to see the location. Default is None.
verbose(bool, optional): whether to verbose compiled log information. Default is False verbose(bool, optional): whether to verbose compiled log information. Default is False.
Returns: Returns:
Module: A callable python module contains all CustomOp Layer APIs. Module: A callable python module contains all CustomOp Layer APIs.
...@@ -931,6 +935,7 @@ def load( ...@@ -931,6 +935,7 @@ def load(
file_path, file_path,
build_base_dir, build_base_dir,
extra_include_paths, extra_include_paths,
extra_library_paths,
extra_cxx_cflags, extra_cxx_cflags,
extra_cuda_cflags, extra_cuda_cflags,
extra_ldflags, extra_ldflags,
......
...@@ -1160,6 +1160,7 @@ def _write_setup_file( ...@@ -1160,6 +1160,7 @@ def _write_setup_file(
file_path, file_path,
build_dir, build_dir,
include_dirs, include_dirs,
library_dirs,
extra_cxx_cflags, extra_cxx_cflags,
extra_cuda_cflags, extra_cuda_cflags,
link_args, link_args,
...@@ -1181,6 +1182,7 @@ def _write_setup_file( ...@@ -1181,6 +1182,7 @@ def _write_setup_file(
{prefix}Extension( {prefix}Extension(
sources={sources}, sources={sources},
include_dirs={include_dirs}, include_dirs={include_dirs},
library_dirs={library_dirs},
extra_compile_args={{'cxx':{extra_cxx_cflags}, 'nvcc':{extra_cuda_cflags}}}, extra_compile_args={{'cxx':{extra_cxx_cflags}, 'nvcc':{extra_cuda_cflags}}},
extra_link_args={extra_link_args})], extra_link_args={extra_link_args})],
cmdclass={{"build_ext" : BuildExtension.with_options( cmdclass={{"build_ext" : BuildExtension.with_options(
...@@ -1199,6 +1201,7 @@ def _write_setup_file( ...@@ -1199,6 +1201,7 @@ def _write_setup_file(
prefix='CUDA' if with_cuda else 'Cpp', prefix='CUDA' if with_cuda else 'Cpp',
sources=list2str(sources), sources=list2str(sources),
include_dirs=list2str(include_dirs), include_dirs=list2str(include_dirs),
library_dirs=list2str(library_dirs),
extra_cxx_cflags=list2str(extra_cxx_cflags), extra_cxx_cflags=list2str(extra_cxx_cflags),
extra_cuda_cflags=list2str(extra_cuda_cflags), extra_cuda_cflags=list2str(extra_cuda_cflags),
extra_link_args=list2str(link_args), extra_link_args=list2str(link_args),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册