diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 4f05406c7f74113d8fb10aa6914166e553858338..da8e030bb1f4861536954a230153277790be97c0 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,6 +1,6 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED - SRCS pybind.cc + SRCS pybind.cc exception.cc DEPS pybind python backward ${GLOB_OP_LIB}) endif(WITH_PYTHON) diff --git a/paddle/pybind/exception.cc b/paddle/pybind/exception.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff79b12ee4b28c53ee04f4c170b5bca9ca28d14a --- /dev/null +++ b/paddle/pybind/exception.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/pybind/exception.h" + +namespace paddle { +namespace pybind { + +void BindException(pybind11::module& m) { + static pybind11::exception exc(m, "EnforceNotMet"); + pybind11::register_exception_translator([](std::exception_ptr p) { + try { + if (p) std::rethrow_exception(p); + } catch (const platform::EnforceNotMet& e) { + exc(e.what()); + } + }); + + m.def("__unittest_throw_exception__", [] { PADDLE_THROW("test exception"); }); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/pybind/exception.h b/paddle/pybind/exception.h new file mode 100644 index 0000000000000000000000000000000000000000..12c7df93f617d40b5e028d1ae897ce47197c47c6 --- /dev/null +++ b/paddle/pybind/exception.h @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/platform/enforce.h" +#include "pybind11/pybind11.h" +namespace paddle { +namespace pybind { + +extern void BindException(pybind11::module& m); +} // namespace pybind +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 25e290ffbb94354da3393ca0b769aff512d74a41..c8e6f66280dab2c9a69267f81b3df000c8d29570 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/operators/recurrent_op.h" #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" +#include "paddle/pybind/exception.h" #include "paddle/pybind/pybind.h" #include "paddle/pybind/tensor_py.h" #include "paddle/string/to_string.h" @@ -55,6 +56,8 @@ PYBIND11_PLUGIN(core) { // not cause namespace pollution. using namespace paddle::framework; // NOLINT + BindException(m); + py::class_(m, "Tensor", py::buffer_protocol()) .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) diff --git a/python/paddle/v2/framework/tests/test_exception.py b/python/paddle/v2/framework/tests/test_exception.py new file mode 100644 index 0000000000000000000000000000000000000000..5284a069a8abae1e02288d63faca35228be74f8e --- /dev/null +++ b/python/paddle/v2/framework/tests/test_exception.py @@ -0,0 +1,12 @@ +import paddle.v2.framework.core as core +import unittest + + +class TestException(unittest.TestCase): + def test_exception(self): + self.assertRaises(core.EnforceNotMet, + lambda: core.__unittest_throw_exception__()) + + +if __name__ == "__main__": + unittest.main()