From 1ce40b5bf7b57ebaade6bf3ec34136c6263476b3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 14 May 2021 19:21:51 +0800 Subject: [PATCH] refactor(interpreter): wrap accesses to channel/worker state GitOrigin-RevId: 1d58f2c876024a0ea2a47d5226339bc0ed626c4b --- .../src/impl/interpreter/interpreter_impl.cpp | 249 +++++++++++------- .../src/impl/interpreter/interpreter_impl.h | 10 +- 2 files changed, 162 insertions(+), 97 deletions(-) diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 081d06916..a50632974 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -23,6 +23,23 @@ using namespace imperative; using namespace interpreter; using namespace interpreter::intl; +std::thread::id ChannelImpl::get_worker_tid() { + return m_worker_state.tid; +} + +ChannelImpl::ChannelState& ChannelImpl::get_channel_state() { + assert_in_channel(); + return m_channel_state; +} + +ChannelImpl::WorkerState& ChannelImpl::get_worker_state() { + assert_in_worker(); + return m_worker_state; +} + +#define m_channel_state +#define m_worker_state + std::unique_ptr InterpreterImpl::create_channel() { return std::make_unique(); } @@ -48,13 +65,14 @@ Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { } Handle ChannelImpl::put(const DeviceTensorND& data) { + auto& state = get_channel_state(); mgb_assert(check_available(), "Channel already closed"); auto info = alloc(); info->desc.layout = data.layout(); info->desc.comp_node = data.comp_node(); info->ptr = Tensor::make(data); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(info->id, info->desc.layout, info->desc.comp_node); + if (state.profiler->is_profiling()) { + state.profiler->record_host(info->id, info->desc.layout, info->desc.comp_node); } return info; } @@ -71,7 +89,8 @@ void ChannelImpl::del(Handle handle) { void ChannelImpl::swap_in(Handle handle) { mgb_assert(check_available(), "Channel already closed"); - if (m_worker_state.options.enable_swap) { + auto& state = get_channel_state(); + if (state.options.enable_swap) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto* info = reinterpret_cast(handle); @@ -81,7 +100,8 @@ void ChannelImpl::swap_in(Handle handle) { void ChannelImpl::swap_out(Handle handle) { mgb_assert(check_available(), "Channel already closed"); - if (m_worker_state.options.enable_swap) { + auto& state = get_channel_state(); + if (state.options.enable_swap) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto* info = reinterpret_cast(handle); @@ -91,7 +111,8 @@ void ChannelImpl::swap_out(Handle handle) { void ChannelImpl::drop(Handle handle) { mgb_assert(check_available(), "Channel already closed"); - if (m_worker_state.options.enable_drop) { + auto& state = get_channel_state(); + if (state.options.enable_drop) { mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto* info = reinterpret_cast(handle); @@ -104,6 +125,7 @@ void ChannelImpl::dispatch_default_cpu( const SmallVector& input_infos, const SmallVector& input_descs, SmallVector* outputs) { + auto& state = get_channel_state(); auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); MGB_MARK_USED_VAR(validated); @@ -147,8 +169,8 @@ void ChannelImpl::dispatch_default_cpu( return tid; }; OpEvent event_data = {++m_last_id, op, tinfo_to_tid(input_infos), {}}; - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(event_data); + if (state.profiler->is_profiling()) { + state.profiler->record_host(event_data); } OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); @@ -166,8 +188,8 @@ void ChannelImpl::dispatch_default_cpu( } event_data.outputs = tinfo_to_tid(output_infos); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(event_data); + if (state.profiler->is_profiling()) { + state.profiler->record_host(event_data); } } @@ -176,6 +198,7 @@ void ChannelImpl::dispatch_kernel( const SmallVector& input_infos, const SmallVector& input_descs, SmallVector* outputs) { + auto& state = get_channel_state(); auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); ApplyOp cmd{std::move(op)}; @@ -194,9 +217,9 @@ void ChannelImpl::dispatch_kernel( outputs->push_back(info); } m_buffer.enqueue(std::move(cmd)); - if (!validated && m_channel_state.options.async_level == 1) { + if (!validated && state.options.async_level == 1) { sync(); - } else if (m_channel_state.options.async_level == 0) { + } else if (state.options.async_level == 0) { sync(); // check device error for (auto&& oup : *outputs) { @@ -210,6 +233,7 @@ SmallVector ChannelImpl::apply_op( std::shared_ptr op, const SmallVector& inputs) { mgb_assert(check_available(), "Channel already closed"); + auto& state = get_channel_state(); for (auto i : inputs) { mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), "invalid handle: %p", i); @@ -229,7 +253,7 @@ SmallVector ChannelImpl::apply_op( } SmallVector outputs; - DispatchMode dispatch_mode = m_channel_state.options.enable_host_compute + DispatchMode dispatch_mode = state.options.enable_host_compute ? OpDef::decide_dispatch_mode(*op, input_descs) : DispatchMode::KERNEL; switch (dispatch_mode) { @@ -247,6 +271,7 @@ SmallVector ChannelImpl::apply_op( HostTensorND ChannelImpl::get_value(Handle handle) { mgb_assert(check_available(), "Channel already closed"); + auto& state = get_channel_state(); // TODO: maybe get_value should be done on host. i.e. delete GetValue mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); @@ -262,16 +287,16 @@ HostTensorND ChannelImpl::get_value(Handle handle) { if (!value_fetched()) { m_waitee = info; m_buffer.enqueue(GetValue{info}); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(info->id, TensorInfo::HostValue); + if (state.profiler->is_profiling()) { + state.profiler->record_host(info->id, TensorInfo::HostValue); } m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); tensor_ptr = info->ptr; return value_fetched(); }); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(info->id, TensorInfo::HostValue); + if (state.profiler->is_profiling()) { + state.profiler->record_host(info->id, TensorInfo::HostValue); } m_waitee = nullptr; } @@ -280,6 +305,7 @@ HostTensorND ChannelImpl::get_value(Handle handle) { TensorShape ChannelImpl::get_shape(Handle handle) { mgb_assert(check_available(), "Channel already closed"); + auto& state = get_channel_state(); mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); @@ -290,15 +316,15 @@ TensorShape ChannelImpl::get_shape(Handle handle) { mgb_assert(!m_waitee); m_waitee = info; m_buffer.flush(); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(info->id, TensorInfo::Shape); + if (state.profiler->is_profiling()) { + state.profiler->record_host(info->id, TensorInfo::Shape); } m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); return static_cast(info->ptr); }); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(info->id, TensorInfo::Shape); + if (state.profiler->is_profiling()) { + state.profiler->record_host(info->id, TensorInfo::Shape); } m_waitee = nullptr; TensorShape ret = info->ptr->layout(); @@ -308,11 +334,12 @@ TensorShape ChannelImpl::get_shape(Handle handle) { DType ChannelImpl::get_dtype(Handle handle) { mgb_assert(check_available(), "Channel already closed"); + auto& state = get_channel_state(); mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(info->id, TensorInfo::DType); + if (state.profiler->is_profiling()) { + state.profiler->record_host(info->id, TensorInfo::DType); } auto ret = info->desc.layout.dtype; mgb_assert(ret.valid()); @@ -321,11 +348,12 @@ DType ChannelImpl::get_dtype(Handle handle) { CompNode ChannelImpl::get_device(Handle handle) { mgb_assert(check_available(), "Channel already closed"); + auto& state = get_channel_state(); mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(info->id, TensorInfo::Device); + if (state.profiler->is_profiling()) { + state.profiler->record_host(info->id, TensorInfo::Device); } auto ret = info->desc.comp_node; mgb_assert(ret.valid()); @@ -334,6 +362,7 @@ CompNode ChannelImpl::get_device(Handle handle) { DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { mgb_assert(check_available(), "Channel already closed"); + auto& state = get_channel_state(); mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); @@ -341,15 +370,15 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { mgb_assert(!m_waitee); m_waitee = info; m_buffer.flush(); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(info->id, TensorInfo::DevValue); + if (state.profiler->is_profiling()) { + state.profiler->record_host(info->id, TensorInfo::DevValue); } m_cv.wait(lock, [&]() { check_worker_exc_unsafe(); return static_cast(info->ptr); }); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(info->id, TensorInfo::DevValue); + if (state.profiler->is_profiling()) { + state.profiler->record_host(info->id, TensorInfo::DevValue); } m_waitee = nullptr; return info->ptr->dev_tensor(); @@ -357,14 +386,15 @@ DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { void ChannelImpl::sync() { mgb_assert(check_available(), "Channel already closed"); + auto& state = get_channel_state(); m_buffer.flush(); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(); + if (state.profiler->is_profiling()) { + state.profiler->record_host(); } m_worker.wait_all_task_finish(); CompNode::sync_all(); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(); + if (state.profiler->is_profiling()) { + state.profiler->record_host(); } MGB_LOCK_GUARD(m_mutex); check_worker_exc_unsafe(); @@ -386,22 +416,25 @@ void ChannelImpl::close() { size_t ChannelImpl::get_option(std::string name) { mgb_assert(check_available(), "Channel already closed"); - return m_channel_state.options.get_option(name); + auto& state = get_channel_state(); + return state.options.get_option(name); } void ChannelImpl::set_option(std::string name, size_t value) { mgb_assert(check_available(), "Channel already closed"); - m_channel_state.options.set_option(name, value); + auto& state = get_channel_state(); + state.options.set_option(name, value); m_buffer.enqueue(SetOption{name, value}); } TensorInfo* ChannelImpl::alloc() { + auto& state = get_channel_state(); MGB_LOCK_GUARD(m_mutex); auto info = m_pool.alloc(); m_valid_handle.insert(info); info->id = m_last_id++; - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(info->id); + if (state.profiler->is_profiling()) { + state.profiler->record_host(info->id); } return info; } @@ -422,7 +455,8 @@ void ChannelImpl::do_drop(TensorInfo* ptr, bool user=false) { } void ChannelImpl::free(TensorInfo* ptr) { - if (m_worker_state.options.enable_dtr_auto_drop) { + auto& state = get_worker_state(); + if (state.options.enable_dtr_auto_drop) { // Evicting a tensor, rather than freeing it, can avoid pinning // potentially exploding amounts of memory and allow us to save // more memory. @@ -455,11 +489,12 @@ void ChannelImpl::recursive_free(TensorInfo* ptr) { } void ChannelImpl::real_free(TensorInfo* ptr) { + auto& state = get_worker_state(); MGB_LOCK_GUARD(m_mutex); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(ptr->id); + if (state.profiler->is_profiling()) { + state.profiler->record_host(ptr->id); } - if (ptr->size_exceeds_thd(m_worker_state.options.dtr_evictee_minimum_size)) { + if (ptr->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { m_dtr.erase_candidate(ptr); } detach_users(ptr); @@ -474,11 +509,14 @@ ChannelImpl::~ChannelImpl() { } void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=true) { - auto lock = notice ? std::unique_lock(m_mutex) - : std::unique_lock(); + auto& state = get_worker_state(); + auto lock = std::unique_lock(m_mutex, std::defer_lock); + if (notice) { + lock.lock(); + } m_dtr.update_used_time(dest); - if (notice && m_worker_state.profiler->is_profiling()) { - m_worker_state.profiler->record_host(dest->id, ptr->layout(), ptr->comp_node()); + if (notice && state.profiler->is_profiling()) { + state.profiler->record_host(dest->id, ptr->layout(), ptr->comp_node()); } dest->value_fetched = ptr->value_fetched(); // update tensor desc for static infer @@ -487,7 +525,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice=tr dest->memory = ptr->blob()->size(); dest->ptr = std::move(ptr); dest->evict_type = EvictType::NONE; - if (notice && dest->size_exceeds_thd(m_worker_state.options.dtr_evictee_minimum_size)) { + if (notice && dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { m_dtr.insert_candidate(dest); } if (notice && m_waitee == dest) { @@ -509,6 +547,7 @@ void ChannelImpl::regenerate(TensorInfo* dest) { } void ChannelImpl::recompute(TensorInfo::ComputePath* path) { + auto& state = get_worker_state(); SmallVector inputs; inputs.reserve(path->inputs.size()); m_dtr.pin(path->inputs); @@ -519,7 +558,7 @@ void ChannelImpl::recompute(TensorInfo::ComputePath* path) { inputs.push_back(i->ptr); m_dtr.update_used_time(i); } - if (m_worker_state.options.enable_dtr_auto_drop && m_worker_state.options.dtr_eviction_threshold > 0) { + if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) { auto_evict(); } auto outputs = OpDef::apply_on_physical_tensor(*path->op, inputs); @@ -531,7 +570,7 @@ void ChannelImpl::recompute(TensorInfo::ComputePath* path) { o->recompute_times ++; if (!o->ptr) { produce_tensor(o, std::move(outputs[i]), false); - if (m_worker_state.options.enable_dtr_auto_drop) { + if (state.options.enable_dtr_auto_drop) { m_dtr.update_dsu_after_recompute(o); } } @@ -540,11 +579,12 @@ void ChannelImpl::recompute(TensorInfo::ComputePath* path) { } void ChannelImpl::auto_evict() { + auto& state = get_worker_state(); if (!m_dtr.comp_node.valid()) { return; } size_t current_memory = m_dtr.comp_node.get_used_memory(); - while (current_memory > m_worker_state.options.dtr_eviction_threshold) { + while (current_memory > state.options.dtr_eviction_threshold) { auto best = m_dtr.find_best_tensor(); if (!best) { if (!m_dtr.warn_printed) { @@ -592,13 +632,14 @@ bool ChannelImpl::check_available() { } void ChannelImpl::sync_device_scope(CompNode device) { - auto& prev = m_worker_state.device_scope_map[device]; - auto& current = m_worker_state.scopes; + auto& state = get_worker_state(); + auto& prev = state.device_scope_map[device]; + auto& current = state.scopes; auto push_scope = [&](std::string name) { - m_worker_state.profiler->record_device(device, name); + state.profiler->record_device(device, name); }; auto pop_scope = [&](std::string name) { - m_worker_state.profiler->record_device(device, name); + state.profiler->record_device(device, name); }; size_t similarity = 0; for (size_t i = 0; i < prev.size() && i < current.size(); i++) { @@ -619,16 +660,17 @@ void ChannelImpl::sync_device_scope(CompNode device) { } void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { - if (m_worker_state.profiler->is_profiling()) { - m_worker_state.profiler->record_host(icmd); + auto& state = get_worker_state(); + if (state.profiler->is_profiling()) { + state.profiler->record_host(icmd); } bool finished = false; auto do_finish_command = [&]{ if (finished) { return; } - if (m_worker_state.profiler->is_profiling()) { - m_worker_state.profiler->record_host(icmd); + if (state.profiler->is_profiling()) { + state.profiler->record_host(icmd); } finished = true; }; @@ -642,7 +684,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { uint64_t apply_id = ++m_last_id; SmallVector tensor_inputs; SmallVector devices; - if (m_worker_state.options.enable_dtr_auto_drop) { + if (state.options.enable_dtr_auto_drop) { m_dtr.pin(cmd.inputs); } for (auto i : cmd.inputs) { @@ -660,7 +702,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { } // Begin profiling operator OpEvent event_data; - if (m_worker_state.profiler->is_profiling()) { + if (state.profiler->is_profiling()) { auto tinfo_to_tid = [&](SmallVector tinfo) { SmallVector tid; for (auto* ptinfo: tinfo) { @@ -689,14 +731,14 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { // Before wait //TODO: split operator wait and execute so that OpWait could be corrected recorded. // Before execute - if (m_worker_state.profiler->is_profiling()) { - m_worker_state.profiler->record_host(event_data); + if (state.profiler->is_profiling()) { + state.profiler->record_host(event_data); for (auto&& device: devices) { sync_device_scope(device); - m_worker_state.profiler->record_device(device, event_data); + state.profiler->record_device(device, event_data); } } - if (m_worker_state.options.enable_dtr_auto_drop && m_worker_state.options.dtr_eviction_threshold > 0) { + if (state.options.enable_dtr_auto_drop && state.options.dtr_eviction_threshold > 0) { auto_evict(); } // Apply op @@ -704,15 +746,15 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { auto tensor_outputs = OpDef::apply_on_physical_tensor( *cmd.op, std::move(tensor_inputs)); // After execute - if (m_worker_state.profiler->is_profiling()) { - m_worker_state.profiler->record_host(event_data); + if (state.profiler->is_profiling()) { + state.profiler->record_host(event_data); for (auto&& device: devices) { - m_worker_state.profiler->record_device(device, event_data); + state.profiler->record_device(device, event_data); } } // End profiling operator double estimate_compute_time = 0; - if (m_worker_state.options.enable_dtr_auto_drop) { + if (state.options.enable_dtr_auto_drop) { for (auto i : cmd.inputs) { estimate_compute_time += i->memory; } @@ -735,12 +777,12 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { continue; } produce_tensor(cmd.outputs[i], std::move(tensor_outputs[i])); - if (m_worker_state.options.enable_dtr_auto_drop) { + if (state.options.enable_dtr_auto_drop) { cmd.outputs[i]->dsu_ptr = std::make_shared(estimate_compute_time); } } - if (m_worker_state.options.enable_drop == 1 - && m_worker_state.options.record_computing_path == 1){ + if (state.options.enable_drop == 1 + && state.options.record_computing_path == 1){ bool is_inplace = false; bool cross_cn = false; for (auto input : cmd.inputs) { @@ -774,7 +816,7 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { TensorInfo::ComputePath::make(cmd.op, cmd.inputs, cmd.outputs); size_t detach_cnt = 0; for (auto output : cmd.outputs) { - if (!output->size_exceeds_thd(m_worker_state.options.dtr_evictee_minimum_size)) { + if (!output->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { output->detach_producer(); detach_cnt ++; } @@ -808,21 +850,22 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { } else if constexpr (std::is_same_v) { do_drop(cmd.dest, true); } else if constexpr (std::is_same_v) { - m_worker_state.options.set_option(cmd.key, cmd.value); + state.options.set_option(cmd.key, cmd.value); } else if constexpr (std::is_same_v) { CompNode::sync_all(); - m_worker_state.profiler.reset(cmd.profiler); + state.profiler.reset(cmd.profiler); } else if constexpr (std::is_same_v) { - for (auto&& [device, scopes]: m_worker_state.device_scope_map) { + for (auto&& [device, scopes]: state.device_scope_map) { MGB_MARK_USED_VAR(scopes); sync_device_scope(device); } do_finish_command(); auto profiler = std::make_unique(); - std::swap(profiler, m_worker_state.profiler); + std::swap(profiler, state.profiler); auto records = profiler->stop(); - auto host_map = [this](std::thread::id tid) { - if (tid == m_worker_state.tid) { + auto worker_tid = get_worker_tid(); + auto host_map = [worker_tid](std::thread::id tid) { + if (tid == worker_tid) { return "worker"; } else { return "unknown"; @@ -830,21 +873,21 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { }; InterpreterProfiler::dump_data(cmd.basename, cmd.format, records, profiler->get_option(), host_map); } else if constexpr (std::is_same_v) { - m_worker_state.scopes.push_back(cmd.scope_name); + state.scopes.push_back(cmd.scope_name); do_finish_command(); - m_worker_state.profiler->record_host(cmd.scope_name); + state.profiler->record_host(cmd.scope_name); } else if constexpr (std::is_same_v) { - mgb_assert(m_worker_state.scopes.back() == cmd.scope_name, "scope name mismatch"); - m_worker_state.scopes.pop_back(); + mgb_assert(state.scopes.back() == cmd.scope_name, "scope name mismatch"); + state.scopes.pop_back(); do_finish_command(); - m_worker_state.profiler->record_host(cmd.scope_name); + state.profiler->record_host(cmd.scope_name); } else { static_assert(!std::is_same_v); } }; std::visit([&](const auto& cmd){ using T = std::decay_t; - if (!m_worker_state.options.catch_worker_execption) { + if (!state.options.catch_worker_execption) { cmd_visitor(cmd); return; } @@ -891,11 +934,12 @@ void ChannelImpl::CommandBuffer::flush() { } void ChannelImpl::CommandBuffer::flush(Handle pos) { + auto& state = m_owner->get_channel_state(); for (auto iter = m_commands.begin(); iter != pos; ++iter) { // mgb_log_debug("%s Flushed", to_string(*iter).c_str()); IdentifiedCommand icmd{++m_owner->m_last_id, std::move(*iter)}; - if (m_owner->m_channel_state.profiler->is_profiling()) { - m_owner->m_channel_state.profiler->record_host(icmd); + if (state.profiler->is_profiling()) { + state.profiler->record_host(icmd); } m_owner->m_worker.add_task(std::move(icmd)); } @@ -903,7 +947,8 @@ void ChannelImpl::CommandBuffer::flush(Handle pos) { } auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { - return std::visit([this](const auto& cmd) { + auto& state = m_owner->get_channel_state(); + return std::visit([&, this](const auto& cmd) { using T = std::decay_t; if constexpr (std::is_same_v) { auto* op_type = cmd.op->dyn_typeinfo(); @@ -917,7 +962,7 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { } else if constexpr (std::is_same_v) { return m_commands.end(); } - size_t buffer_length = m_owner->m_channel_state.options.buffer_length; + size_t buffer_length = state.options.buffer_length; if (m_commands.size() > buffer_length) { return m_commands.begin() + (m_commands.size() - buffer_length); } @@ -993,42 +1038,54 @@ auto ChannelImpl::CommandBuffer::find_produce(TensorInfo* dest, Range range) void ChannelImpl::start_profile(std::unordered_map option) { mgb_assert(check_available(), "Channel already closed"); + auto& state = get_channel_state(); auto profiler_option = InterpreterProfiler::Option::from_dict(option); auto profiler = std::make_unique(); profiler->set_option(profiler_option); profiler->start(InterpreterProfiler::topic_to_mask(profiler_option.topic)); - std::swap(profiler, m_channel_state.profiler); - m_buffer.enqueue(StartProfile{m_channel_state.profiler.get()}); + std::swap(profiler, state.profiler); + m_buffer.enqueue(StartProfile{state.profiler.get()}); } void ChannelImpl::stop_profile(std::string basename, std::string format) { mgb_assert(check_available(), "Channel already closed"); + auto& state = get_channel_state(); m_buffer.flush(); auto profiler = std::make_unique(); - std::swap(profiler, m_channel_state.profiler); + std::swap(profiler, state.profiler); profiler.release(); m_buffer.enqueue(StopProfile{basename, format}); } void ChannelImpl::push_scope(std::string name) { mgb_assert(check_available(), "Channel already closed"); - if (m_channel_state.profiler->is_profiling()) { - m_channel_state.profiler->record_host(name); - m_channel_state.scopes.push_back(name); + auto& state = get_channel_state(); + if (state.profiler->is_profiling()) { + state.profiler->record_host(name); + state.scopes.push_back(name); m_buffer.enqueue(PushScope{name}); } } void ChannelImpl::pop_scope(std::string name) { mgb_assert(check_available(), "Channel already closed"); - if (m_channel_state.profiler->is_profiling()) { - mgb_assert((!m_channel_state.scopes.empty()) && m_channel_state.scopes.back() == name, "scope name mismatch"); - m_channel_state.scopes.pop_back(); - m_channel_state.profiler->record_host(name); + auto& state = get_channel_state(); + if (state.profiler->is_profiling()) { + mgb_assert((!state.scopes.empty()) && state.scopes.back() == name, "scope name mismatch"); + state.scopes.pop_back(); + state.profiler->record_host(name); m_buffer.enqueue(PopScope{name}); } } +void ChannelImpl::assert_in_channel() { + mgb_assert(get_worker_tid() != std::this_thread::get_id(), "this method cannot be called in worker thread"); +} + +void ChannelImpl::assert_in_worker() { + mgb_assert(get_worker_tid() == std::this_thread::get_id(), "this method can only be called in worker thread"); +} + void ChannelImpl::DynamicSublinear::pin(const SmallVector& vec) { for (auto i : vec) { i->pin(); diff --git a/imperative/src/impl/interpreter/interpreter_impl.h b/imperative/src/impl/interpreter/interpreter_impl.h index d10cc5c40..2f050f11f 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.h +++ b/imperative/src/impl/interpreter/interpreter_impl.h @@ -90,7 +90,6 @@ private: void regenerate(TensorInfo* dest); void recompute(TensorInfo::ComputePath* path); - void dispatch_default_cpu( std::shared_ptr op, @@ -105,6 +104,10 @@ private: bool check_available(); + void assert_in_channel(); + void assert_in_worker(); + std::thread::id get_worker_tid(); + void sync_device_scope(CompNode device); template @@ -319,6 +322,11 @@ private: //! automatically evict an optimal tensor void auto_evict(); + + // assert thread id when call get_xxx_state to avoid misuse + ChannelState& get_channel_state(); + WorkerState& get_worker_state(); + }; } // namespace mgb::imperative::interpreter::intl -- GitLab