diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index aa9ca4e31aa8bdae159ce2d8db8eadd2ab49dffc..326cc4a75bd5cc29f79de88a3e0802d17c812ecd 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,6 +1,6 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED - SRCS pybind.cc protobuf.cc + SRCS pybind.cc exception.cc protobuf.cc DEPS pybind python backward ${GLOB_OP_LIB}) endif(WITH_PYTHON) diff --git a/paddle/pybind/exception.cc b/paddle/pybind/exception.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff79b12ee4b28c53ee04f4c170b5bca9ca28d14a --- /dev/null +++ b/paddle/pybind/exception.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/pybind/exception.h" + +namespace paddle { +namespace pybind { + +void BindException(pybind11::module& m) { + static pybind11::exception exc(m, "EnforceNotMet"); + pybind11::register_exception_translator([](std::exception_ptr p) { + try { + if (p) std::rethrow_exception(p); + } catch (const platform::EnforceNotMet& e) { + exc(e.what()); + } + }); + + m.def("__unittest_throw_exception__", [] { PADDLE_THROW("test exception"); }); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/pybind/exception.h b/paddle/pybind/exception.h new file mode 100644 index 0000000000000000000000000000000000000000..12c7df93f617d40b5e028d1ae897ce47197c47c6 --- /dev/null +++ b/paddle/pybind/exception.h @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/platform/enforce.h" +#include "pybind11/pybind11.h" +namespace paddle { +namespace pybind { + +extern void BindException(pybind11::module& m); +} // namespace pybind +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 11c1578e6aead6e256082c487a86fb9afa0e1fc2..3816aee21f8842c8fc73c56621234b66661e880c 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/operators/recurrent_op.h" #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" +#include "paddle/pybind/exception.h" #include "paddle/pybind/pybind.h" #include "paddle/pybind/tensor_py.h" #include "paddle/string/to_string.h" @@ -47,6 +48,8 @@ PYBIND11_PLUGIN(core) { // not cause namespace pollution. using namespace paddle::framework; // NOLINT + BindException(m); + py::class_(m, "Tensor", py::buffer_protocol()) .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) diff --git a/python/paddle/v2/framework/tests/test_exception.py b/python/paddle/v2/framework/tests/test_exception.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae048817cfcc1ec85e0d0e0c5db749da4521012 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_exception.py @@ -0,0 +1,17 @@ +import paddle.v2.framework.core as core +import unittest + + +class TestException(unittest.TestCase): + def test_exception(self): + ex = None + try: + core.__unittest_throw_exception__() + except core.EnforceNotMet as ex: + self.assertIn("test exception", ex.message) + + self.assertIsNotNone(ex) + + +if __name__ == "__main__": + unittest.main()