提交 8a73193c 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(dtr): remove eviction threshold

GitOrigin-RevId: 35c2014bf3241d864db21b6e431158535442a672
上级 69d1fd0f
...@@ -114,6 +114,7 @@ def disable(): ...@@ -114,6 +114,7 @@ def disable():
r""" r"""
Stop recording computing path of tensors and performing DTR policy. Stop recording computing path of tensors and performing DTR policy.
""" """
_set_defrag(False)
_set_option("enable_dtr_auto_drop", 0) _set_option("enable_dtr_auto_drop", 0)
_set_option("enable_drop", 0) _set_option("enable_drop", 0)
_set_option("record_computing_path", 0) _set_option("record_computing_path", 0)
...@@ -156,6 +156,9 @@ void BlobManagerImpl::set_enable(bool flag) { ...@@ -156,6 +156,9 @@ void BlobManagerImpl::set_enable(bool flag) {
} }
struct BlobManagerStub : BlobManager { struct BlobManagerStub : BlobManager {
void alloc_direct(Blob* blob, size_t size) {
mgb_assert(0, "prohibited after global variable destruction");
};
void alloc_with_defrag(Blob* blob, size_t size) { void alloc_with_defrag(Blob* blob, size_t size) {
mgb_assert(0, "prohibited after global variable destruction"); mgb_assert(0, "prohibited after global variable destruction");
}; };
......
...@@ -43,7 +43,7 @@ class BlobManagerImpl final: public BlobManager { ...@@ -43,7 +43,7 @@ class BlobManagerImpl final: public BlobManager {
void defrag(const CompNode& cn) override; void defrag(const CompNode& cn) override;
void alloc_direct(Blob* blob, size_t size); void alloc_direct(Blob* blob, size_t size) override;
DeviceTensorND alloc_workspace(CompNode cn, TensorLayout layout); DeviceTensorND alloc_workspace(CompNode cn, TensorLayout layout);
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/utils/to_string.h" #include "megbrain/imperative/utils/to_string.h"
#include "../blob_manager_impl.h"
#include "../event_pool.h" #include "../event_pool.h"
#include "../op_trait.h" #include "../op_trait.h"
...@@ -629,8 +630,9 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { ...@@ -629,8 +630,9 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
tensor_inputs.push_back(i->ptr); tensor_inputs.push_back(i->ptr);
input_memory_desc.push_back(i->mem_desc); input_memory_desc.push_back(i->mem_desc);
} }
// SmallVector<MemoryDesc> outputs_mem_desc; if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) {
// SmallVector<TensorPtr> tensor_outputs, workspaces; auto_evict(0);
}
auto [outputs_mem_desc, tensor_outputs, workspaces] = init_output_and_workspace(*cmd.op, tensor_inputs, input_memory_desc); auto [outputs_mem_desc, tensor_outputs, workspaces] = init_output_and_workspace(*cmd.op, tensor_inputs, input_memory_desc);
if (outputs_mem_desc.size()) { if (outputs_mem_desc.size()) {
for (size_t i = 0;i < outputs_mem_desc.size();i ++) { for (size_t i = 0;i < outputs_mem_desc.size();i ++) {
...@@ -682,9 +684,6 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) { ...@@ -682,9 +684,6 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd) {
for (auto&& [device, kernel_id]: kernels) { for (auto&& [device, kernel_id]: kernels) {
RECORD_EVENT(KernelExecuteEvent, apply_id, kernel_id, Timer::record_event(device)); RECORD_EVENT(KernelExecuteEvent, apply_id, kernel_id, Timer::record_event(device));
} }
if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) {
auto_evict();
}
// Apply op // Apply op
// Here std::move is REQUIRED for removing duplicated references. // Here std::move is REQUIRED for removing duplicated references.
if (outputs_mem_desc.size()) { if (outputs_mem_desc.size()) {
...@@ -752,29 +751,26 @@ void ChannelImpl::recompute(TensorInfo::ComputePath* path) { ...@@ -752,29 +751,26 @@ void ChannelImpl::recompute(TensorInfo::ComputePath* path) {
} }
} }
void ChannelImpl::auto_evict() { bool ChannelImpl::auto_evict(size_t force_num=0) {
auto& state = get_worker_state(); auto& state = get_worker_state();
if (!m_dtr.comp_node.valid()) { if (!m_dtr.comp_node.valid()) {
return; return false;
} }
size_t current_memory = m_dtr.comp_node.get_used_memory(); size_t current_memory = m_dtr.comp_node.get_used_memory();
while (current_memory > state.options.dtr_eviction_threshold) { size_t flag = false;
while ((state.options.dtr_eviction_threshold > 0 && current_memory > state.options.dtr_eviction_threshold) || force_num > 0) {
RECORD_EVENT(AutoEvictEvent); RECORD_EVENT(AutoEvictEvent);
sample_on_device(m_dtr.comp_node, false); sample_on_device(m_dtr.comp_node, false);
auto best = m_dtr.find_best_tensor(); auto best = m_dtr.find_best_tensor();
if (!best) { if (!best) {
if (!m_dtr.warn_printed) {
m_dtr.warn_printed = true;
mgb_log_warn("No tensors on %s can be evicted automatically "
"when memory usage is %.0lfMB. Maybe memory "
"budget is too small.",
m_dtr.comp_node.to_string().c_str(),
current_memory / 1024.0 / 1024.0);
}
break; break;
} }
if (best->ptr.unique() && best->ptr->blob().unique()) { if (best->ptr.unique() && best->ptr->blob().unique()) {
current_memory -= best->memory; current_memory -= best->memory;
if (force_num > 0) {
force_num --;
}
flag = true;
} }
do_drop(best); do_drop(best);
if (best->evict_type == EvictType::DROP) { if (best->evict_type == EvictType::DROP) {
...@@ -783,6 +779,7 @@ void ChannelImpl::auto_evict() { ...@@ -783,6 +779,7 @@ void ChannelImpl::auto_evict() {
sample_on_device(m_dtr.comp_node, false); sample_on_device(m_dtr.comp_node, false);
RECORD_EVENT(AutoEvictFinishEvent); RECORD_EVENT(AutoEvictFinishEvent);
} }
return flag;
} }
void ChannelImpl::detach_users(TensorInfo* dest) { void ChannelImpl::detach_users(TensorInfo* dest) {
...@@ -859,6 +856,41 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() { ...@@ -859,6 +856,41 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
return valid_tensors; return valid_tensors;
} }
void ChannelImpl::alloc_tensor_with_evict(TensorPtr x) {
auto reserve_size = [&](size_t size) {
if (!m_dtr.comp_node.valid()) {
return false;
}
while (size > m_dtr.comp_node.get_max_block_size_available()) {
bool evict_suc = auto_evict(1);
if (!evict_suc) return false;
}
return true;
};
auto pre_level = set_log_level(LogLevel::NO_LOG);
reserve_size(x->blob()->size());
MGB_TRY { BlobManager::inst()->alloc_direct(x->blob().get(), x->blob()->size()); }
MGB_CATCH(MemAllocError&, {
bool suc = false;
while (!suc) {
if (!auto_evict(1)) {
break;
}
MGB_TRY { BlobManager::inst()->alloc_direct(x->blob().get(), x->blob()->size()); }
MGB_CATCH(MemAllocError&, { continue; });
suc = true;
}
if (!suc) {
set_log_level(pre_level);
mgb_log_warn("reallocating all cuda memory to alleviate fragmentation, the performance may be affected");
set_log_level(LogLevel::NO_LOG);
BlobManager::inst()->defrag(x->blob()->comp_node());
BlobManager::inst()->alloc_direct(x->blob().get(), x->blob()->size());
}
});
set_log_level(pre_level);
}
std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> ChannelImpl::init_output_and_workspace( std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> ChannelImpl::init_output_and_workspace(
const OpDef& def, const OpDef& def,
SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> inputs,
...@@ -876,11 +908,15 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPt ...@@ -876,11 +908,15 @@ std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPt
desc.id->id = ++ m_storage_id; desc.id->id = ++ m_storage_id;
} }
} }
auto& state = get_worker_state();
auto alloc_storage = [&](SmallVector<MemoryDesc>& desc) { auto alloc_storage = [&](SmallVector<MemoryDesc>& desc) {
SmallVector<TensorPtr> tensors; SmallVector<TensorPtr> tensors;
for (size_t i = 0; i < desc.size(); i ++) { for (size_t i = 0; i < desc.size(); i ++) {
if (desc[i].id->is_sys_alloc()) { if (desc[i].id->is_sys_alloc()) {
tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn)); tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn));
if (!desc[i].layout.is_empty() && state.options.enable_dtr_auto_drop) {
alloc_tensor_with_evict(tensors.back());
}
} else if (desc[i].id->is_from_other()) { } else if (desc[i].id->is_from_other()) {
for (size_t j = 0; j < inputs_mem_desc.size();j ++) { for (size_t j = 0; j < inputs_mem_desc.size();j ++) {
if (inputs_mem_desc[j].id->desc == desc[i].id->desc) { if (inputs_mem_desc[j].id->desc == desc[i].id->desc) {
......
...@@ -403,19 +403,19 @@ private: ...@@ -403,19 +403,19 @@ private:
//! store all tensors that may be evicted //! store all tensors that may be evicted
std::unordered_set<TensorInfo*> candidates; std::unordered_set<TensorInfo*> candidates;
//! whether the warning message has been printed
bool warn_printed = false;
bool is_bad_op(std::string op_name) { bool is_bad_op(std::string op_name) {
return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) != op_blacklist.end(); return std::find(op_blacklist.begin(), op_blacklist.end(), op_name) != op_blacklist.end();
} }
std::vector<std::string> op_blacklist = {"CollectiveComm", "InplaceAdd", std::vector<std::string> op_blacklist = {"CollectiveComm", "InplaceAdd",
"ParamPackSplit", "ParamPackConcat", "GaussianRNG"}; "ParamPackSplit", "ParamPackConcat", "GaussianRNG", "UniformRNG",
"GammaRNG", "PermutationRNG", "PoissonRNG", "BetaRNG"};
} m_dtr; } m_dtr;
//! automatically evict an optimal tensor //! automatically evict an optimal tensor
void auto_evict(); bool auto_evict(size_t);
void alloc_tensor_with_evict(TensorPtr);
// assert thread id when call get_xxx_state to avoid misuse // assert thread id when call get_xxx_state to avoid misuse
ChannelState& get_channel_state(); ChannelState& get_channel_state();
......
...@@ -22,6 +22,8 @@ public: ...@@ -22,6 +22,8 @@ public:
static BlobManager* inst(); static BlobManager* inst();
virtual void alloc_direct(Blob* blob, size_t size) = 0;
virtual void alloc_with_defrag(Blob* blob, size_t size) = 0; virtual void alloc_with_defrag(Blob* blob, size_t size) = 0;
virtual DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) = 0; virtual DeviceTensorND alloc_workspace_with_defrag(CompNode cn, TensorLayout layout) = 0;
......
...@@ -267,10 +267,14 @@ public: ...@@ -267,10 +267,14 @@ public:
} }
#if !MGB_BUILD_SLIM_SERVING #if !MGB_BUILD_SLIM_SERVING
std::pair<size_t, size_t> get_free_left_and_right(size_t begin_ptr, std::pair<size_t, size_t> get_free_left_and_right(size_t begin_ptr, size_t end_ptr) override {
size_t end_ptr) override {
return m_mem_alloc->get_free_left_and_right(begin_ptr, end_ptr); return m_mem_alloc->get_free_left_and_right(begin_ptr, end_ptr);
} }
size_t get_max_block_size_available() {
activate();
return m_mem_alloc->get_max_block_size_available();
}
#endif #endif
Locator locator() override { return m_locator; } Locator locator() override { return m_locator; }
......
...@@ -40,6 +40,19 @@ std::pair<size_t, size_t> MemAllocImplHelper::get_free_left_and_right(size_t beg ...@@ -40,6 +40,19 @@ std::pair<size_t, size_t> MemAllocImplHelper::get_free_left_and_right(size_t beg
} }
return {left_free, right_free}; return {left_free, right_free};
} }
size_t MemAllocImplHelper::get_max_block_size_available_unsafe() {
if (!m_free_blk_size.size()) {
return 0;
} else {
return m_free_blk_size.rbegin()->first.size;
}
}
size_t MemAllocImplHelper::get_max_block_size_available() {
MGB_LOCK_GUARD(m_mutex);
return get_max_block_size_available_unsafe();
}
#endif #endif
MemAllocImplHelper::MemAddr MemAllocImplHelper::do_alloc( MemAllocImplHelper::MemAddr MemAllocImplHelper::do_alloc(
......
...@@ -116,6 +116,8 @@ class MemAllocImplHelper: virtual public MemAllocBase { ...@@ -116,6 +116,8 @@ class MemAllocImplHelper: virtual public MemAllocBase {
FreeMemStat get_free_memory_self_unsafe(); FreeMemStat get_free_memory_self_unsafe();
#if !MGB_BUILD_SLIM_SERVING #if !MGB_BUILD_SLIM_SERVING
size_t get_max_block_size_available_unsafe();
std::pair<size_t, size_t> get_free_left_and_right(size_t begin_ptr, size_t end_ptr) override; std::pair<size_t, size_t> get_free_left_and_right(size_t begin_ptr, size_t end_ptr) override;
#endif #endif
...@@ -123,6 +125,11 @@ class MemAllocImplHelper: virtual public MemAllocBase { ...@@ -123,6 +125,11 @@ class MemAllocImplHelper: virtual public MemAllocBase {
void print_memory_state() override; void print_memory_state() override;
FreeMemStat get_free_memory() override final; FreeMemStat get_free_memory() override final;
#if !MGB_BUILD_SLIM_SERVING
size_t get_max_block_size_available() override final;
#endif
}; };
......
...@@ -359,6 +359,10 @@ class CompNode { ...@@ -359,6 +359,10 @@ class CompNode {
size_t get_used_memory() const { size_t get_used_memory() const {
return m_impl->get_used_memory(); return m_impl->get_used_memory();
} }
size_t get_max_block_size_available() const {
return m_impl->get_max_block_size_available();
}
#endif #endif
//! change to another stream on the same memory node //! change to another stream on the same memory node
...@@ -545,6 +549,9 @@ class CompNode { ...@@ -545,6 +549,9 @@ class CompNode {
virtual size_t get_used_memory() { virtual size_t get_used_memory() {
return 0; return 0;
} }
virtual size_t get_max_block_size_available() {
return 0;
}
#endif #endif
virtual Locator locator() = 0; virtual Locator locator() = 0;
......
...@@ -141,6 +141,10 @@ class MemAllocBase { ...@@ -141,6 +141,10 @@ class MemAllocBase {
virtual std::pair<size_t, size_t> get_free_left_and_right(size_t begin_ptr, size_t end_ptr) { virtual std::pair<size_t, size_t> get_free_left_and_right(size_t begin_ptr, size_t end_ptr) {
return {0, 0}; return {0, 0};
} }
virtual size_t get_max_block_size_available() {
return 0;
}
#endif #endif
virtual ~MemAllocBase() = default; virtual ~MemAllocBase() = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册