提交 6f581906 编写于 作者: M Megvii Engine Team

refactor(mge/profiler): refactor profiler

GitOrigin-RevId: 279aa779a69e503e38e91ce06a973028824eb0aa
上级 cc952b2b
...@@ -6,24 +6,148 @@ ...@@ -6,24 +6,148 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional import base64
import json
import os
from typing import List, Optional
from ..core._imperative_rt import ProfilerImpl from ..core._imperative_rt import OperatorNodeConfig, ProfileEntry
from ..core._imperative_rt import ProfilerImpl as _Profiler
from ..core._imperative_rt.imperative import sync from ..core._imperative_rt.imperative import sync
from ..core._imperative_rt.ops import CollectiveCommMode
from ..core.ops.builtin import GetVarShape
class Profiler: class Profiler:
def __init__(self, path: Optional[str] = None): r"""
self.impl = ProfilerImpl(path) Profile graph execution in imperative mode.
:type path: Optional[str]
:param path: default path for profiler to dump
Examples:
.. testcode::
import megengine as mge
import megengine.module as M
import megengine.utils.profiler.Profiler
# With Learnable Parameters
for iter in range(0, 10):
# Only profile record of last iter would be saved
with Profiler("profile.json"):
# your code here
# Then open the profile file in chrome timeline window
"""
# see https://github.com/catapult-project/catapult/blob/master/tracing/tracing/base/color_scheme.html
GOOD = "good"
BAD = "bad"
TERRIBLE = "terrible"
BLACK = "black"
GREY = "grey"
WHITE = "white"
YELLOW = "yellow"
OLIVE = "olive"
def __init__(self, path: str = "profile.json"):
self._impl = _Profiler()
self._path = path
self._color_map = {}
self._type_map = {
OperatorNodeConfig: lambda x: self.print_opnode_config(x),
bytes: lambda x: base64.encodebytes(x).decode("ascii"),
CollectiveCommMode: lambda x: str(x),
}
def __enter__(self): def __enter__(self):
sync() sync()
self.impl.enable() self._impl.start()
return self return self
def __exit__(self, val, type, trace): def __exit__(self, val, type, trace):
sync() sync()
self.impl.disable() self._impl.stop()
if self._path is not None:
self.dump()
def recolor(self, target: str, color: str):
self._color_map[target] = color
return self
def print_opnode_config(self, config):
return self.make_dict(
name=config.name, dtype=config.dtype, comp_node_arr=config.comp_node_arr,
)
def fetch_attrs(self, op):
attrs = dir(op)
results = {}
for attr in attrs:
if attr.startswith("_"):
continue
value = op.__getattribute__(attr)
if callable(value):
continue
value_type = type(value)
if value_type in self._type_map:
value = self._type_map[value_type](value)
results[attr] = value
return results
def make_dict(self, **kwargs):
unused_keys = []
for k, v in kwargs.items():
if v is None:
unused_keys.append(k)
for k in unused_keys:
del kwargs[k]
return kwargs
def dump(self, path: Optional[str] = None): def dump(self, path: Optional[str] = None):
self.impl.dump(path) pid = os.getpid()
if path is None:
path = self._path
trace_events = []
def append_event(**kwargs):
trace_events.append(self.make_dict(**kwargs))
entries: List[ProfileEntry] = self._impl.dump()
for id, entry in enumerate(entries):
op = entry.op
name = type(op).__name__
host_begin, host_end = entry.host
device_list = entry.device_list
args = self.fetch_attrs(op)
args["__id__"] = "[{}]".format(id)
cname = self._color_map[name] if name in self._color_map else None
cat = name
for ts, ph in [(host_begin, "B"), (host_end, "E")]:
append_event(
name=name,
ph=ph,
ts=ts * 1000,
pid=pid,
tid="host",
args=args,
cname=cname,
cat=cat,
)
for device, device_begin, device_end in device_list:
for ts, ph in [(device_begin(), "B"), (device_end(), "E")]:
append_event(
name=name,
ph=ph,
ts=ts * 1000,
pid=pid,
tid=str(device),
args=args,
cname=cname,
)
with open(path, "w") as f:
json.dump(trace_events, f, indent=2)
...@@ -651,9 +651,14 @@ PyObject* npy::dtype_mgb2np(mgb::DType dtype) { ...@@ -651,9 +651,14 @@ PyObject* npy::dtype_mgb2np(mgb::DType dtype) {
// https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType // https://docs.scipy.org/doc/numpy/reference/c-api.array.html#c.PyArray_TypeObjectFromType
// the following is equivalent to PyArray_TypeObjectFromType for built-in // the following is equivalent to PyArray_TypeObjectFromType for built-in
// types. // types.
if(!dtype.valid()){
Py_XINCREF(Py_None);
return Py_None;
}
auto descr = dtype_mgb2np_descr(dtype); auto descr = dtype_mgb2np_descr(dtype);
if (descr == nullptr) { if (descr == nullptr) {
return nullptr; Py_XINCREF(Py_None);
return Py_None;
} }
if (dtype.has_param()) { if (dtype.has_param()) {
return reinterpret_cast<PyObject*>(descr.release()); return reinterpret_cast<PyObject*>(descr.release());
......
...@@ -199,32 +199,22 @@ void init_utils(py::module m) { ...@@ -199,32 +199,22 @@ void init_utils(py::module m) {
m.def("_get_device_count", &mgb::CompNode::get_device_count, m.def("_get_device_count", &mgb::CompNode::get_device_count,
"Get total number of specific devices on this system"); "Get total number of specific devices on this system");
using mgb::imperative::Profiler; using mgb::imperative::ProfileEntry;
py::class_<Profiler>(m, "ProfilerImpl") py::class_<ProfileEntry>(m, "ProfileEntry")
.def_readwrite("op", &ProfileEntry::op)
.def_readwrite("host", &ProfileEntry::host)
.def_readwrite("device_list", &ProfileEntry::device_list);
py::class_<mgb::imperative::Profiler>(m, "ProfilerImpl")
.def(py::init<>()) .def(py::init<>())
.def(py::init<const std::string&>()) .def("start",
.def("enable", [](mgb::imperative::Profiler& profiler) { profiler.start(); })
[](Profiler& profiler) -> Profiler& { .def("stop",
profiler.enable(); [](mgb::imperative::Profiler& profiler) { profiler.stop(); })
return profiler; .def("dump", [](mgb::imperative::Profiler& profiler) {
}) return profiler.get_profile();
.def("disable", });
[](Profiler& profiler) {
if (profiler.get_dump_count() == 0) {
profiler.dump();
}
profiler.disable();
})
.def("dump",
[](Profiler& profiler, std::optional<std::string> path) {
if (path.has_value()) {
profiler.dump(path.value());
} else {
profiler.dump();
}
},
py::arg("path") = std::optional<std::string>());
using mgb::imperative::TensorSanityCheck; using mgb::imperative::TensorSanityCheck;
py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl") py::class_<TensorSanityCheck>(m, "TensorSanityCheckImpl")
......
#include "./event_pool.h"
namespace mgb {
namespace imperative {
EventPool::EventPool(size_t flags) : m_flags{flags} {}
EventPool& EventPool::with_timer() {
static Spinlock lock;
static std::unique_ptr<EventPool> ptr;
MGB_LOCK_GUARD(lock);
if (!ptr || ptr->is_finalized()) {
ptr.reset(new EventPool(CompNode::Event::NEED_TIMER));
}
return *ptr;
}
EventPool& EventPool::without_timer() {
static Spinlock lock;
static std::unique_ptr<EventPool> ptr;
MGB_LOCK_GUARD(lock);
if (!ptr || ptr->is_finalized()) {
ptr.reset(new EventPool());
}
return *ptr;
}
CompNode::Event* EventPool::alloc(CompNode cn) {
CompNode::EventPool* pool;
{
MGB_LOCK_GUARD(m_lock);
auto iter = m_cn2pool.find(cn);
if (iter == m_cn2pool.end()) {
iter = m_cn2pool
.emplace(std::piecewise_construct,
std::forward_as_tuple(cn),
std::forward_as_tuple(cn, m_flags))
.first;
}
pool = &iter->second;
}
return pool->alloc();
}
std::shared_ptr<CompNode::Event> EventPool::alloc_shared(CompNode cn) {
auto* raw_event = alloc(cn);
return {raw_event, [this](CompNode::Event* event){ this->free(event); }};
}
void EventPool::free(CompNode::Event* event) {
CompNode::EventPool* pool;
{
MGB_LOCK_GUARD(m_lock);
pool = &m_cn2pool.at(event->comp_node());
}
pool->free(event);
}
std::shared_ptr<void> EventPool::on_comp_node_finalize() {
MGB_LOCK_GUARD(m_lock);
for (auto&& i : m_cn2pool) {
i.second.assert_all_freed();
}
return {};
}
EventPool::~EventPool() {
for (auto&& i : m_cn2pool) {
i.second.assert_all_freed();
}
}
} // namespace imperative
} // namespace mgb
#pragma once
#include "megbrain/comp_node.h"
namespace mgb {
namespace imperative {
class EventPool : CompNodeDepedentObject {
CompNode::UnorderedMap<CompNode::EventPool> m_cn2pool;
Spinlock m_lock;
size_t m_flags;
EventPool(size_t flags = 0);
public:
static EventPool& with_timer();
static EventPool& without_timer();
CompNode::Event* alloc(CompNode cn);
std::shared_ptr<CompNode::Event> alloc_shared(CompNode cn);
void free(CompNode::Event* event);
std::shared_ptr<void> on_comp_node_finalize();
~EventPool();
};
} // namespace imperative
} // namespace mgb
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "megbrain/imperative.h" #include "megbrain/imperative.h"
#include "megbrain/imperative/blob_manager.h" #include "megbrain/imperative/blob_manager.h"
#include "./event_pool.h"
#include <mutex> #include <mutex>
namespace mgb { namespace mgb {
...@@ -18,86 +19,31 @@ namespace imperative { ...@@ -18,86 +19,31 @@ namespace imperative {
namespace { namespace {
class EventPool : CompNodeDepedentObject {
CompNode::UnorderedMap<CompNode::EventPool> m_cn2pool;
Spinlock m_lock;
EventPool() = default;
public:
static EventPool& inst() {
static Spinlock lock;
static std::unique_ptr<EventPool> ptr;
MGB_LOCK_GUARD(lock);
if (!ptr || ptr->is_finalized()) {
ptr.reset(new EventPool());
}
return *ptr;
}
CompNode::Event* alloc(CompNode cn) {
CompNode::EventPool *pool;
{
MGB_LOCK_GUARD(m_lock);
auto iter = m_cn2pool.find(cn);
if (iter == m_cn2pool.end()) {
iter = m_cn2pool.emplace(
std::piecewise_construct,
std::forward_as_tuple(cn),
std::forward_as_tuple(cn)).first;
}
pool = &iter->second;
}
return pool->alloc();
}
void free(CompNode::Event* event) {
CompNode::EventPool* pool;
{
MGB_LOCK_GUARD(m_lock);
pool = &m_cn2pool.at(event->comp_node());
}
pool->free(event);
}
std::shared_ptr<void> on_comp_node_finalize() override {
MGB_LOCK_GUARD(m_lock);
for (auto&& i : m_cn2pool) {
i.second.assert_all_freed();
}
return {};
}
~EventPool() {
for (auto&& i : m_cn2pool) {
i.second.assert_all_freed();
}
}
};
class AsyncReleaser : public CompNodeDepedentObject { class AsyncReleaser : public CompNodeDepedentObject {
struct WaiterParam { struct WaiterParam {
CompNode cn; CompNode cn;
CompNode::Event *event; CompNode::Event* event;
BlobPtr blob; BlobPtr blob;
HostTensorStorage::RawStorage storage; HostTensorStorage::RawStorage storage;
}; };
class Waiter final: public AsyncQueueSC<WaiterParam, Waiter> { class Waiter final : public AsyncQueueSC<WaiterParam, Waiter> {
AsyncReleaser *m_par_releaser; AsyncReleaser* m_par_releaser;
public: public:
Waiter(AsyncReleaser *releaser): Waiter(AsyncReleaser* releaser) : m_par_releaser(releaser) {}
m_par_releaser(releaser)
{ void process_one_task(WaiterParam& param) {
if (param.event->finished()) {
param.blob.reset();
param.storage.reset();
EventPool::without_timer().free(param.event);
return;
} }
void process_one_task(WaiterParam &param) { using namespace std::literals;
if (param.event->finished()) { std::this_thread::sleep_for(1us);
param.blob.reset(); add_task(std::move(param));
param.storage.reset(); }
EventPool::inst().free(param.event);
return;
}
using namespace std::literals;
std::this_thread::sleep_for(1us);
add_task(std::move(param));
}
}; };
Waiter m_waiter{this}; Waiter m_waiter{this};
...@@ -113,20 +59,17 @@ public: ...@@ -113,20 +59,17 @@ public:
return &releaser; return &releaser;
} }
~AsyncReleaser() { ~AsyncReleaser() { m_waiter.wait_task_queue_empty(); }
m_waiter.wait_task_queue_empty();
}
void add(BlobPtr blob, CompNode cn) { void add(BlobPtr blob, CompNode cn) { add(cn, std::move(blob), {}); }
add(cn, std::move(blob), {});
}
void add(const HostTensorND& hv) { void add(const HostTensorND& hv) {
add(hv.comp_node(), {}, hv.storage().raw_storage()); add(hv.comp_node(), {}, hv.storage().raw_storage());
} }
void add(CompNode cn, BlobPtr blob, HostTensorStorage::RawStorage storage = {}) { void add(CompNode cn, BlobPtr blob,
auto event = EventPool::inst().alloc(cn); HostTensorStorage::RawStorage storage = {}) {
auto event = EventPool::without_timer().alloc(cn);
event->record(); event->record();
m_waiter.add_task({cn, event, std::move(blob), std::move(storage)}); m_waiter.add_task({cn, event, std::move(blob), std::move(storage)});
} }
...@@ -290,10 +233,10 @@ struct MultiCNConstTensorCache : CompNodeDepedentObject { ...@@ -290,10 +233,10 @@ struct MultiCNConstTensorCache : CompNodeDepedentObject {
MultiCNConstTensorCache const_tensor_cache; MultiCNConstTensorCache const_tensor_cache;
} // namespace } // namespace
void EventDeleter::operator()(CompNode::Event* event) { void EventDeleter::operator()(CompNode::Event* event) {
EventPool::inst().free(event); EventPool::without_timer().free(event);
} }
Blob::Blob(const DeviceTensorStorage& s): Blob::Blob(const DeviceTensorStorage& s):
...@@ -373,7 +316,7 @@ void Tensor::fetch_value() { ...@@ -373,7 +316,7 @@ void Tensor::fetch_value() {
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_mtx);
if (m_value.empty()) { if (m_value.empty()) {
m_value.copy_from(dev_tensor()); m_value.copy_from(dev_tensor());
m_value_ready.reset(EventPool::inst().alloc(comp_node())); m_value_ready.reset(EventPool::without_timer().alloc(comp_node()));
m_value_ready->record(); m_value_ready->record();
} }
} }
...@@ -421,7 +364,7 @@ CompNode::Event* Tensor::get_or_create_event() { ...@@ -421,7 +364,7 @@ CompNode::Event* Tensor::get_or_create_event() {
return e; return e;
} }
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -11,63 +11,18 @@ ...@@ -11,63 +11,18 @@
#include "megbrain/imperative/profiler.h" #include "megbrain/imperative/profiler.h"
#if defined(_MSC_VER) || defined(WIN32)
#include <windows.h>
#define getpid GetCurrentProcessId
#else
#include <sys/unistd.h>
#endif
#if defined(__APPLE__) || defined(__MACOSX)
#include <unistd.h>
#endif
#include <variant> #include <variant>
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/physical_tensor.h" #include "megbrain/imperative/physical_tensor.h"
#include "./event_pool.h"
#include "./op_trait.h" #include "./op_trait.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
class OpDefInfo{
public:
size_t id;
std::string name;
};
class ProfilerEntry {
public:
ProfilerEntry(size_t index, Profiler::EventKind type, std::unique_ptr<CompNode::Event> device)
: index{index}, type{type}, device{std::move(device)}{
}
ProfilerEntry(size_t index, Profiler::EventKind type, double host): index{index}, type{type}, host{host}{
}
size_t index;
Profiler::EventKind type;
std::unique_ptr<CompNode::Event> device = nullptr;
double host = 0;
};
class ProfilerPrivate {
public:
std::vector<OpDefInfo> op_list;
std::vector<ProfilerEntry> entry_list;
std::vector<std::unique_ptr<CompNode::Event>> event_list;
std::vector<std::tuple<OpTrait*, std::unique_ptr<ApplyOnPhysicalTensor>>>
hook_list;
ThinHashMap<CompNode, std::tuple<CompNode::Event*, double>>
comp_node_begin_map;
ThinHashMap<CompNode, CompNode::Event*> comp_node_end_map;
RealTimer timer;
size_t dump_count = 0;
bool enabled = false;
std::string path;
};
namespace { namespace {
CompNode::UnorderedSet collect_comp_nodes( CompNode::UnorderedSet collect_comp_nodes(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
...@@ -80,145 +35,65 @@ CompNode::UnorderedSet collect_comp_nodes( ...@@ -80,145 +35,65 @@ CompNode::UnorderedSet collect_comp_nodes(
} }
return comp_nodes; return comp_nodes;
} }
} // namespace
std::unique_ptr<CompNode::Event> Profiler::create_event(CompNode comp_node){
auto event = comp_node.create_event(CompNode::Event::NEED_TIMER);
event->record();
auto& [begin, time] = m_private->comp_node_begin_map[comp_node];
if (begin == nullptr) {
begin = event.get();
time = m_private->timer.get_msecs();
}
return event;
}
double Profiler::get_host_time_now(){
return m_private->timer.get_msecs();
}
double Profiler::get_device_time(CompNode::Event& event) {
auto [base_event, host_time] =
m_private->comp_node_begin_map[event.comp_node()];
if (base_event == &event) {
return host_time;
} else {
return host_time + base_event->elapsed_time_until(event) * 1000;
}
}
size_t Profiler::get_dump_count(){ } // namespace
return m_private->dump_count;
}
Profiler::Profiler() {
m_private = std::make_unique<ProfilerPrivate>();
}
Profiler::Profiler(const std::string& path): Profiler() {
m_private->path = path;
}
void Profiler::enable() { void DeviceTimer::reset(thin_function<double()> host_timer) {
m_private->enabled = true; CompNode::foreach ([this, host_timer](CompNode device) {
CompNode::sync_all(); auto base_event = EventPool::with_timer().alloc_shared(device);
OpTrait::for_each_trait([this](OpTrait& trait) { base_event->record();
auto backup = std::make_unique<ApplyOnPhysicalTensor>( m_base_event_table[device] = {std::move(base_event), host_timer()};
std::move(trait.apply_on_physical_tensor));
trait.apply_on_physical_tensor =
[this, backup = backup.get()] (
const OpDef& def,
const SmallVector<TensorPtr>& inputs){
size_t index = m_private->op_list.size();
std::string name = "[" + std::to_string(index) + "]" + print_op(def);
m_private->op_list.push_back({reinterpret_cast<size_t>(&def), name});
m_private->entry_list.emplace_back(index, OprBegin, get_host_time_now());
auto&& comp_nodes = collect_comp_nodes(def, inputs);
for (auto&& comp_node : comp_nodes) {
m_private->entry_list.emplace_back(index, OprBegin, create_event(comp_node));
}
auto output = (*backup)(def, inputs);
for (auto&& comp_node : comp_nodes) {
m_private->entry_list.emplace_back(index, OprEnd, create_event(comp_node));
}
m_private->entry_list.emplace_back(index, OprEnd, get_host_time_now());
return output;
};
m_private->hook_list.push_back({&trait, std::move(backup)});
}); });
} }
void Profiler::disable() { thin_function<double()> DeviceTimer::get_device_time(CompNode device) {
for (auto&& hook : m_private->hook_list) { auto event = EventPool::with_timer().alloc_shared(device);
std::get<0>(hook)->apply_on_physical_tensor = event->record();
std::move(*std::get<1>(hook)); auto base = m_base_event_table[device];
} return [base, event] {
m_private->hook_list.clear(); auto [base_event, host_time] = base;
m_private->enabled = false; //TODO: sync once for each compnode
} event->host_wait();
return base_event->elapsed_time_until(*event) * 1000 + host_time;
Profiler::~Profiler() { };
}
void Profiler::dump(){
dump(m_private->path);
} }
void Profiler::dump(const std::string& path) { void Profiler::start() {
using namespace json; m_host_timer.reset();
auto obj = json::Object::make(); m_device_timer.reset([&]{ return m_host_timer.get_msecs();} );
if (!(*obj)["traceEvents"]) { OpTrait::for_each_trait([this](OpTrait& trait) {
(*obj)["traceEvents"] = Array::make(); FunctionHooker hooker{&trait.apply_on_physical_tensor};
} hooker.apply_hook([this](auto&& apply, const OpDef& def,
auto& trace_events = (*obj)["traceEvents"]->cast_final<Array>(); const SmallVector<TensorPtr>& inputs) {
for (auto&& entry : m_private->entry_list) { ProfileEntry entry;
auto trace_event_ptr = Object::make(); entry.op = def.copy();
auto& trace_event = *trace_event_ptr; double host_begin = m_host_timer.get_msecs();
std::string name; auto&& comp_nodes = collect_comp_nodes(def, inputs);
size_t id; for (auto&& comp_node : comp_nodes) {
int pid; entry.device_list.push_back(
std::string tid; {comp_node,
double ts; m_device_timer.get_device_time(comp_node),
const char* ph; {}});
name = m_private->op_list[entry.index].name;
id = entry.index;
pid = getpid();
if (entry.device) {
entry.device->host_wait();
ts = get_device_time(*entry.device);
tid = entry.device->comp_node().to_string();
} else {
ts = entry.host;
tid = "host";
}
switch (entry.type) {
case OprBegin: {
ph = "B";
break;
} }
case OprEnd: { auto outputs = apply(def, inputs);
ph = "E"; for (auto& [cn, dev_begin, dev_end] : entry.device_list) {
break; MGB_MARK_USED_VAR(cn);
MGB_MARK_USED_VAR(dev_begin);
dev_end = m_device_timer.get_device_time(cn);
} }
} entry.host = {host_begin, m_host_timer.get_msecs()};
trace_event["name"] = String::make(name); m_profile->push_back(std::move(entry));
trace_event["id"] = Number::make(id); return outputs;
trace_event["pid"] = Number::make(pid); });
trace_event["tid"] = String::make(tid); m_hooker_list.push_back(std::move(hooker));
trace_event["ts"] = Number::make(ts * 1000); });
trace_event["ph"] = String::make(ph);
trace_events.add(std::move(trace_event_ptr));
}
obj->writeto_fpath(path.empty() ? path : m_private->path);
m_private->dump_count++;
} }
std::string Profiler::print_op(const OpDef& def){ void Profiler::stop() {
auto* opr_attr = def.try_cast_final<const OprAttr>(); m_hooker_list.clear();
if(opr_attr){ for (auto& entry : *m_profile) {
return std::string("OprAttr:") + opr_attr->type; entry.wait_device();
} }
return def.dyn_typeinfo()->name;
} }
} // namespace imperative } // namespace imperative
......
/**
* \file imperative/src/include/megbrain/imperative/function_hook.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/utils/thin/function.h"
namespace mgb {
namespace imperative {
template <typename TFunction>
class FunctionHooker;
template <typename TRet, typename... TArgs>
class FunctionHooker<TRet(TArgs...)> {
public:
using FunctionType = thin_function<TRet(TArgs&&...)>;
using HookType = thin_function<TRet(FunctionType, TArgs&&...)>;
explicit FunctionHooker(FunctionType* fptr) : m_fptr{fptr} {}
public:
FunctionHooker& apply_hook(HookType&& hook) {
if (!m_backup) {
FunctionType* backup = new FunctionType(*m_fptr);
std::function<void(FunctionType*)> restorer =
[fptr = m_fptr](FunctionType* bkp) -> void {
*fptr = *bkp;
delete bkp;
};
m_backup = decltype(m_backup)(backup, restorer);
}
*m_fptr = [func = *m_fptr, hook](TArgs&&... args) -> TRet {
return hook(func, std::forward<TArgs>(args)...);
};
return *this;
}
private:
FunctionType* m_fptr;
std::unique_ptr<FunctionType, std::function<void(FunctionType*)>> m_backup;
};
template <typename TRet, typename... TArgs>
FunctionHooker(thin_function<TRet(TArgs...)>* f)
->FunctionHooker<TRet(TArgs...)>;
} // namespace imperative
} // namespace mgb
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#pragma once #pragma once
#include <variant>
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megbrain/graph/event.h" #include "megbrain/graph/event.h"
#include "megbrain/utils/json.h" #include "megbrain/utils/json.h"
...@@ -18,37 +20,59 @@ ...@@ -18,37 +20,59 @@
#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/function_hook.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
class ProfilerPrivate; struct ProfileEntry{
using TimeClosure = std::function<double()>;
std::shared_ptr<OpDef> op;
std::tuple<double, double> host;
std::vector<std::tuple<CompNode, TimeClosure, TimeClosure>> device_list;
void wait_device(){
for(auto& [cn, begin, end]: device_list){
MGB_MARK_USED_VAR(cn);
begin = [begin=begin()]{ return begin; };
end = [end = end()]{ return end; };
}
}
};
using Profile = std::vector<ProfileEntry>;
using OpDefPrinter = thin_function<std::string(const OpDef&)>; class DeviceTimer {
public:
using SharedEvent = std::shared_ptr<CompNode::Event>;
DeviceTimer() = default;
void reset(thin_function<double()> host_timer);
thin_function<double()> get_device_time(CompNode device);
class Profiler {
private: private:
std::unique_ptr<ProfilerPrivate> m_private; CompNode::UnorderedMap<std::tuple<SharedEvent, double>> m_base_event_table;
};
class Profiler {
public: public:
enum EventKind { OprBegin, OprEnd }; Profiler(Profile* profile = nullptr) {
if (!profile) {
m_owned_profile = std::make_unique<Profile>();
profile = m_owned_profile.get();
}
m_profile = profile;
}
void start();
void stop();
Profile& get_profile() { return *m_profile; }
public: private:
Profiler(); DeviceTimer m_device_timer;
Profiler(const std::string& path); RealTimer m_host_timer;
~Profiler(); Profile* m_profile;
void enable(); std::unique_ptr<Profile> m_owned_profile;
void disable(); std::vector<FunctionHooker<decltype(OpDef::apply_on_physical_tensor)>>
void dump(); m_hooker_list;
void dump(const std::string& path);
void record_host(size_t id, std::string name, EventKind type,
double host_time);
void record_device(size_t id, std::string name, EventKind type,
double host_time, CompNode comp_node);
double get_device_time(CompNode::Event& event);
size_t get_dump_count();
std::unique_ptr<CompNode::Event> create_event(CompNode comp_node);
double get_host_time_now();
std::string print_op(const OpDef& def);
}; };
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
...@@ -89,8 +89,8 @@ namespace { ...@@ -89,8 +89,8 @@ namespace {
/* ==================== EventPool ==================== */ /* ==================== EventPool ==================== */
CompNode::EventPool::EventPool(CompNode cn): CompNode::EventPool::EventPool(CompNode cn, size_t flags):
m_cn{cn} m_cn{cn}, m_flags{flags}
{ {
} }
...@@ -105,7 +105,7 @@ CompNode::Event* CompNode::EventPool::alloc() { ...@@ -105,7 +105,7 @@ CompNode::Event* CompNode::EventPool::alloc() {
m_free.pop_back(); m_free.pop_back();
return rst; return rst;
} }
m_allocated.push_back(m_cn.create_event()); m_allocated.push_back(m_cn.create_event(m_flags));
return m_allocated.back().get(); return m_allocated.back().get();
} }
......
...@@ -643,9 +643,10 @@ class CompNode::EventPool { ...@@ -643,9 +643,10 @@ class CompNode::EventPool {
std::vector<std::unique_ptr<CompNode::Event>> m_allocated; std::vector<std::unique_ptr<CompNode::Event>> m_allocated;
std::vector<CompNode::Event*> m_free; std::vector<CompNode::Event*> m_free;
Spinlock m_lock; Spinlock m_lock;
size_t m_flags;
public: public:
explicit EventPool(CompNode cn); explicit EventPool(CompNode cn, size_t flags = 0);
~EventPool(); ~EventPool();
CompNode::Event* alloc(); CompNode::Event* alloc();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册