提交 221ec38a 编写于 作者: M Megvii Engine Team

feat(imperative): reduce profiler overhead

GitOrigin-RevId: dded9d9391a49883f1c1035c1aa10f8174867367
上级 1c01128f
......@@ -52,7 +52,9 @@ Handle ChannelImpl::put(const DeviceTensorND& data) {
info->desc.layout = data.layout();
info->desc.comp_node = data.comp_node();
info->ptr = Tensor::make(data);
m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorProduceEvent>(info->id, info->desc.layout, info->desc.comp_node);
}
return info;
}
......@@ -147,8 +149,9 @@ void ChannelImpl::dispatch_default_cpu(
return tid;
};
OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}};
m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<HostOpExecuteEvent>(event_data);
}
OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds);
......@@ -169,8 +172,9 @@ void ChannelImpl::dispatch_default_cpu(
}
event_data.outputs = tinfo_to_tid(output_infos);
m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<HostOpFinishEvent>(event_data);
}
}
void ChannelImpl::dispatch_kernel(
......@@ -267,13 +271,17 @@ HostTensorND ChannelImpl::get_value(Handle handle) {
m_waitee = info;
regenerate(info);
m_buffer.enqueue(GetValue{info});
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::HostValue);
}
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
tensor_ptr = info->ptr;
return value_fetched();
});
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::HostValue);
}
m_waitee = nullptr;
}
return tensor_ptr->get_value();
......@@ -290,12 +298,16 @@ TensorShape ChannelImpl::get_shape(Handle handle) {
mgb_assert(!m_waitee);
m_waitee = info;
m_buffer.flush();
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::Shape);
}
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return static_cast<bool>(info->ptr);
});
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::Shape);
}
m_waitee = nullptr;
TensorShape ret = info->ptr->layout();
mgb_assert(ret.ndim != 0);
......@@ -306,7 +318,9 @@ DType ChannelImpl::get_dtype(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::DType);
}
auto ret = info->desc.layout.dtype;
mgb_assert(ret.valid());
return ret;
......@@ -316,7 +330,9 @@ CompNode ChannelImpl::get_device(Handle handle) {
mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(),
"invalid handle: %p", handle);
auto info = reinterpret_cast<TensorInfo*>(handle);
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorGetPropEvent>(info->id, TensorInfo::Device);
}
auto ret = info->desc.comp_node;
mgb_assert(ret.valid());
return ret;
......@@ -331,22 +347,30 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) {
m_waitee = info;
regenerate(info);
m_buffer.flush();
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropEvent>(info->id, TensorInfo::DevValue);
}
m_cv.wait(lock, [&]() {
check_worker_exc_unsafe();
return static_cast<bool>(info->ptr);
});
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorWaitPropFinishEvent>(info->id, TensorInfo::DevValue);
}
m_waitee = nullptr;
return info->ptr->dev_tensor();
}
void ChannelImpl::sync() {
m_buffer.flush();
m_channel_state.profiler->record_host<SyncStartEvent>();
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<SyncStartEvent>();
}
m_worker.wait_all_task_finish();
CompNode::sync_all();
m_channel_state.profiler->record_host<SyncFinishEvent>();
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<SyncFinishEvent>();
}
MGB_LOCK_GUARD(m_mutex);
check_worker_exc_unsafe();
}
......@@ -369,13 +393,17 @@ TensorInfo* ChannelImpl::alloc() {
auto info = m_pool.alloc();
m_valid_handle.insert(info);
info->id = m_last_id++;
m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorDeclareEvent>(info->id);
}
return info;
}
void ChannelImpl::free(TensorInfo* ptr) {
MGB_LOCK_GUARD(m_mutex);
m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id);
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<TensorEraseEvent>(ptr->id);
}
m_pool.free(ptr);
}
......@@ -389,7 +417,9 @@ ChannelImpl::~ChannelImpl() {
void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
MGB_LOCK_GUARD(m_mutex);
m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node());
if (m_worker_state.profiler->is_profiling()) {
m_worker_state.profiler->record_host<TensorProduceEvent>(dest->id, ptr->layout(), ptr->comp_node());
}
dest->value_fetched = ptr->value_fetched();
// update tensor desc for static infer
dest->desc.layout = ptr->layout();
......@@ -471,13 +501,17 @@ void ChannelImpl::sync_device_scope(CompNode device) {
}
void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd);
if (m_worker_state.profiler->is_profiling()) {
m_worker_state.profiler->record_host<CommandExecuteEvent>(icmd);
}
bool finished = false;
auto do_finish_command = [&]{
if (finished) {
return;
}
m_worker_state.profiler->record_host<CommandFinishEvent>(icmd);
if (m_worker_state.profiler->is_profiling()) {
m_worker_state.profiler->record_host<CommandFinishEvent>(icmd);
}
finished = true;
};
//TODO: remove std::visit for support osx 10.12
......@@ -498,22 +532,25 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) {
tensor_inputs.push_back(i->ptr);
}
// Begin profiling operator
auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
SmallVector<uint64_t> tid;
for (auto* ptinfo: tinfo) {
tid.push_back(ptinfo->id);
OpEvent event_data;
if (m_worker_state.profiler->is_profiling()) {
auto tinfo_to_tid = [&](SmallVector<TensorInfo*> tinfo) {
SmallVector<uint64_t> tid;
for (auto* ptinfo: tinfo) {
tid.push_back(ptinfo->id);
}
return tid;
};
event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)};
// Collecting devices
for (auto i : cmd.inputs) {
devices.push_back(i->desc.comp_node);
}
return tid;
};
OpEvent event_data = {apply_id, cmd.op, tinfo_to_tid(cmd.inputs), tinfo_to_tid(cmd.outputs)};
// Collecting devices
for (auto i : cmd.inputs) {
devices.push_back(i->desc.comp_node);
}
for (auto i : cmd.outputs) {
devices.push_back(i->desc.comp_node);
for (auto i : cmd.outputs) {
devices.push_back(i->desc.comp_node);
}
devices.erase(std::unique(devices.begin(), devices.end()), devices.end());
}
devices.erase(std::unique(devices.begin(), devices.end()), devices.end());
// Fused by command buffer. @see: CommandBuffer::fuse_del
// Now if dest is inplacable, it's refcnt would be decreased to 1 and owned by tensor_inputs after Del.
// Note for exprs like 'y = x op x', inplace is unsupported yet but Del would be also fused.
......@@ -643,7 +680,7 @@ void ChannelImpl::CommandBuffer::enqueue(Command cmd) {
if (std::get_if<Del>(&cmd) && fuse_del(std::get<Del>(cmd))) {
return;
}
mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
// mgb_log_debug("%s Enqueued", to_string(cmd).c_str());
m_commands.push_back(std::move(cmd));
auto flush_pos = flush_pos_for(m_commands.back());
flush(flush_pos);
......@@ -655,9 +692,11 @@ void ChannelImpl::CommandBuffer::flush() {
void ChannelImpl::CommandBuffer::flush(Handle pos) {
for (auto iter = m_commands.begin(); iter != pos; ++iter) {
mgb_log_debug("%s Flushed", to_string(*iter).c_str());
// mgb_log_debug("%s Flushed", to_string(*iter).c_str());
IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)};
m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd);
if (m_owner->m_channel_state.profiler->is_profiling()) {
m_owner->m_channel_state.profiler->record_host<CommandEnqueueEvent>(icmd);
}
m_owner->m_worker.add_task(std::move(icmd));
}
m_commands.erase(m_commands.begin(), pos);
......@@ -705,7 +744,7 @@ bool ChannelImpl::CommandBuffer::fuse_del(const Del& cmd) {
if (apply_iter == end || find_last_usage(dest, {apply_iter+1, end}) != end) {
return false;
}
mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
// mgb_log_debug("%s Fused", to_string(Command{cmd}).c_str());
std::get<ApplyOp>(*apply_iter).dels.push_back(dest);
return true;
}
......@@ -771,16 +810,20 @@ void ChannelImpl::stop_profile(std::string basename, std::string format) {
}
void ChannelImpl::push_scope(std::string name) {
m_channel_state.profiler->record_host<ChannelBeginScope>(name);
m_channel_state.scopes.push_back(name);
m_buffer.enqueue(PushScope{name});
if (m_channel_state.profiler->is_profiling()) {
m_channel_state.profiler->record_host<ChannelBeginScope>(name);
m_channel_state.scopes.push_back(name);
m_buffer.enqueue(PushScope{name});
}
}
void ChannelImpl::pop_scope(std::string name) {
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch");
m_channel_state.scopes.pop_back();
m_channel_state.profiler->record_host<ChannelEndScope>(name);
m_buffer.enqueue(PopScope{name});
if (m_channel_state.profiler->is_profiling()) {
mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch");
m_channel_state.scopes.pop_back();
m_channel_state.profiler->record_host<ChannelEndScope>(name);
m_buffer.enqueue(PopScope{name});
}
}
void ChannelImpl::assert_in_channel() {
......
......@@ -163,7 +163,6 @@ public:
}
// unsafe
bool is_profiling() {
MGB_LOCK_GUARD(m_lock);
return m_status == Profiling;
}
void start(Mask mask) {
......@@ -188,7 +187,7 @@ public:
protected:
std::vector<Record> m_record_list;
Mask m_event_mask;
Status m_status = NotStarted;
std::atomic<Status> m_status = NotStarted;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册