提交 57e19747 编写于 作者: M Megvii Engine Team

fix(imperative): check async error when getting value

GitOrigin-RevId: 3945a9bfa2b27e8f47f6ee00184f69f3d992bb22
上级 1f0cc891
...@@ -420,6 +420,7 @@ def warp_affine( ...@@ -420,6 +420,7 @@ def warp_affine(
Here all available options for params are listed, Here all available options for params are listed,
however it does not mean that you can use all the combinations. however it does not mean that you can use all the combinations.
On different platforms, different combinations are supported. On different platforms, different combinations are supported.
``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed.
""" """
conv_format = _config._get_actual_op_param(format, _config.__conv_format) conv_format = _config._get_actual_op_param(format, _config.__conv_format)
......
...@@ -1074,6 +1074,10 @@ void init_tensor(py::module m) { ...@@ -1074,6 +1074,10 @@ void init_tensor(py::module m) {
[]() { []() {
interpreter_for_py->sync(); interpreter_for_py->sync();
CompNode::sync_all(); CompNode::sync_all();
CompNode::foreach ([](CompNode cn) {
auto err = cn.check_async_error();
mgb_assert(!err, "%s", err->what());
});
sync_py_task_q(); sync_py_task_q();
}, },
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
......
...@@ -156,6 +156,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { ...@@ -156,6 +156,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
if (m_async_level == 0) { if (m_async_level == 0) {
sync_impl(); sync_impl();
info->desc.comp_node.sync(); info->desc.comp_node.sync();
auto err = info->desc.comp_node.check_async_error();
mgb_assert(!err, "%s", err->what());
} }
return info; return info;
} }
...@@ -336,6 +338,8 @@ void ChannelImpl::dispatch_kernel( ...@@ -336,6 +338,8 @@ void ChannelImpl::dispatch_kernel(
for (auto&& oup : *outputs) { for (auto&& oup : *outputs) {
auto info = reinterpret_cast<TensorInfo*>(oup); auto info = reinterpret_cast<TensorInfo*>(oup);
info->ptr->comp_node().sync(); info->ptr->comp_node().sync();
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
} }
} }
} }
...@@ -944,6 +948,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { ...@@ -944,6 +948,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
}); });
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
m_waitee = nullptr; m_waitee = nullptr;
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
return info->ptr; return info->ptr;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册