提交 7e909a09 编写于 作者: M Megvii Engine Team

fix(imperative): fix imperative exit segment fault

GitOrigin-RevId: dc498e3634a11aae4e3b3af69c0f398a55b3b0dc
上级 7302354a
......@@ -994,6 +994,15 @@ void init_tensor(py::module m) {
sync_py_task_q();
});
m.def("close", [channel]() {
// sync channel and compnode before close to ensure all tasks have been completed
if (channel->check_available()) {
channel->sync();
}
CompNode::sync_all();
CompNode::foreach ([](CompNode cn) {
auto err = cn.check_async_error();
mgb_assert(!err, "%s", err->what());
});
channel->close();
sync_py_task_q();
});
......
......@@ -71,3 +71,30 @@ def test_opdef_path():
assert Mode.__module__ == "megengine.core._imperative_rt.ops"
assert Mode.__name__ == "Mode"
assert Mode.__qualname__ == "Elemwise.Mode"
def _exit_impl():
import numpy as np
import megengine
from megengine import functional as F
megengine.set_default_device("cpu0")
in_channel = 32
out_channel = 32
x = megengine.tensor(np.random.randn(32, in_channel, 224, 224).astype(np.float32))
w = megengine.tensor(
np.random.randn(out_channel, in_channel, 3, 3).astype(np.float32)
)
y = F.conv2d(x, w)
def test_imperative_exit():
import multiprocessing as mp
recover = mp.get_start_method()
mp.set_start_method("spawn", force=True)
pro = mp.Process(target=_exit_impl)
pro.start()
pro.join()
assert pro.exitcode == 0, f"{pro.exitcode}"
mp.set_start_method(recover, force=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册