From 67cdd5bc617db671808606042ad2f41484a6df6d Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 26 Sep 2017 10:10:28 -0700 Subject: [PATCH] Make PyBind support C++ exception --- paddle/pybind/CMakeLists.txt | 2 +- paddle/pybind/exception.cc | 34 +++++++++++++++++++ paddle/pybind/exception.h | 23 +++++++++++++ paddle/pybind/pybind.cc | 3 ++ .../v2/framework/tests/test_exception.py | 12 +++++++ 5 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 paddle/pybind/exception.cc create mode 100644 paddle/pybind/exception.h create mode 100644 python/paddle/v2/framework/tests/test_exception.py diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 4f05406c7f7..da8e030bb1f 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,6 +1,6 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED - SRCS pybind.cc + SRCS pybind.cc exception.cc DEPS pybind python backward ${GLOB_OP_LIB}) endif(WITH_PYTHON) diff --git a/paddle/pybind/exception.cc b/paddle/pybind/exception.cc new file mode 100644 index 00000000000..ff79b12ee4b --- /dev/null +++ b/paddle/pybind/exception.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/pybind/exception.h" + +namespace paddle { +namespace pybind { + +void BindException(pybind11::module& m) { + static pybind11::exception exc(m, "EnforceNotMet"); + pybind11::register_exception_translator([](std::exception_ptr p) { + try { + if (p) std::rethrow_exception(p); + } catch (const platform::EnforceNotMet& e) { + exc(e.what()); + } + }); + + m.def("__unittest_throw_exception__", [] { PADDLE_THROW("test exception"); }); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/pybind/exception.h b/paddle/pybind/exception.h new file mode 100644 index 00000000000..12c7df93f61 --- /dev/null +++ b/paddle/pybind/exception.h @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/platform/enforce.h" +#include "pybind11/pybind11.h" +namespace paddle { +namespace pybind { + +extern void BindException(pybind11::module& m); +} // namespace pybind +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 25e290ffbb9..c8e6f66280d 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/operators/recurrent_op.h" #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" +#include "paddle/pybind/exception.h" #include "paddle/pybind/pybind.h" #include "paddle/pybind/tensor_py.h" #include "paddle/string/to_string.h" @@ -55,6 +56,8 @@ PYBIND11_PLUGIN(core) { // not cause namespace pollution. using namespace paddle::framework; // NOLINT + BindException(m); + py::class_(m, "Tensor", py::buffer_protocol()) .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) diff --git a/python/paddle/v2/framework/tests/test_exception.py b/python/paddle/v2/framework/tests/test_exception.py new file mode 100644 index 00000000000..5284a069a8a --- /dev/null +++ b/python/paddle/v2/framework/tests/test_exception.py @@ -0,0 +1,12 @@ +import paddle.v2.framework.core as core +import unittest + + +class TestException(unittest.TestCase): + def test_exception(self): + self.assertRaises(core.EnforceNotMet, + lambda: core.__unittest_throw_exception__()) + + +if __name__ == "__main__": + unittest.main() -- GitLab