提交 221ec38a 编写于 作者: M Megvii Engine Team

feat(imperative): reduce profiler overhead

GitOrigin-RevId: dded9d9391a49883f1c1035c1aa10f8174867367
上级 1c01128f
...@@ -52,7 +52,9 @@ Handle ChannelImpl::put(const DeviceTensorND& data) { ...@@ -52,7 +52,9 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
info->desc.layout = data.layout(); info->desc.layout = data.layout();
info->desc.comp_node = data.comp_node(); info->desc.comp_node = data.comp_node();
info->ptr = Tensor::make(data); info->ptr = Tensor::make(data);
m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node);
}
return info; return info;
} }
...@@ -147,8 +149,9 @@ void ChannelImpl::dispatch_default_cpu( ...@@ -147,8 +149,9 @@ void ChannelImpl::dispatch_default_cpu(
return tid; return tid;
}; };
OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}}; OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}};
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data); m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data);
}
OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
...@@ -169,8 +172,9 @@ void ChannelImpl::dispatch_default_cpu( ...@@ -169,8 +172,9 @@ void ChannelImpl::dispatch_default_cpu(
} }
event_data.outputs = tinfo_to_tid(output_infos); event_data.outputs = tinfo_to_tid(output_infos);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data); m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data);
}
} }
void ChannelImpl::dispatch_kernel( void ChannelImpl::dispatch_kernel(
...@@ -267,13 +271,17 @@ HostTensorND ChannelImpl::get_value(Handle handle) { ...@@ -267,13 +271,17 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
m_waitee = info; m_waitee = info;
regenerate(info); regenerate(info);
m_buffer.enqueue(GetValue{info}); m_buffer.enqueue(GetValue{info});
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue);
}
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
tensor_ptr = info->ptr; tensor_ptr = info->ptr;
return value_fetched(); return value_fetched();
}); });
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue);
}
m_waitee = nullptr; m_waitee = nullptr;
} }
return tensor_ptr->get_value(); return tensor_ptr->get_value();
...@@ -290,12 +298,16 @@ TensorShape ChannelImpl::get_shape(Handle handle) { ...@@ -290,12 +298,16 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert(!m_waitee); mgb_assert(!m_waitee);
m_waitee = info; m_waitee = info;
m_buffer.flush(); m_buffer.flush();
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape);
}
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
return static_cast<bool>(info->ptr); return static_cast<bool>(info->ptr);
}); });
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape);
}
m_waitee = nullptr; m_waitee = nullptr;
TensorShape ret = info->ptr->layout(); TensorShape ret = info->ptr->layout();
mgb_assert(ret.ndim != 0); mgb_assert(ret.ndim != 0);
...@@ -306,7 +318,9 @@ DType ChannelImpl::get_dtype(Handle handle) { ...@@ -306,7 +318,9 @@ DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType);
}
auto ret = info->desc.layout.dtype; auto ret = info->desc.layout.dtype;
mgb_assert(ret.valid()); mgb_assert(ret.valid());
return ret; return ret;
...@@ -316,7 +330,9 @@ CompNode ChannelImpl::get_device(Handle handle) { ...@@ -316,7 +330,9 @@ CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle); "invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle); auto info = reinterpret_cast<TensorInfo*>(handle);
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device);
}
auto ret = info->desc.comp_node; auto ret = info->desc.comp_node;
mgb_assert(ret.valid()); mgb_assert(ret.valid());
return ret; return ret;
...@@ -331,22 +347,30 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { ...@@ -331,22 +347,30 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
m_waitee = info; m_waitee = info;
regenerate(info); regenerate(info);
m_buffer.flush(); m_buffer.flush();
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue);
}
m_cv.wait(lock, [&]() { m_cv.wait(lock, [&]() {
check_worker_exc_unsafe(); check_worker_exc_unsafe();
return static_cast<bool>(info->ptr); return static_cast<bool>(info->ptr);
}); });
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue);
}
m_waitee = nullptr; m_waitee = nullptr;
return info->ptr->dev_tensor(); return info->ptr->dev_tensor();
} }
void ChannelImpl::sync() { void ChannelImpl::sync() {
m_buffer.flush(); m_buffer.flush();
m_channel_state.profiler->record_host<SyncStartEvent>(); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<SyncStartEvent>();
}
m_worker.wait_all_task_finish(); m_worker.wait_all_task_finish();
CompNode::sync_all(); CompNode::sync_all();
m_channel_state.profiler->record_host<SyncFinishEvent>(); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<SyncFinishEvent>();
}
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
check_worker_exc_unsafe(); check_worker_exc_unsafe();
} }
...@@ -369,13 +393,17 @@ TensorInfo* ChannelImpl::alloc() { ...@@ -369,13 +393,17 @@ TensorInfo* ChannelImpl::alloc() {
auto info = m_pool.alloc(); auto info = m_pool.alloc();
m_valid_handle.insert(info); m_valid_handle.insert(info);
info->id = m_last_id++; info->id = m_last_id++;
m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id);
}
return info; return info;
} }
void ChannelImpl::free(TensorInfo* ptr) { void ChannelImpl::free(TensorInfo* ptr) {
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id);
}
m_pool.free(ptr); m_pool.free(ptr);
} }
...@@ -389,7 +417,9 @@ ChannelImpl::~ChannelImpl() { ...@@ -389,7 +417,9 @@ ChannelImpl::~ChannelImpl() {
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
MGB_LOCK_GUARD(m_mutex); MGB_LOCK_GUARD(m_mutex);
m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node()); if (m_worker_state.profiler->is_profiling()) {
m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node());
}
dest->value_fetched = ptr->value_fetched(); dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer // update tensor desc for static infer
dest->desc.layout = ptr->layout(); dest->desc.layout = ptr->layout();
...@@ -471,13 +501,17 @@ void ChannelImpl::sync_device_scope(CompNode device) { ...@@ -471,13 +501,17 @@ void ChannelImpl::sync_device_scope(CompNode device) {
} }
void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd); if (m_worker_state.profiler->is_profiling()) {
m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd);
}
bool finished = false; bool finished = false;
auto do_finish_command = [&]{ auto do_finish_command = [&]{
if (finished) { if (finished) {
return; return;
} }
m_worker_state.profiler->record_host<CommandFinishEvent>(icmd); if (m_worker_state.profiler->is_profiling()) {
m_worker_state.profiler->record_host<CommandFinishEvent>(icmd);
}
finished = true; finished = true;
}; };
//TODO: remove std::visit for support osx 10.12 //TODO: remove std::visit for support osx 10.12
...@@ -498,22 +532,25 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { ...@@ -498,22 +532,25 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
tensor_inputs.push_back(i->ptr); tensor_inputs.push_back(i->ptr);
} }
// Begin profiling operator // Begin profiling operator
auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) { OpEvent event_data;
SmallVector<uint64_t> tid; if (m_worker_state.profiler->is_profiling()) {
for (auto* ptinfo: tinfo) { auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
tid.push_back(ptinfo->id); SmallVector<uint64_t> tid;
for (auto* ptinfo: tinfo) {
tid.push_back(ptinfo->id);
}
return tid;
};
event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)};
// Collecting devices
for (auto i : cmd.inputs) {
devices.push_back(i->desc.comp_node);
} }
return tid; for (auto i : cmd.outputs) {
}; devices.push_back(i->desc.comp_node);
OpEvent event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)}; }
// Collecting devices devices.erase(std::unique(devices.begin(), devices.end()), devices.end());
for (auto i : cmd.inputs) {
devices.push_back(i->desc.comp_node);
}
for (auto i : cmd.outputs) {
devices.push_back(i->desc.comp_node);
} }
devices.erase(std::unique(devices.begin(), devices.end()), devices.end());
// Fused by command buffer. @see: CommandBuffer::fuse_del // Fused by command buffer. @see: CommandBuffer::fuse_del
// Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del. // Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
// Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused. // Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
...@@ -643,7 +680,7 @@ void ChannelImpl::CommandBuffer::enqueue(Command cmd) { ...@@ -643,7 +680,7 @@ void ChannelImpl::CommandBuffer::enqueue(Command cmd) {
if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) { if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
return; return;
} }
mgb_log_debug("%s Enqueued", to_string(cmd).c_str()); // mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
m_commands.push_back(std::move(cmd)); m_commands.push_back(std::move(cmd));
auto flush_pos = flush_pos_for(m_commands.back()); auto flush_pos = flush_pos_for(m_commands.back());
flush(flush_pos); flush(flush_pos);
...@@ -655,9 +692,11 @@ void ChannelImpl::CommandBuffer::flush() { ...@@ -655,9 +692,11 @@ void ChannelImpl::CommandBuffer::flush() {
void ChannelImpl::CommandBuffer::flush(Handle pos) { void ChannelImpl::CommandBuffer::flush(Handle pos) {
for (auto iter = m_commands.begin(); iter != pos; ++iter) { for (auto iter = m_commands.begin(); iter != pos; ++iter) {
mgb_log_debug("%s Flushed", to_string(*iter).c_str()); // mgb_log_debug("%s Flushed", to_string(*iter).c_str());
IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)}; IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)};
m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd); if (m_owner->m_channel_state.profiler->is_profiling()) {
m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd);
}
m_owner->m_worker.add_task(std::move(icmd)); m_owner->m_worker.add_task(std::move(icmd));
} }
m_commands.erase(m_commands.begin(), pos); m_commands.erase(m_commands.begin(), pos);
...@@ -705,7 +744,7 @@ bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) { ...@@ -705,7 +744,7 @@ bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) { if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) {
return false; return false;
} }
mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str()); // mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
std::get<ApplyOp>(*apply_iter).dels.push_back(dest); std::get<ApplyOp>(*apply_iter).dels.push_back(dest);
return true; return true;
} }
...@@ -771,16 +810,20 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) { ...@@ -771,16 +810,20 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) {
} }
void ChannelImpl::push_scope(std::string name) { void ChannelImpl::push_scope(std::string name) {
m_channel_state.profiler->record_host<ChannelBeginScope>(name); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.scopes.push_back(name); m_channel_state.profiler->record_host<ChannelBeginScope>(name);
m_buffer.enqueue(PushScope{name}); m_channel_state.scopes.push_back(name);
m_buffer.enqueue(PushScope{name});
}
} }
void ChannelImpl::pop_scope(std::string name) { void ChannelImpl::pop_scope(std::string name) {
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch"); if (m_channel_state.profiler->is_profiling()) {
m_channel_state.scopes.pop_back(); mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch");
m_channel_state.profiler->record_host<ChannelEndScope>(name); m_channel_state.scopes.pop_back();
m_buffer.enqueue(PopScope{name}); m_channel_state.profiler->record_host<ChannelEndScope>(name);
m_buffer.enqueue(PopScope{name});
}
} }
void ChannelImpl::assert_in_channel() { void ChannelImpl::assert_in_channel() {
......
...@@ -163,7 +163,6 @@ public: ...@@ -163,7 +163,6 @@ public:
} }
// unsafe // unsafe
bool is_profiling() { bool is_profiling() {
MGB_LOCK_GUARD(m_lock);
return m_status == Profiling; return m_status == Profiling;
} }
void start(Mask mask) { void start(Mask mask) {
...@@ -188,7 +187,7 @@ public: ...@@ -188,7 +187,7 @@ public:
protected: protected:
std::vector<Record> m_record_list; std::vector<Record> m_record_list;
Mask m_event_mask; Mask m_event_mask;
Status m_status = NotStarted; std::atomic<Status> m_status = NotStarted;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册