未验证 提交 01e85182 编写于 作者: H HongyuJia 提交者: GitHub

[Cpp Extension] return Py_None if Tensor not initialized (#50745)

* [Cpp Extension] return Py_None if Tensor not initialized

* fix jit test
上级 b197e66c
...@@ -69,7 +69,8 @@ struct type_caster<paddle::experimental::Tensor> { ...@@ -69,7 +69,8 @@ struct type_caster<paddle::experimental::Tensor> {
static handle cast(const paddle::experimental::Tensor& src, static handle cast(const paddle::experimental::Tensor& src,
return_value_policy /* policy */, return_value_policy /* policy */,
handle /* parent */) { handle /* parent */) {
return handle(paddle::pybind::ToPyObject(src)); return handle(paddle::pybind::ToPyObject(
src, true /* return_py_none_if_not_initialize */));
} }
}; };
} // namespace detail } // namespace detail
......
...@@ -25,13 +25,14 @@ for site_packages_path in getsitepackages(): ...@@ -25,13 +25,14 @@ for site_packages_path in getsitepackages():
paddle_includes.append( paddle_includes.append(
os.path.join(site_packages_path, 'paddle', 'include', 'third_party') os.path.join(site_packages_path, 'paddle', 'include', 'third_party')
) )
# Add current dir, search custom_power.h
paddle_includes.append(os.path.dirname(os.path.abspath(__file__)))
setup( setup(
name='custom_cpp_extension', name='custom_cpp_extension',
ext_modules=CppExtension( ext_modules=CppExtension(
sources=["custom_add.cc", "custom_sub.cc"], sources=["custom_extension.cc", "custom_sub.cc"],
include_dirs=paddle_includes include_dirs=paddle_includes,
+ [os.path.dirname(os.path.abspath(__file__))],
extra_compile_args={'cc': ['-w', '-g']}, extra_compile_args={'cc': ['-w', '-g']},
verbose=True, verbose=True,
), ),
......
...@@ -24,9 +24,18 @@ paddle::Tensor custom_add(const paddle::Tensor& x, const paddle::Tensor& y) { ...@@ -24,9 +24,18 @@ paddle::Tensor custom_add(const paddle::Tensor& x, const paddle::Tensor& y) {
return paddle::add(paddle::exp(x), paddle::exp(y)); return paddle::add(paddle::exp(x), paddle::exp(y));
} }
paddle::Tensor nullable_tensor(bool return_none = false) {
paddle::Tensor t;
if (!return_none) {
t = paddle::ones({2, 2});
}
return t;
}
PYBIND11_MODULE(custom_cpp_extension, m) { PYBIND11_MODULE(custom_cpp_extension, m) {
m.def("custom_add", &custom_add, "exp(x) + exp(y)"); m.def("custom_add", &custom_add, "exp(x) + exp(y)");
m.def("custom_sub", &custom_sub, "exp(x) - exp(y)"); m.def("custom_sub", &custom_sub, "exp(x) - exp(y)");
m.def("nullable_tensor", &nullable_tensor, "returned Tensor might be None");
py::class_<Power>(m, "Power") py::class_<Power>(m, "Power")
.def(py::init<int, int>()) .def(py::init<int, int>())
......
...@@ -27,7 +27,7 @@ if os.name == 'nt' or sys.platform.startswith('darwin'): ...@@ -27,7 +27,7 @@ if os.name == 'nt' or sys.platform.startswith('darwin'):
exit() exit()
# Compile and load cpp extension Just-In-Time. # Compile and load cpp extension Just-In-Time.
sources = ["custom_add.cc", "custom_sub.cc"] sources = ["custom_extension.cc", "custom_sub.cc"]
paddle_includes = [] paddle_includes = []
for site_packages_path in getsitepackages(): for site_packages_path in getsitepackages():
paddle_includes.append( paddle_includes.append(
......
...@@ -148,6 +148,7 @@ class TestCppExtensionSetupInstall(unittest.TestCase): ...@@ -148,6 +148,7 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
self._test_extension_function_plain() self._test_extension_function_plain()
self._test_extension_function_mixed() self._test_extension_function_mixed()
self._test_extension_class() self._test_extension_class()
self._test_nullable_tensor()
# Custom op # Custom op
self._test_static() self._test_static()
self._test_dynamic() self._test_dynamic()
...@@ -213,6 +214,19 @@ class TestCppExtensionSetupInstall(unittest.TestCase): ...@@ -213,6 +214,19 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
atol=1e-5, atol=1e-5,
) )
def _test_nullable_tensor(self):
import custom_cpp_extension
x = custom_cpp_extension.nullable_tensor(True)
assert x is None, "Return None when input parameter return_none = True"
x = custom_cpp_extension.nullable_tensor(False).numpy()
x_np = np.ones(shape=[2, 2])
np.testing.assert_array_equal(
x,
x_np,
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
)
def _test_static(self): def _test_static(self):
import mix_relu_extension import mix_relu_extension
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册