提交 619d78ed 编写于 作者: M Megvii Engine Team

fix(imperative): check async error when getting value

GitOrigin-RevId: 52b8a29932d2abb33f4bb3d4acff91fe53a6a998
上级 2afa0af9
...@@ -420,6 +420,7 @@ def warp_affine( ...@@ -420,6 +420,7 @@ def warp_affine(
Here all available options for params are listed, Here all available options for params are listed,
however it does not mean that you can use all the combinations. however it does not mean that you can use all the combinations.
On different platforms, different combinations are supported. On different platforms, different combinations are supported.
``warp_affine`` only support forward inference, Please refer to ``warp_perspective`` if backward is needed.
""" """
conv_format = _config._get_actual_op_param(format, _config.__conv_format) conv_format = _config._get_actual_op_param(format, _config.__conv_format)
......
...@@ -104,6 +104,7 @@ class TensorInfo: ...@@ -104,6 +104,7 @@ class TensorInfo:
"shape", "shape",
"is_const", "is_const",
"bound_data", "bound_data",
"bound_data_numpy",
# resources for execution # resources for execution
"varnode", "varnode",
"data_setter", "data_setter",
...@@ -119,12 +120,18 @@ class TensorInfo: ...@@ -119,12 +120,18 @@ class TensorInfo:
self.shape_read = None self.shape_read = None
self.value_read = None self.value_read = None
self.bound_data = None self.bound_data = None
self.bound_data_numpy = None
self.data_setter = None self.data_setter = None
self.shape_reader = None self.shape_reader = None
self.value_reader = None self.value_reader = None
self.data_reader = None self.data_reader = None
def get_numpy(self):
if self.bound_data_numpy is None:
self.bound_data_numpy = self.bound_data.numpy()
return self.bound_data_numpy
_io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv} _io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv}
...@@ -292,7 +299,7 @@ class trace: ...@@ -292,7 +299,7 @@ class trace:
# Const op is represented by a str # Const op is represented by a str
assert isinstance(op_, str) and op_ == "Const" assert isinstance(op_, str) and op_ == "Const"
expected = self._tinfo[ohandles[0]].bound_data.numpy() expected = self._tinfo[ohandles[0]].get_numpy()
shape = value.shape shape = value.shape
if shape != expected.shape or dtype != expected.dtype: if shape != expected.shape or dtype != expected.dtype:
eq = False eq = False
...@@ -369,6 +376,7 @@ class trace: ...@@ -369,6 +376,7 @@ class trace:
info.dtype = x.dtype info.dtype = x.dtype
info.shape = x.shape info.shape = x.shape
info.bound_data = x info.bound_data = x
info.bound_data_numpy = None
info.is_const = True info.is_const = True
x._mixin_handle = h x._mixin_handle = h
x._recording = True x._recording = True
...@@ -612,9 +620,7 @@ class trace: ...@@ -612,9 +620,7 @@ class trace:
assert info.external assert info.external
assert info.bound_data assert info.bound_data
info.varnode = graph.make_const( info.varnode = graph.make_const(
info.bound_data.numpy(), info.get_numpy(), info.bound_data.dtype, info.bound_data.device,
info.bound_data.dtype,
info.bound_data.device,
) )
continue continue
...@@ -627,7 +633,7 @@ class trace: ...@@ -627,7 +633,7 @@ class trace:
if info.bound_data: if info.bound_data:
if getattr(info, "is_const", False): if getattr(info, "is_const", False):
info.varnode = graph.make_const( info.varnode = graph.make_const(
info.bound_data.numpy(), info.get_numpy(),
info.bound_data.dtype, info.bound_data.dtype,
info.bound_data.device, info.bound_data.device,
) )
...@@ -1174,7 +1180,7 @@ class trace: ...@@ -1174,7 +1180,7 @@ class trace:
assert info.external assert info.external
assert info.bound_data assert info.bound_data
h2v[h] = graph.make_const( h2v[h] = graph.make_const(
info.bound_data.numpy(), info.get_numpy(),
dtype=info.dtype, dtype=info.dtype,
device=dumped_device(info), device=dumped_device(info),
name=info.name, name=info.name,
...@@ -1187,7 +1193,7 @@ class trace: ...@@ -1187,7 +1193,7 @@ class trace:
assert info.external assert info.external
assert info.bound_data assert info.bound_data
h2v[h] = graph.make_const( h2v[h] = graph.make_const(
info.bound_data.numpy(), info.get_numpy(),
dtype=info.dtype, dtype=info.dtype,
device=dumped_device(info), device=dumped_device(info),
name=info.name, name=info.name,
......
...@@ -1074,6 +1074,10 @@ void init_tensor(py::module m) { ...@@ -1074,6 +1074,10 @@ void init_tensor(py::module m) {
[]() { []() {
interpreter_for_py->sync(); interpreter_for_py->sync();
CompNode::sync_all(); CompNode::sync_all();
CompNode::foreach ([](CompNode cn) {
auto err = cn.check_async_error();
mgb_assert(!err, "%s", err->what());
});
sync_py_task_q(); sync_py_task_q();
}, },
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
......
...@@ -96,6 +96,15 @@ def test_regression_2870(): ...@@ -96,6 +96,15 @@ def test_regression_2870():
(x + x).numpy() (x + x).numpy()
@pytest.mark.require_ngpu(1)
def test_async_error_check():
src = mge.tensor([[1.0, 2.0]])
index = mge.tensor([3])
val = F.indexing_one_hot(src, index)
with pytest.raises(RuntimeError):
val.numpy()
# NOTE: DO NOT REMOVE THIS TEST # NOTE: DO NOT REMOVE THIS TEST
# This is also a compatibility test for # This is also a compatibility test for
# mge.core.set_option('async_level', 0). # mge.core.set_option('async_level', 0).
......
...@@ -156,6 +156,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { ...@@ -156,6 +156,8 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) {
if (m_async_level == 0) { if (m_async_level == 0) {
sync_impl(); sync_impl();
info->desc.comp_node.sync(); info->desc.comp_node.sync();
auto err = info->desc.comp_node.check_async_error();
mgb_assert(!err, "%s", err->what());
} }
return info; return info;
} }
...@@ -336,6 +338,8 @@ void ChannelImpl::dispatch_kernel( ...@@ -336,6 +338,8 @@ void ChannelImpl::dispatch_kernel(
for (auto&& oup : *outputs) { for (auto&& oup : *outputs) {
auto info = reinterpret_cast<TensorInfo*>(oup); auto info = reinterpret_cast<TensorInfo*>(oup);
info->ptr->comp_node().sync(); info->ptr->comp_node().sync();
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
} }
} }
} }
...@@ -931,7 +935,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { ...@@ -931,7 +935,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop);
bool require_host = prop == TensorProp::HostValue; bool require_host = prop == TensorProp::HostValue;
auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
if (require_host && !host_available()) { bool wait_host = !host_available();
if (require_host && wait_host) {
// avoid dead lock // avoid dead lock
lock.unlock(); lock.unlock();
m_buffer.enqueue(GetValue{info}); m_buffer.enqueue(GetValue{info});
...@@ -944,6 +949,10 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { ...@@ -944,6 +949,10 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
}); });
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop);
m_waitee = nullptr; m_waitee = nullptr;
if (require_host && wait_host) {
auto err = info->ptr->comp_node().check_async_error();
mgb_assert(!err, "%s", err->what());
}
return info->ptr; return info->ptr;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册