diff --git a/imperative/python/megengine/dtr/dtr.py b/imperative/python/megengine/dtr/dtr.py index f19c0d2a054d4dd58399c856da08669440b27498..382bbec38c9ea23716f3174ea562290c2c43aec3 100644 --- a/imperative/python/megengine/dtr/dtr.py +++ b/imperative/python/megengine/dtr/dtr.py @@ -9,6 +9,7 @@ import re from typing import Union +from ..core._imperative_rt.core2 import clear_candidates as _clear_candidates from ..core._imperative_rt.core2 import set_option as _set_option _eviction_threshold = 0 @@ -128,3 +129,4 @@ def disable(): _set_option("enable_dtr_auto_drop", 0) _set_option("enable_drop", 0) _set_option("record_computing_path", 0) + _clear_candidates() diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 6a90f3c829e917404d28b9c9ce1bcf755d7e11d7..87f2459d5b7f3b3a683573996429302c191e47e5 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -1020,6 +1020,7 @@ void init_tensor(py::module m) { m.def("set_option", [](std::string name, size_t value) { interpreter_for_py->set_option(name, value); }); + m.def("clear_candidates", []() { interpreter_for_py->clear_candidates(); }); m.def("get_option", [](std::string name) { return interpreter_for_py->get_option(name); }); m.def("_set_drop_flag", diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index f3f98dc3bea52adb1aa93ff382096597089279b3..d80f23e7bcd9f6188ad2cc72c2176a40c630b963 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -494,6 +494,12 @@ void ChannelImpl::set_option(std::string name, size_t value) { m_buffer.enqueue(SetOption{name, value}); } +void ChannelImpl::clear_candidates() { + MGB_LOCK_GUARD(m_spin); + mgb_assert(check_available(), "Channel already closed"); + m_dtr.candidates.clear(); +} + TensorInfo* ChannelImpl::alloc() { auto& state = get_channel_state(); auto info = [this] { @@ -798,7 +804,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { i->compute_time = estimate_compute_time; } } - m_dtr.unpin(cmd.inputs); + m_dtr.unpin(cmd.inputs, state); } MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason); // End profiling operator @@ -1430,12 +1436,19 @@ void ChannelImpl::sample_on_device(CompNode device, bool force) { void ChannelImpl::DynamicSublinear::pin(const SmallVector& vec) { for (auto i : vec) { i->pin(); + erase_candidate(i); } } -void ChannelImpl::DynamicSublinear::unpin(const SmallVector& vec) { +void ChannelImpl::DynamicSublinear::unpin( + const SmallVector& vec, WorkerState& state) { for (auto i : vec) { i->unpin(); + if (i->pinned == 0 && + i->size_exceeds_thd(state.options.dtr_evictee_minimum_size) && + i->cand_index == UINT_MAX) { + insert_candidate(i); + } } } @@ -1504,7 +1517,7 @@ TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor( ti = vi; } auto i = candidates[ti]; - if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) { + if (i->producer && i->ptr && i->evict_type == EvictType::NONE) { double neighbor_cost = estimate_neighbor_cost(i); size_t begin_ptr = reinterpret_cast(i->ptr->blob()->storage().get()); @@ -1561,7 +1574,12 @@ void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) { } void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) { - // some tensors may be erased already, so just skip them + // close dtr will just clear candidates, so nothing to erase + if (candidates.empty()) { + ptr->cand_index = UINT_MAX; + return; + } + // some tensors may be erased already, just skip them if (ptr->cand_index != UINT_MAX) { std::swap(candidates[ptr->cand_index], candidates.back()); candidates[ptr->cand_index]->cand_index = ptr->cand_index; diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index a786130b18028e2e0d9ff8abf1ca49d0c272001d..0b480205d50db3a196a0135000dfbd5b75a41f8d 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -64,6 +64,7 @@ struct ChannelImpl : Interpreter::Channel { size_t get_option(std::string name) override; void set_option(std::string name, size_t value) override; + void clear_candidates() override; void start_profile() override; void stop_profile() override; @@ -308,7 +309,7 @@ private: /*! * \brief unpin the tensors in vec */ - void unpin(const SmallVector& vec); + void unpin(const SmallVector& vec, WorkerState& state); /*! * \brief add the tensor to the candidate set diff --git a/imperative/src/include/megbrain/imperative/interpreter.h b/imperative/src/include/megbrain/imperative/interpreter.h index 88a28bf782e9e60d1b1bbb4c72fe8c3f5e33c26a..86fe691a454281b8a82af0625726f5af9120225a 100644 --- a/imperative/src/include/megbrain/imperative/interpreter.h +++ b/imperative/src/include/megbrain/imperative/interpreter.h @@ -57,6 +57,7 @@ struct Interpreter { virtual size_t get_option(std::string name) = 0; virtual void set_option(std::string name, size_t value) = 0; + virtual void clear_candidates() = 0; virtual void start_profile() = 0; virtual void stop_profile() = 0;