提交 1d64792b 编写于 作者: M Megvii Engine Team 提交者: huangxinda

refactor(profiler): detach profiler from interpreter

GitOrigin-RevId: f3954728d1dd8e93e2eb5a94ee5f3a030a54fb5a
上级 f2027b8d
...@@ -7,9 +7,14 @@ ...@@ -7,9 +7,14 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import json import json
from contextlib import contextmanager import os
import re
from contextlib import ContextDecorator, contextmanager
from functools import wraps
from typing import List from typing import List
from weakref import WeakSet
from .. import _atexit
from ..core._imperative_rt.core2 import ( from ..core._imperative_rt.core2 import (
pop_scope, pop_scope,
push_scope, push_scope,
...@@ -17,9 +22,13 @@ from ..core._imperative_rt.core2 import ( ...@@ -17,9 +22,13 @@ from ..core._imperative_rt.core2 import (
stop_profile, stop_profile,
sync, sync,
) )
from ..logger import get_logger
_running_profiler = None
_living_profilers = WeakSet()
class Profiler:
class Profiler(ContextDecorator):
r""" r"""
Profile graph execution in imperative mode. Profile graph execution in imperative mode.
...@@ -35,9 +44,10 @@ class Profiler: ...@@ -35,9 +44,10 @@ class Profiler:
from megengine.utils.profiler import Profiler from megengine.utils.profiler import Profiler
# With Learnable Parameters # With Learnable Parameters
profiler = Profiler()
for iter in range(0, 10): for iter in range(0, 10):
# Only profile record of last iter would be saved # Only profile record of last iter would be saved
with Profiler("profile"): with profiler:
# your code here # your code here
# Then open the profile file in chrome timeline window # Then open the profile file in chrome timeline window
...@@ -45,46 +55,105 @@ class Profiler: ...@@ -45,46 +55,105 @@ class Profiler:
CHROME_TIMELINE = "chrome_timeline.json" CHROME_TIMELINE = "chrome_timeline.json"
COMMAND = 1 << 0 valid_options = {"sample_rate": 0, "profile_device": 1, "num_tensor_watch": 10}
OPERATOR = 1 << 1 valid_formats = {"chrome_timeline.json", "memory_flow.svg"}
TENSOR_LIFETIME = 1 << 2
TENSOR_PROP = 1 << 3
SYNC = 1 << 4
SCOPE = 1 << 5
ALL = (1 << 6) - 1
def __init__( def __init__(
self, self,
path: str = "profile", path: str = "profile",
format: str = CHROME_TIMELINE, format: str = "chrome_timeline.json",
*, formats: List[str] = None,
topic=OPERATOR | SCOPE, **kwargs
align_time=True,
show_operator_name=True
) -> None: ) -> None:
self._path = path if not formats:
self._format = format formats = [format]
self._options = {
"topic": int(topic),
"align_time": int(align_time),
"show_operator_name": int(show_operator_name),
}
def __enter__(self): assert not isinstance(formats, str), "formats excepts list, got str"
for format in formats:
assert format in Profiler.valid_formats, "unsupported format {}".format(
format
)
self._path = path
self._formats = formats
self._options = {}
for opt, optval in Profiler.valid_options.items():
self._options[opt] = int(kwargs.pop(opt, optval))
self._pid = "<PID>"
@property
def path(self):
if len(self._formats) == 0:
format = "<FORMAT>"
elif len(self._formats) == 1:
format = self._formats[0]
else:
format = "{" + ",".join(self._formats) + "}"
return self.format_path(self._path, self._pid, format)
@property
def directory(self):
return self._path
@property
def formats(self):
return list(self._formats)
def start(self):
global _running_profiler
assert _running_profiler is None
_running_profiler = self
self._pid = os.getpid()
start_profile(self._options) start_profile(self._options)
return self return self
def __exit__(self, val, tp, trace): def stop(self):
stop_profile(self._path, self._format) global _running_profiler
# dump is async, so it's necessary to sync interpreter
assert _running_profiler is self
_running_profiler = None
sync() sync()
self._dump_callback = stop_profile()
self._pid = os.getpid()
_living_profilers.add(self)
def dump(self):
if self._dump_callback is not None:
if not os.path.exists(self._path):
os.makedirs(self._path)
if not os.path.isdir(self._path):
get_logger().warning(
"{} is not a directory, cannot write profiling results".format(
self._path
)
)
return
for format in self._formats:
path = self.format_path(self._path, self._pid, format)
get_logger().info("process {} generating {}".format(self._pid, format))
self._dump_callback(path, format)
get_logger().info("profiling results written to {}".format(path))
self._dump_callback = None
_living_profilers.remove(self)
def format_path(self, path, pid, format):
return os.path.join(path, "{}.{}".format(pid, format))
def __enter__(self):
self.start()
def __exit__(self, val, tp, trace):
self.stop()
def __call__(self, func): def __call__(self, func):
def wrapper(*args, **kwargs): func = super().__call__(func)
with self: func.__profiler__ = self
return func(*args, **kwargs) return func
return wrapper def __del__(self):
self.dump()
@contextmanager @contextmanager
...@@ -94,16 +163,77 @@ def scope(name): ...@@ -94,16 +163,77 @@ def scope(name):
pop_scope(name) pop_scope(name)
profile = Profiler def profile(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return Profiler()(args[0])
return Profiler(*args, **kwargs)
def merge_trace_events(directory: str):
names = filter(
lambda x: re.match(r"\d+\.chrome_timeline\.json", x), os.listdir(directory)
)
def load_trace_events(name):
with open(os.path.join(directory, name), "r", encoding="utf-8") as f:
return json.load(f)
def find_metadata(content):
if isinstance(content, dict):
assert "traceEvents" in content
content = content["traceEvents"]
if len(content) == 0:
return None
assert content[0]["name"] == "Metadata"
return content[0]["args"]
contents = list(map(load_trace_events, names))
metadata_list = list(map(find_metadata, contents))
min_local_time = min(
map(lambda x: x["localTime"], filter(lambda x: x is not None, metadata_list))
)
events = []
for content, metadata in zip(contents, metadata_list):
local_events = content["traceEvents"]
if len(local_events) == 0:
continue
local_time = metadata["localTime"]
time_shift = local_time - min_local_time
for event in local_events:
if "ts" in event:
event["ts"] = int(event["ts"] + time_shift)
events.extend(filter(lambda x: x["name"] != "Metadata", local_events))
result = {
"traceEvents": events,
}
path = os.path.join(directory, "merge.chrome_timeline.json")
with open(path, "w") as f:
json.dump(result, f, ensure_ascii=False, separators=(",", ":"))
get_logger().info("profiling results written to {}".format(path))
def is_profiling():
return _running_profiler is not None
def _stop_current_profiler():
global _running_profiler
if _running_profiler is not None:
_running_profiler.stop()
living_profilers = [*_living_profilers]
for profiler in living_profilers:
profiler.dump()
def merge_trace_events(sources: List[str], target: str): _atexit(_stop_current_profiler)
names = list(map(lambda x: x + ".chrome_timeline.json", sources))
result = []
for name in names:
with open(name, "r", encoding="utf-8") as f:
content = json.load(f)
for entry in content:
result.append(entry)
with open(target + ".chrome_timeline.json", "w") as f:
json.dump(result, f, ensure_ascii=False, indent=4)
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/profiler.h"
#include "megbrain/opr/io.h" #include "megbrain/opr/io.h"
#include "./tensor.h" #include "./tensor.h"
...@@ -927,9 +928,23 @@ void init_tensor(py::module m) { ...@@ -927,9 +928,23 @@ void init_tensor(py::module m) {
m.def("pop_scope", m.def("pop_scope",
[](std::string name) { interpreter_for_py->pop_scope(name); }); [](std::string name) { interpreter_for_py->pop_scope(name); });
m.def("start_profile", m.def("start_profile",
[](std::unordered_map<std::string, int> option) { return interpreter_for_py->start_profile(option); }); [](imperative::Profiler::options_t options) {
interpreter_for_py->sync();
imperative::Profiler::load_options(std::move(options));
imperative::Profiler::start_profile();
interpreter_for_py->start_profile();
});
m.def("stop_profile", m.def("stop_profile",
[](std::string basename, std::string format) { interpreter_for_py->stop_profile(basename, format); }); []() -> std::function<void(std::string, std::string)> {
interpreter_for_py->stop_profile();
interpreter_for_py->sync();
imperative::Profiler::stop_profile();
auto results = imperative::Profiler::collect();
auto options = imperative::Profiler::get_options();
return [results=std::move(results), options=std::move(options)](std::string basename, std::string format){
imperative::Profiler::dump_profile(basename, format, results, options);
};
});
m.def("sync", m.def("sync",
[]() { []() {
interpreter_for_py->sync(); interpreter_for_py->sync();
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied
import json import json
import os import os
import tempfile
import pytest import pytest
...@@ -28,15 +29,18 @@ class Simple(Module): ...@@ -28,15 +29,18 @@ class Simple(Module):
def test_profiler(): def test_profiler():
profile_prefix = "pytest_profile" tempdir = tempfile.NamedTemporaryFile()
profile_prefix = tempdir.name
profile_format = "chrome_timeline.json" profile_format = "chrome_timeline.json"
profile_path = "{}.{}".format(profile_prefix, profile_format) profile_path = os.path.join(
profile_prefix, "{}.{}".format(os.getpid(), profile_format)
)
with option("enable_host_compute", 0):
with Profiler(profile_prefix, format=profile_format): with Profiler(profile_prefix, format=profile_format):
with scope("my_scope"): with scope("my_scope"):
oup = Simple()(tensor([1.23], dtype="float32")) oup = Simple()(tensor([1.23], dtype="float32"))
with open(profile_path, "r") as f: with open(profile_path, "r") as f:
events = json.load(f) events = json.load(f)
os.remove(profile_path)
prev_ts = {} prev_ts = {}
scope_count = 0 scope_count = 0
for event in events: for event in events:
......
...@@ -13,11 +13,14 @@ ...@@ -13,11 +13,14 @@
#include <string> #include <string>
#include <variant> #include <variant>
#include <unordered_set>
#include "megbrain/tensor.h" #include "megbrain/tensor.h"
#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/utils/to_string.h" #include "megbrain/imperative/utils/to_string.h"
#include "./tensor_info.h"
namespace mgb::imperative { namespace mgb::imperative {
namespace interpreter::intl { namespace interpreter::intl {
...@@ -43,7 +46,7 @@ struct Put { ...@@ -43,7 +46,7 @@ struct Put {
}; };
struct ApplyOp { struct ApplyOp {
uint64_t id; uint64_t id; //used by profiler to identify unique apply
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
SmallVector<TensorInfo*> inputs; SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> outputs; SmallVector<TensorInfo*> outputs;
...@@ -143,7 +146,7 @@ struct SetOption { ...@@ -143,7 +146,7 @@ struct SetOption {
}; };
struct StartProfile { struct StartProfile {
InterpreterProfiler* profiler; std::unordered_set<TensorInfo*> capture_tensors;
template <typename TFunctor> template <typename TFunctor>
void get_props(TFunctor&& functor) const {} void get_props(TFunctor&& functor) const {}
...@@ -154,14 +157,10 @@ struct StartProfile { ...@@ -154,14 +157,10 @@ struct StartProfile {
}; };
struct StopProfile { struct StopProfile {
std::string basename; std::unordered_set<TensorInfo*> escape_tensors;
std::string format;
template <typename TFunctor> template <typename TFunctor>
void get_props(TFunctor&& functor) const { void get_props(TFunctor&& functor) const {}
functor("basename", basename);
functor("format", format);
}
const char* get_name() const { const char* get_name() const {
return "StopProfile"; return "StopProfile";
......
...@@ -24,10 +24,10 @@ ...@@ -24,10 +24,10 @@
#include "megbrain/imperative/profiler.h" #include "megbrain/imperative/profiler.h"
#include "./commands.h" #include "./commands.h"
#include "./events.h"
#include "./tensor_info.h" #include "./tensor_info.h"
#include "./option_manager.h" #include "./option_manager.h"
#include "./profiler.h"
#include "../profiler/events.h"
namespace mgb::imperative::interpreter::intl { namespace mgb::imperative::interpreter::intl {
...@@ -37,7 +37,6 @@ struct InterpreterImpl : Interpreter { ...@@ -37,7 +37,6 @@ struct InterpreterImpl : Interpreter {
std::unique_ptr<Channel> create_channel() override; std::unique_ptr<Channel> create_channel() override;
}; };
struct ChannelImpl : Interpreter::Channel { struct ChannelImpl : Interpreter::Channel {
ChannelImpl(); ChannelImpl();
~ChannelImpl() override; ~ChannelImpl() override;
...@@ -67,19 +66,27 @@ struct ChannelImpl : Interpreter::Channel { ...@@ -67,19 +66,27 @@ struct ChannelImpl : Interpreter::Channel {
size_t get_option(std::string name) override; size_t get_option(std::string name) override;
void set_option(std::string name, size_t value) override; void set_option(std::string name, size_t value) override;
void start_profile(std::unordered_map<std::string, int> option) override; void start_profile() override;
void stop_profile(std::string basename, std::string format) override; void stop_profile() override;
void push_scope(std::string) override; void push_scope(std::string) override;
void pop_scope(std::string) override; void pop_scope(std::string) override;
private: private:
struct WorkQueue;
struct State;
TensorInfo* alloc(); TensorInfo* alloc();
void init(TensorInfo*, LogicalTensorDesc desc);
void free(TensorInfo*); void free(TensorInfo*);
void real_free(TensorInfo*); void real_free(TensorInfo*);
void recursive_free(TensorInfo*); void recursive_free(TensorInfo*);
void do_drop(TensorInfo*, bool); void do_drop(TensorInfo*, bool);
void detach_users(TensorInfo*); void detach_users(TensorInfo*);
TensorInfo* put_impl(const HostTensorND& value, bool no_cache);
TensorPtr wait_tensor(TensorInfo* info, profiler::TensorProp prop);
void notify_tensor_unsafe(TensorInfo* info);
void process_one_task(IdentifiedCommand&); void process_one_task(IdentifiedCommand&);
void check_worker_exc_unsafe(); void check_worker_exc_unsafe();
...@@ -105,24 +112,31 @@ private: ...@@ -105,24 +112,31 @@ private:
bool check_available(); bool check_available();
void push_scope(std::string, State&);
void pop_scope(std::string, State&);
void assert_in_channel(); void assert_in_channel();
void assert_in_worker(); void assert_in_worker();
std::thread::id get_worker_tid(); std::thread::id get_worker_tid();
void sync_device_scope(CompNode device);
template <typename TCommand> template <typename TCommand>
void enqueue_command(TCommand&& cmd) { void enqueue_command(TCommand&& cmd) {
m_buffer.enqueue(Command{std::forward<TCommand>(cmd)}); m_buffer.enqueue(Command{std::forward<TCommand>(cmd)});
} }
void sample_on_device(CompNode device, bool force);
// valid => status != Deleted
std::unordered_set<TensorInfo*> collect_valid_tensors();
std::mutex m_mutex; std::mutex m_mutex;
std::condition_variable m_cv; std::condition_variable m_cv;
MemPool<TensorInfo> m_pool; MemPool<TensorInfo> m_pool;
std::unordered_set<Handle> m_valid_handle; std::unordered_set<Handle> m_valid_handle;
TensorInfo* m_waitee = nullptr; TensorInfo* m_waitee = nullptr;
uint64_t m_waitee_id = 0;
std::exception_ptr m_worker_exc; std::exception_ptr m_worker_exc;
std::atomic_uint64_t m_last_id = 0; std::function<void(std::string, std::string)> m_profile_dump_callback;
bool m_closed = false; bool m_closed = false;
...@@ -191,27 +205,98 @@ private: ...@@ -191,27 +205,98 @@ private:
//! level 0: both sync. //! level 0: both sync.
int m_async_level = 2; int m_async_level = 2;
struct State { struct Scope {
OptionManager options; std::string name;
std::vector<std::string> scopes; std::unordered_map<std::string, std::unique_ptr<Scope>> children;
std::unique_ptr<InterpreterProfiler> profiler; size_t version = 0;
size_t parent_version = 0;
size_t tensor_count = 0;
Scope* active_child = nullptr;
Scope* parent = nullptr;
Scope* enter(std::string name) {
auto& child = children[name];
if (!child) {
child = std::make_unique<Scope>();
child->name = name;
child->parent = this;
}
if (version != child->parent_version) {
child->version = 0;
child->parent_version = version;
} else {
child->version++;
}
child->tensor_count = 0;
return active_child = child.get();
}
State() { Scope* exit(std::string name) {
profiler = std::make_unique<InterpreterProfiler>(); mgb_assert(this->name == name, "scope name mismatch");
parent->active_child = nullptr;
return parent;
} }
}; };
struct ChannelState: State {}; class ScopeManager {
private:
Scope m_root;
Scope* m_current_scope = &m_root;
public:
class ScopeGuard{
private:
ScopeManager* m_manager;
std::string m_name;
public:
ScopeGuard(ScopeManager* manager, std::string name): m_manager{manager}, m_name{name} {
m_manager->push(m_name);
}
~ScopeGuard() {
m_manager->pop(m_name);
}
};
void push(std::string name) {
m_current_scope = m_current_scope->enter(name);
}
void pop(std::string name) {
m_current_scope = m_current_scope->exit(name);
}
std::string next_tensor_name() {
std::string builder;
Scope* scope = &m_root;
while (true) {
builder.append(scope->name);
if (scope->version != 0) {
builder.append(ssprintf("(%ld)", scope->version));
}
if (scope != &m_root) {
builder.append(".");
}
if (scope->active_child == nullptr) {
builder.append(ssprintf(":%%%ld", scope->tensor_count++));
break;
} else {
scope = scope->active_child;
}
}
return builder;
}
};
struct WorkerState: State { struct State {
std::thread::id tid; std::thread::id tid;
CompNode::UnorderedMap<std::vector<std::string>> device_scope_map; OptionManager options;
};
struct ChannelState: State {
ScopeManager scopes;
}; };
struct WorkerState: State {};
ChannelState m_channel_state; ChannelState m_channel_state;
WorkerState m_worker_state; WorkerState m_worker_state;
/*! /*!
* \brief A framework of dynamic sublienar memory optimization * \brief A framework of dynamic sublienar memory optimization
* *
...@@ -327,7 +412,6 @@ private: ...@@ -327,7 +412,6 @@ private:
// assert thread id when call get_xxx_state to avoid misuse // assert thread id when call get_xxx_state to avoid misuse
ChannelState& get_channel_state(); ChannelState& get_channel_state();
WorkerState& get_worker_state(); WorkerState& get_worker_state();
}; };
} // namespace mgb::imperative::interpreter::intl } // namespace mgb::imperative::interpreter::intl
/**
* \file imperative/src/impl/interpreter/profiler.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/profiler.h"
#include "./commands.h"
#include "./events.h"
#include "./option_manager.h"
namespace mgb::imperative::interpreter::intl {
class InterpreterProfiler: public Profiler<
CommandEnqueueEvent, CommandExecuteEvent, CommandFinishEvent,
OpExecuteEvent, OpExecuteFinishEvent,
KernelExecuteEvent, KernelExecuteFinishEvent,
TensorDeclareEvent, TensorProduceEvent, TensorEraseEvent,
TensorGetPropEvent, TensorWaitPropEvent, TensorNotifyPropEvent, TensorWaitPropFinishEvent,
SyncEvent, SyncFinishEvent,
ScopeEvent, ScopeFinishEvent,
DeviceScopeEvent, DeviceScopeFinishEvent> {
public:
enum Topic {
Command = 0b000001,
Operator = 0b000010,
TensorLifetime = 0b000100,
TensorProp = 0b001000,
Sync = 0b010000,
Scope = 0b100000,
};
struct Option {
Topic topic;
bool align_time;
bool show_operator_name;
static Option from_dict(std::unordered_map<std::string, int> dict) {
Option option;
option.topic = Topic(dict.at("topic"));
option.align_time = bool(dict.at("align_time"));
option.show_operator_name = bool(dict.at("show_operator_name"));
return option;
}
};
Option get_option() const {
return m_option;
}
void set_option(const Option& option) {
m_option = option;
}
static Mask topic_to_mask(Topic topic) {
Mask result;
if (topic & Command) {
result |= mask_of<CommandEnqueueEvent, CommandExecuteEvent, CommandFinishEvent>();
}
if (topic & Operator) {
result |= mask_of<OpExecuteEvent, OpExecuteFinishEvent>();
result |= mask_of<KernelExecuteEvent, KernelExecuteFinishEvent>();
}
if (topic & TensorLifetime) {
result |= mask_of<TensorDeclareEvent, TensorProduceEvent, TensorEraseEvent>();
}
if (topic & TensorProp) {
result |= mask_of<TensorGetPropEvent, TensorWaitPropEvent, TensorNotifyPropEvent, TensorWaitPropFinishEvent>();
}
if (topic & Sync) {
result |= mask_of<SyncEvent, SyncFinishEvent>();
}
if (topic & Scope) {
result |= mask_of<ScopeEvent, ScopeFinishEvent>();
result |= mask_of<DeviceScopeEvent, DeviceScopeFinishEvent>();
}
return result;
}
private:
Option m_option;
};
}
...@@ -47,11 +47,15 @@ struct TensorInfo; ...@@ -47,11 +47,15 @@ struct TensorInfo;
using TensorInfoPtr = std::shared_ptr<TensorInfo>; using TensorInfoPtr = std::shared_ptr<TensorInfo>;
struct TensorInfo { struct TensorInfo {
enum Prop { enum Status {
Device, Shape, DType, DevValue, HostValue InvalidStatus, Allocated, Produced, Swapped, Dropped, Deleted,
}; };
uint64_t id; uint64_t id = -1;
std::string name;
// Most attrs of TensorInfo, except `ptr` and `h_value`,
// were visited read and written in main thread.
// Lock interpreter when visiting `ptr`.
TensorPtr ptr; TensorPtr ptr;
LogicalTensorDesc desc; LogicalTensorDesc desc;
...@@ -59,13 +63,17 @@ struct TensorInfo { ...@@ -59,13 +63,17 @@ struct TensorInfo {
size_t memory; size_t memory;
double last_used_time; double last_used_time;
// FIXME: broken by drop
bool value_fetched = false;
bool invalid = false; bool invalid = false;
bool allow_delete = false; bool allow_delete = false;
EvictType evict_type = NONE; EvictType evict_type = NONE;
// Status should be only modified in worker thread
Status status = InvalidStatus;
// Used by HostCompute and Memory Swap.
// HostCompute and Swap does not happen in one thread.
// Maybe a barrier is needed.
HostTensorND h_value; HostTensorND h_value;
// reserved for auto drop // reserved for auto drop
...@@ -74,6 +82,10 @@ struct TensorInfo { ...@@ -74,6 +82,10 @@ struct TensorInfo {
size_t ref_cnt = 0; size_t ref_cnt = 0;
std::shared_ptr<DsuNode> dsu_ptr; std::shared_ptr<DsuNode> dsu_ptr;
// Not reference count, inc when used as input
size_t ptr_use_count = 0;
// Used by `Drop` action
struct ComputePath { struct ComputePath {
uint64_t id; uint64_t id;
std::shared_ptr<OpDef> op; std::shared_ptr<OpDef> op;
...@@ -126,20 +138,24 @@ struct TensorInfo { ...@@ -126,20 +138,24 @@ struct TensorInfo {
--pinned; --pinned;
} }
void detach_producer() { // returns true if producer is deleted
bool detach_producer() {
if (!producer) { if (!producer) {
return; return false;
} }
auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this); auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this);
mgb_assert(output != producer->outputs.end()); mgb_assert(output != producer->outputs.end());
*output = nullptr; *output = nullptr;
bool deleted = false;
if (producer->ref_cnt() == 0) { if (producer->ref_cnt() == 0) {
for (auto* input: producer->unique_inputs) { for (auto* input: producer->unique_inputs) {
input->users.erase(std::find(input->users.begin(), input->users.end(), producer)); input->users.erase(std::find(input->users.begin(), input->users.end(), producer));
} }
delete producer; delete producer;
deleted = true;
} }
producer = nullptr; producer = nullptr;
return deleted;
} }
bool size_exceeds_thd(size_t thd) { bool size_exceeds_thd(size_t thd) {
...@@ -150,26 +166,4 @@ struct TensorInfo { ...@@ -150,26 +166,4 @@ struct TensorInfo {
}; };
} }
template <>
struct ToStringTrait<interpreter::intl::TensorInfo::Prop>{
using TensorInfo = interpreter::intl::TensorInfo;
std::string operator()(TensorInfo::Prop prop) const {
switch(prop) {
case TensorInfo::DType:
return "dtype";
case TensorInfo::DevValue:
return "dev_value";
case TensorInfo::Device:
return "device";
case TensorInfo::HostValue:
return "host_value";
case TensorInfo::Shape:
return "shape";
default:
return "unknown";
}
}
};
} }
...@@ -22,47 +22,58 @@ ...@@ -22,47 +22,58 @@
#include "./event_pool.h" #include "./event_pool.h"
#include "./op_trait.h" #include "./op_trait.h"
#include "./profiler/formats.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
namespace { uint64_t Timer::get_nsecs() {
using namespace std::chrono;
DeviceTimer::SharedEvent alloc_recorded_event(CompNode device) { auto finish = steady_clock::now();
auto event = EventPool::with_timer().alloc_shared(device); auto duration = duration_cast<nanoseconds>(finish - m_start);
event->record(); return duration.count();
return event;
} }
} // namespace uint64_t Timer::get_started_at() {
return m_started_at;
DeviceTimer::SharedEvent DeviceTimer::get_device_time(CompNode device) {
return alloc_recorded_event(device);
} }
SmallVector<DeviceTimer::SharedEvent> DeviceTimer::get_all(SmallVector<CompNode> device_list) { void Timer::reset() {
SmallVector<DeviceTimer::SharedEvent> results; using namespace std::chrono;
for (auto&& device: device_list) { m_start = steady_clock::now();
results.push_back(alloc_recorded_event(device)); auto now_ns = duration_cast<nanoseconds>(std::chrono::system_clock::now().time_since_epoch());
} m_started_at = now_ns.count();
return results;
} }
double HostTimer::get_msecs() { std::shared_ptr<CompNode::Event> Timer::record_event(CompNode device) {
using namespace std::chrono; auto event = EventPool::with_timer().alloc_shared(device);
auto finish = steady_clock::now(); event->record();
auto duration = duration_cast<microseconds>(finish - m_start); return event;
return (double)duration.count() / 1e3;
} }
double HostTimer::get_started_at() { Profiler::options_t Profiler::sm_profile_options;
return m_started_at; std::mutex Profiler::sm_mutex;
std::unordered_map<std::thread::id, Profiler*> Profiler::sm_profilers;
Timer Profiler::sm_timer;
std::atomic_uint64_t Profiler::sm_last_id = 0;
bool Profiler::sm_profiling = false;
thread_local std::unique_ptr<Profiler> Profiler::tm_profiler = std::make_unique<Profiler>();
std::atomic_size_t Profiler::sm_preferred_capacity;
auto Profiler::get_thread_dict() -> thread_dict_t {
MGB_LOCK_GUARD(sm_mutex);
thread_dict_t thread_dict;
for (auto&& [tid, profiler]: sm_profilers) {
thread_dict[tid] = profiler->m_thread_name;
}
return thread_dict;
} }
void HostTimer::reset() { void Profiler::dump_profile(std::string basename, std::string format, results_t results, options_t options) {
using namespace std::chrono; auto thread_dict = get_thread_dict();
m_start = steady_clock::now(); {
auto now_us = duration_cast<microseconds>(std::chrono::system_clock::now().time_since_epoch()); mgb_log_error("unsupported profiling format %s", format.c_str());
m_started_at = (double)(now_us.count()) / 1e3; }
} }
} // namespace imperative } // namespace imperative
......
#include <string>
#include <memory>
#include "megbrain/utils/json.h"
namespace mgb {
namespace imperative {
class ChromeTraceEvent {
public:
ChromeTraceEvent& name(std::string name) {
m_name = std::move(name);
return *this;
}
ChromeTraceEvent& tid(uint64_t tid) {
m_tid = std::move(tid);
return *this;
}
ChromeTraceEvent& cat(std::string cat) {
m_cat = std::move(cat);
return *this;
}
ChromeTraceEvent& pid(uint64_t pid) {
m_pid = pid;
return *this;
}
ChromeTraceEvent& id(uint64_t id) {
m_id = id;
return *this;
}
ChromeTraceEvent& idx(uint64_t idx) {
m_idx = idx;
return *this;
}
ChromeTraceEvent& ts(double ts) {
m_ts = ts;
return *this;
}
ChromeTraceEvent& dur(double dur) {
m_dur = dur;
return *this;
}
ChromeTraceEvent& ph(char ph) {
m_ph = ph;
return *this;
}
ChromeTraceEvent& bp(char bp) {
m_bp = bp;
return *this;
}
ChromeTraceEvent& args(std::shared_ptr<json::Object> args) {
m_args = std::move(args);
return *this;
}
ChromeTraceEvent& arg(std::string key, std::string value) {
if (!m_args) {
m_args = json::Object::make();
}
(*m_args)[key] = json::String::make(value);
return *this;
}
ChromeTraceEvent& arg(std::string key, double value) {
if (!m_args) {
m_args = json::Object::make();
}
(*m_args)[key] = json::Number::make(value);
return *this;
}
ChromeTraceEvent& arg(std::string key, std::shared_ptr<json::Value> value) {
if (!m_args) {
m_args = json::Object::make();
}
(*m_args)[key] = value;
return *this;
}
std::shared_ptr<json::Object> to_json() const {
auto result = json::Object::make();
auto prop_str = [&](auto key, auto value) {
if (value.empty()) {
return;
}
(*result)[key] = json::String::make(value);
};
auto prop_num = [&](auto key, auto value) {
if (!value) {
return;
}
(*result)[key] = json::Number::make(value.value());
};
auto prop_char = [&](auto key, auto value) {
if (!value) {
return;
}
(*result)[key] = json::String::make(std::string{} + value.value());
};
prop_str("name", m_name);
prop_num("tid", m_tid);
prop_str("cat", m_cat);
prop_num("pid", m_pid);
prop_num("id", m_id);
prop_num("idx", m_idx);
prop_num("ts", m_ts);
prop_num("dur", m_dur);
prop_char("ph", m_ph);
prop_char("bp", m_bp);
if (m_args) {
(*result)["args"] = m_args;
}
return result;
}
private:
std::string m_name;
std::string m_cat;
std::optional<uint64_t> m_tid;
std::optional<uint64_t> m_pid;
std::optional<uint64_t> m_id;
std::optional<uint64_t> m_idx;
std::optional<double> m_ts;
std::optional<double> m_dur;
std::optional<char> m_ph;
std::optional<char> m_bp;
std::shared_ptr<json::Object> m_args;
};
class ChromeTraceEventList {
public:
ChromeTraceEvent& new_event() {
m_content.emplace_back();
return m_content.back();
}
std::shared_ptr<json::Array> to_json() const {
auto result = json::Array::make();
for (auto&& event: m_content) {
result->add(event.to_json());
}
return result;
}
private:
std::vector<ChromeTraceEvent> m_content;
};
} // namespace imperative
} // namespace mgb
...@@ -11,65 +11,176 @@ ...@@ -11,65 +11,176 @@
#pragma once #pragma once
#include "./commands.h" #include "megbrain/utils/small_vector.h"
#include "./tensor_info.h"
namespace mgb::imperative::interpreter::intl { #include "../op_trait.h"
namespace mgb::imperative::profiler {
enum class TensorProp {
InvalidProp, Device, Shape, DType, DevValue, HostValue,
};
using OpParams = std::unordered_map<std::string, std::string>;
}
namespace mgb::imperative {
template <>
struct ToStringTrait<profiler::TensorProp>{
using TensorProp = profiler::TensorProp;
std::string operator()(TensorProp prop) const {
switch(prop) {
case TensorProp::DType:
return "dtype";
case TensorProp::DevValue:
return "dev_value";
case TensorProp::Device:
return "device";
case TensorProp::HostValue:
return "host_value";
case TensorProp::Shape:
return "shape";
default:
return "unknown";
}
}
};
}
namespace mgb::imperative::profiler {
#define DEF_EVENT(X, ...) struct X##Event __VA_ARGS__; #define DEF_EVENT(X, ...) struct X##Event __VA_ARGS__;
#define DEF_DUR_EVENT(X, ...) struct X##Event __VA_ARGS__; struct X##FinishEvent __VA_ARGS__; #define DEF_DUR_EVENT(X, ...) struct X##Event __VA_ARGS__; struct X##FinishEvent __VA_ARGS__;
DEF_EVENT(Command, { DEF_EVENT(OpDispatch, {
IdentifiedCommand icmd; uint64_t op_id;
std::string op_name;
std::function<OpParams()> op_params;
SmallVector<uint64_t> inputs;
SmallVector<uint64_t> outputs;
});
DEF_DUR_EVENT(OpInput, {
uint64_t tensor_id;
TensorShape shape;
});
DEF_DUR_EVENT(OpDel, {
uint64_t tensor_id;
TensorShape shape;
});
DEF_DUR_EVENT(OpOutput, {
uint64_t tensor_id;
TensorShape shape;
}); });
DEF_EVENT(CommandEnqueue, :CommandEvent {});
DEF_EVENT(CommandExecute, :CommandEvent {});
DEF_EVENT(CommandFinish, :CommandEvent {});
DEF_DUR_EVENT(OpExecute, { DEF_DUR_EVENT(OpExecute, {
uint64_t id; uint64_t op_id;
std::shared_ptr<OpDef> op; });
SmallVector<uint64_t> inputs;
SmallVector<uint64_t> outputs; DEF_DUR_EVENT(OpPostExecute, {
uint64_t op_id;
}); });
DEF_DUR_EVENT(KernelExecute, { DEF_DUR_EVENT(KernelExecute, {
uint64_t id; uint64_t op_id;
std::shared_ptr<OpDef> op; uint64_t kernel_id;
SmallVector<uint64_t> inputs; std::shared_ptr<CompNode::Event> event;
SmallVector<uint64_t> outputs;
}); });
DEF_EVENT(TensorDeclare, { DEF_EVENT(TensorDeclare, {
uint64_t tensor_id; uint64_t tensor_id;
std::string name;
}); });
DEF_EVENT(TensorProduce, { DEF_EVENT(TensorProduce, {
uint64_t tensor_id; uint64_t tensor_id;
TensorLayout layout; TensorLayout layout;
CompNode device; CompNode device;
void* ptr;
}); });
DEF_EVENT(TensorUsage, {
uint64_t tensor_id;
});
DEF_EVENT(TensorRelease, {
uint64_t tensor_id;
});
DEF_EVENT(TensorErase, { DEF_EVENT(TensorErase, {
uint64_t tensor_id; uint64_t tensor_id;
size_t use_count;
}); });
DEF_EVENT(TensorGetProp, { DEF_EVENT(TensorGetProp, {
uint64_t tensor_id; uint64_t tensor_id;
TensorInfo::Prop prop; TensorProp prop;
std::string prop_desc; });
DEF_EVENT(TensorNotifyProp, {
uint64_t tensor_id;
uint64_t wait_id;
TensorProp prop;
}); });
DEF_DUR_EVENT(TensorWaitProp, {
DEF_EVENT(TensorWaitProp, {
uint64_t tensor_id; uint64_t tensor_id;
TensorInfo::Prop prop; uint64_t wait_id;
std::string prop_desc; TensorProp prop;
}); });
DEF_EVENT(TensorNotifyProp, {
DEF_EVENT(TensorWaitPropFinish, {
uint64_t tensor_id; uint64_t tensor_id;
TensorInfo::Prop prop; uint64_t wait_id;
std::string prop_desc; TensorProp prop;
bool notified;
}); });
DEF_DUR_EVENT(Sync, {});
DEF_DUR_EVENT(SampleDevice, {
CompNode device;
size_t total_memory;
size_t free_memory;
});
DEF_EVENT(WorkerException, {});
DEF_EVENT(ShapeInfer, {
bool success;
});
DEF_DUR_EVENT(Scope, { DEF_DUR_EVENT(Scope, {
std::string name; std::string name;
}); });
DEF_DUR_EVENT(DeviceScope, { DEF_DUR_EVENT(DeviceScope, {
std::string name; std::string name;
std::shared_ptr<CompNode::Event> event;
});
DEF_DUR_EVENT(Sync, {});
DEF_DUR_EVENT(StartProfile, {
size_t capture_count;
});
DEF_DUR_EVENT(StopProfile, {
size_t escape_count;
}); });
DEF_DUR_EVENT(TensorCommand, {
enum Kind {
Put, Del, SwapIn, SwapOut, Drop, ReGen, RecFree, GetValue
};
uint64_t tensor_id;
Kind kind;
});
#undef DEF_EVENT
#undef DEF_DUR_EVENT
} }
/** /**
* \file imperative/src/impl/interpreter/profiler.cpp * \file imperative/src/impl/interpreter/profiler.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
...@@ -9,22 +9,12 @@ ...@@ -9,22 +9,12 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "./profiler.h" #pragma once
#include <sstream> #include <unordered_set>
#include <cinttypes>
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) #include "megbrain/imperative/profiler.h"
#include <unistd.h>
#elif defined(_WIN32)
#include <process.h>
#else
#error Unsupported platform
#endif
#include "../op_trait.h"
namespace mgb::imperative::interpreter::intl {
namespace mgb::imperative::profiler {
} }
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include "megbrain/tensor.h" #include "megbrain/tensor.h"
#include "./events.h"
namespace mgb::imperative::profiler { namespace mgb::imperative::profiler {
struct ProfileDeviceState { struct ProfileDeviceState {
...@@ -53,6 +55,7 @@ struct ProfileStaticsState { ...@@ -53,6 +55,7 @@ struct ProfileStaticsState {
struct ProfileOperatorState { struct ProfileOperatorState {
uint64_t id; uint64_t id;
std::string name; std::string name;
OpParams params;
SmallVector<uint64_t> inputs; SmallVector<uint64_t> inputs;
SmallVector<uint64_t> outputs; SmallVector<uint64_t> outputs;
CompNode device; CompNode device;
......
...@@ -47,8 +47,8 @@ struct Interpreter { ...@@ -47,8 +47,8 @@ struct Interpreter {
virtual size_t get_option(std::string name) = 0; virtual size_t get_option(std::string name) = 0;
virtual void set_option(std::string name, size_t value) = 0; virtual void set_option(std::string name, size_t value) = 0;
virtual void start_profile(std::unordered_map<std::string, int> option) = 0; virtual void start_profile() = 0;
virtual void stop_profile(std::string basename, std::string format) = 0; virtual void stop_profile() = 0;
virtual void push_scope(std::string name) = 0; virtual void push_scope(std::string name) = 0;
virtual void pop_scope(std::string name) = 0; virtual void pop_scope(std::string name) = 0;
......
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <fstream> #include <fstream>
#include <chrono> #include <chrono>
#include <bitset> #include <bitset>
#include <deque>
#include <any>
#include <typeindex>
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megbrain/graph/event.h" #include "megbrain/graph/event.h"
...@@ -29,165 +32,188 @@ ...@@ -29,165 +32,188 @@
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
class DeviceTimer { class Timer {
public:
using SharedEvent = std::shared_ptr<CompNode::Event>;
DeviceTimer() = default;
SharedEvent get_device_time(CompNode device);
SmallVector<SharedEvent> get_all(SmallVector<CompNode> device_list);
};
class HostTimer {
public: public:
void reset(); void reset();
double get_msecs(); uint64_t get_nsecs();
double get_started_at(); uint64_t get_started_at();
static std::shared_ptr<CompNode::Event> record_event(CompNode device);
private: private:
decltype(std::chrono::steady_clock::now()) m_start; decltype(std::chrono::steady_clock::now()) m_start;
double m_started_at; uint64_t m_started_at;
}; };
class ProfilerBase { class Profiler {
public: public:
using Host = std::thread::id; struct Record {
using Device = CompNode; uint64_t id;
uint64_t time; //in ns
struct HostInstant { std::any data;
Host tid;
double time;
void wait() const {}
}; };
enum Status: uint8_t {
struct DeviceInstant { Running = 0,
double before; Recording = 1,
std::shared_ptr<CompNode::Event> event; Collecting = 2,
double after;
void wait() const {
event->host_wait();
}
}; };
using ProfileCollector = std::function<void(std::thread::id, Record)>;
using option_t = uint64_t;
using options_t = std::unordered_map<std::string, option_t>;
using result_t = std::pair<std::thread::id, Record>;
using results_t = std::vector<result_t>;
using thread_dict_t = std::unordered_map<std::thread::id, std::string>;
private:
std::thread::id m_thread_id;
std::vector<Record> m_records;
std::atomic<Status> m_status = Running;
uint64_t m_last_time = 0;
std::string m_thread_name;
static options_t sm_profile_options;
static std::mutex sm_mutex;
static std::unordered_map<std::thread::id, Profiler*> sm_profilers;
static Timer sm_timer;
static std::atomic_uint64_t sm_last_id;
static std::atomic_size_t sm_preferred_capacity;
static bool sm_profiling;
static constexpr bool sm_debug = false;
thread_local static std::unique_ptr<Profiler> tm_profiler;
public:
Profiler() {
m_thread_id = std::this_thread::get_id();
MGB_LOCK_GUARD(sm_mutex);
if (sm_profilers.size() == 0) {
reset();
}
mgb_assert(sm_profilers.count(m_thread_id) == 0);
sm_profilers[m_thread_id] = this;
}
~Profiler() {
MGB_LOCK_GUARD(sm_mutex);
mgb_assert(sm_profilers.count(m_thread_id) == 1);
sm_profilers.erase(m_thread_id);
}
public:
static Profiler& get_instance() {
return *tm_profiler;
}
using Instant = std::variant<HostInstant, DeviceInstant>; static void reset() {
mgb_assert(sm_profilers.size() == 0, "profiler already running");
sm_timer.reset();
}
template <typename TEvent> static uint64_t next_id() {
struct EventRecord { return sm_last_id++;
Instant instant; }
TEvent data;
const HostInstant& host() const { template <typename T, typename... TArgs>
return std::get<HostInstant>(instant); static uint64_t record(TArgs&&... args) {
auto& profiler = get_instance();
auto last_time = profiler.m_last_time;
if constexpr (sm_debug) {
Status expected = Running;
mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording));
}
uint64_t id = next_id();
uint64_t time = sm_timer.get_nsecs();
time = std::max(time, last_time + 2000);
profiler.m_last_time = time;
profiler.m_records.push_back({id, time, T{std::forward<TArgs>(args)...}});
if constexpr (sm_debug) {
Status expected = Recording;
mgb_assert(profiler.m_status.compare_exchange_strong(expected, Running));
}
return id;
} }
const DeviceInstant& device() const { static results_t collect() {
return std::get<DeviceInstant>(instant); MGB_LOCK_GUARD(sm_mutex);
if constexpr (sm_debug) {
for (auto&& [tid, profiler]: sm_profilers) {
Status expected = Running;
mgb_assert(profiler->m_status.compare_exchange_strong(expected, Collecting));
}
}
std::vector<std::pair<std::thread::id, Record>> profile_data;
for (auto&& [tid, profiler]: sm_profilers) {
sm_preferred_capacity = std::max(sm_preferred_capacity.load(), profiler->m_records.size());
for (auto& record: profiler->m_records) {
profile_data.push_back({tid, std::move(record)});
}
profiler->m_records.clear();
profiler->m_records.reserve(sm_preferred_capacity);
}
std::sort(profile_data.begin(), profile_data.end(), [](auto& lhs, auto& rhs){
return lhs.second.id < rhs.second.id;
});
if constexpr (sm_debug) {
for (auto&& [tid, profiler]: sm_profilers) {
Status expected = Collecting;
mgb_assert(profiler->m_status.compare_exchange_strong(expected, Running));
}
}
return profile_data;
} }
void wait() const { static option_t get_option(std::string key, option_t default_val) {
std::visit([&](const auto& instant){ instant.wait(); }, instant); if (!sm_profile_options.count(key)) {
return default_val;
}
return sm_profile_options.at(key);
} }
};
protected:
HostInstant record_host() {
return {std::this_thread::get_id(), m_host_timer.get_msecs()};
}
DeviceInstant record_device(Device device) {
auto before = m_host_timer.get_msecs();
auto event = m_device_timer.get_device_time(device);
auto after = m_host_timer.get_msecs();
return {before, event, after};
}
protected:
std::atomic_int64_t m_last_id = 0;
HostTimer m_host_timer;
DeviceTimer m_device_timer;
Spinlock m_lock;
};
static void load_options(options_t options) {
sm_profile_options = std::move(options);
}
template <typename... TEvents> static options_t get_options() {
class Profiler: public ProfilerBase { return sm_profile_options;
public: }
using Record = std::variant<EventRecord<TEvents>...>;
using Mask = std::bitset<sizeof...(TEvents)>;
struct Data { static bool is_profiling() {
std::vector<Record> records; return sm_profiling;
double started_at; }
};
template <typename TEvent, size_t index = 0> static void start_profile() {
static constexpr size_t index_of() { mgb_assert(!sm_profiling);
if constexpr (index == std::variant_size_v<Record>) { sm_profiling = true;
return index;
} else if constexpr (std::is_same_v<EventRecord<TEvent>, std::variant_alternative_t<index, Record>>) {
return index;
} else {
return index_of<TEvent, index+1>();
} }
};
template <typename... TEvents2> static void stop_profile() {
static Mask mask_of() { mgb_assert(sm_profiling);
return Mask{} | (Mask{}.set(index_of<TEvents2>()) |...); sm_profiling = false;
} }
enum Status { static thread_dict_t get_thread_dict();
NotStarted, Profiling, Stopped
}; static void dump_profile(std::string basename, std::string format, results_t results, options_t options);
};
class ProfileDataCollector {
public: public:
template <typename TEvent, typename... TArgs> template <typename T>
void record_host(TArgs&&... args) { using SubCollector = std::function<void(uint64_t, std::thread::id, uint64_t, T)>;
MGB_LOCK_GUARD(m_lock); private:
if (!m_event_mask.test(index_of<TEvent>())) { std::unordered_map<std::type_index, SubCollector<std::any>> m_collectors;
return; public:
} template <typename T>
mgb_assert(m_status != Stopped, "record after stop"); ProfileDataCollector& handle(SubCollector<T> collector) {
auto instant = HostInstant{std::this_thread::get_id(), m_host_timer.get_msecs()}; auto erased = [collector](uint64_t id, std::thread::id tid, uint64_t time, std::any data){
m_record_list.emplace_back(EventRecord<TEvent>{std::move(instant), {std::forward<TArgs>(args)...}}); collector(id, tid, time, std::any_cast<T>(std::move(data)));
};
m_collectors[typeid(T)] = erased;
return *this;
} }
template <typename TEvent, typename... TArgs> void operator()(uint64_t id, std::thread::id tid, uint64_t time, std::any event) {
void record_device(Device device, TArgs&&... args) { std::type_index type = event.type();
MGB_LOCK_GUARD(m_lock); if (m_collectors.count(type) == 0) {
if (!m_event_mask.test(index_of<TEvent>())) {
return; return;
} }
mgb_assert(m_status != Stopped, "record after stop"); auto& handler = m_collectors.at(type);
auto before = m_host_timer.get_msecs(); handler(id, tid, time, std::move(event));
auto event = m_device_timer.get_device_time(device); }
auto after = m_host_timer.get_msecs();
auto instant = DeviceInstant{before, event, after};
m_record_list.emplace_back(EventRecord<TEvent>{std::move(instant), {std::forward<TArgs>(args)...}});
}
// unsafe
bool is_profiling() {
return m_status == Profiling;
}
void start(Mask mask) {
MGB_LOCK_GUARD(m_lock);
mgb_assert(m_status == NotStarted, "profiler already started");
m_status = Profiling;
m_event_mask = mask;
m_host_timer.reset();
}
Data stop() {
MGB_LOCK_GUARD(m_lock);
mgb_assert(m_status == Profiling, "profiler not active");
m_status = Stopped;
for (auto&& record: m_record_list) {
std::visit([&](const auto& record){
record.wait();
}, record);
}
auto records = std::move(m_record_list);
return { records, m_host_timer.get_started_at() };
}
protected:
std::vector<Record> m_record_list;
Mask m_event_mask;
std::atomic<Status> m_status = NotStarted;
}; };
} // namespace imperative } // namespace imperative
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册