From 221ec38a48906ad76f0d12c090d2229ea5da7329 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 18 Mar 2021 13:40:35 +0800 Subject: [PATCH] feat(imperative): reduce profiler overhead GitOrigin-RevId: dded9d9391a49883f1c1035c1aa10f8174867367 --- .../src/impl/interpreter/interpreter_impl.cpp | 133 ++++++++++++------ .../include/megbrain/imperative/profiler.h | 3 +- 2 files changed, 89 insertions(+), 47 deletions(-) diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 133155165..dd8e7d8ea 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -52,7 +52,9 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { info->desc.layout = data.layout(); info->desc.comp_node = data.comp_node(); info->ptr = Tensor::make(data); - m_channel_state.profiler->record_host(info->id, info->desc.layout, info->desc.comp_node); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(info->id, info->desc.layout, info->desc.comp_node); + } return info; } @@ -147,8 +149,9 @@ void ChannelImpl::dispatch_default_cpu( return tid; }; OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}}; - - m_channel_state.profiler->record_host(event_data); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(event_data); + } OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); @@ -169,8 +172,9 @@ void ChannelImpl::dispatch_default_cpu( } event_data.outputs = tinfo_to_tid(output_infos); - - m_channel_state.profiler->record_host(event_data); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(event_data); + } } void ChannelImpl::dispatch_kernel( @@ -267,13 +271,17 @@ HostTensorND ChannelImpl::get_value(Handle handle) { m_waitee = info; regenerate(info); m_buffer.enqueue(GetValue{info}); - m_channel_state.profiler->record_host(info->id, TensorInfo::HostValue); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(info->id, TensorInfo::HostValue); + } m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); tensor_ptr = info->ptr; return value_fetched(); }); - m_channel_state.profiler->record_host(info->id, TensorInfo::HostValue); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(info->id, TensorInfo::HostValue); + } m_waitee = nullptr; } return tensor_ptr->get_value(); @@ -290,12 +298,16 @@ TensorShape ChannelImpl::get_shape(Handle handle) { mgb_assert(!m_waitee); m_waitee = info; m_buffer.flush(); - m_channel_state.profiler->record_host(info->id, TensorInfo::Shape); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(info->id, TensorInfo::Shape); + } m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); return static_cast(info->ptr); }); - m_channel_state.profiler->record_host(info->id, TensorInfo::Shape); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(info->id, TensorInfo::Shape); + } m_waitee = nullptr; TensorShape ret = info->ptr->layout(); mgb_assert(ret.ndim != 0); @@ -306,7 +318,9 @@ DType ChannelImpl::get_dtype(Handle handle) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); - m_channel_state.profiler->record_host(info->id, TensorInfo::DType); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(info->id, TensorInfo::DType); + } auto ret = info->desc.layout.dtype; mgb_assert(ret.valid()); return ret; @@ -316,7 +330,9 @@ CompNode ChannelImpl::get_device(Handle handle) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); - m_channel_state.profiler->record_host(info->id, TensorInfo::Device); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(info->id, TensorInfo::Device); + } auto ret = info->desc.comp_node; mgb_assert(ret.valid()); return ret; @@ -331,22 +347,30 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { m_waitee = info; regenerate(info); m_buffer.flush(); - m_channel_state.profiler->record_host(info->id, TensorInfo::DevValue); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(info->id, TensorInfo::DevValue); + } m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); return static_cast(info->ptr); }); - m_channel_state.profiler->record_host(info->id, TensorInfo::DevValue); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(info->id, TensorInfo::DevValue); + } m_waitee = nullptr; return info->ptr->dev_tensor(); } void ChannelImpl::sync() { m_buffer.flush(); - m_channel_state.profiler->record_host(); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(); + } m_worker.wait_all_task_finish(); CompNode::sync_all(); - m_channel_state.profiler->record_host(); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(); + } MGB_LOCK_GUARD(m_mutex); check_worker_exc_unsafe(); } @@ -369,13 +393,17 @@ TensorInfo* ChannelImpl::alloc() { auto info = m_pool.alloc(); m_valid_handle.insert(info); info->id = m_last_id++; - m_channel_state.profiler->record_host(info->id); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(info->id); + } return info; } void ChannelImpl::free(TensorInfo* ptr) { MGB_LOCK_GUARD(m_mutex); - m_channel_state.profiler->record_host(ptr->id); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(ptr->id); + } m_pool.free(ptr); } @@ -389,7 +417,9 @@ ChannelImpl::~ChannelImpl() { void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { MGB_LOCK_GUARD(m_mutex); - m_worker_state.profiler->record_host(dest->id, ptr->layout(), ptr->comp_node()); + if (m_worker_state.profiler->is_profiling()) { + m_worker_state.profiler->record_host(dest->id, ptr->layout(), ptr->comp_node()); + } dest->value_fetched = ptr->value_fetched(); // update tensor desc for static infer dest->desc.layout = ptr->layout(); @@ -471,13 +501,17 @@ void ChannelImpl::sync_device_scope(CompNode device) { } void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { - m_worker_state.profiler->record_host(icmd); + if (m_worker_state.profiler->is_profiling()) { + m_worker_state.profiler->record_host(icmd); + } bool finished = false; auto do_finish_command = [&]{ if (finished) { return; } - m_worker_state.profiler->record_host(icmd); + if (m_worker_state.profiler->is_profiling()) { + m_worker_state.profiler->record_host(icmd); + } finished = true; }; //TODO: remove std::visit for support osx 10.12 @@ -498,22 +532,25 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { tensor_inputs.push_back(i->ptr); } // Begin profiling operator - auto tinfo_to_tid = [&](SmallVector tinfo) { - SmallVector tid; - for (auto* ptinfo: tinfo) { - tid.push_back(ptinfo->id); + OpEvent event_data; + if (m_worker_state.profiler->is_profiling()) { + auto tinfo_to_tid = [&](SmallVector tinfo) { + SmallVector tid; + for (auto* ptinfo: tinfo) { + tid.push_back(ptinfo->id); + } + return tid; + }; + event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)}; + // Collecting devices + for (auto i : cmd.inputs) { + devices.push_back(i->desc.comp_node); } - return tid; - }; - OpEvent event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)}; - // Collecting devices - for (auto i : cmd.inputs) { - devices.push_back(i->desc.comp_node); - } - for (auto i : cmd.outputs) { - devices.push_back(i->desc.comp_node); + for (auto i : cmd.outputs) { + devices.push_back(i->desc.comp_node); + } + devices.erase(std::unique(devices.begin(), devices.end()), devices.end()); } - devices.erase(std::unique(devices.begin(), devices.end()), devices.end()); // Fused by command buffer. @see: CommandBuffer::fuse_del // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del. // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused. @@ -643,7 +680,7 @@ void ChannelImpl::CommandBuffer::enqueue(Command cmd) { if (std::get_if(&cmd) && fuse_del(std::get(cmd))) { return; } - mgb_log_debug("%s Enqueued", to_string(cmd).c_str()); + // mgb_log_debug("%s Enqueued", to_string(cmd).c_str()); m_commands.push_back(std::move(cmd)); auto flush_pos = flush_pos_for(m_commands.back()); flush(flush_pos); @@ -655,9 +692,11 @@ void ChannelImpl::CommandBuffer::flush() { void ChannelImpl::CommandBuffer::flush(Handle pos) { for (auto iter = m_commands.begin(); iter != pos; ++iter) { - mgb_log_debug("%s Flushed", to_string(*iter).c_str()); + // mgb_log_debug("%s Flushed", to_string(*iter).c_str()); IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)}; - m_owner->m_channel_state.profiler->record_host(icmd); + if (m_owner->m_channel_state.profiler->is_profiling()) { + m_owner->m_channel_state.profiler->record_host(icmd); + } m_owner->m_worker.add_task(std::move(icmd)); } m_commands.erase(m_commands.begin(), pos); @@ -705,7 +744,7 @@ bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) { if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) { return false; } - mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str()); + // mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str()); std::get(*apply_iter).dels.push_back(dest); return true; } @@ -771,16 +810,20 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) { } void ChannelImpl::push_scope(std::string name) { - m_channel_state.profiler->record_host(name); - m_channel_state.scopes.push_back(name); - m_buffer.enqueue(PushScope{name}); + if (m_channel_state.profiler->is_profiling()) { + m_channel_state.profiler->record_host(name); + m_channel_state.scopes.push_back(name); + m_buffer.enqueue(PushScope{name}); + } } void ChannelImpl::pop_scope(std::string name) { - mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch"); - m_channel_state.scopes.pop_back(); - m_channel_state.profiler->record_host(name); - m_buffer.enqueue(PopScope{name}); + if (m_channel_state.profiler->is_profiling()) { + mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch"); + m_channel_state.scopes.pop_back(); + m_channel_state.profiler->record_host(name); + m_buffer.enqueue(PopScope{name}); + } } void ChannelImpl::assert_in_channel() { diff --git a/imperative/src/include/megbrain/imperative/profiler.h b/imperative/src/include/megbrain/imperative/profiler.h index f3253a63b..b092f6044 100644 --- a/imperative/src/include/megbrain/imperative/profiler.h +++ b/imperative/src/include/megbrain/imperative/profiler.h @@ -163,7 +163,6 @@ public: } // unsafe bool is_profiling() { - MGB_LOCK_GUARD(m_lock); return m_status == Profiling; } void start(Mask mask) { @@ -188,7 +187,7 @@ public: protected: std::vector m_record_list; Mask m_event_mask; - Status m_status = NotStarted; + std::atomic m_status = NotStarted; }; -- GitLab