diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 4430c2fcfdaf701030ffe68f0dafe4fa68e6d0bb..1331551652d6953a0db47b53e22e9860944ec74a 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -525,19 +525,23 @@ void ChannelImpl::process_one_task(IdentifiedCommand& icmd) { // Before wait //TODO: split operator wait and execute so that OpWait could be corrected recorded. // Before execute - m_worker_state.profiler->record_host(event_data); - for (auto&& device: devices) { - sync_device_scope(device); - m_worker_state.profiler->record_device(device, event_data); + if (m_worker_state.profiler->is_profiling()) { + m_worker_state.profiler->record_host(event_data); + for (auto&& device: devices) { + sync_device_scope(device); + m_worker_state.profiler->record_device(device, event_data); + } } // Apply op // Here std::move is REQUIRED for removing duplicated references. auto tensor_outputs = OpDef::apply_on_physical_tensor( *cmd.op, std::move(tensor_inputs)); // After execute - m_worker_state.profiler->record_host(event_data); - for (auto&& device: devices) { - m_worker_state.profiler->record_device(device, event_data); + if (m_worker_state.profiler->is_profiling()) { + m_worker_state.profiler->record_host(event_data); + for (auto&& device: devices) { + m_worker_state.profiler->record_device(device, event_data); + } } // End profiling operator mgb_assert(tensor_outputs.size() == cmd.outputs.size()); diff --git a/imperative/src/include/megbrain/imperative/profiler.h b/imperative/src/include/megbrain/imperative/profiler.h index 702befabf3763287dec10a56b5d2e52678983369..f3253a63bca80dae2aff4c3c82f9aea4cea97aed 100644 --- a/imperative/src/include/megbrain/imperative/profiler.h +++ b/imperative/src/include/megbrain/imperative/profiler.h @@ -140,27 +140,32 @@ public: public: template void record_host(TArgs&&... args) { - auto instant = HostInstant{std::this_thread::get_id(), m_host_timer.get_msecs()}; MGB_LOCK_GUARD(m_lock); if (!m_event_mask.test(index_of())) { return; } mgb_assert(m_status != Stopped, "record after stop"); + auto instant = HostInstant{std::this_thread::get_id(), m_host_timer.get_msecs()}; m_record_list.emplace_back(EventRecord{std::move(instant), {std::forward(args)...}}); } template void record_device(Device device, TArgs&&... args) { - auto before = m_host_timer.get_msecs(); - auto event = m_device_timer.get_device_time(device); - auto after = m_host_timer.get_msecs(); - auto instant = DeviceInstant{before, event, after}; MGB_LOCK_GUARD(m_lock); if (!m_event_mask.test(index_of())) { return; } mgb_assert(m_status != Stopped, "record after stop"); + auto before = m_host_timer.get_msecs(); + auto event = m_device_timer.get_device_time(device); + auto after = m_host_timer.get_msecs(); + auto instant = DeviceInstant{before, event, after}; m_record_list.emplace_back(EventRecord{std::move(instant), {std::forward(args)...}}); } + // unsafe + bool is_profiling() { + MGB_LOCK_GUARD(m_lock); + return m_status == Profiling; + } void start(Mask mask) { MGB_LOCK_GUARD(m_lock); mgb_assert(m_status == NotStarted, "profiler already started");