diff --git a/imperative/python/test/integration/test_dtr.py b/imperative/python/test/integration/test_dtr.py index c3ea68d1e5f460b31d49ed2c0f5b5343a8e09ced..592b9273899fe340a561c4c6c6acac6c253b4693 100644 --- a/imperative/python/test/integration/test_dtr.py +++ b/imperative/python/test/integration/test_dtr.py @@ -90,6 +90,18 @@ class ResNet(M.Module): return out +def run_dtr_drop_copy_dev_tensor(): + mge.dtr.evictee_minimum_size = 128 + mge.dtr.enable() + x = F.ones((10, 100)) + x._drop() + x[...] = mge.tensor(x, no_cache=True) + x.numpy() + mge.dtr.evictee_minimum_size = 1024 ** 2 + mge.dtr.disable() + mge._exit(0) + + def run_dtr_resnet1202(): batch_size = 6 resnet1202 = ResNet(BasicBlock, [200, 200, 200]) @@ -135,3 +147,12 @@ def test_dtr_resnet1202(): p.start() p.join() assert p.exitcode == 0 + + +@pytest.mark.require_ngpu(1) +@pytest.mark.isolated_distributed +def test_dtr_drop_copy_dev_tensor(): + p = mp.Process(target=run_dtr_drop_copy_dev_tensor) + p.start() + p.join() + assert p.exitcode == 0 diff --git a/imperative/src/impl/interpreter/commands.h b/imperative/src/impl/interpreter/commands.h index ce05197f40b3ff09b8987df0d0b7d2bb1ea048ed..59c64af40b9d6d75ce3284d5fbe350f9eed47c03 100644 --- a/imperative/src/impl/interpreter/commands.h +++ b/imperative/src/impl/interpreter/commands.h @@ -136,9 +136,31 @@ struct PopScope { const char* get_name() const { return "PopScope"; } }; +struct StartRegen { + TensorInfo* dest; + + template + void get_props(TFunctor&& functor) const { + functor("dest", dest); + } + + const char* get_name() const { return "StartRegen"; } +}; + +struct StopRegen { + TensorInfo* dest; + + template + void get_props(TFunctor&& functor) const { + functor("dest", dest); + } + + const char* get_name() const { return "StopRegen"; } +}; + using CommandData = std::variant< Put, ApplyOp, Del, GetValue, Drop, SetOption, StartProfile, StopProfile, - PushScope, PopScope>; + PushScope, PopScope, StartRegen, StopRegen>; struct Command { uint64_t id; diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index b871a364ff57745d349ba39aaad81ff9752ca4ae..e71187f63ece08ec016e5fc0d71bef8c9404c718 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -1002,8 +1002,11 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { m_waitee_id = Profiler::next_id(); MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); bool require_host = prop == TensorProp::HostValue; + bool require_dev = prop == TensorProp::DevValue; auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; + auto dev_available = [&] { return info->ptr; }; bool wait_host = false; + bool wait_regen = false; if (require_host && !host_available()) { // avoid dead lock lock.unlock(); @@ -1020,16 +1023,52 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { lock.lock(); wait_host = true; } - m_cv.wait(lock, [&]() { - check_worker_exc_unsafe(); - return require_host ? host_available() : static_cast(info->ptr); - }); + if (require_dev && !dev_available()) { + lock.unlock(); + if (Profiler::is_profiling()) { + m_worker.add_task( + {Profiler::next_id(), StartRegen{info}, + get_channel_state().stack_manager.dump()}); + } else { + m_worker.add_task({ + Profiler::next_id(), + StartRegen{info}, + }); + } + lock.lock(); + wait_regen = true; + } + if (require_dev) { + m_cv.wait(lock, [&]() { + check_worker_exc_unsafe(); + return dev_available(); + }); + } else { + m_cv.wait(lock, [&]() { + check_worker_exc_unsafe(); + return require_host ? host_available() : static_cast(info->ptr); + }); + } MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); m_waitee = nullptr; if (wait_host) { auto err = info->ptr->comp_node().check_async_error(); mgb_assert(!err, "%s", err->what()); } + if (wait_regen) { + lock.unlock(); + if (Profiler::is_profiling()) { + m_worker.add_task( + {Profiler::next_id(), StopRegen{info}, + get_channel_state().stack_manager.dump()}); + } else { + m_worker.add_task({ + Profiler::next_id(), + StopRegen{info}, + }); + } + lock.lock(); + } return info->ptr; } @@ -1254,6 +1293,17 @@ void ChannelImpl::process_one_task(Command& icmd) { MGB_RECORD_EVENT(ScopeEvent, cmd.scope_name); } else if constexpr (std::is_same_v) { MGB_RECORD_EVENT(ScopeFinishEvent, cmd.scope_name); + } else if constexpr (std::is_same_v) { + if (cmd.dest->invalid) + return; + cmd.dest->pin(); + if (!cmd.dest->ptr && cmd.dest->evict_type != EvictType::NONE) { + regenerate(cmd.dest); + } + MGB_LOCK_GUARD(m_mutex); + notify_tensor_unsafe(cmd.dest); + } else if constexpr (std::is_same_v) { + cmd.dest->unpin(); } else { static_assert(!std::is_same_v); }