提交 a2a09ef9 编写于 作者: M Megvii Engine Team

fix(imperative): release dtr related resources when disable dtr

GitOrigin-RevId: eacfded9dec8989252de9956d16c1b1cd6f3a560
上级 2676fb73
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
import re import re
from typing import Union from typing import Union
from ..core._imperative_rt.core2 import clear_candidates as _clear_candidates
from ..core._imperative_rt.core2 import set_option as _set_option from ..core._imperative_rt.core2 import set_option as _set_option
_eviction_threshold = 0 _eviction_threshold = 0
...@@ -128,3 +129,4 @@ def disable(): ...@@ -128,3 +129,4 @@ def disable():
_set_option("enable_dtr_auto_drop", 0) _set_option("enable_dtr_auto_drop", 0)
_set_option("enable_drop", 0) _set_option("enable_drop", 0)
_set_option("record_computing_path", 0) _set_option("record_computing_path", 0)
_clear_candidates()
...@@ -1020,6 +1020,7 @@ void init_tensor(py::module m) { ...@@ -1020,6 +1020,7 @@ void init_tensor(py::module m) {
m.def("set_option", [](std::string name, size_t value) { m.def("set_option", [](std::string name, size_t value) {
interpreter_for_py->set_option(name, value); interpreter_for_py->set_option(name, value);
}); });
m.def("clear_candidates", []() { interpreter_for_py->clear_candidates(); });
m.def("get_option", m.def("get_option",
[](std::string name) { return interpreter_for_py->get_option(name); }); [](std::string name) { return interpreter_for_py->get_option(name); });
m.def("_set_drop_flag", m.def("_set_drop_flag",
......
...@@ -494,6 +494,12 @@ void ChannelImpl::set_option(std::string name, size_t value) { ...@@ -494,6 +494,12 @@ void ChannelImpl::set_option(std::string name, size_t value) {
m_buffer.enqueue(SetOption{name, value}); m_buffer.enqueue(SetOption{name, value});
} }
void ChannelImpl::clear_candidates() {
MGB_LOCK_GUARD(m_spin);
mgb_assert(check_available(), "Channel already closed");
m_dtr.candidates.clear();
}
TensorInfo* ChannelImpl::alloc() { TensorInfo* ChannelImpl::alloc() {
auto& state = get_channel_state(); auto& state = get_channel_state();
auto info = [this] { auto info = [this] {
...@@ -798,7 +804,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { ...@@ -798,7 +804,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
i->compute_time = estimate_compute_time; i->compute_time = estimate_compute_time;
} }
} }
m_dtr.unpin(cmd.inputs); m_dtr.unpin(cmd.inputs, state);
} }
MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason); MGB_RECORD_EVENT(OpExecuteFinishEvent, apply_id, {}, reason);
// End profiling operator // End profiling operator
...@@ -1430,12 +1436,19 @@ void ChannelImpl::sample_on_device(CompNode device, bool force) { ...@@ -1430,12 +1436,19 @@ void ChannelImpl::sample_on_device(CompNode device, bool force) {
void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) { void ChannelImpl::DynamicSublinear::pin(const SmallVector<TensorInfo*>& vec) {
for (auto i : vec) { for (auto i : vec) {
i->pin(); i->pin();
erase_candidate(i);
} }
} }
void ChannelImpl::DynamicSublinear::unpin(const SmallVector<TensorInfo*>& vec) { void ChannelImpl::DynamicSublinear::unpin(
const SmallVector<TensorInfo*>& vec, WorkerState& state) {
for (auto i : vec) { for (auto i : vec) {
i->unpin(); i->unpin();
if (i->pinned == 0 &&
i->size_exceeds_thd(state.options.dtr_evictee_minimum_size) &&
i->cand_index == UINT_MAX) {
insert_candidate(i);
}
} }
} }
...@@ -1504,7 +1517,7 @@ TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor( ...@@ -1504,7 +1517,7 @@ TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
ti = vi; ti = vi;
} }
auto i = candidates[ti]; auto i = candidates[ti];
if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) { if (i->producer && i->ptr && i->evict_type == EvictType::NONE) {
double neighbor_cost = estimate_neighbor_cost(i); double neighbor_cost = estimate_neighbor_cost(i);
size_t begin_ptr = size_t begin_ptr =
reinterpret_cast<size_t>(i->ptr->blob()->storage().get()); reinterpret_cast<size_t>(i->ptr->blob()->storage().get());
...@@ -1561,7 +1574,12 @@ void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) { ...@@ -1561,7 +1574,12 @@ void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
} }
void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) { void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
// some tensors may be erased already, so just skip them // close dtr will just clear candidates, so nothing to erase
if (candidates.empty()) {
ptr->cand_index = UINT_MAX;
return;
}
// some tensors may be erased already, just skip them
if (ptr->cand_index != UINT_MAX) { if (ptr->cand_index != UINT_MAX) {
std::swap(candidates[ptr->cand_index], candidates.back()); std::swap(candidates[ptr->cand_index], candidates.back());
candidates[ptr->cand_index]->cand_index = ptr->cand_index; candidates[ptr->cand_index]->cand_index = ptr->cand_index;
......
...@@ -64,6 +64,7 @@ struct ChannelImpl : Interpreter::Channel { ...@@ -64,6 +64,7 @@ struct ChannelImpl : Interpreter::Channel {
size_t get_option(std::string name) override; size_t get_option(std::string name) override;
void set_option(std::string name, size_t value) override; void set_option(std::string name, size_t value) override;
void clear_candidates() override;
void start_profile() override; void start_profile() override;
void stop_profile() override; void stop_profile() override;
...@@ -308,7 +309,7 @@ private: ...@@ -308,7 +309,7 @@ private:
/*! /*!
* \brief unpin the tensors in vec * \brief unpin the tensors in vec
*/ */
void unpin(const SmallVector<TensorInfo*>& vec); void unpin(const SmallVector<TensorInfo*>& vec, WorkerState& state);
/*! /*!
* \brief add the tensor to the candidate set * \brief add the tensor to the candidate set
......
...@@ -57,6 +57,7 @@ struct Interpreter { ...@@ -57,6 +57,7 @@ struct Interpreter {
virtual size_t get_option(std::string name) = 0; virtual size_t get_option(std::string name) = 0;
virtual void set_option(std::string name, size_t value) = 0; virtual void set_option(std::string name, size_t value) = 0;
virtual void clear_candidates() = 0;
virtual void start_profile() = 0; virtual void start_profile() = 0;
virtual void stop_profile() = 0; virtual void stop_profile() = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册