未验证 提交 ff4ec23a 编写于 作者: H HongyuJia 提交者: GitHub

[Cpp Extension] Support optional types (#50764)

* [Cpp Extension] Support optional type

* fix custom_extension.cc
上级 dca3a099
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/api/include/tensor.h"
#include "paddle/utils/optional.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
......@@ -73,5 +74,12 @@ struct type_caster<paddle::experimental::Tensor> {
src, true /* return_py_none_if_not_initialize */));
}
};
// Pybind11 bindings for optional types.
// http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
template <typename T>
struct type_caster<paddle::optional<T>> : optional_caster<paddle::optional<T>> {
};
} // namespace detail
} // namespace pybind11
......@@ -32,10 +32,20 @@ paddle::Tensor nullable_tensor(bool return_none = false) {
return t;
}
paddle::optional<paddle::Tensor> optional_tensor(bool return_option = false) {
paddle::optional<paddle::Tensor> t;
if (!return_option) {
t = paddle::ones({2, 2});
}
return t;
}
PYBIND11_MODULE(custom_cpp_extension, m) {
m.def("custom_add", &custom_add, "exp(x) + exp(y)");
m.def("custom_sub", &custom_sub, "exp(x) - exp(y)");
m.def("nullable_tensor", &nullable_tensor, "returned Tensor might be None");
m.def(
"optional_tensor", &optional_tensor, "returned Tensor might be optional");
py::class_<Power>(m, "Power")
.def(py::init<int, int>())
......
......@@ -67,6 +67,8 @@ class TestCppExtensionJITInstall(unittest.TestCase):
def test_cpp_extension(self):
self._test_extension_function()
self._test_extension_class()
self._test_nullable_tensor()
self._test_optional_tensor()
def _test_extension_function(self):
for dtype in self.dtypes:
......@@ -104,6 +106,30 @@ class TestCppExtensionJITInstall(unittest.TestCase):
atol=1e-5,
)
def _test_nullable_tensor(self):
x = custom_cpp_extension.nullable_tensor(True)
assert x is None, "Return None when input parameter return_none = True"
x = custom_cpp_extension.nullable_tensor(False).numpy()
x_np = np.ones(shape=[2, 2])
np.testing.assert_array_equal(
x,
x_np,
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
)
def _test_optional_tensor(self):
x = custom_cpp_extension.optional_tensor(True)
assert (
x is None
), "Return None when input parameter return_option = True"
x = custom_cpp_extension.optional_tensor(False).numpy()
x_np = np.ones(shape=[2, 2])
np.testing.assert_array_equal(
x,
x_np,
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
)
if __name__ == '__main__':
unittest.main()
......@@ -149,6 +149,7 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
self._test_extension_function_mixed()
self._test_extension_class()
self._test_nullable_tensor()
self._test_optional_tensor()
# Custom op
self._test_static()
self._test_dynamic()
......@@ -227,6 +228,21 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
)
def _test_optional_tensor(self):
import custom_cpp_extension
x = custom_cpp_extension.optional_tensor(True)
assert (
x is None
), "Return None when input parameter return_option = True"
x = custom_cpp_extension.optional_tensor(False).numpy()
x_np = np.ones(shape=[2, 2])
np.testing.assert_array_equal(
x,
x_np,
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
)
def _test_static(self):
import mix_relu_extension
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册