未验证 提交 ff4ec23a 编写于 作者: H HongyuJia 提交者: GitHub

[Cpp Extension] Support optional types (#50764)

* [Cpp Extension] Support optional type

* fix custom_extension.cc
上级 dca3a099
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/utils/optional.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
...@@ -73,5 +74,12 @@ struct type_caster<paddle::experimental::Tensor> { ...@@ -73,5 +74,12 @@ struct type_caster<paddle::experimental::Tensor> {
src, true /* return_py_none_if_not_initialize */)); src, true /* return_py_none_if_not_initialize */));
} }
}; };
// Pybind11 bindings for optional types.
// http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
template <typename T>
struct type_caster<paddle::optional<T>> : optional_caster<paddle::optional<T>> {
};
} // namespace detail } // namespace detail
} // namespace pybind11 } // namespace pybind11
...@@ -32,10 +32,20 @@ paddle::Tensor nullable_tensor(bool return_none = false) { ...@@ -32,10 +32,20 @@ paddle::Tensor nullable_tensor(bool return_none = false) {
return t; return t;
} }
paddle::optional<paddle::Tensor> optional_tensor(bool return_option = false) {
paddle::optional<paddle::Tensor> t;
if (!return_option) {
t = paddle::ones({2, 2});
}
return t;
}
PYBIND11_MODULE(custom_cpp_extension, m) { PYBIND11_MODULE(custom_cpp_extension, m) {
m.def("custom_add", &custom_add, "exp(x) + exp(y)"); m.def("custom_add", &custom_add, "exp(x) + exp(y)");
m.def("custom_sub", &custom_sub, "exp(x) - exp(y)"); m.def("custom_sub", &custom_sub, "exp(x) - exp(y)");
m.def("nullable_tensor", &nullable_tensor, "returned Tensor might be None"); m.def("nullable_tensor", &nullable_tensor, "returned Tensor might be None");
m.def(
"optional_tensor", &optional_tensor, "returned Tensor might be optional");
py::class_<Power>(m, "Power") py::class_<Power>(m, "Power")
.def(py::init<int, int>()) .def(py::init<int, int>())
......
...@@ -67,6 +67,8 @@ class TestCppExtensionJITInstall(unittest.TestCase): ...@@ -67,6 +67,8 @@ class TestCppExtensionJITInstall(unittest.TestCase):
def test_cpp_extension(self): def test_cpp_extension(self):
self._test_extension_function() self._test_extension_function()
self._test_extension_class() self._test_extension_class()
self._test_nullable_tensor()
self._test_optional_tensor()
def _test_extension_function(self): def _test_extension_function(self):
for dtype in self.dtypes: for dtype in self.dtypes:
...@@ -104,6 +106,30 @@ class TestCppExtensionJITInstall(unittest.TestCase): ...@@ -104,6 +106,30 @@ class TestCppExtensionJITInstall(unittest.TestCase):
atol=1e-5, atol=1e-5,
) )
def _test_nullable_tensor(self):
x = custom_cpp_extension.nullable_tensor(True)
assert x is None, "Return None when input parameter return_none = True"
x = custom_cpp_extension.nullable_tensor(False).numpy()
x_np = np.ones(shape=[2, 2])
np.testing.assert_array_equal(
x,
x_np,
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
)
def _test_optional_tensor(self):
x = custom_cpp_extension.optional_tensor(True)
assert (
x is None
), "Return None when input parameter return_option = True"
x = custom_cpp_extension.optional_tensor(False).numpy()
x_np = np.ones(shape=[2, 2])
np.testing.assert_array_equal(
x,
x_np,
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -149,6 +149,7 @@ class TestCppExtensionSetupInstall(unittest.TestCase): ...@@ -149,6 +149,7 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
self._test_extension_function_mixed() self._test_extension_function_mixed()
self._test_extension_class() self._test_extension_class()
self._test_nullable_tensor() self._test_nullable_tensor()
self._test_optional_tensor()
# Custom op # Custom op
self._test_static() self._test_static()
self._test_dynamic() self._test_dynamic()
...@@ -227,6 +228,21 @@ class TestCppExtensionSetupInstall(unittest.TestCase): ...@@ -227,6 +228,21 @@ class TestCppExtensionSetupInstall(unittest.TestCase):
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np), err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
) )
def _test_optional_tensor(self):
import custom_cpp_extension
x = custom_cpp_extension.optional_tensor(True)
assert (
x is None
), "Return None when input parameter return_option = True"
x = custom_cpp_extension.optional_tensor(False).numpy()
x_np = np.ones(shape=[2, 2])
np.testing.assert_array_equal(
x,
x_np,
err_msg='extension out: {},\n numpy out: {}'.format(x, x_np),
)
def _test_static(self): def _test_static(self):
import mix_relu_extension import mix_relu_extension
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册