diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 3ac8fbbd880581928933e72c36686becfd99ebd6..71974d6db605000de4f308d84d273eedb50be603 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -2,12 +2,14 @@ import collections import contextlib import functools import itertools +import json import typing import warnings import weakref import numpy as np +from ..core._imperative_rt import GraphProfiler from ..core.ops.special import Const from ..core.tensor import megbrain_graph as G from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply @@ -85,11 +87,14 @@ class trace: symbolic=False, capture_as_const=False, sublinear_memory_config: SublinearMemoryConfig = None, + profiling: bool = False, ): self.__wrapped__ = function self._symbolic = symbolic self._capture_as_const = capture_as_const self._sublinear_memory_config = sublinear_memory_config + self._profiling = profiling + self._profiler = None self._untraced = True self._tinfo = [] # handle -> TensorInfo @@ -308,6 +313,8 @@ class trace: ) sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try sublinear_config.num_worker = self._sublinear_memory_config.num_worker + if self._profiling: + self._profiler = GraphProfiler(graph) def _compile(self): graph = self._graph = G.Graph() @@ -581,6 +588,16 @@ class trace: % (output_names and output_names[i] or i) ) + def get_profile(self): + """ + Get profiling result for compiled trace. + + :return: a json compatible object. + """ + if not self._profiler: + raise RuntimeError("trace is not set with profiling=True") + return json.loads(self._profiler.get()) + class CompiledTensorProxy(RawTensor): """ diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index f827466f44f07e1967ffec9c25c80bed92899998..deb29bba1704c6f797e43ac3b2f71e83788ef15c 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -11,18 +11,38 @@ #include "./graph_rt.h" +#include "megbrain/graph/cg.h" #include "megbrain/serialization/serializer.h" #include "megbrain/imperative/opr_utility.h" #include "megbrain/opr/io.h" #include "megbrain/opr/basic_arith.h" #include "megbrain/imperative.h" #include "./helper.h" +#include "megbrain/plugin/profiler.h" namespace py = pybind11; using namespace mgb; using namespace imperative; +namespace { +class _CompGraphProfilerImpl { + std::shared_ptr m_comp_graph; + GraphProfiler m_profiler; + public: + _CompGraphProfilerImpl(std::shared_ptr cg): + m_comp_graph{cg}, + m_profiler{m_comp_graph.get()} + { + } + + std::string _get_result() { + auto json = m_profiler.to_json_full( + m_comp_graph->current_comp_seq()); + return json->to_string(); + } +}; +} #define DEF_READWRITE(name) .def_readwrite(#name, &CURRENT_CLASS::name) template @@ -102,6 +122,12 @@ void init_graph_rt(py::module m) { }) .def_property_readonly("options", py::overload_cast<>(&cg::ComputingGraph::options)); + py::class_<_CompGraphProfilerImpl, std::shared_ptr<_CompGraphProfilerImpl>>(m, "GraphProfiler") + .def(py::init([](std::shared_ptr graph) { + return std::make_shared<_CompGraphProfilerImpl>(graph); + })) + .def("get", [](_CompGraphProfilerImpl& profiler) { return profiler._get_result(); }); + m.def("dump_graph", [](const std::vector& dest_vars) { using namespace mgb::serialization; std::vector buf; diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index b7a6dacd657564513710003826e4f80e234eae32..45202c4e79cca825331caad971876a30f6505297 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -82,3 +82,22 @@ def test_dump(): file = io.BytesIO() f.dump(file) + + +def test_trace_profiler(): + for symbolic in [False, True]: + + @trace(symbolic=symbolic, profiling=True) + def f(x): + op = ops.Elemwise(mode="negate") + (y,) = apply(op, x) + return y + + x = as_raw_tensor([1]).numpy() + y = f.__wrapped__(as_raw_tensor(x)).numpy() + + f(as_raw_tensor(x)) + f(as_raw_tensor(x)) # XXX: has to run twice + + out = f.get_profile() + assert out.get("profiler")