未验证 提交 2a0563fa 编写于 作者: H HongyuJia 提交者: GitHub

[Cpp Extension] Support Cpp Extension (#49893)

* update include directory

* fully support C++ extension, pass unittest

* fix include directory

* support both extension and operator in one file

* polish testcase

* add jit unittest

* update third_party.cmake, pass CI test

* fix cmake

* fix setup

* fix inference, fix unittest precision

* fix unittest precision

* fix inference_lib cmake

* try fix setup, try fix inference_lib

* try fix inference_lib pybind

* fix mix_op_extension, fix inference_lib

* fix mix_op_extension, fix inference_lib

* change cmake

* change cmake

* add compile flags

* add Python.h headerfile

* add test_custom_plugin_creater cmake

* comment compile flag

* pass all CI

* pass all CI

* comment compile flag

* try solve test_custom_plugin_creater link error

* try solve test_custom_plugin_creater link error

* polish codes

* remove windows compile flag

* remove python_include_path

* update pybind11, 2.4.3->2.6.0

* update pybind11, 2.6.0->2.10.0

* update pybind11, 2.10.0->2.6.0b1

* update pybind11, 2.6.0b1->2.6.0, start fix unittest

* fix pybind11 2.6.0 VarBase print error

* fix pybind11 2.6.0 VarBase print error

* handle PADDLE_ON_INFERENCE

* modify according to reviewer

* fix cmake

* cmake decouple pybind_util when not ON_INFER

* cmake decouple pybind_util when not ON_INFER

* remove copy of inference_lib.cmake

* change pybind.cc headerfile fluid->phi
上级 8c844356
......@@ -301,7 +301,7 @@ if(TARGET extern_protobuf)
list(APPEND third_party_deps extern_protobuf)
endif()
if(WITH_PYTHON)
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
include(external/python) # find python and python_module
include(external/pybind11) # download pybind11
list(APPEND third_party_deps extern_pybind)
......
......@@ -16,3 +16,7 @@ limitations under the License. */
// All paddle apis in C++ frontend
#include "paddle/phi/api/all.h"
// Python bindings for the C++ frontend
#ifndef PADDLE_ON_INFERENCE
#include "paddle/utils/pybind.h"
#endif
......@@ -502,6 +502,7 @@ if(WITH_PYTHON)
list(APPEND PYBIND_DEPS eager_tensor_operants)
list(APPEND PYBIND_DEPS static_tensor_operants)
list(APPEND PYBIND_DEPS phi_tensor_operants)
list(APPEND PYBIND_DEPS pybind_util)
endif()
# On Linux, cc_library(paddle SHARED ..) will generate the libpaddle.so,
......
......@@ -46,8 +46,8 @@ namespace pybind {
namespace py = ::pybind11;
PyTypeObject* p_tensor_type;
PyTypeObject* p_string_tensor_type; // For StringTensor
extern PyTypeObject* p_tensor_type;
extern PyTypeObject* p_string_tensor_type; // For StringTensor
extern PyTypeObject* g_vartype_pytype;
extern PyTypeObject* g_framework_tensor_pytype;
......
......@@ -19,18 +19,13 @@ limitations under the License. */
#include "paddle/fluid/eager/hooks.h"
#include "paddle/fluid/eager/pylayer/py_layer_node.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/pybind.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace paddle {
namespace pybind {
typedef struct {
PyObject_HEAD paddle::experimental::Tensor tensor;
// Weak references
PyObject* weakrefs;
} TensorObject;
typedef struct {
PyObject_HEAD PyObject* container;
bool container_be_packed;
......
......@@ -211,24 +211,6 @@ std::string CastPyArg2AttrString(PyObject* obj, ssize_t arg_pos) {
}
}
bool PyCheckTensor(PyObject* obj) {
return PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type));
}
paddle::experimental::Tensor CastPyArg2Tensor(PyObject* obj, ssize_t arg_pos) {
if (PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type)) ||
PyObject_IsInstance(obj,
reinterpret_cast<PyObject*>(p_string_tensor_type))) {
return reinterpret_cast<TensorObject*>(obj)->tensor;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be "
"Tensor, but got %s",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
}
std::shared_ptr<imperative::VarBase> CastPyArg2VarBase(PyObject* obj,
ssize_t arg_pos) {
return py::cast<std::shared_ptr<imperative::VarBase>>(obj);
......@@ -662,30 +644,6 @@ PyObject* ToPyObject(const std::string& value) {
return PyUnicode_FromString(value.c_str());
}
PyObject* ToPyObject(const paddle::experimental::Tensor& value,
bool return_py_none_if_not_initialize) {
if (return_py_none_if_not_initialize && !value.initialized()) {
RETURN_PY_NONE
}
PyObject* obj = nullptr;
if (value.initialized() && value.is_string_tensor()) {
// In order to return the core.eager.StringTensor, there is need
// to use p_string_tensor_type to create a python obj.
obj = p_string_tensor_type->tp_alloc(p_string_tensor_type, 0);
} else {
obj = p_tensor_type->tp_alloc(p_tensor_type, 0);
}
if (obj) {
auto v = reinterpret_cast<TensorObject*>(obj);
new (&(v->tensor)) paddle::experimental::Tensor();
v->tensor = value;
} else {
PADDLE_THROW(platform::errors::Fatal(
"tp_alloc return null, can not new a PyObject."));
}
return obj;
}
PyObject* ToPyObject(const paddle::experimental::Tensor& value,
PyObject* args,
const std::map<ssize_t, ssize_t>& inplace_var_idx_map) {
......
......@@ -34,6 +34,7 @@ typedef SSIZE_T ssize_t;
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/utils/pybind.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace paddle {
......@@ -44,14 +45,9 @@ class Scope;
namespace pybind {
namespace py = ::pybind11;
#define RETURN_PY_NONE \
Py_INCREF(Py_None); \
return Py_None;
int TensorDtype2NumpyDtype(phi::DataType dtype);
bool PyCheckTensor(PyObject* obj);
bool PyObject_CheckLongOrConvertToLong(PyObject** obj);
bool PyObject_CheckFloatOrConvertToFloat(PyObject** obj);
bool PyObject_CheckStr(PyObject* obj);
......@@ -63,7 +59,6 @@ float CastPyArg2AttrFloat(PyObject* obj, ssize_t arg_pos);
std::string CastPyArg2AttrString(PyObject* obj, ssize_t arg_pos);
paddle::CustomOpKernelContext CastPyArg2CustomOpKernelContext(PyObject* obj,
ssize_t arg_pos);
paddle::experimental::Tensor CastPyArg2Tensor(PyObject* obj, ssize_t arg_pos);
std::shared_ptr<imperative::VarBase> CastPyArg2VarBase(PyObject* obj,
ssize_t arg_pos);
std::vector<paddle::experimental::Tensor> CastPyArg2VectorOfTensor(
......@@ -94,8 +89,6 @@ PyObject* ToPyObject(float value);
PyObject* ToPyObject(double value);
PyObject* ToPyObject(const char* value);
PyObject* ToPyObject(const std::string& value);
PyObject* ToPyObject(const paddle::experimental::Tensor& value,
bool return_py_none_if_not_initialize = false);
PyObject* ToPyObject(const paddle::experimental::Tensor& value,
PyObject* args,
const std::map<ssize_t, ssize_t>& inplace_var_idx_map);
......
......@@ -12,3 +12,10 @@ cc_test(
variant_test
SRCS variant_test.cc
DEPS gtest)
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
cc_library(
pybind_util
SRCS pybind.cc
DEPS phi_tensor_raw)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/utils/pybind.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace pybind {
PyTypeObject* p_tensor_type;
PyTypeObject* p_string_tensor_type;
bool PyCheckTensor(PyObject* obj) {
if (!p_tensor_type) {
return false;
}
return PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type));
}
paddle::experimental::Tensor CastPyArg2Tensor(PyObject* obj, ssize_t arg_pos) {
if (PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type)) ||
PyObject_IsInstance(obj,
reinterpret_cast<PyObject*>(p_string_tensor_type))) {
return reinterpret_cast<TensorObject*>(obj)->tensor;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"argument (position %d) must be "
"Tensor, but got %s",
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
}
PyObject* ToPyObject(const paddle::experimental::Tensor& value,
bool return_py_none_if_not_initialize) {
if (return_py_none_if_not_initialize && !value.initialized()) {
RETURN_PY_NONE
}
PyObject* obj = nullptr;
if (value.initialized() && value.is_string_tensor()) {
// In order to return the core.eager.StringTensor, there is need
// to use p_string_tensor_type to create a python obj.
obj = p_string_tensor_type->tp_alloc(p_string_tensor_type, 0);
} else {
obj = p_tensor_type->tp_alloc(p_tensor_type, 0);
}
if (obj) {
auto v = reinterpret_cast<TensorObject*>(obj);
new (&(v->tensor)) paddle::experimental::Tensor();
v->tensor = value;
} else {
PADDLE_THROW(
phi::errors::Fatal("tp_alloc return null, can not new a PyObject."));
}
return obj;
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/api/include/tensor.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
extern PyTypeObject* p_tensor_type;
typedef struct {
PyObject_HEAD paddle::experimental::Tensor tensor;
// Weak references
PyObject* weakrefs;
} TensorObject;
#define RETURN_PY_NONE \
Py_INCREF(Py_None); \
return Py_None;
// Internal use only, to expose the Tensor type to Python.
bool PyCheckTensor(PyObject* obj);
// Internal use only, to expose the Tensor type to Python.
paddle::experimental::Tensor CastPyArg2Tensor(PyObject* obj, ssize_t arg_pos);
// Internal use only, to expose the Tensor type to Python.
PyObject* ToPyObject(const paddle::experimental::Tensor& value,
bool return_py_none_if_not_initialize = false);
} // namespace pybind
} // namespace paddle
namespace pybind11 {
namespace detail {
template <>
struct type_caster<paddle::experimental::Tensor> {
public:
PYBIND11_TYPE_CASTER(paddle::experimental::Tensor,
_("paddle::experimental::Tensor"));
bool load(handle src, bool) {
PyObject* obj = src.ptr();
if (paddle::pybind::PyCheckTensor(obj)) {
value = paddle::pybind::CastPyArg2Tensor(obj, 0);
return true;
}
return false;
}
static handle cast(const paddle::experimental::Tensor& src,
return_value_policy /* policy */,
handle /* parent */) {
return handle(paddle::pybind::ToPyObject(src));
}
};
} // namespace detail
} // namespace pybind11
......@@ -66,5 +66,6 @@ env_dict={
'ORIGIN':'@ORIGIN@',
'WIN32':'@WIN32@',
'JIT_RELEASE_WHL':'@JIT_RELEASE_WHL@',
'WITH_PSLIB':'@WITH_PSLIB@'
'WITH_PSLIB':'@WITH_PSLIB@',
'PYBIND_INCLUDE_DIR':'@PYBIND_INCLUDE_DIR@'
}
......@@ -10,6 +10,7 @@ endforeach()
add_subdirectory(unittests)
add_subdirectory(book)
add_subdirectory(cpp_extension)
add_subdirectory(custom_op)
add_subdirectory(custom_kernel)
add_subdirectory(custom_runtime)
py_test(test_cpp_extension_setup SRCS test_cpp_extension_setup.py)
py_test(test_cpp_extension_jit SRCS test_cpp_extension_jit.py)
set_tests_properties(test_cpp_extension_setup PROPERTIES TIMEOUT 120)
set_tests_properties(test_cpp_extension_jit PROPERTIES TIMEOUT 120)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from site import getsitepackages
from paddle.utils.cpp_extension import CppExtension, setup
paddle_includes = []
for site_packages_path in getsitepackages():
paddle_includes.append(
os.path.join(site_packages_path, 'paddle', 'include')
)
paddle_includes.append(
os.path.join(site_packages_path, 'paddle', 'include', 'third_party')
)
setup(
name='custom_cpp_extension',
ext_modules=CppExtension(
sources=["custom_add.cc", "custom_sub.cc"],
include_dirs=paddle_includes
+ [os.path.dirname(os.path.abspath(__file__))],
extra_compile_args={'cc': ['-w', '-g']},
verbose=True,
),
)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "custom_power.h" // NOLINT
#include "paddle/extension.h"
paddle::Tensor custom_sub(paddle::Tensor x, paddle::Tensor y);
paddle::Tensor custom_add(const paddle::Tensor& x, const paddle::Tensor& y) {
return paddle::add(paddle::exp(x), paddle::exp(y));
}
PYBIND11_MODULE(custom_cpp_extension, m) {
m.def("custom_add", &custom_add, "exp(x) + exp(y)");
m.def("custom_sub", &custom_sub, "exp(x) - exp(y)");
py::class_<Power>(m, "Power")
.def(py::init<int, int>())
.def(py::init<paddle::Tensor>())
.def("forward", &Power::forward)
.def("get", &Power::get);
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/extension.h"
struct Power {
Power(int A, int B) {
tensor_ = paddle::ones({A, B}, phi::DataType::FLOAT32, phi::CPUPlace());
}
explicit Power(paddle::Tensor x) { tensor_ = x; }
paddle::Tensor forward() { return paddle::experimental::pow(tensor_, 2); }
paddle::Tensor get() const { return tensor_; }
private:
paddle::Tensor tensor_;
};
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/extension.h"
paddle::Tensor custom_sub(paddle::Tensor x, paddle::Tensor y) {
return paddle::subtract(paddle::exp(x), paddle::exp(y));
}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest
from site import getsitepackages
import numpy as np
import paddle
from paddle.utils.cpp_extension import load
if os.name == 'nt' or sys.platform.startswith('darwin'):
# only support Linux now
exit()
# Compile and load cpp extension Just-In-Time.
sources = ["custom_add.cc", "custom_sub.cc"]
paddle_includes = []
for site_packages_path in getsitepackages():
paddle_includes.append(
os.path.join(site_packages_path, 'paddle', 'include')
)
paddle_includes.append(
os.path.join(site_packages_path, 'paddle', 'include', 'third_party')
)
# include "custom_power.h"
paddle_includes.append(os.path.dirname(os.path.abspath(__file__)))
custom_cpp_extension = load(
name='custom_cpp_extension',
sources=sources,
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=['-w', '-g'],
verbose=True,
)
class TestCppExtensionJITInstall(unittest.TestCase):
"""
Tests setup install cpp extensions.
"""
def setUp(self):
# config seed
SEED = 2021
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
self.dtypes = ['float32', 'float64']
def tearDown(self):
pass
def test_cpp_extension(self):
self._test_extension_function()
self._test_extension_class()
def _test_extension_function(self):
for dtype in self.dtypes:
np_x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
x = paddle.to_tensor(np_x, dtype=dtype)
np_y = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
y = paddle.to_tensor(np_y, dtype=dtype)
out = custom_cpp_extension.custom_add(x, y)
target_out = np.exp(np_x) + np.exp(np_y)
np.testing.assert_allclose(out.numpy(), target_out, atol=1e-5)
# Test we can call a method not defined in the main C++ file.
out = custom_cpp_extension.custom_sub(x, y)
target_out = np.exp(np_x) - np.exp(np_y)
np.testing.assert_allclose(out.numpy(), target_out, atol=1e-5)
def _test_extension_class(self):
for dtype in self.dtypes:
# Test we can use CppExtension class with C++ methods.
power = custom_cpp_extension.Power(3, 3)
self.assertEqual(power.get().sum(), 9)
self.assertEqual(power.forward().sum(), 9)
np_x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
x = paddle.to_tensor(np_x, dtype=dtype)
power = custom_cpp_extension.Power(x)
np.testing.assert_allclose(
power.get().sum().numpy(), np.sum(np_x), atol=1e-5
)
np.testing.assert_allclose(
power.forward().sum().numpy(),
np.sum(np.power(np_x, 2)),
atol=1e-5,
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import site
import sys
import unittest
import numpy as np
import paddle
from paddle.utils.cpp_extension.extension_utils import run_cmd
class TestCppExtensionSetupInstall(unittest.TestCase):
"""
Tests setup install cpp extensions.
"""
def setUp(self):
cur_dir = os.path.dirname(os.path.abspath(__file__))
# compile, install the custom op egg into site-packages under background
if os.name == 'nt':
cmd = 'cd /d {} && python cpp_extension_setup.py install'.format(
cur_dir
)
else:
cmd = 'cd {} && {} cpp_extension_setup.py install'.format(
cur_dir, sys.executable
)
run_cmd(cmd)
# os.system(cmd)
# See: https://stackoverflow.com/questions/56974185/import-runtime-installed-module-using-pip-in-python-3
if os.name == 'nt':
site_dir = site.getsitepackages()[1]
else:
site_dir = site.getsitepackages()[0]
custom_egg_path = [
x for x in os.listdir(site_dir) if 'custom_cpp_extension' in x
]
assert len(custom_egg_path) == 1, "Matched egg number is %d." % len(
custom_egg_path
)
sys.path.append(os.path.join(site_dir, custom_egg_path[0]))
# config seed
SEED = 2021
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
self.dtypes = ['float32', 'float64']
def tearDown(self):
pass
def test_cpp_extension(self):
self._test_extension_function()
self._test_extension_class()
def _test_extension_function(self):
import custom_cpp_extension
for dtype in self.dtypes:
np_x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
x = paddle.to_tensor(np_x, dtype=dtype)
np_y = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
y = paddle.to_tensor(np_y, dtype=dtype)
out = custom_cpp_extension.custom_add(x, y)
target_out = np.exp(np_x) + np.exp(np_y)
np.testing.assert_allclose(out.numpy(), target_out, atol=1e-5)
# Test we can call a method not defined in the main C++ file.
out = custom_cpp_extension.custom_sub(x, y)
target_out = np.exp(np_x) - np.exp(np_y)
np.testing.assert_allclose(out.numpy(), target_out, atol=1e-5)
def _test_extension_class(self):
import custom_cpp_extension
for dtype in self.dtypes:
# Test we can use CppExtension class with C++ methods.
power = custom_cpp_extension.Power(3, 3)
self.assertEqual(power.get().sum(), 9)
self.assertEqual(power.forward().sum(), 9)
np_x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
x = paddle.to_tensor(np_x, dtype=dtype)
power = custom_cpp_extension.Power(x)
np.testing.assert_allclose(
power.get().sum().numpy(), np.sum(np_x), atol=1e-5
)
np.testing.assert_allclose(
power.forward().sum().numpy(),
np.sum(np.power(np_x, 2)),
atol=1e-5,
)
if __name__ == '__main__':
if os.name == 'nt' or sys.platform.startswith('darwin'):
# only support Linux now
exit()
unittest.main()
......@@ -166,56 +166,63 @@ def custom_write_stub(resource, pyfile):
"""
_stub_template = textwrap.dedent(
"""
{custom_api}
import os
import sys
import types
import paddle
import importlib.util
cur_dir = os.path.dirname(os.path.abspath(__file__))
so_path = os.path.join(cur_dir, "{resource}")
def inject_ext_module(module_name, api_names):
if module_name in sys.modules:
return sys.modules[module_name]
new_module = types.ModuleType(module_name)
for api_name in api_names:
setattr(new_module, api_name, eval(api_name))
return new_module
def __bootstrap__():
assert os.path.exists(so_path)
if os.name == 'nt' or sys.platform.startswith('darwin'):
# Cpp Extension only support Linux now
mod = types.ModuleType(__name__)
else:
try:
spec = importlib.util.spec_from_file_location(__name__, so_path)
assert spec is not None
mod = importlib.util.module_from_spec(spec)
assert isinstance(spec.loader, importlib.abc.Loader)
spec.loader.exec_module(mod)
except ImportError:
print('using custom operator only')
mod = types.ModuleType(__name__)
# load custom op shared library with abs path
new_custom_ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path)
m = inject_ext_module(__name__, new_custom_ops)
custom_ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(so_path)
for custom_ops in custom_ops:
setattr(mod, custom_ops, eval(custom_ops))
__bootstrap__()
{custom_api}
"""
).lstrip()
# Parse registering op information
_, op_info = CustomOpInfo.instance().last()
so_path = op_info.so_path
new_custom_ops = load_op_meta_info_and_register_op(so_path)
assert len(new_custom_ops) > 0, (
"Required at least one custom operators, but received len(custom_op) = %d"
% len(new_custom_ops)
)
# NOTE: To avoid importing .so file instead of python file because they have same name,
# we rename .so shared library to another name, see EasyInstallCommand.
filename, ext = os.path.splitext(resource)
resource = filename + "_pd_" + ext
api_content = []
if CustomOpInfo.instance().empty():
print("Received len(custom_op) = 0, using cpp extension only")
else:
# Parse registering op information
_, op_info = CustomOpInfo.instance().last()
so_path = op_info.so_path
new_custom_ops = load_op_meta_info_and_register_op(so_path)
for op_name in new_custom_ops:
api_content.append(_custom_api_content(op_name))
print(
"Received len(custom_op) = %d, using custom operator"
% len(new_custom_ops)
)
with open(pyfile, 'w') as f:
f.write(
......@@ -256,6 +263,11 @@ class CustomOpInfo:
assert len(self.op_info_map) > 0
return next(reversed(self.op_info_map.items()))
def empty(self):
if self.op_info_map:
return False
return True
VersionFields = collections.namedtuple(
'VersionFields',
......@@ -949,11 +961,33 @@ def _import_module_from_library(module_name, build_directory, verbose=False):
log_v('loading shared library from: {}'.format(ext_path), verbose)
op_names = load_op_meta_info_and_register_op(ext_path)
# generate Python api in ext_path
if os.name == 'nt' or sys.platform.startswith('darwin'):
# Cpp Extension only support Linux now
return _generate_python_module(
module_name, op_names, build_directory, verbose
)
try:
spec = importlib.util.spec_from_file_location(module_name, ext_path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
assert isinstance(spec.loader, importlib.abc.Loader)
spec.loader.exec_module(module)
except ImportError:
log_v('using custom operator only')
return _generate_python_module(
module_name, op_names, build_directory, verbose
)
# generate Python api in ext_path
op_module = _generate_python_module(
module_name, op_names, build_directory, verbose
)
for op_name in op_names:
# Mix use of Cpp Extension and Custom Operator
setattr(module, op_name, getattr(op_module, op_name))
return module
def _generate_python_module(
module_name, op_names, build_directory, verbose=False
......
......@@ -725,6 +725,8 @@ if '${WITH_GPU}' == 'ON' or '${WITH_ROCM}' == 'ON':
# externalErrorMsg.pb for External Error message
headers += list(find_files('*.pb', '${externalError_INCLUDE_DIR}'))
headers += list(find_files('*.h', '${PYBIND_INCLUDE_DIR}', True)) # pybind headers
class InstallCommand(InstallCommandBase):
def finalize_options(self):
ret = InstallCommandBase.finalize_options(self)
......@@ -771,7 +773,7 @@ class InstallHeaders(Command):
else:
# third_party
install_dir = re.sub('${THIRD_PARTY_PATH}', 'third_party', header)
patterns = ['install/mkldnn/include']
patterns = ['install/mkldnn/include', 'pybind/src/extern_pybind/include']
for pattern in patterns:
install_dir = re.sub(pattern, '', install_dir)
install_dir = os.path.join(self.install_dir, os.path.dirname(install_dir))
......
......@@ -145,7 +145,10 @@ def get_header_install_dir(header):
install_dir = re.sub(
env_dict.get("THIRD_PARTY_PATH") + '/', 'third_party', header
)
patterns = ['install/mkldnn/include']
patterns = [
'install/mkldnn/include',
'pybind/src/extern_pybind/include',
]
for pattern in patterns:
install_dir = re.sub(pattern, '', install_dir)
return install_dir
......@@ -1203,6 +1206,9 @@ def get_headers():
headers += list(
find_files('*.pb', env_dict.get("externalError_INCLUDE_DIR"))
)
# pybind headers
headers += list(find_files('*.h', env_dict.get("PYBIND_INCLUDE_DIR"), True))
return headers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册