提交 8109cc4b 编写于 作者: M Megvii Engine Team

fix(imperative): set exception which worker thread throws for rendezvous

GitOrigin-RevId: f583888fdfdd422262a9bd0bcd3425055ce51a94
上级 2f4a75e7
......@@ -50,7 +50,16 @@ class Graph(_imperative_rt.ComputingGraph):
def execute(self, *args):
assert self._future is None
self._future = self._executor.submit(self._function.execute, *args)
def wrapped(*args):
try:
self._function.execute(*args)
except Exception as exc:
for i in self._function._all_rendezvous:
i.set_exception(str(exc))
raise exc
self._future = self._executor.submit(wrapped, *args)
def wait(self):
assert self._future is not None
......
......@@ -49,17 +49,28 @@ class _CompGraphProfilerImpl {
return json->to_string();
}
};
struct WeakRendezvousArray:
public std::vector<std::weak_ptr<RendezvousBase>>,
public UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;
};
MGB_TYPEINFO_OBJ_IMPL(WeakRendezvousArray);
}
#define DEF_READWRITE(name) .def_readwrite(#name, &CURRENT_CLASS::name)
template<typename T>
auto def_rendezvous(py::object m, const char* name) {
return py::class_<Rendezvous<T>, std::shared_ptr<Rendezvous<T>>>(m, name)
.def(py::init([](){return std::make_shared<Rendezvous<T>>();}))
.def(py::init([](){return Rendezvous<T>::make();}))
.def("set", [](Rendezvous<T>& r, T v) {r.set(std::move(v));})
.def("get", [](Rendezvous<T>& r) {return r.get();}, py::call_guard<py::gil_scoped_release>())
.def("drop", &Rendezvous<T>::drop)
.def("reset", &Rendezvous<T>::reset);
.def("reset", &Rendezvous<T>::reset)
.def("set_exception", [](Rendezvous<T>& r, std::string&& message) {
r.set_exception(std::make_exception_ptr(
std::runtime_error(std::move(message))));
});
}
using TensorAttr = LogicalTensorDesc;
......@@ -186,7 +197,21 @@ void init_graph_rt(py::module m) {
py::class_<cg::AsyncExecutable>(m, "AsyncExecutable")
.def("execute", &cg::AsyncExecutable::execute, py::call_guard<py::gil_scoped_release>())
.def("wait", &cg::AsyncExecutable::wait, py::call_guard<py::gil_scoped_release>());
.def("wait", &cg::AsyncExecutable::wait, py::call_guard<py::gil_scoped_release>())
// only used for exception handle
.def_property_readonly("_all_rendezvous", [](cg::AsyncExecutable* exec) {
auto ud = exec->owner_graph()->options().user_data
.get_user_data<WeakRendezvousArray>();
std::vector<std::shared_ptr<RendezvousBase>> ret;
if (ud.second) {
for (auto&& r: *ud.first[0]) {
if (auto p = r.lock()) {
ret.emplace_back(std::move(p));
}
}
}
return ret;
});
auto PyComputingGraph = py::class_<cg::ComputingGraph, std::shared_ptr<cg::ComputingGraph>>(m, "ComputingGraph")
.def(py::init(py::overload_cast<>(&cg::ComputingGraph::make)))
......@@ -483,7 +508,14 @@ void init_graph_rt(py::module m) {
},
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none());
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false, bool prefer_host_value = false) {
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs,
std::shared_ptr<RendezvousBase> r = {}, bool borrow = false, bool prefer_host_value = false) {
if (r) {
mgb_assert(inputs.size());
auto cg = inputs[0]->owner_graph();
cg->options().user_data.get_user_data_or_create<WeakRendezvousArray>()
->emplace_back(r);
}
SymbolVarArray sinputs;
for (auto i : inputs) {
sinputs.emplace_back(i);
......@@ -508,7 +540,7 @@ void init_graph_rt(py::module m) {
auto f = [p](DeviceTensorND dv) {
p->set(std::move(dv));
};
return output_callback(std::move(f), std::move(inputs));
return output_callback(std::move(f), std::move(inputs), p);
});
m.def("value_output_callback", [output_callback](std::shared_ptr<Rendezvous<HostNDWithEvent>> p, std::vector<cg::VarNode*> inputs) {
......@@ -519,13 +551,13 @@ void init_graph_rt(py::module m) {
hv_with_event.second->record();
p->set(std::move(hv_with_event));
};
return output_callback(std::move(f), std::move(inputs), true, true);
return output_callback(std::move(f), std::move(inputs), p, true, true);
});
m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) {
auto f = [p](DeviceTensorND dv) {
p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()});
};
return output_callback(std::move(f), std::move(inputs), true);
return output_callback(std::move(f), std::move(inputs), p, true);
});
}
......@@ -35,18 +35,36 @@ public:
PYBIND11_DECLARE_HOLDER_TYPE(T, GraphNodePtr<T>, true);
class RendezvousBase {
public:
virtual ~RendezvousBase() = default;
virtual void set_exception(std::exception_ptr p) = 0;
};
template<typename R>
class Rendezvous {
class Rendezvous: public RendezvousBase {
std::mutex m_lock;
int m_read_ahead = 0;
bool m_drop_next = false;
std::promise<R> m_promise;
public:
Rendezvous() = default;
struct Factory {
template<typename ...Args>
static auto make_rendezvous(Args&& ...args) {
auto ptr = new Rendezvous<R>{std::forward(args)...};
return std::shared_ptr<Rendezvous<R>>(ptr);
}
};
public:
Rendezvous(const Rendezvous& rhs) = delete;
Rendezvous(Rendezvous&& rhs) = delete;
Rendezvous& operator=(const Rendezvous& rhs) = delete;
template<typename ...Args>
static auto make(Args&& ...args) {
return Factory::make_rendezvous(std::forward<Args>(args)...);
}
R get() {
std::future<R> f;
{
......@@ -96,6 +114,29 @@ public:
m_read_ahead = 0;
m_drop_next = false;
}
void set_exception(std::exception_ptr e) {
if (e) {
MGB_LOCK_GUARD(m_lock);
if (m_read_ahead >= 0) {
mgb_assert(m_read_ahead <= 1);
if (m_drop_next) {
m_drop_next = false;
} else {
m_promise.set_exception(e);
}
if (m_read_ahead == 1) {
m_promise = {};
}
--m_read_ahead;
} else {
mgb_assert(m_read_ahead == -1);
// TODO: maybe exception should be ignored
// if value was already set ?
m_promise.set_exception(e);
}
}
}
};
void init_graph_rt(pybind11::module m);
......@@ -82,3 +82,20 @@ def test_op():
f()
np.testing.assert_equal(x.numpy(), -y.result().numpy())
def test_exception():
err_msg = "QwQ"
def throw_exc():
raise RuntimeError(err_msg)
g = mgb_graph.Graph()
x, _ = mgb_graph.input_callback(throw_exc, device="xpux", dtype="float32", graph=g)
y = mgb_graph.OutputNode(F.neg(x))
f = g.compile(y.outputs[0])
try:
f.execute()
y.get_value()
except Exception as exc:
assert err_msg in str(exc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册