未验证 提交 01e85182 编写于 作者: H HongyuJia 提交者: GitHub

[Cpp Extension] return Py_None if Tensor not initialized (#50745)

* [Cpp Extension] return Py_None if Tensor not initialized

* fix jit test
上级 b197e66c
......@@ -69,7 +69,8 @@ struct type_caster<paddle::experimental::Tensor> {
static handle cast(const paddle::experimental::Tensor& src,
return_value_policy /* policy */,
handle /* parent */) {
return handle(paddle::pybind::ToPyObject(src));
return handle(paddle::pybind::ToPyObject(
src, true /* return_py_none_if_not_initialize */));
}
};
} // namespace detail
......
......@@ -25,13 +25,14 @@ for site_packages_path in getsitepackages():
paddle_includes.append(
os.path.join(site_packages_path, 'paddle', 'include', 'third_party')
)
# Add current dir, search custom_power.h
paddle_includes.append(os.path.dirname(os.path.abspath(__file__)))
setup(
name='custom_cpp_extension',
ext_modules=CppExtension(
sources=["custom_add.cc", "custom_sub.cc"],
include_dirs=paddle_includes
+ [os.path.dirname(os.path.abspath(__file__))],
sources=["custom_extension.cc", "custom_sub.cc"],
include_dirs=paddle_includes,
extra_compile_args={'cc': ['-w', '-g']},
verbose=True,
),
......
......@@ -24,9 +24,18 @@ paddle::Tensor custom_add(const paddle::Tensor& x, const paddle::Tensor& y) {
return paddle::add(paddle::exp(x), paddle::exp(y));
}
paddle::Tensor nullable_tensor(bool return_none = false) {
paddle::Tensor t;
if (!return_none) {
t = paddle::ones({2, 2});
}
return t;
}
PYBIND11_MODULE(custom_cpp_extension, m) {
m.def("custom_add", &custom_add, "exp(x) + exp(y)");
m.def("custom_sub", &custom_sub, "exp(x) - exp(y)");
m.def("nullable_tensor", &nullable_tensor, "returned Tensor might be None");
py::class_<Power>(m, "Power")
.def(py::init<int, int>())
......
......@@ -27,7 +27,7 @@ if os.name == 'nt' or sys.platform.startswith('darwin'):
exit()
# Compile and load cpp extension Just-In-Time.
sources = ["custom_add.cc", "custom_sub.cc"]
sources = ["custom_extension.cc", "custom_sub.cc"]
paddle_includes = []
for site_packages_path in getsitepackages():
paddle_includes.append(
......
......@@ -148,6 +148,7 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
self._test_extension_function_plain()
self._test_extension_function_mixed()
self._test_extension_class()
self._test_nullable_tensor()
# Custom op
self._test_static()
self._test_dynamic()
......@@ -213,6 +214,19 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
atol=1e-5,
)
def _test_nullable_tensor(self):
import custom_cpp_extension
x = custom_cpp_extension.nullable_tensor(True)
assert x is None, "Return None when input parameter return_none = True"
x = custom_cpp_extension.nullable_tensor(False).numpy()
x_np = np.ones(shape=[2, 2])
np.testing.assert_array_equal(
x,
x_np,
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
)
def _test_static(self):
import mix_relu_extension
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册