diff --git a/imperative/python/src/imperative_rt.cpp b/imperative/python/src/imperative_rt.cpp index c96efbfe757662649549d1eee4661b6c5aaa51bf..fc181293412033b6f70bc55435d02f195ba2b25d 100644 --- a/imperative/python/src/imperative_rt.cpp +++ b/imperative/python/src/imperative_rt.cpp @@ -77,12 +77,14 @@ void init_imperative_rt(py::module m) { .def("get_shape", &Interpreter::Channel::get_shape) .def("_get_dev_tensor", &Interpreter::Channel::get_dev_tensor) .def("apply_op", &Interpreter::Channel::apply_op) + .def("config_async_level", &Interpreter::Channel::config_async_level) + .def("get_async_level", &Interpreter::Channel::get_async_level) .def("sync", &Interpreter::Channel::sync, py::call_guard()); std::unique_ptr ch = Interpreter::inst().create_channel(); m.attr("interpreter") = py::detail::make_caster::cast( std::move(ch), py::return_value_policy::move, {}); - for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op"}) { + for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op", "config_async_level", "get_async_level"}) { m.attr(name) = m.attr("interpreter").attr(name); } diff --git a/imperative/python/test/unit/core/test_async_level.py b/imperative/python/test/unit/core/test_async_level.py new file mode 100644 index 0000000000000000000000000000000000000000..08f4d28ce2f0ef0b3e3fc1adf22d6005e65fbd0b --- /dev/null +++ b/imperative/python/test/unit/core/test_async_level.py @@ -0,0 +1,35 @@ +import pytest + +import megengine as mge +import megengine.functional as F +from megengine.core._imperative_rt.imperative import config_async_level, get_async_level + + +def test_basic(): + config_async_level(2) + assert get_async_level() == 2 + with pytest.raises(RuntimeError): + config_async_level(3) + + +def test_level1_infer_value(): + config_async_level(1) + a = mge.tensor([[1, 2], [2, 3], [3, 4]], dtype="float32") + b = mge.tensor([1, 1], dtype="float32") + # make DepType::VALUE unknown + c = b * 2 + with pytest.raises(RuntimeError): + d = F.reshape(a, c) + + +def test_level1_infer_shape_with_unknown(): + config_async_level(2) + a = mge.tensor([[1, 2, 2, 3]], dtype="float32") + b = mge.tensor([1, 1]) + c = b * 2 + # make DepType::SHAPE unknown + d = F.reshape(a, c) + config_async_level(1) + e = mge.tensor([[1, 2]], dtype="float32") + with pytest.raises(RuntimeError): + f = F.matmul(d, e) diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index 79800ed2bd19eba93f072b8d5b04e7dbb92d2367..b0500549ab13cf65f8d57e83ab6f49b8d840ec24 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -54,21 +54,25 @@ void ChannelImpl::del(void* handle) { SmallVector ChannelImpl::apply_op( std::shared_ptr op, const SmallVector& inputs) { + SmallVector input_infos; + input_infos.reserve(inputs.size()); SmallVector input_descs; input_descs.reserve(inputs.size()); - for (auto h : inputs) { - auto info = reinterpret_cast(h); + for (auto i : inputs) { + auto info = reinterpret_cast(i); + input_infos.push_back(info); input_descs.push_back(info->desc); } auto output_descs = OpDef::infer_output_attrs_fallible(*op, input_descs); ApplyOp cmd{std::move(op)}; - cmd.inputs.reserve(inputs.size()); - for (auto i : inputs) { - cmd.inputs.push_back(reinterpret_cast(i)); - } + cmd.inputs = std::move(input_infos); cmd.outputs.reserve(output_descs.size()); SmallVector outputs; + bool is_fallible = false; for (auto&& desc : output_descs) { + if (desc.layout.ndim == 0) { + is_fallible = true; + } auto info = alloc(); info->desc = desc; m_valid_handle.insert(info); @@ -76,6 +80,9 @@ SmallVector ChannelImpl::apply_op( outputs.push_back(info); } m_worker.add_task(std::move(cmd)); + if (is_fallible && m_async_level <= 1) { + sync(); + } return outputs; } @@ -162,7 +169,12 @@ void ChannelImpl::close() { } void ChannelImpl::config_async_level(int level) { - mgb_assert(0); + mgb_assert(level <= 2 and level >= 0, "async_level should be 0, 1 or 2"); + m_async_level = level; +} + +int ChannelImpl::get_async_level() { + return m_async_level; } TensorInfo* ChannelImpl::alloc() { diff --git a/imperative/src/impl/interpreter_impl.h b/imperative/src/impl/interpreter_impl.h index 4676d27af95cf56bda3cefb32460806145ab6056..652a31ea27352280e255cede01ee1bcd9d749bbc 100644 --- a/imperative/src/impl/interpreter_impl.h +++ b/imperative/src/impl/interpreter_impl.h @@ -74,6 +74,7 @@ struct ChannelImpl : Interpreter::Channel { void close() override; void config_async_level(int level) override; + int get_async_level() override; private: TensorInfo* alloc(); @@ -101,7 +102,11 @@ private: ChannelImpl* m_owner; } m_worker; - int m_async_level = 2; + //! config whether raise error exactly when invoking op. + //! level 2: both device and user side errors are async; + //! level 1: user side errors are sync; + //! level 0: both sync. + int m_async_level = 1; }; } // namespace mgb::imperative::interpreter::intl diff --git a/imperative/src/include/megbrain/imperative/interpreter.h b/imperative/src/include/megbrain/imperative/interpreter.h index 016d20551d4c41b857837cbebd1b8781fb23db19..12de2a729a2a6ac6585a7ec148e82ce1bc64dea2 100644 --- a/imperative/src/include/megbrain/imperative/interpreter.h +++ b/imperative/src/include/megbrain/imperative/interpreter.h @@ -41,6 +41,7 @@ struct Interpreter { virtual void close() = 0; virtual void config_async_level(int level) = 0; + virtual int get_async_level() = 0; }; virtual std::unique_ptr create_channel() = 0;