提交 bcf69d8f 编写于 作者: M Megvii Engine Team

refactor(imperative): correctly apply sqrt sampling for dtr

GitOrigin-RevId: dabd36551765af1d2646789ae9ed57d8eac4a936
上级 48100781
......@@ -646,6 +646,10 @@ void ChannelImpl::release_tensor(TensorInfo* dest) {
MGB_RECORD_EVENT(TensorReleaseEvent, dest->id);
MGB_LOCK_GUARD(m_mutex);
dest->ptr.reset();
auto& state = get_worker_state();
if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) {
m_dtr.erase_candidate(dest);
}
}
void ChannelImpl::regenerate(TensorInfo* dest) {
......@@ -891,8 +895,7 @@ bool ChannelImpl::auto_evict(size_t force_num) {
force_num > 0) {
MGB_RECORD_EVENT(AutoEvictEvent);
sample_on_device(m_dtr.comp_node, false);
auto best = m_dtr.find_best_tensor(
state.options.enable_dtr_sqrt_sampling && !force_num);
auto best = m_dtr.find_best_tensor(state.options.enable_dtr_sqrt_sampling);
if (!best) {
MGB_RECORD_EVENT(AutoEvictFinishEvent);
break;
......@@ -1300,7 +1303,6 @@ void ChannelImpl::CommandBuffer::enqueue(CommandData cmd) {
if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
return;
}
// mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
m_commands.push_back(
{Profiler::next_id(), std::move(cmd), state.stack_manager.dump()});
auto flush_pos = flush_pos_for(m_commands.back());
......@@ -1365,7 +1367,6 @@ bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
if (apply_iter == end || find_last_usage(dest, {apply_iter + 1, end}) != end) {
return false;
}
// mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
std::get<ApplyOp>(apply_iter->data).dels.push_back(dest);
return true;
}
......@@ -1538,16 +1539,26 @@ double ChannelImpl::DynamicSublinear::estimate_neighbor_cost(TensorInfo* ptr) {
TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
bool enable_dtr_sqrt_sampling = false) {
if (candidates.empty())
return nullptr;
double min_msps = -1;
TensorInfo* best = nullptr;
size_t sz = 1;
if (enable_dtr_sqrt_sampling) {
while (sz * sz <= candidates.size())
sz++;
sz--;
} else {
sz = candidates.size();
}
for (auto i : candidates) {
size_t ti = rand() % sz;
for (size_t vi = 0; vi < sz; vi++) {
if (!enable_dtr_sqrt_sampling) {
ti = vi;
}
auto i = candidates[ti];
if (i->producer && i->ptr && !i->pinned && i->evict_type == EvictType::NONE) {
double neighbor_cost = estimate_neighbor_cost(i);
size_t begin_ptr =
......@@ -1562,8 +1573,11 @@ TensorInfo* ChannelImpl::DynamicSublinear::find_best_tensor(
best = i;
}
}
if (--sz == 0)
break;
if (enable_dtr_sqrt_sampling) {
ti += rand() % sz;
if (ti > candidates.size())
break;
}
}
return best;
}
......@@ -1590,14 +1604,25 @@ std::shared_ptr<DsuNode> ChannelImpl::DynamicSublinear::find_father(
}
void ChannelImpl::DynamicSublinear::insert_candidate(TensorInfo* ptr) {
candidates.insert(ptr);
// tensor to be inserted must be brand new
mgb_assert(
ptr->cand_index == UINT_MAX, "got wrong candidate index : %lu",
ptr->cand_index);
ptr->cand_index = candidates.size();
candidates.push_back(ptr);
if (!comp_node.valid()) {
comp_node = ptr->ptr->comp_node();
}
}
void ChannelImpl::DynamicSublinear::erase_candidate(TensorInfo* ptr) {
candidates.erase(ptr);
// some tensors may be erased already, so just skip them
if (ptr->cand_index != UINT_MAX) {
std::swap(candidates[ptr->cand_index], candidates.back());
candidates[ptr->cand_index]->cand_index = ptr->cand_index;
candidates.pop_back();
ptr->cand_index = UINT_MAX;
}
}
void ChannelImpl::DynamicSublinear::update_used_time(TensorInfo* ptr) {
......
......@@ -335,7 +335,7 @@ private:
CompNode comp_node;
//! store all tensors that may be evicted
std::unordered_set<TensorInfo*> candidates;
SmallVector<TensorInfo*> candidates;
bool is_bad_op(std::string op_name) {
return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) !=
......
......@@ -170,6 +170,9 @@ struct TensorInfo {
bool size_exceeds_thd(size_t thd) { return memory > thd; }
SmallVector<ComputePath*> users;
// UINT_MAX as a magic default value
size_t cand_index = UINT_MAX;
};
} // namespace interpreter::intl
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册