提交 0a6f4a88 编写于 作者: M Megvii Engine Team

fix(mge/dtr): fix dtr problem

GitOrigin-RevId: 2a703f9ee4ebf8667ac889a73f67c688dfebd9bc
上级 529b394f
......@@ -41,6 +41,10 @@ void BlobManagerImpl::unregister_blob(Blob* blob) {
}
void BlobManagerImpl::alloc_with_defrag(Blob* blob, size_t size) {
if (custom_allocator) {
blob->m_storage = custom_allocator(blob->m_comp_node, size);
return;
}
// try alloc
MGB_TRY { alloc_direct(blob, size); }
// if fail, try defrag, alloc again
......@@ -61,6 +65,13 @@ void BlobManagerImpl::alloc_direct(Blob* blob, size_t size) {
DeviceTensorND BlobManagerImpl::alloc_workspace_with_defrag(
CompNode cn, TensorLayout& layout) {
DeviceTensorND dev_tensor;
if (custom_allocator) {
DeviceTensorStorage storage(cn);
size_t sz = layout.dtype.size(layout.total_nr_elems());
storage.reset(cn, sz, custom_allocator(cn, sz));
dev_tensor.reset(storage, layout);
return dev_tensor;
}
MGB_TRY { return alloc_workspace(cn, layout); }
MGB_CATCH(MemAllocError&, {
mgb_log_warn("memory allocation failed for workspace; try defragmenting");
......@@ -78,6 +89,10 @@ DeviceTensorND BlobManagerImpl::alloc_workspace(CompNode cn, TensorLayout layout
return dev_tensor;
}
void BlobManagerImpl::set_allocator(allocator_t allocator) {
custom_allocator = allocator;
}
void BlobManagerImpl::defrag(const CompNode& cn) {
BlobSetWithMux* blobs_set_ptr;
{
......@@ -159,6 +174,9 @@ struct BlobManagerStub : BlobManager {
void defrag(const CompNode& cn) {
mgb_assert(0, "prohibited after global variable destruction");
};
virtual void set_allocator(allocator_t allocator) {
mgb_assert(0, "prohibited after global variable destruction");
};
};
BlobManager* BlobManager::inst() {
......
......@@ -45,6 +45,8 @@ class BlobManagerImpl final : public BlobManager {
DeviceTensorND alloc_workspace(CompNode cn, TensorLayout layout);
BlobManager::allocator_t custom_allocator;
public:
static BlobManager* inst();
......@@ -56,6 +58,8 @@ public:
void register_blob(Blob* blob) override;
void unregister_blob(Blob* blob) override;
void set_allocator(allocator_t allocator) override;
};
} // namespace imperative
......
......@@ -49,6 +49,7 @@ struct ApplyOp {
std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> outputs;
SmallVector<LogicalTensorDesc> outputs_descs;
bool validated = false;
template <typename TFunctor>
......
......@@ -114,11 +114,13 @@ ChannelImpl::WorkerState& ChannelImpl::get_worker_state() {
void ChannelImpl::WorkQueue::on_async_queue_worker_thread_start() {
sys::set_thread_name("worker");
m_owner->m_worker_state.tid = std::this_thread::get_id();
OpDef::set_allocator([&](CompNode device, size_t size) {
auto custom_allocator = [&](CompNode device, size_t size) {
auto blob = Blob::make(device, size);
m_owner->alloc_tensor_with_evict(blob.get());
return blob->storage();
});
};
OpDef::set_allocator(custom_allocator);
BlobManager::inst()->set_allocator(custom_allocator);
}
// Do not use m_xxx_state directly
......@@ -353,7 +355,7 @@ void ChannelImpl::dispatch_kernel(
for (int i = 0; i < output_descs.size(); ++i) {
auto&& desc = output_descs[i];
auto info = alloc();
init(info, std::move(desc));
init(info, desc);
// make sure desc's value is consistent with h_value
if (!info->desc.value.empty()) {
info->h_value = HostTensorND::make_proxy(desc.value)
......@@ -362,9 +364,9 @@ void ChannelImpl::dispatch_kernel(
output_infos.push_back(info);
outputs->push_back(reinterpret_cast<Handle>(info));
}
ApplyOp cmd{
Profiler::next_id(), std::move(op), std::move(input_infos),
std::move(output_infos), validated};
ApplyOp cmd{Profiler::next_id(), std::move(op),
std::move(input_infos), std::move(output_infos),
std::move(output_descs), validated};
if (Profiler::is_profiling()) {
auto op_info_getter = [op = cmd.op] {
std::unordered_map<std::string, std::string> op_info;
......@@ -594,7 +596,7 @@ TensorInfo* ChannelImpl::alloc() {
return info;
}
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc&& desc) {
void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) {
m_valid_handle.insert(reinterpret_cast<Handle>(info));
MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name);
info->status = TensorInfo::Allocated;
......@@ -692,6 +694,11 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
"shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(),
ptr->layout().to_string().c_str());
}
// in order to avoid performance impact,
// memory forwarding is disabled when DTR is enabled
if (state.options.enable_dtr_auto_drop) {
ptr->to_contiguous_inplace();
}
dest->desc.layout = ptr->layout();
dest->desc.comp_node = ptr->comp_node();
dest->memory = ptr->blob()->size();
......@@ -719,8 +726,9 @@ void ChannelImpl::regenerate(TensorInfo* dest) {
if (dest->evict_type == EvictType::DROP) {
auto&& path = dest->producer;
m_apply_stack.push(
{ApplyOp{path->id, path->op, path->inputs, path->outputs}, 0, dest,
"dtr"});
{ApplyOp{path->id, path->op, path->inputs, path->outputs,
path->outputs_descs},
0, dest, "dtr"});
if (!m_applying)
flush_apply_stack();
}
......@@ -812,8 +820,8 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
}
// Apply op
SmallVector<LogicalTensorDesc> output_descs;
for (auto i : cmd.outputs) {
output_descs.push_back(i->desc);
for (auto i : cmd.outputs_descs) {
output_descs.push_back(i);
}
// Here std::move is REQUIRED for removing duplicated references.
auto outputs = apply_on_physical_tensor(
......@@ -1031,6 +1039,7 @@ std::unordered_set<TensorInfo*> ChannelImpl::collect_valid_tensors() {
}
void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
bool in_worker = (get_worker_tid() == std::this_thread::get_id());
auto reserve_size = [&](size_t size) {
if (!m_dtr.comp_node.valid()) {
return false;
......@@ -1043,17 +1052,21 @@ void ChannelImpl::alloc_tensor_with_evict(Blob* x) {
return true;
};
auto pre_level = set_log_level(LogLevel::NO_LOG);
reserve_size(x->size());
if (in_worker) {
reserve_size(x->size());
}
MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
MGB_CATCH(MemAllocError&, {
bool suc = false;
while (!suc) {
if (!auto_evict(1)) {
break;
if (in_worker) {
while (!suc) {
if (!auto_evict(1)) {
break;
}
MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
MGB_CATCH(MemAllocError&, { continue; });
suc = true;
}
MGB_TRY { BlobManager::inst()->alloc_direct(x, x->size()); }
MGB_CATCH(MemAllocError&, { continue; });
suc = true;
}
if (!suc) {
set_log_level(pre_level);
......@@ -1143,10 +1156,10 @@ void ChannelImpl::process_one_task(Command& icmd) {
if (!inplace && !cross_cn && !m_dtr.is_bad_op(get_name(*cmd.op))) {
TensorInfo::ComputePath::make(
cmd.id, cmd.op, cmd.inputs, cmd.outputs);
cmd.id, cmd.op, cmd.inputs, cmd.outputs, cmd.outputs_descs);
size_t detach_cnt = 0;
if (!strcmp(get_name(*cmd.op), "BatchNorm") &&
cmd.outputs.size() == 5) {
cmd.outputs.size() == 6) {
cmd.outputs[0]->detach_producer(); // detach running_mean
cmd.outputs[1]->detach_producer(); // detach running_var
for (auto input : cmd.inputs) {
......
......@@ -77,7 +77,7 @@ private:
struct State;
TensorInfo* alloc();
void init(TensorInfo*, LogicalTensorDesc&& desc);
void init(TensorInfo*, LogicalTensorDesc desc);
void free(TensorInfo*);
void real_free(TensorInfo*);
void recursive_free(TensorInfo*);
......
......@@ -91,6 +91,7 @@ struct TensorInfo {
SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> unique_inputs;
SmallVector<TensorInfo*> outputs;
SmallVector<LogicalTensorDesc> outputs_descs;
size_t ref_cnt() {
return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr);
......@@ -98,12 +99,14 @@ struct TensorInfo {
static ComputePath* make(
uint64_t id, std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs,
SmallVector<TensorInfo*> outputs) {
SmallVector<TensorInfo*> outputs,
SmallVector<LogicalTensorDesc> outputs_descs) {
auto* path = new TensorInfo::ComputePath();
path->id = id;
path->op = op;
path->inputs = inputs;
path->outputs = outputs;
path->outputs_descs = outputs_descs;
// dedup
SmallVector<TensorInfo*> unique_inputs = inputs;
std::sort(unique_inputs.begin(), unique_inputs.end());
......
......@@ -87,7 +87,7 @@ Blob::~Blob() {
}
const Blob::RawStorage& Blob::storage() {
if (!m_storage) {
if (!m_storage && m_size) {
BlobManager::inst()->alloc_with_defrag(this, m_size);
}
return m_storage;
......
......@@ -18,6 +18,8 @@ namespace imperative {
class BlobManager : public NonCopyableObj {
public:
using allocator_t =
std::function<DeviceTensorStorage::RawStorage(CompNode, size_t)>;
virtual ~BlobManager() = default;
static BlobManager* inst();
......@@ -26,6 +28,8 @@ public:
virtual void alloc_with_defrag(Blob* blob, size_t size) = 0;
virtual void set_allocator(allocator_t allocator) = 0;
virtual DeviceTensorND alloc_workspace_with_defrag(
CompNode cn, TensorLayout& layout) = 0;
......
......@@ -119,7 +119,7 @@ public:
return make_scalar(value, m_blob->comp_node());
}
BlobPtr blob() { return m_blob; }
BlobPtr& blob() { return m_blob; }
void fetch_value();
bool value_fetched();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册