diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 48ab55c2362b857e00e402b067422c1a0c57d6d2..2330af73be0a1dd2822ce522256833fea6c038e4 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -50,7 +50,16 @@ class Graph(_imperative_rt.ComputingGraph): def execute(self, *args): assert self._future is None - self._future = self._executor.submit(self._function.execute, *args) + + def wrapped(*args): + try: + self._function.execute(*args) + except Exception as exc: + for i in self._function._all_rendezvous: + i.set_exception(str(exc)) + raise exc + + self._future = self._executor.submit(wrapped, *args) def wait(self): assert self._future is not None diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index fcfbd55e784e4a18de075cc6fc69ad73ef761c2c..dfe306af60a0b4d081c533e6921f2cc144e69a6f 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -49,17 +49,28 @@ class _CompGraphProfilerImpl { return json->to_string(); } }; + +struct WeakRendezvousArray: + public std::vector>, + public UserDataContainer::UserData { + MGB_TYPEINFO_OBJ_DECL; +}; +MGB_TYPEINFO_OBJ_IMPL(WeakRendezvousArray); } #define DEF_READWRITE(name) .def_readwrite(#name, &CURRENT_CLASS::name) template auto def_rendezvous(py::object m, const char* name) { return py::class_, std::shared_ptr>>(m, name) - .def(py::init([](){return std::make_shared>();})) + .def(py::init([](){return Rendezvous::make();})) .def("set", [](Rendezvous& r, T v) {r.set(std::move(v));}) .def("get", [](Rendezvous& r) {return r.get();}, py::call_guard()) .def("drop", &Rendezvous::drop) - .def("reset", &Rendezvous::reset); + .def("reset", &Rendezvous::reset) + .def("set_exception", [](Rendezvous& r, std::string&& message) { + r.set_exception(std::make_exception_ptr( + std::runtime_error(std::move(message)))); + }); } using TensorAttr = LogicalTensorDesc; @@ -186,7 +197,21 @@ void init_graph_rt(py::module m) { py::class_(m, "AsyncExecutable") .def("execute", &cg::AsyncExecutable::execute, py::call_guard()) - .def("wait", &cg::AsyncExecutable::wait, py::call_guard()); + .def("wait", &cg::AsyncExecutable::wait, py::call_guard()) + // only used for exception handle + .def_property_readonly("_all_rendezvous", [](cg::AsyncExecutable* exec) { + auto ud = exec->owner_graph()->options().user_data + .get_user_data(); + std::vector> ret; + if (ud.second) { + for (auto&& r: *ud.first[0]) { + if (auto p = r.lock()) { + ret.emplace_back(std::move(p)); + } + } + } + return ret; + }); auto PyComputingGraph = py::class_>(m, "ComputingGraph") .def(py::init(py::overload_cast<>(&cg::ComputingGraph::make))) @@ -483,7 +508,14 @@ void init_graph_rt(py::module m) { }, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); - auto output_callback = [](auto callback, const std::vector& inputs, bool borrow = false, bool prefer_host_value = false) { + auto output_callback = [](auto callback, const std::vector& inputs, + std::shared_ptr r = {}, bool borrow = false, bool prefer_host_value = false) { + if (r) { + mgb_assert(inputs.size()); + auto cg = inputs[0]->owner_graph(); + cg->options().user_data.get_user_data_or_create() + ->emplace_back(r); + } SymbolVarArray sinputs; for (auto i : inputs) { sinputs.emplace_back(i); @@ -508,7 +540,7 @@ void init_graph_rt(py::module m) { auto f = [p](DeviceTensorND dv) { p->set(std::move(dv)); }; - return output_callback(std::move(f), std::move(inputs)); + return output_callback(std::move(f), std::move(inputs), p); }); m.def("value_output_callback", [output_callback](std::shared_ptr> p, std::vector inputs) { @@ -519,13 +551,13 @@ void init_graph_rt(py::module m) { hv_with_event.second->record(); p->set(std::move(hv_with_event)); }; - return output_callback(std::move(f), std::move(inputs), true, true); + return output_callback(std::move(f), std::move(inputs), p, true, true); }); m.def("attr_output_callback", [output_callback](std::shared_ptr> p, std::vector inputs) { auto f = [p](DeviceTensorND dv) { p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()}); }; - return output_callback(std::move(f), std::move(inputs), true); + return output_callback(std::move(f), std::move(inputs), p, true); }); } diff --git a/imperative/python/src/graph_rt.h b/imperative/python/src/graph_rt.h index c069e3a419f9333abc32147541f29d2e83650c56..a7ad80b1880e3f007796f021e3adcd052cf8c432 100644 --- a/imperative/python/src/graph_rt.h +++ b/imperative/python/src/graph_rt.h @@ -35,18 +35,36 @@ public: PYBIND11_DECLARE_HOLDER_TYPE(T, GraphNodePtr, true); +class RendezvousBase { +public: + virtual ~RendezvousBase() = default; + virtual void set_exception(std::exception_ptr p) = 0; +}; + template -class Rendezvous { +class Rendezvous: public RendezvousBase { std::mutex m_lock; int m_read_ahead = 0; bool m_drop_next = false; std::promise m_promise; -public: Rendezvous() = default; + struct Factory { + template + static auto make_rendezvous(Args&& ...args) { + auto ptr = new Rendezvous{std::forward(args)...}; + return std::shared_ptr>(ptr); + } + }; +public: Rendezvous(const Rendezvous& rhs) = delete; Rendezvous(Rendezvous&& rhs) = delete; Rendezvous& operator=(const Rendezvous& rhs) = delete; + template + static auto make(Args&& ...args) { + return Factory::make_rendezvous(std::forward(args)...); + } + R get() { std::future f; { @@ -96,6 +114,29 @@ public: m_read_ahead = 0; m_drop_next = false; } + + void set_exception(std::exception_ptr e) { + if (e) { + MGB_LOCK_GUARD(m_lock); + if (m_read_ahead >= 0) { + mgb_assert(m_read_ahead <= 1); + if (m_drop_next) { + m_drop_next = false; + } else { + m_promise.set_exception(e); + } + if (m_read_ahead == 1) { + m_promise = {}; + } + --m_read_ahead; + } else { + mgb_assert(m_read_ahead == -1); + // TODO: maybe exception should be ignored + // if value was already set ? + m_promise.set_exception(e); + } + } + } }; void init_graph_rt(pybind11::module m); diff --git a/imperative/python/test/unit/core/test_megbrain_graph.py b/imperative/python/test/unit/core/test_megbrain_graph.py index 0abd92b331dcda747d51186f9fdf5c1c6c6dd888..e79976064d7f44caeac88c733f7973bd67991825 100644 --- a/imperative/python/test/unit/core/test_megbrain_graph.py +++ b/imperative/python/test/unit/core/test_megbrain_graph.py @@ -82,3 +82,20 @@ def test_op(): f() np.testing.assert_equal(x.numpy(), -y.result().numpy()) + + +def test_exception(): + err_msg = "QwQ" + + def throw_exc(): + raise RuntimeError(err_msg) + + g = mgb_graph.Graph() + x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g) + y = mgb_graph.OutputNode(F.neg(x)) + f = g.compile(y.outputs[0]) + try: + f.execute() + y.get_value() + except Exception as exc: + assert err_msg in str(exc)