提交 403a1e7b 编写于 作者: M Megvii Engine Team

feat(mge/imperative): add graph profiler

GitOrigin-RevId: c73563f33787916a83454dcfecbc6d571e2e7e79
上级 d06f248d
......@@ -2,12 +2,14 @@ import collections
import contextlib
import functools
import itertools
import json
import typing
import warnings
import weakref
import numpy as np
from ..core._imperative_rt import GraphProfiler
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
......@@ -85,11 +87,14 @@ class trace:
symbolic=False,
capture_as_const=False,
sublinear_memory_config: SublinearMemoryConfig = None,
profiling: bool = False,
):
self.__wrapped__ = function
self._symbolic = symbolic
self._capture_as_const = capture_as_const
self._sublinear_memory_config = sublinear_memory_config
self._profiling = profiling
self._profiler = None
self._untraced = True
self._tinfo = [] # handle -> TensorInfo
......@@ -308,6 +313,8 @@ class trace:
)
sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
sublinear_config.num_worker = self._sublinear_memory_config.num_worker
if self._profiling:
self._profiler = GraphProfiler(graph)
def _compile(self):
graph = self._graph = G.Graph()
......@@ -581,6 +588,16 @@ class trace:
% (output_names and output_names[i] or i)
)
def get_profile(self):
"""
Get profiling result for compiled trace.
:return: a json compatible object.
"""
if not self._profiler:
raise RuntimeError("trace is not set with profiling=True")
return json.loads(self._profiler.get())
class CompiledTensorProxy(RawTensor):
"""
......
......@@ -11,18 +11,38 @@
#include "./graph_rt.h"
#include "megbrain/graph/cg.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/imperative.h"
#include "./helper.h"
#include "megbrain/plugin/profiler.h"
namespace py = pybind11;
using namespace mgb;
using namespace imperative;
namespace {
class _CompGraphProfilerImpl {
std::shared_ptr<ComputingGraph> m_comp_graph;
GraphProfiler m_profiler;
public:
_CompGraphProfilerImpl(std::shared_ptr<ComputingGraph> cg):
m_comp_graph{cg},
m_profiler{m_comp_graph.get()}
{
}
std::string _get_result() {
auto json = m_profiler.to_json_full(
m_comp_graph->current_comp_seq());
return json->to_string();
}
};
}
#define DEF_READWRITE(name) .def_readwrite(#name, &CURRENT_CLASS::name)
template<typename T>
......@@ -102,6 +122,12 @@ void init_graph_rt(py::module m) {
})
.def_property_readonly("options", py::overload_cast<>(&cg::ComputingGraph::options));
py::class_<_CompGraphProfilerImpl, std::shared_ptr<_CompGraphProfilerImpl>>(m, "GraphProfiler")
.def(py::init([](std::shared_ptr<ComputingGraph> graph) {
return std::make_shared<_CompGraphProfilerImpl>(graph);
}))
.def("get", [](_CompGraphProfilerImpl& profiler) { return profiler._get_result(); });
m.def("dump_graph", [](const std::vector<VarNode*>& dest_vars) {
using namespace mgb::serialization;
std::vector<uint8_t> buf;
......
......@@ -82,3 +82,22 @@ def test_dump():
file = io.BytesIO()
f.dump(file)
def test_trace_profiler():
for symbolic in [False, True]:
@trace(symbolic=symbolic, profiling=True)
def f(x):
op = ops.Elemwise(mode="negate")
(y,) = apply(op, x)
return y
x = as_raw_tensor([1]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()
f(as_raw_tensor(x))
f(as_raw_tensor(x)) # XXX: has to run twice
out = f.get_profile()
assert out.get("profiler")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册