From 2c9fa7f650e52f39058dc957f2f131941e7c46ae Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 24 Sep 2020 15:20:09 +0800 Subject: [PATCH] feat(mge/profiler): add compatible and graphviz mode GitOrigin-RevId: 40ca9ea80e995547b03f2ca24c8bf2428bd326b5 --- imperative/python/megengine/utils/profiler.py | 292 ++++++++++++------ imperative/python/src/utils.cpp | 26 +- .../imperative => impl}/function_hook.h | 18 +- imperative/src/impl/profiler.cpp | 153 ++++++++- .../include/megbrain/imperative/profiler.h | 80 +++-- 5 files changed, 435 insertions(+), 134 deletions(-) rename imperative/src/{include/megbrain/imperative => impl}/function_hook.h (77%) diff --git a/imperative/python/megengine/utils/profiler.py b/imperative/python/megengine/utils/profiler.py index 84d84321c..668ed9b42 100644 --- a/imperative/python/megengine/utils/profiler.py +++ b/imperative/python/megengine/utils/profiler.py @@ -9,13 +9,155 @@ import base64 import json import os -from typing import List, Optional +import re +from typing import Iterable, List, Optional from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry from ..core._imperative_rt import ProfilerImpl as _Profiler from ..core._imperative_rt.imperative import sync from ..core._imperative_rt.ops import CollectiveCommMode -from ..core.ops.builtin import GetVarShape + + +def _make_dict(**kwargs): + unused_keys = [] + for k, v in kwargs.items(): + if v is None: + unused_keys.append(k) + for k in unused_keys: + del kwargs[k] + return kwargs + + +def _print_opnode_config(config): + return _make_dict( + name=config.name, dtype=config.dtype, comp_node_arr=config.comp_node_arr, + ) + + +def _dump_chrome_timeline(entries: List[ProfileEntry], path: str): + pid = os.getpid() + trace_events = [] + + def append_event(**kwargs): + trace_events.append(_make_dict(**kwargs)) + + for id, entry in enumerate(entries): + op = entry.op + name = type(op).__name__ + host_begin, host_end = entry.host + device_list = entry.device_list + args = Profiler.fetch_attrs(op) + args["__id__"] = "[{}]".format(id) + cat = name + for ts, ph in [(host_begin, "B"), (host_end, "E")]: + append_event( + name=name, ph=ph, ts=ts * 1000, pid=pid, tid="host", args=args, cat=cat, + ) + for device, device_begin, device_end in device_list: + for ts, ph in [(device_begin(), "B"), (device_end(), "E")]: + append_event( + name=name, ph=ph, ts=ts * 1000, pid=pid, tid=str(device), args=args, + ) + with open("{}.chrome_timeline.json".format(path), "w") as f: + json.dump(trace_events, f, indent=2) + + +def _dump_compatible(entries: List[ProfileEntry], path: str): + obj = { + "graph_exec": {"var": [], "operator": {}}, + "profiler": {"device": {}, "host": {}, "opr_footprint": {}}, + } + var_list = obj["graph_exec"]["var"] + operator_dict = obj["graph_exec"]["operator"] + device_dict = obj["profiler"]["device"] + host_dict = obj["profiler"]["host"] + opr_foot_print_dict = obj["profiler"]["opr_footprint"] + + def add_var(var) -> int: + var_id = len(var_list) + var_list.append( + {"comp_node": str(var[2]),} + ) + return var_id + + for op_id, entry in enumerate(entries): + operator_dict[op_id] = { + "input": [add_var(var) for var in entry.inputs], + "output": [add_var(var) for var in entry.outputs], + "name": str(entry.op.ctype()), + "type": "imperative", + "id": entry.id, + } + op_device_dict = {} + for device, device_begin, device_end in entry.device_list: + op_device_dict[str(device)] = { + "start": device_begin(), + "kern": device_begin(), + "end": device_end(), + } + device_dict[op_id] = op_device_dict + host_begin, host_end = entry.host + host_dict[op_id] = { + "host": {"start": host_begin, "kern": host_begin, "end": host_end} + } + opr_footprint = { + "out_shapes": [oup[1] for oup in entry.outputs], + "in_shapes": [inp[1] for inp in entry.inputs], + "params": {}, + } + if entry.memory > 0: + opr_footprint["memory"] = entry.memory + if entry.computation > 0: + opr_footprint["computation"] = entry.computation + opr_foot_print_dict[op_id] = opr_footprint + with open("{}.compatible.json".format(path), "w") as f: + json.dump(obj, f, indent=2) + + +def _dump_graphviz(entries: List[ProfileEntry], path: str): + import graphviz + import json + + graph = graphviz.Digraph() + graph.graph_attr["ordering"] = "out" + var_cache = {} + + def cache_var(var_id, var_shape): + if var_id not in var_cache: + var_name = "var({})".format(var_id) + var_label = "{}\nshape:{}\n".format(var_name, shape) + graph.node(var_name, var_label) + var_cache[var_id] = var_name + return var_cache[var_id] + + for op_id, entry in enumerate(entries): + op = entry.op + op_name = "op({})".format(op_id) + op_type = type(op).__name__ + op_attrs = Profiler.fetch_attrs(op) + label_lines = [] + if "param" in op_attrs: + del op_attrs["param"] + label_lines.append("{}:{}".format(op_name, op_type)) + for k, v in op_attrs.items(): + label_lines.append("attr[{}]: {}".format(k, v)) + op_param_str = entry.param + if len(op_param_str) > 0: + op_param = json.loads(op_param_str) + for k, v in op_param.items(): + label_lines.append("param[{}]:{}".format(k, v)) + host_begin, host_end = entry.host + label_lines.append("time[host]: {:f}ms".format(host_end - host_begin)) + for device, device_begin, device_end in entry.device_list: + device_time = device_end() - device_begin() + label_lines.append("time[{}]: {:f}ms".format(device, device_time)) + op_label = "\n".join(label_lines) + graph.node(op_name, op_label, shape="rectangle") + for var_id, shape, device in entry.inputs: + graph.edge(cache_var(var_id, shape), op_name) + for var_id, shape, device in entry.outputs: + graph.edge(op_name, cache_var(var_id, shape)) + graph.save("{}.graphviz.dot".format(path)) class Profiler: @@ -23,7 +165,7 @@ class Profiler: Profile graph execution in imperative mode. :type path: Optional[str] - :param path: default path for profiler to dump. + :param path: default path prefix for profiler to dump. Examples: @@ -31,59 +173,67 @@ class Profiler: import megengine as mge import megengine.module as M - import megengine.utils.profiler.Profiler + from megengine.utils.profiler import Profiler # With Learnable Parameters for iter in range(0, 10): # Only profile record of last iter would be saved - with Profiler("profile.json"): + with Profiler("profile"): # your code here # Then open the profile file in chrome timeline window """ - # see https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html - GOOD = "good" - BAD = "bad" - TERRIBLE = "terrible" + CHROME_TIMELINE = "chrome_timeline" + COMPATIBLE = "compatible" + GRAPHVIZ = "graphviz" + + WITH_FOOTPRINT = 1 - BLACK = "black" - GREY = "grey" - WHITE = "white" - YELLOW = "yellow" - OLIVE = "olive" + _type_map = { + OperatorNodeConfig: lambda x: _print_opnode_config(x), + bytes: lambda x: base64.encodebytes(x).decode("ascii"), + CollectiveCommMode: lambda x: str(x), + } - def __init__(self, path: str = "profile.json"): + _dumper_map = { + CHROME_TIMELINE: _dump_chrome_timeline, + COMPATIBLE: _dump_compatible, + GRAPHVIZ: _dump_graphviz, + } + + def __init__( + self, + path: str = "profile", + *, + formats: Iterable[str] = (CHROME_TIMELINE,), + type_filter: str = ".*", + exit_dump: bool = True + ) -> None: self._impl = _Profiler() self._path = path - self._color_map = {} - self._type_map = { - OperatorNodeConfig: lambda x: self.print_opnode_config(x), - bytes: lambda x: base64.encodebytes(x).decode("ascii"), - CollectiveCommMode: lambda x: str(x), - } + + if isinstance(formats, str): + formats = (formats,) + + self._filter = type_filter + self._dumpers = [Profiler._dumper_map[fmt] for fmt in formats] + self._exit_dump = exit_dump def __enter__(self): sync() - self._impl.start() + self._impl.start(Profiler.WITH_FOOTPRINT) return self - def __exit__(self, val, type, trace): + def __exit__(self, val, tp, trace): + if self._exit_dump: + self.dump() sync() self._impl.stop() - if self._path is not None: - self.dump() - - def recolor(self, target: str, color: str): - self._color_map[target] = color - return self + self._impl.clear() - def print_opnode_config(self, config): - return self.make_dict( - name=config.name, dtype=config.dtype, comp_node_arr=config.comp_node_arr, - ) - - def fetch_attrs(self, op): + @classmethod + def fetch_attrs(cls, op): attrs = dir(op) results = {} for attr in attrs: @@ -93,61 +243,29 @@ class Profiler: if callable(value): continue value_type = type(value) - if value_type in self._type_map: - value = self._type_map[value_type](value) + if value_type in cls._type_map: + value = cls._type_map[value_type](value) results[attr] = value return results - def make_dict(self, **kwargs): - unused_keys = [] - for k, v in kwargs.items(): - if v is None: - unused_keys.append(k) - for k in unused_keys: - del kwargs[k] - return kwargs - def dump(self, path: Optional[str] = None): - pid = os.getpid() + sync() + raw = [ + entry + for entry in self._impl.dump() + if re.match(self._filter, type(entry.op).__name__) + ] if path is None: path = self._path - trace_events = [] - - def append_event(**kwargs): - trace_events.append(self.make_dict(**kwargs)) - - entries: List[ProfileEntry] = self._impl.dump() - - for id, entry in enumerate(entries): - op = entry.op - name = type(op).__name__ - host_begin, host_end = entry.host - device_list = entry.device_list - args = self.fetch_attrs(op) - args["__id__"] = "[{}]".format(id) - cname = self._color_map[name] if name in self._color_map else None - cat = name - for ts, ph in [(host_begin, "B"), (host_end, "E")]: - append_event( - name=name, - ph=ph, - ts=ts * 1000, - pid=pid, - tid="host", - args=args, - cname=cname, - cat=cat, - ) - for device, device_begin, device_end in device_list: - for ts, ph in [(device_begin(), "B"), (device_end(), "E")]: - append_event( - name=name, - ph=ph, - ts=ts * 1000, - pid=pid, - tid=str(device), - args=args, - cname=cname, - ) - with open(path, "w") as f: - json.dump(trace_events, f, indent=2) + for dumper in self._dumpers: + dumper(raw, path) + + def __call__(self, func): + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper + + +profile = Profiler diff --git a/imperative/python/src/utils.cpp b/imperative/python/src/utils.cpp index ed851af25..3169f2726 100644 --- a/imperative/python/src/utils.cpp +++ b/imperative/python/src/utils.cpp @@ -204,17 +204,27 @@ void init_utils(py::module m) { py::class_(m, "ProfileEntry") .def_readwrite("op", &ProfileEntry::op) .def_readwrite("host", &ProfileEntry::host) - .def_readwrite("device_list", &ProfileEntry::device_list); + .def_readwrite("device_list", &ProfileEntry::device_list) + .def_readwrite("inputs", &ProfileEntry::inputs) + .def_readwrite("outputs", &ProfileEntry::outputs) + .def_readwrite("id", &ProfileEntry::id) + .def_readwrite("parent", &ProfileEntry::parent) + .def_readwrite("memory", &ProfileEntry::memory) + .def_readwrite("computation", &ProfileEntry::computation) + .def_property_readonly("param", [](ProfileEntry& self)->std::string{ + if(self.param){ + return self.param->to_string(); + } else { + return {}; + } + }); py::class_(m, "ProfilerImpl") .def(py::init<>()) - .def("start", - [](mgb::imperative::Profiler& profiler) { profiler.start(); }) - .def("stop", - [](mgb::imperative::Profiler& profiler) { profiler.stop(); }) - .def("dump", [](mgb::imperative::Profiler& profiler) { - return profiler.get_profile(); - }); + .def("start", &mgb::imperative::Profiler::start) + .def("stop", &mgb::imperative::Profiler::stop) + .def("clear", &mgb::imperative::Profiler::clear) + .def("dump", &mgb::imperative::Profiler::get_profile); using mgb::imperative::TensorSanityCheck; py::class_(m, "TensorSanityCheckImpl") diff --git a/imperative/src/include/megbrain/imperative/function_hook.h b/imperative/src/impl/function_hook.h similarity index 77% rename from imperative/src/include/megbrain/imperative/function_hook.h rename to imperative/src/impl/function_hook.h index 64582f113..83cb65529 100644 --- a/imperative/src/include/megbrain/imperative/function_hook.h +++ b/imperative/src/impl/function_hook.h @@ -15,6 +15,7 @@ namespace mgb { namespace imperative { + template class FunctionHooker; @@ -22,13 +23,18 @@ template class FunctionHooker { public: using FunctionType = thin_function; + //Type of hooks. Hook should accept a real function as argument + //and invoke it on an appropriate time using HookType = thin_function; - explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} {} + explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} { + m_backup = {nullptr, [](FunctionType*){}}; + } public: FunctionHooker& apply_hook(HookType&& hook) { if (!m_backup) { FunctionType* backup = new FunctionType(*m_fptr); + //Restore hooked function, would be invoked when destructed std::function restorer = [fptr = m_fptr](FunctionType* bkp) -> void { *fptr = *bkp; @@ -36,9 +42,11 @@ public: }; m_backup = decltype(m_backup)(backup, restorer); } + //Replace with hooked version *m_fptr = [func = *m_fptr, hook](TArgs&&... args) -> TRet { return hook(func, std::forward(args)...); }; + //Convinent for chain call return *this; } @@ -47,9 +55,15 @@ private: std::unique_ptr> m_backup; }; +//Helps to deduce template args template FunctionHooker(thin_function* f) ->FunctionHooker; -} // namespace imperative +template +auto make_shared_hook(thin_function* fptr){ + return std::make_shared>(fptr); +} + +} // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/profiler.cpp b/imperative/src/impl/profiler.cpp index 4987ce535..681cccbea 100644 --- a/imperative/src/impl/profiler.cpp +++ b/imperative/src/impl/profiler.cpp @@ -11,19 +11,20 @@ #include "megbrain/imperative/profiler.h" -#include - +#include "./function_hook.h" #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/physical_tensor.h" +#include "megbrain/plugin/opr_footprint.h" + #include "./event_pool.h" #include "./op_trait.h" namespace mgb { - namespace imperative { namespace { + CompNode::UnorderedSet collect_comp_nodes( const OpDef& def, const SmallVector& inputs) { CompNode::UnorderedSet comp_nodes; @@ -36,37 +37,101 @@ CompNode::UnorderedSet collect_comp_nodes( return comp_nodes; } +DeviceTimer::SharedEvent alloc_recorded_event(CompNode device) { + auto event = EventPool::with_timer().alloc_shared(device); + event->record(); + return event; +} + +OprFootprint footprint{}; + } // namespace void DeviceTimer::reset(thin_function host_timer) { CompNode::foreach ([this, host_timer](CompNode device) { - auto base_event = EventPool::with_timer().alloc_shared(device); - base_event->record(); - m_base_event_table[device] = {std::move(base_event), host_timer()}; + m_base_event_table[device] = {alloc_recorded_event(device), host_timer()}; }); + m_host_timer = host_timer; } thin_function DeviceTimer::get_device_time(CompNode device) { auto event = EventPool::with_timer().alloc_shared(device); event->record(); + if(m_base_event_table.count(device) == 0) { + m_base_event_table[device] = {alloc_recorded_event(device), m_host_timer()}; + } auto base = m_base_event_table[device]; return [base, event] { auto [base_event, host_time] = base; - //TODO: sync once for each compnode + // TODO: sync once for each compnode event->host_wait(); return base_event->elapsed_time_until(*event) * 1000 + host_time; }; } -void Profiler::start() { +void DeviceTimer::clear() { + m_base_event_table.clear(); +} + +size_t TensorRecorder::record_tensor(const TensorPtr& tensor) { + if (m_tensor_map.count(tensor.get()) > 0) { + auto& [prev, id] = m_tensor_map[tensor.get()]; + if (prev.lock() != tensor) { + prev = tensor; + id = m_next_id++; + } + return id; + } else { + auto id = m_next_id++; + m_tensor_map.insert({tensor.get(), {std::weak_ptr{tensor}, id}}); + return id; + } +} + +void TensorRecorder::clear() { + m_next_id = 0; + m_tensor_map.clear(); +} + +Profile& Profiler::get_profile() { + for (auto& entry : m_profile) { + for (auto& [device, device_begin, device_end] : entry.device_list) { + MGB_MARK_USED_VAR(device); + device_begin = [value = device_begin()] { return value; }; + device_end = [value = device_end()] { return value; }; + } + } + return m_profile; +} + +void Profiler::start(uint32_t flags) { m_host_timer.reset(); - m_device_timer.reset([&]{ return m_host_timer.get_msecs();} ); - OpTrait::for_each_trait([this](OpTrait& trait) { - FunctionHooker hooker{&trait.apply_on_physical_tensor}; - hooker.apply_hook([this](auto&& apply, const OpDef& def, - const SmallVector& inputs) { + m_device_timer.reset([&] { return m_host_timer.get_msecs(); }); + OpTrait::for_each_trait([this, flags](OpTrait& trait) { + auto hook_apply_on_physical_tensor = + make_shared_hook(&trait.apply_on_physical_tensor); + auto hook_apply_on_var_node = + make_shared_hook(&trait.apply_on_var_node); + hook_apply_on_physical_tensor->apply_hook([this, flags] + (auto&& apply, const OpDef& def, const SmallVector& inputs) { + auto shape2vector = [](const TensorShape& shape) { + std::vector vector_shape; + for (size_t i = 0; i < shape.ndim; i++) { + vector_shape.push_back(shape[i]); + } + return vector_shape; + }; ProfileEntry entry; + entry.id = m_entry_count++; + // TODO: assign parent + entry.parent = 0; + // Record apply context and save to m_profile entry.op = def.copy(); + for (auto&& input : inputs) { + entry.inputs.push_back({m_tensor_recorder.record_tensor(input), + shape2vector(input->layout()), + input->comp_node()}); + } double host_begin = m_host_timer.get_msecs(); auto&& comp_nodes = collect_comp_nodes(def, inputs); for (auto&& comp_node : comp_nodes) { @@ -75,6 +140,11 @@ void Profiler::start() { m_device_timer.get_device_time(comp_node), {}}); } + if (flags & PROFILE_FOOTPRINT) { + MGB_LOCK_GUARD(m_lock); + m_entry_stack.push({&def, &entry, std::this_thread::get_id()}); + } + // Do real apply auto outputs = apply(def, inputs); for (auto& [cn, dev_begin, dev_end] : entry.device_list) { MGB_MARK_USED_VAR(cn); @@ -82,20 +152,71 @@ void Profiler::start() { dev_end = m_device_timer.get_device_time(cn); } entry.host = {host_begin, m_host_timer.get_msecs()}; - m_profile->push_back(std::move(entry)); + for (auto&& output : outputs) { + entry.outputs.push_back( + {m_tensor_recorder.record_tensor(output), + shape2vector(output->layout()), output->comp_node()}); + } + if (flags & PROFILE_FOOTPRINT) { + mgb_assert(std::get<1>(m_entry_stack.top()) == &entry); + MGB_LOCK_GUARD(m_lock); + m_entry_stack.pop(); + } + m_profile.push_back(std::move(entry)); return outputs; }); - m_hooker_list.push_back(std::move(hooker)); + if (flags & PROFILE_FOOTPRINT) { + hook_apply_on_var_node->apply_hook( + [this](auto&& apply, const OpDef& def, + VarNodeArray inputs) -> cg::OperatorNodeBase* { + auto* operator_node = apply(def, std::move(inputs)); + std::remove_reference_t + top; + { + MGB_LOCK_GUARD(m_lock); + if (m_entry_stack.empty()) { + return operator_node; + } + top = m_entry_stack.top(); + } + auto [current_op, current_entry, thread_id] = top; + if (current_op != &def || + thread_id != std::this_thread::get_id()) { + return operator_node; + } + auto&& footprint_result = + footprint.calc_footprint(operator_node); + current_entry->memory = footprint_result.memory; + current_entry->computation = + footprint_result.computation; +#if MGB_ENABLE_JSON + current_entry->param = footprint_result.param; +#endif + return operator_node; + }); + } + m_hooker_list.push_back(std::move(hook_apply_on_physical_tensor)); + m_hooker_list.push_back(std::move(hook_apply_on_var_node)); }); } void Profiler::stop() { m_hooker_list.clear(); - for (auto& entry : *m_profile) { + for (auto& entry : m_profile) { entry.wait_device(); } } +void Profiler::clear() { + mgb_assert(m_entry_stack.empty(), + "entry_stack should be empty after profile"); + mgb_assert(m_hooker_list.empty(), "hooks should be released"); + m_profile.clear(); + m_entry_count = 0; + m_device_timer.clear(); + m_tensor_recorder.clear(); +} + } // namespace imperative } // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/profiler.h b/imperative/src/include/megbrain/imperative/profiler.h index bece82261..d1cd37deb 100644 --- a/imperative/src/include/megbrain/imperative/profiler.h +++ b/imperative/src/include/megbrain/imperative/profiler.h @@ -11,7 +11,10 @@ #pragma once -#include +#include +#include +#include +#include #include "megbrain/comp_node.h" #include "megbrain/graph/event.h" @@ -19,27 +22,39 @@ #include "megbrain/utils/timer.h" #include "megbrain/imperative/op_def.h" - -#include "megbrain/imperative/function_hook.h" +#include "megbrain/imperative/physical_tensor.h" namespace mgb { namespace imperative { -struct ProfileEntry{ +using ProfileTensor = std::tuple, CompNode>; + +struct ProfileEntry { using TimeClosure = std::function; + size_t id; + size_t parent; std::shared_ptr op; + //(host_begin, host_end) std::tuple host; + //[(device, device_begin, device_end)] std::vector> device_list; - void wait_device(){ - for(auto& [cn, begin, end]: device_list){ + std::vector inputs; + std::vector outputs; + ssize_t memory = 0; + ssize_t computation = 0; +#if MGB_ENABLE_JSON + std::shared_ptr param; +#endif + void wait_device() { + for (auto& [cn, begin, end] : device_list) { MGB_MARK_USED_VAR(cn); - begin = [begin=begin()]{ return begin; }; - end = [end = end()]{ return end; }; + begin = [begin = begin()] { return begin; }; + end = [end = end()] { return end; }; } } }; -using Profile = std::vector; +using Profile = std::list; class DeviceTimer { public: @@ -47,31 +62,54 @@ public: DeviceTimer() = default; void reset(thin_function host_timer); thin_function get_device_time(CompNode device); + void clear(); private: CompNode::UnorderedMap> m_base_event_table; + thin_function m_host_timer; +}; + +class TensorRecorder { +private: + // active tensors + std::unordered_map, size_t>> + m_tensor_map; + size_t m_next_id; + +public: + size_t record_tensor(const TensorPtr& tensor); + void clear(); }; class Profiler { public: - Profiler(Profile* profile = nullptr) { - if (!profile) { - m_owned_profile = std::make_unique(); - profile = m_owned_profile.get(); - } - m_profile = profile; - } - void start(); + enum Flags { + PROFILE_FOOTPRINT = 1, + }; + +public: + Profiler() = default; + // Start profiler by hook OpTrait + void start(uint32_t flags); + // Stop profiler and clean environment void stop(); - Profile& get_profile() { return *m_profile; } + void clear(); + Profile& get_profile(); private: DeviceTimer m_device_timer; RealTimer m_host_timer; - Profile* m_profile; + Profile m_profile; + TensorRecorder m_tensor_recorder; + std::stack> + m_entry_stack; + // Hold profile owned by this Profiler std::unique_ptr m_owned_profile; - std::vector> - m_hooker_list; + // Hold hooks, cleared when stop + std::vector m_hooker_list; + size_t m_entry_count = 0; + Spinlock m_lock; + std::unordered_map> m_recorded_tensors; }; } // namespace imperative -- GitLab