提交 ffe7ceb3 编写于 作者: M Megvii Engine Team

feat(mge): support get python backtrace

GitOrigin-RevId: 52a5e406ca702a10e8da6992f34413606f8808e5
上级 bf1a0fb7
from ..core._imperative_rt.core2 import pop_scope, push_scope import sys
from ..core._imperative_rt.core2 import pop_scope, push_scope, record_scope
class AutoNaming: class AutoNaming:
...@@ -21,6 +23,7 @@ class AutoNaming: ...@@ -21,6 +23,7 @@ class AutoNaming:
def push_scope(cls, scope): def push_scope(cls, scope):
if scope is not None: if scope is not None:
push_scope(scope) push_scope(scope)
record_scope(sys._getframe().f_back.f_back, scope)
cls.scopes.append(scope) cls.scopes.append(scope)
@classmethod @classmethod
......
...@@ -15,6 +15,7 @@ from ..core._imperative_rt.core2 import ( ...@@ -15,6 +15,7 @@ from ..core._imperative_rt.core2 import (
full_sync, full_sync,
pop_scope, pop_scope,
push_scope, push_scope,
set_python_backtrace_enabled,
start_profile, start_profile,
stop_profile, stop_profile,
sync, sync,
...@@ -30,6 +31,7 @@ class Profiler(ContextDecorator): ...@@ -30,6 +31,7 @@ class Profiler(ContextDecorator):
Args: Args:
path: default path prefix for profiler to dump. path: default path prefix for profiler to dump.
with_backtrace: Whether to record backtrace information for ops.
Examples: Examples:
...@@ -66,6 +68,7 @@ class Profiler(ContextDecorator): ...@@ -66,6 +68,7 @@ class Profiler(ContextDecorator):
path: str = "profile", path: str = "profile",
format: str = "chrome_timeline.json", format: str = "chrome_timeline.json",
formats: List[str] = None, formats: List[str] = None,
with_backtrace: bool = False,
**kwargs **kwargs
) -> None: ) -> None:
if not formats: if not formats:
...@@ -90,6 +93,7 @@ class Profiler(ContextDecorator): ...@@ -90,6 +93,7 @@ class Profiler(ContextDecorator):
enable_cupti() enable_cupti()
else: else:
get_logger().warning("CuPTI unavailable") get_logger().warning("CuPTI unavailable")
self.with_backtrace = with_backtrace
@property @property
def path(self): def path(self):
...@@ -116,6 +120,7 @@ class Profiler(ContextDecorator): ...@@ -116,6 +120,7 @@ class Profiler(ContextDecorator):
_running_profiler = self _running_profiler = self
self._pid = os.getpid() self._pid = os.getpid()
start_profile(self._options) start_profile(self._options)
self._origin_enable_bt = set_python_backtrace_enabled(self.with_backtrace)
return self return self
def stop(self): def stop(self):
...@@ -127,6 +132,7 @@ class Profiler(ContextDecorator): ...@@ -127,6 +132,7 @@ class Profiler(ContextDecorator):
self._dump_callback = stop_profile() self._dump_callback = stop_profile()
self._pid = os.getpid() self._pid = os.getpid()
_living_profilers.add(self) _living_profilers.add(self)
set_python_backtrace_enabled(self._origin_enable_bt)
def dump(self): def dump(self):
if self._dump_callback is not None: if self._dump_callback is not None:
......
#include "./backtrace.h"
#include <cstdint>
#include "megbrain/common.h"
#include "megbrain/imperative/transformation.h"
namespace mgb::imperative::python {
static bool enable_py_bt = false;
static bool enable_trans_bt = false;
std::unordered_map<ptrdiff_t, PyObjRefKeeper> FrameInfo::code_ref_keeper = {};
std::pair<FrameInfoPtr, int> FrameInfo::make(
PyFrameObject* frame, FrameInfoCache* cache) {
if (frame == NULL) {
return std::make_pair(nullptr, -1);
}
auto* keywrapper = TraceKeyWrapper::try_cast(frame->f_trace);
auto key = keywrapper ? keywrapper->key : -1;
#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 10
int lineno = frame->f_lasti;
#else
int lineno = PyFrame_GetLineNumber(frame);
#endif
FrameInfoPtr cache_finfo;
if (key != -1 && cache != nullptr && key < cache->size() &&
(*cache)[key]->lineno == lineno) {
cache_finfo = (*cache)[key];
}
if (cache_finfo) {
return std::make_pair(cache_finfo, key);
} else {
PyCodeObject* code = frame->f_code;
FrameInfoPtr f = std::make_shared<FrameInfo>(code, lineno);
if (keywrapper) {
f->scope = keywrapper->scope;
}
return std::make_pair(f, -1);
}
};
std::string FrameInfo::traceback() {
FrameInfoPtr cur = shared_from_this();
std::list<FrameInfo*> frames;
while (cur) {
frames.push_front(cur.get());
cur = cur->prev_frame;
}
std::string logs;
for (auto&& f : frames) {
auto code = py::handle((PyObject*)f->code_obj);
#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 10
int lineno = PyCode_Addr2Line(f->code_obj, f->lineno);
#else
int lineno = f->lineno;
#endif
if (f->scope != "")
logs += "scope: <" + f->scope + ">\n";
py::object filename = py::getattr(code, "co_filename");
logs += py::str(filename);
logs += " , line ";
logs += std::to_string(lineno);
logs += ", in ";
logs += py::str(py::getattr(code, "co_name"));
logs += '\n';
}
return logs;
}
#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION > 6
static Py_tss_t tss_key = Py_tss_NEEDS_INIT;
#else
static int tss_key;
#endif
static bool support_tss = false;
void init_backtrace_tss_key() {
#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION > 6
int result = PyThread_tss_create(&tss_key);
support_tss = result == 0;
#else
tss_key = PyThread_create_key();
support_tss = tss_key != -1;
#endif
}
FrameInfoCache* FrameInfoCache::get_instance() {
mgb_assert(support_tss);
constexpr int max_cache_size = 10;
static FrameInfoCache caches[max_cache_size];
static uintptr_t tid = 1;
static std::list<std::pair<uintptr_t, FrameInfoCache*>> cache_list;
static std::unordered_map<uintptr_t, decltype(cache_list)::iterator> kv_map;
auto get_cache = [](uintptr_t key) {
if (kv_map.find(key) != kv_map.end()) {
auto it = kv_map[key];
auto rst = it->second;
if (it != cache_list.begin()) {
cache_list.push_front(*it);
cache_list.erase(it);
kv_map[key] = cache_list.begin();
}
return rst;
}
if (cache_list.size() < max_cache_size) {
auto* rst = &caches[key % max_cache_size];
cache_list.emplace_front(key, rst);
kv_map[key] = cache_list.begin();
return rst;
} else {
auto it = --cache_list.end();
auto empty_cache = *it;
cache_list.erase(it);
empty_cache.second->stack_cache.clear();
cache_list.push_front(empty_cache);
kv_map[key] = cache_list.begin();
return empty_cache.second;
}
};
#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION > 6
auto* id = PyThread_tss_get(&tss_key);
if (id == NULL) {
mgb_assert(PyThread_tss_set(&tss_key, (void*)tid) == 0);
return get_cache(tid++);
} else {
auto cache_tid = (uintptr_t)id;
return get_cache(cache_tid);
}
#else
auto* id = PyThread_get_key_value(tss_key);
if (id == NULL) {
mgb_assert(PyThread_set_key_value(tss_key, (void*)tid) == 0);
return get_cache(tid++);
} else {
auto cache_tid = (uintptr_t)id;
return get_cache(cache_tid);
}
#endif
}
void FrameInfoCache::update_cache(
int key,
const SmallVector<std::pair<PyFrameObject*, FrameInfoPtr>, 100>& frames) {
stack_cache.resize(key + frames.size() + 1);
auto it = frames.rbegin();
auto cur_key = key + 1;
for (; it != frames.rend(); it++, cur_key++) {
auto&& [frame, finfo] = *it;
stack_cache[cur_key] = finfo;
if (auto* key_ptr = TraceKeyWrapper::try_cast(frame->f_trace)) {
key_ptr->key = cur_key;
} else {
auto* py_key = TraceKeyWrapper::make(cur_key, frame->f_trace);
frame->f_trace = py_key;
}
}
}
int FrameInfoCache::get_frame_key(PyFrameObject* frame) {
auto* key = TraceKeyWrapper::try_cast(frame->f_trace);
if (key == nullptr) {
return -1;
} else {
return key->key;
}
}
FrameInfoPtr get_frameinfo_from_pyframe(PyFrameObject* frame) {
auto* cache = FrameInfoCache::get_instance();
auto&& [cur_info, key] = FrameInfo::make(frame, cache);
auto rst = cur_info;
SmallVector<std::pair<PyFrameObject*, FrameInfoPtr>, 100> frames;
py::object cur_frame = py::reinterpret_borrow<py::object>((PyObject*)frame);
while (key == -1) {
if (((PyFrameObject*)cur_frame.ptr())->f_gen == NULL)
frames.push_back({(PyFrameObject*)cur_frame.ptr(), cur_info});
auto prev_frame = py::getattr(py::handle(cur_frame), "f_back");
if (prev_frame.is_none())
break;
auto&& [prev_info, prev_key] =
FrameInfo::make((PyFrameObject*)prev_frame.ptr(), cache);
cur_info->prev_frame = prev_info;
cur_info = prev_info, key = prev_key;
cur_frame = std::move(prev_frame);
}
if (cache != nullptr)
cache->update_cache(key, frames);
return rst;
}
bool set_python_backtrace_enabled(bool enabled) {
std::swap(enable_py_bt, enabled);
return enabled;
}
bool set_transformation_backtrace_enabled(bool enabled) {
std::swap(enable_trans_bt, enabled);
return enabled;
}
void record_py_backtrace() {
auto& context = Transformation::get_context();
FrameInfoPtr info;
if (enable_py_bt) {
auto frame = PyEval_GetFrame();
info = get_frameinfo_from_pyframe(frame);
}
context.bt = std::make_shared<BackTraceInfo>(std::move(info));
context.record_bt_trans_id = context.next_transformation;
context.record_trans_bt = enable_trans_bt;
}
void record_scope(PyFrameObject* frame, std::string scope) {
if (enable_py_bt) {
frame->f_trace = TraceKeyWrapper::make(-1, frame->f_trace, std::move(scope));
}
}
std::string get_py_backtrace() {
auto frame = PyEval_GetFrame();
return get_frameinfo_from_pyframe(frame)->traceback();
}
} // namespace mgb::imperative::python
\ No newline at end of file
#pragma once
#include <Python.h>
#include <frameobject.h>
#include <cstdint>
#include <memory>
#include <string>
#include "./helper.h"
#include "./pyext17.h"
#include "megbrain/common.h"
#include "megbrain/imperative/backtrace.h"
#include "megbrain/utils/metahelper.h"
#include "megbrain/utils/small_vector.h"
#include "pybind11/pybind11.h"
namespace py = pybind11;
namespace mgb::imperative::python {
struct FrameInfoCache;
struct FrameInfo;
using FrameInfoPtr = std::shared_ptr<FrameInfo>;
struct FrameInfo : public PyFrameInfo, public std::enable_shared_from_this<FrameInfo> {
PyCodeObject* code_obj;
int lineno;
std::string scope;
std::shared_ptr<FrameInfo> prev_frame;
static std::unordered_map<ptrdiff_t, PyObjRefKeeper> code_ref_keeper;
FrameInfo(PyCodeObject* code, int lineno) : code_obj{code}, lineno{lineno} {
if (code_ref_keeper.find((ptrdiff_t)code_obj) == code_ref_keeper.end()) {
Py_INCREF(code);
code_ref_keeper[(ptrdiff_t)code_obj] = {(PyObject*)code_obj};
}
}
std::string traceback() override;
static std::pair<FrameInfoPtr, int> make(
PyFrameObject* frame, FrameInfoCache* cache);
};
struct FrameInfoCache {
std::vector<FrameInfoPtr> stack_cache;
void update_cache(
int key,
const SmallVector<std::pair<PyFrameObject*, FrameInfoPtr>, 100>& frames);
size_t size() { return stack_cache.size(); }
FrameInfoPtr& operator[](int key) { return stack_cache[key]; }
static int get_frame_key(PyFrameObject* frame);
static FrameInfoCache* get_instance();
};
struct TraceKeyWrapper {
int key;
std::string scope;
py::object orig_func;
TraceKeyWrapper(int key, PyObject* func, std::string scope = "")
: key{key}, scope{std::move(scope)} {
if (func != NULL) {
orig_func = py::reinterpret_steal<py::object>(func);
}
}
static constexpr auto tp_name = pybind11::detail::_("TraceKeyWrapper");
using wrap_t = pyext17::wrap<TraceKeyWrapper>;
friend wrap_t;
inline static TraceKeyWrapper* cast(PyObject* obj) {
return reinterpret_cast<wrap_t*>(obj)->inst();
}
inline static TraceKeyWrapper* try_cast(PyObject* obj) {
if (obj == NULL || !wrap_t::type().isinstance(obj))
return nullptr;
return cast(obj);
}
template <typename... Args>
static PyObject* make(Args&&... args) {
return wrap_t::cnew(std::forward<Args>(args)...);
}
PyObject* tp_call(PyObject* args, PyObject* kwargs) {
if (orig_func.ptr() != nullptr) {
return PyObject_Call(orig_func.ptr(), args, kwargs);
}
Py_RETURN_NONE;
}
};
FrameInfoPtr get_frameinfo_from_pyframe(PyFrameObject* frame);
void record_py_backtrace();
void record_scope(PyFrameObject*, std::string);
std::string get_py_backtrace();
bool set_python_backtrace_enabled(bool);
bool set_transformation_backtrace_enabled(bool);
void init_backtrace_tss_key();
} // namespace mgb::imperative::python
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/dtype.h" #include "megbrain/dtype.h"
#include "megbrain/imperative/backtrace.h"
#include "megbrain/imperative/cpp_cupti.h" #include "megbrain/imperative/cpp_cupti.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/profiler.h" #include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/transformations/dim_expansion.h" #include "megbrain/imperative/transformations/dim_expansion.h"
#include "megbrain/imperative/transformations/dtype_promote.h" #include "megbrain/imperative/transformations/dtype_promote.h"
#include "megbrain/imperative/transformations/eval.h" #include "megbrain/imperative/transformations/eval.h"
...@@ -42,6 +44,7 @@ ...@@ -42,6 +44,7 @@
#include <unordered_map> #include <unordered_map>
#include "../../src/impl/mgb_cg_impl.h" #include "../../src/impl/mgb_cg_impl.h"
#include "./backtrace.h"
namespace py = pybind11; namespace py = pybind11;
namespace views = ranges::views; namespace views = ranges::views;
...@@ -97,6 +100,7 @@ PyObject* py_apply( ...@@ -97,6 +100,7 @@ PyObject* py_apply(
} }
HostTensorND ht(target_cn); HostTensorND ht(target_cn);
ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype); ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
record_py_backtrace();
if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler if (PyArray_Check(args[i]) || PyList_Check(args[i])) { // non scaler
// py_tuple is not allowed here because of tracing // py_tuple is not allowed here because of tracing
return imperative::apply( return imperative::apply(
...@@ -125,7 +129,7 @@ PyObject* py_apply( ...@@ -125,7 +129,7 @@ PyObject* py_apply(
return nullptr; return nullptr;
} }
} }
record_py_backtrace();
auto outputs = [&] { return imperative::apply(*op, tensors); }(); auto outputs = [&] { return imperative::apply(*op, tensors); }();
size_t nout = outputs.size(); size_t nout = outputs.size();
auto ret = py::tuple(nout); auto ret = py::tuple(nout);
...@@ -137,6 +141,11 @@ PyObject* py_apply( ...@@ -137,6 +141,11 @@ PyObject* py_apply(
} }
PYEXT17_TRANSLATE_EXC_RET(nullptr) PYEXT17_TRANSLATE_EXC_RET(nullptr)
} }
FrameInfoPtr get_current_frameinfo() {
auto frame = PyEval_GetFrame();
auto frameinfo = get_frameinfo_from_pyframe(frame);
return frameinfo;
}
namespace { namespace {
...@@ -740,6 +749,7 @@ struct TensorWeakRef { ...@@ -740,6 +749,7 @@ struct TensorWeakRef {
auto size = PyTuple_GET_SIZE(args); \ auto size = PyTuple_GET_SIZE(args); \
return FUNC(self, arr, size); \ return FUNC(self, arr, size); \
} }
WRAP_FUNC_PY35(py_apply); WRAP_FUNC_PY35(py_apply);
WRAP_FUNC_PY35(dtype_promotion); WRAP_FUNC_PY35(dtype_promotion);
WRAP_FUNC_PY35(get_device); WRAP_FUNC_PY35(get_device);
...@@ -768,7 +778,7 @@ WRAP_FUNC_PY35(pixel_shuffle_cpp); ...@@ -768,7 +778,7 @@ WRAP_FUNC_PY35(pixel_shuffle_cpp);
void init_tensor(py::module m) { void init_tensor(py::module m) {
imperative::Tensor::static_initialize(); imperative::Tensor::static_initialize();
init_backtrace_tss_key();
// Transformations // Transformations
static auto& transformations = TransformationManager::get_instance(); static auto& transformations = TransformationManager::get_instance();
...@@ -866,6 +876,10 @@ void init_tensor(py::module m) { ...@@ -866,6 +876,10 @@ void init_tensor(py::module m) {
if (!tensor_type) if (!tensor_type)
throw py::error_already_set(); throw py::error_already_set();
py::setattr(m, "Tensor", tensor_type); py::setattr(m, "Tensor", tensor_type);
auto* tracekey_type = TraceKeyWrapper::wrap_t::type().finalize();
py::setattr(m, "tracekey", tracekey_type);
py::enum_<Format::Type>(m, "FormatType") py::enum_<Format::Type>(m, "FormatType")
.value("DEFAULT", Format::Type::DEFAULT) .value("DEFAULT", Format::Type::DEFAULT)
.value("NCHW", Format::Type::NCHW) .value("NCHW", Format::Type::NCHW)
...@@ -923,6 +937,10 @@ void init_tensor(py::module m) { ...@@ -923,6 +937,10 @@ void init_tensor(py::module m) {
Transformation::push_scope(name); Transformation::push_scope(name);
channel->push_scope(name); channel->push_scope(name);
}); });
m.def("record_scope", [](py::object frame, std::string name) {
mgb_assert(PyFrame_Check(frame.ptr()));
record_scope((PyFrameObject*)frame.ptr(), std::move(name));
});
m.def("pop_scope", [channel](std::string name) { m.def("pop_scope", [channel](std::string name) {
channel->pop_scope(name); channel->pop_scope(name);
Transformation::pop_scope(name); Transformation::pop_scope(name);
...@@ -1298,7 +1316,12 @@ void init_tensor(py::module m) { ...@@ -1298,7 +1316,12 @@ void init_tensor(py::module m) {
m.def("unset_module_tracing", [=] { get_module_trace()->disable(); }); m.def("unset_module_tracing", [=] { get_module_trace()->disable(); });
m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); }); m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); });
m.def("set_python_backtrace_enabled", &set_python_backtrace_enabled);
m.def("set_transformation_backtrace_enabled",
&set_transformation_backtrace_enabled);
m.def("_mge_backtrace", &get_py_backtrace);
m.def("_get_frame_cache_id",
[]() { return (size_t)FrameInfoCache::get_instance(); });
m.def("set_module_trace_hook", [](py::function function) { m.def("set_module_trace_hook", [](py::function function) {
module_trace_hook = function; module_trace_hook = function;
module_trace_hook.inc_ref(); module_trace_hook.inc_ref();
...@@ -1306,7 +1329,6 @@ void init_tensor(py::module m) { ...@@ -1306,7 +1329,6 @@ void init_tensor(py::module m) {
auto atexit = py::module::import("atexit"); auto atexit = py::module::import("atexit");
atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; })); atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; }));
m.def("begin_record_values", [] { Value::begin_record_values(); }); m.def("begin_record_values", [] { Value::begin_record_values(); });
m.def("end_record_values", [] { m.def("end_record_values", [] {
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
import os import os
import sys
import tempfile import tempfile
import threading
import pytest import pytest
from megengine import Parameter from megengine import Parameter
from megengine import distributed as dist from megengine import distributed as dist
from megengine import tensor from megengine import tensor
from megengine.core._imperative_rt.core2 import _get_frame_cache_id, _mge_backtrace
from megengine.jit import trace from megengine.jit import trace
from megengine.module import Module from megengine.module import Module
from megengine.utils.profiler import Profiler, scope from megengine.utils.profiler import Profiler, scope
...@@ -97,3 +100,43 @@ def test_profiler_dist(format, trace_mode): ...@@ -97,3 +100,43 @@ def test_profiler_dist(format, trace_mode):
assert os.path.exists(profile_path), "profiling results not found" assert os.path.exists(profile_path), "profiling results not found"
assert len(os.listdir(tempdir.name)) == n_gpus + 1 assert len(os.listdir(tempdir.name)) == n_gpus + 1
def test_backtrace():
bts = []
cur_frame = sys._getframe(0)
start_lineno = cur_frame.f_lineno
cur_frame.f_trace = lambda *_: 1
def gen():
for i in range(10):
bts.append(_mge_backtrace())
yield i
ge = gen()
next(ge)
next(ge)
def func():
return next(ge)
func()
bt_lines = []
for bt in bts:
lines = []
for path in bt.split("\n")[-3:-1]:
lines.append(int(path.split(",")[1].split()[1]))
bt_lines.append([i - start_lineno for i in lines])
# test if backtrace line numbers are correct
assert cur_frame.f_trace(0, 0, 0) == 1
assert bt_lines == [[9, 5], [10, 5], [13, 5]]
cache_id = []
def get_cache_id():
cache_id.append(_get_frame_cache_id())
threads = [threading.Thread(target=get_cache_id) for i in range(10)]
for t in threads:
t.start()
t.join()
assert len(set(cache_id)) == 10
...@@ -3,18 +3,39 @@ ...@@ -3,18 +3,39 @@
#include "megbrain/imperative/utils/debug.h" #include "megbrain/imperative/utils/debug.h"
#include "megbrain/imperative/utils/helper.h" #include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/map.h" #include "megbrain/imperative/utils/map.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
namespace { namespace {
ValueRefList apply_release(const Operator& op, Span<ValueRef> inputs) { ValueRefList apply_release(const Operator& op, Span<ValueRef> inputs) {
auto& context = Transformation::get_context(); auto& context = Transformation::get_context();
ValueRefList result;
size_t& depth = context.next_transformation; size_t& depth = context.next_transformation;
mgb_assert(depth < context.transformations.size()); mgb_assert(depth < context.transformations.size());
auto& transformation = *context.transformations[depth++]; auto& transformation = *context.transformations[depth++];
CleanupGuard _{[&] { --depth; }}; CleanupGuard _{[&] { --depth; }};
return transformation.apply_transformation(op, inputs); if (context.bt != nullptr && context.record_trans_bt) {
std::vector<std::string> types;
for (size_t i = 0; i < inputs.size(); i++) {
types.push_back(inputs[i].raw_type());
}
context.bt->trans_stack_info.push_back(TransformationCallInfo{
depth, op.raw_type(), transformation.name(),
TransformationCallInfo::get_op_attr(op), std::move(types)});
}
result = transformation.apply_transformation(op, inputs);
if (context.bt != nullptr && context.record_trans_bt) {
std::vector<std::string> types;
for (size_t i = 0; i < result.size(); i++) {
types.push_back(result[i].raw_type());
}
context.bt->trans_stack_info.push_back(
TransformationReturnInfo{depth, std::move(types)});
}
if (depth - 1 == context.record_bt_trans_id) {
context.bt = nullptr;
}
return result;
} }
MGB_NOINLINE ValueRefList apply_debug(const Operator& op, Span<ValueRef> inputs) { MGB_NOINLINE ValueRefList apply_debug(const Operator& op, Span<ValueRef> inputs) {
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <unordered_set> #include <unordered_set>
#include <variant> #include <variant>
#include "megbrain/imperative/backtrace.h"
#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/utils/to_string.h" #include "megbrain/imperative/utils/to_string.h"
#include "megbrain/tensor.h" #include "megbrain/tensor.h"
...@@ -39,6 +40,7 @@ struct ApplyOp { ...@@ -39,6 +40,7 @@ struct ApplyOp {
SmallVector<TensorInfo*> inputs; SmallVector<TensorInfo*> inputs;
SmallVector<TensorInfo*> outputs; SmallVector<TensorInfo*> outputs;
bool validated = false; bool validated = false;
BackTraceInfoPtr bt = nullptr;
template <typename TFunctor> template <typename TFunctor>
void get_props(TFunctor&& functor) const { void get_props(TFunctor&& functor) const {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
#include "../blob_manager_impl.h" #include "../blob_manager_impl.h"
#include "../event_pool.h" #include "../event_pool.h"
#include "../op_trait.h" #include "../op_trait.h"
#include "megbrain/imperative/backtrace.h"
using namespace mgb; using namespace mgb;
using namespace imperative; using namespace imperative;
using namespace interpreter; using namespace interpreter;
...@@ -328,12 +328,19 @@ void ChannelImpl::dispatch_default_cpu( ...@@ -328,12 +328,19 @@ void ChannelImpl::dispatch_default_cpu(
output_infos.push_back(info); output_infos.push_back(info);
outputs->push_back(reinterpret_cast<Handle>(info)); outputs->push_back(reinterpret_cast<Handle>(info));
} }
auto op_info_getter = [op] { auto& bt = get_backtrace();
auto op_info_getter = [op, bt] {
std::unordered_map<std::string, std::string> op_info; std::unordered_map<std::string, std::string> op_info;
auto props = OpDef::props(*op); auto props = OpDef::props(*op);
for (auto&& [key, value] : props) { for (auto&& [key, value] : props) {
op_info[key] = value; op_info[key] = value;
} }
if (bt != nullptr) {
if (bt->py_stack_info != nullptr)
op_info["python_backtrace"] = bt->py_traceback();
if (bt->trans_stack_info.size() > 0)
op_info["transformation_backtrace"] = bt->transformation_traceback();
}
return op_info; return op_info;
}; };
MGB_RECORD_EVENT( MGB_RECORD_EVENT(
...@@ -374,16 +381,23 @@ void ChannelImpl::dispatch_kernel( ...@@ -374,16 +381,23 @@ void ChannelImpl::dispatch_kernel(
output_infos.push_back(info); output_infos.push_back(info);
outputs->push_back(reinterpret_cast<Handle>(info)); outputs->push_back(reinterpret_cast<Handle>(info));
} }
ApplyOp cmd{ auto& bt = get_backtrace();
Profiler::next_id(), std::move(op), std::move(input_infos), ApplyOp cmd{Profiler::next_id(), std::move(op), std::move(input_infos),
std::move(output_infos), validated}; std::move(output_infos), validated, bt};
if (Profiler::is_profiling()) { if (Profiler::is_profiling()) {
auto op_info_getter = [op = cmd.op] { auto op_info_getter = [op = cmd.op, bt = cmd.bt] {
std::unordered_map<std::string, std::string> op_info; std::unordered_map<std::string, std::string> op_info;
auto props = OpDef::props(*op); auto props = OpDef::props(*op);
for (auto&& [key, value] : props) { for (auto&& [key, value] : props) {
op_info[key] = value; op_info[key] = value;
} }
if (bt != nullptr) {
if (bt->py_stack_info != nullptr)
op_info["python_backtrace"] = bt->py_traceback();
if (bt->trans_stack_info.size() > 0)
op_info["transformation_backtrace"] =
bt->transformation_traceback();
}
return op_info; return op_info;
}; };
MGB_RECORD_EVENT( MGB_RECORD_EVENT(
...@@ -1024,7 +1038,18 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { ...@@ -1024,7 +1038,18 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
mgb_assert(!m_waitee, "duplicate waitee"); mgb_assert(!m_waitee, "duplicate waitee");
m_waitee = info; m_waitee = info;
m_waitee_id = Profiler::next_id(); m_waitee_id = Profiler::next_id();
MGB_RECORD_EVENT(TensorWaitPropEvent, info->id, m_waitee_id, prop); auto backtrace_getter = [bt = get_backtrace()]() {
std::unordered_map<std::string, std::string> infos;
if (bt != nullptr) {
if (bt->py_stack_info != nullptr)
infos["python_backtrace"] = bt->py_traceback();
if (bt->trans_stack_info.size() > 0)
infos["transformation_backtrace"] = bt->transformation_traceback();
}
return infos;
};
MGB_RECORD_EVENT(
TensorWaitPropEvent, info->id, m_waitee_id, prop, backtrace_getter);
bool require_host = prop == TensorProp::HostValue; bool require_host = prop == TensorProp::HostValue;
bool require_dev = prop == TensorProp::DevValue; bool require_dev = prop == TensorProp::DevValue;
auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); }; auto host_available = [&] { return info->ptr && info->ptr->value_fetched(); };
...@@ -1073,7 +1098,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { ...@@ -1073,7 +1098,8 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) {
return require_host ? host_available() : static_cast<bool>(info->ptr); return require_host ? host_available() : static_cast<bool>(info->ptr);
}); });
} }
MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); MGB_RECORD_EVENT(
TensorWaitPropFinishEvent, info->id, m_waitee_id, prop, backtrace_getter);
m_waitee = nullptr; m_waitee = nullptr;
if (wait_host) { if (wait_host) {
auto err = info->ptr->comp_node().check_async_error(); auto err = info->ptr->comp_node().check_async_error();
...@@ -1446,6 +1472,18 @@ void ChannelImpl::pop_scope(std::string name) { ...@@ -1446,6 +1472,18 @@ void ChannelImpl::pop_scope(std::string name) {
} }
} }
BackTraceInfoPtr& ChannelImpl::get_backtrace() {
return m_bt;
}
void ChannelImpl::set_backtrace(BackTraceInfoPtr bt) {
m_bt = std::move(bt);
}
void ChannelImpl::clear_backtrace() {
m_bt = nullptr;
}
bool ChannelImpl::worker_started() const { bool ChannelImpl::worker_started() const {
return m_worker.worker_started(); return m_worker.worker_started();
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "./tensor_info.h" #include "./tensor_info.h"
#include "../profiler/events.h" #include "../profiler/events.h"
#include "megbrain/imperative/backtrace.h"
namespace mgb::imperative::interpreter::intl { namespace mgb::imperative::interpreter::intl {
...@@ -66,6 +67,10 @@ struct ChannelImpl : Interpreter::Channel, NonCopyableObj, NonMoveableObj { ...@@ -66,6 +67,10 @@ struct ChannelImpl : Interpreter::Channel, NonCopyableObj, NonMoveableObj {
void push_scope(std::string) override; void push_scope(std::string) override;
void pop_scope(std::string) override; void pop_scope(std::string) override;
BackTraceInfoPtr& get_backtrace() override;
void set_backtrace(BackTraceInfoPtr bt) override;
void clear_backtrace() override;
bool worker_started() const; bool worker_started() const;
void update_status_to_forked(void); void update_status_to_forked(void);
void assert_available() const; void assert_available() const;
...@@ -133,6 +138,7 @@ private: ...@@ -133,6 +138,7 @@ private:
MemPool<TensorInfo> m_pool; MemPool<TensorInfo> m_pool;
std::unordered_set<Handle> m_valid_handle; std::unordered_set<Handle> m_valid_handle;
TensorInfo* m_waitee = nullptr; TensorInfo* m_waitee = nullptr;
BackTraceInfoPtr m_bt = nullptr;
Spinlock m_pool_spin; Spinlock m_pool_spin;
Spinlock m_info_spin; Spinlock m_info_spin;
uint64_t m_waitee_id = 0; uint64_t m_waitee_id = 0;
......
...@@ -106,6 +106,10 @@ std::string OpDef::to_string() const { ...@@ -106,6 +106,10 @@ std::string OpDef::to_string() const {
return builder + "}"; return builder + "}";
} }
std::string OpDef::name() const {
return trait()->name;
}
size_t OpDef::hash() const { size_t OpDef::hash() const {
return trait()->hash(*this); return trait()->hash(*this);
} }
......
...@@ -317,8 +317,12 @@ struct ChromeTimelineEventVisitor : EventVisitor<ChromeTimelineEventVisitor> { ...@@ -317,8 +317,12 @@ struct ChromeTimelineEventVisitor : EventVisitor<ChromeTimelineEventVisitor> {
.bp('e') .bp('e')
.cat("TensorProp") .cat("TensorProp")
.scope(pid_str); .scope(pid_str);
new_host_event("TensorWaitProp", 'E') auto args = current_tensor->detail(current->time);
.args(current_tensor->detail(current->time)); auto params = event.param();
for (auto&& [name, value] : params) {
args[name] = value;
}
new_host_event("TensorWaitProp", 'E').args(args);
} else if constexpr (std::is_same_v<TEvent, TensorNotifyPropEvent>) { } else if constexpr (std::is_same_v<TEvent, TensorNotifyPropEvent>) {
new_host_event(pid_str, 's') new_host_event(pid_str, 's')
.id(event.tensor_id) .id(event.tensor_id)
......
...@@ -126,6 +126,7 @@ DEF_DUR_EVENT(TensorWaitProp, { ...@@ -126,6 +126,7 @@ DEF_DUR_EVENT(TensorWaitProp, {
uint64_t tensor_id; uint64_t tensor_id;
uint64_t wait_id; uint64_t wait_id;
TensorProp prop; TensorProp prop;
std::function<OpParams()> param;
}); });
DEF_DUR_EVENT(SampleDevice, { DEF_DUR_EVENT(SampleDevice, {
......
...@@ -40,6 +40,7 @@ ValueRefList InterpreterTransformation::apply_op( ...@@ -40,6 +40,7 @@ ValueRefList InterpreterTransformation::apply_op(
for (auto input : inputs) { for (auto input : inputs) {
input_handles.push_back(input.cast(m_value_type).handle()->handle()); input_handles.push_back(input.cast(m_value_type).handle()->handle());
} }
m_channel->set_backtrace(Transformation::get_context().bt);
output_handles = output_handles =
m_channel->apply_op(apply_op.op().shared_from_this(), input_handles); m_channel->apply_op(apply_op.op().shared_from_this(), input_handles);
ValueRefList outputs(output_handles.size()); ValueRefList outputs(output_handles.size());
...@@ -48,6 +49,7 @@ ValueRefList InterpreterTransformation::apply_op( ...@@ -48,6 +49,7 @@ ValueRefList InterpreterTransformation::apply_op(
output_handles[i] = nullptr; output_handles[i] = nullptr;
} }
output_handles.clear(); output_handles.clear();
m_channel->clear_backtrace();
return outputs; return outputs;
} }
...@@ -55,6 +57,7 @@ ValueRefList InterpreterTransformation::apply_get_attr( ...@@ -55,6 +57,7 @@ ValueRefList InterpreterTransformation::apply_get_attr(
const GetAttr& get_attr, Span<ValueRef> inputs) { const GetAttr& get_attr, Span<ValueRef> inputs) {
auto& input = inputs.item().cast(m_value_type); auto& input = inputs.item().cast(m_value_type);
ValueRef output; ValueRef output;
m_channel->set_backtrace(Transformation::get_context().bt);
switch (get_attr.attr()) { switch (get_attr.attr()) {
case GetAttr::DType: case GetAttr::DType:
output = input.dtype(); output = input.dtype();
...@@ -77,6 +80,7 @@ ValueRefList InterpreterTransformation::apply_get_attr( ...@@ -77,6 +80,7 @@ ValueRefList InterpreterTransformation::apply_get_attr(
MegBrainError, "Interpreter: malformed GetAttr: %s", MegBrainError, "Interpreter: malformed GetAttr: %s",
get_attr.to_string().c_str()); get_attr.to_string().c_str());
} }
m_channel->clear_backtrace();
return {output}; return {output};
} }
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h" #include "megbrain/imperative/ops/utility.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
......
...@@ -110,7 +110,14 @@ std::string ValueRef::raw_type() const { ...@@ -110,7 +110,14 @@ std::string ValueRef::raw_type() const {
if (!m_storage) { if (!m_storage) {
return "null"; return "null";
} }
return m_storage->type().name(); return this->storage()->type().name();
}
const IType* ValueRef::type() const {
if (!m_storage) {
return nullptr;
}
return &m_storage->type();
} }
bool ValueRef::watching() const { bool ValueRef::watching() const {
......
#pragma once
#include <memory>
#include <string>
#include <variant>
#include <vector>
#include "./basic_operators.h"
#include "./operator.h"
#include "./value.h"
#include "megbrain/common.h"
namespace mgb::imperative {
struct BackTraceInfo;
using BackTraceInfoPtr = std::shared_ptr<BackTraceInfo>;
struct PyFrameInfo {
virtual std::string traceback() = 0;
virtual ~PyFrameInfo() {}
};
using PyFrameInfoPtr = std::shared_ptr<PyFrameInfo>;
using OpAttrInfo = std::variant<std::monostate, std::string, GetAttr::Attr>;
struct TransformationCallInfo {
size_t depth;
std::string op;
std::string transform;
OpAttrInfo attrs;
std::vector<std::string> inp_types;
std::string to_string() {
static const char tabs[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t";
const char* prefix = tabs + (sizeof(tabs) / sizeof(char)) - depth - 1;
std::string inps = "";
for (auto i : inp_types) {
inps += i + ", ";
}
std::string opinfo = op;
std::visit(
[&opinfo](auto&& i) {
using T = std::decay_t<decltype(i)>;
if constexpr (std::is_same_v<T, std::string>) {
opinfo += "(" + i + ")";
} else if constexpr (std::is_same_v<T, GetAttr::Attr>) {
switch (i) {
case GetAttr::Attr::Data:
opinfo += "(data)";
break;
case GetAttr::Attr::Shape:
opinfo += "(shape)";
break;
case GetAttr::Attr::DType:
opinfo += "(dtype)";
break;
case GetAttr::Attr::Device:
opinfo += "(device)";
break;
case GetAttr::Attr::Value:
opinfo += "(value)";
break;
case GetAttr::Attr::None:
opinfo += "(none)";
break;
default:
break;
}
}
},
attrs);
return ssprintf(
"%s %s: Apply (%s, %s)", prefix, transform.c_str(), opinfo.c_str(),
inps.c_str());
}
static OpAttrInfo get_op_attr(const Operator& op) {
if (op.is<GetAttr>()) {
return op.as<GetAttr>()->attr();
} else if (op.is<ApplyOp>()) {
auto& opdef = op.as<ApplyOp>()->op();
return opdef.name();
} else {
return {};
}
}
};
struct TransformationReturnInfo {
size_t depth;
std::vector<std::string> return_types;
std::string to_string() {
static const char tabs[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t";
const char* prefix = tabs + (sizeof(tabs) / sizeof(char)) - depth - 1;
std::string returns = "";
for (auto i : return_types) {
returns += i + ", ";
}
return ssprintf("%s return: %s", prefix, returns.c_str());
}
};
struct BackTraceInfo {
std::vector<std::variant<TransformationCallInfo, TransformationReturnInfo>>
trans_stack_info;
PyFrameInfoPtr py_stack_info;
BackTraceInfo(PyFrameInfoPtr info) : py_stack_info{std::move(info)} {}
std::string py_traceback() {
return "Python Backtrace: " + py_stack_info->traceback();
}
std::string transformation_traceback() {
std::string trace_info = "Dispatch Transformation Backtrace: ";
for (auto&& i : trans_stack_info) {
std::visit(
[&trace_info](auto& i) { trace_info += "\n" + i.to_string(); }, i);
}
return trace_info;
}
};
} // namespace mgb::imperative
\ No newline at end of file
...@@ -28,6 +28,7 @@ public: ...@@ -28,6 +28,7 @@ public:
const OpDef& op() const { return m_op; } const OpDef& op() const { return m_op; }
std::string to_string() const override; std::string to_string() const override;
std::string raw_type() const { return "ApplyOp"; }
}; };
/** /**
...@@ -54,7 +55,7 @@ public: ...@@ -54,7 +55,7 @@ public:
} }
Attr attr() const { return m_attr; } Attr attr() const { return m_attr; }
std::string raw_type() const { return "GetAttr"; }
std::string to_string() const; std::string to_string() const;
}; };
...@@ -104,6 +105,7 @@ public: ...@@ -104,6 +105,7 @@ public:
DType dtype() const { return m_dtype; } DType dtype() const { return m_dtype; }
ValueShape shape() const { return m_shape; } ValueShape shape() const { return m_shape; }
Format format() const { return m_format; } Format format() const { return m_format; }
std::string raw_type() const { return "CreateTensor"; }
std::string to_string() const override; std::string to_string() const override;
}; };
...@@ -122,6 +124,7 @@ public: ...@@ -122,6 +124,7 @@ public:
DTRCommand(Kind kind) : m_kind(kind) {} DTRCommand(Kind kind) : m_kind(kind) {}
Kind kind() const { return m_kind; } Kind kind() const { return m_kind; }
std::string raw_type() const { return "DTRCommand"; }
std::string to_string() const override; std::string to_string() const override;
...@@ -132,6 +135,7 @@ public: ...@@ -132,6 +135,7 @@ public:
class GetName final : public OperatorImpl<GetName, Operator::GetAttrLike> { class GetName final : public OperatorImpl<GetName, Operator::GetAttrLike> {
public: public:
std::string to_string() const override; std::string to_string() const override;
std::string raw_type() const { return "GetName"; }
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; } ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; }
}; };
...@@ -148,6 +152,7 @@ public: ...@@ -148,6 +152,7 @@ public:
RenameValue(std::string name) : m_name(name) {} RenameValue(std::string name) : m_name(name) {}
std::string name() const { return m_name; } std::string name() const { return m_name; }
std::string raw_type() const { return "RenameValue"; }
std::string to_string() const override; std::string to_string() const override;
...@@ -160,11 +165,13 @@ class IsScalar final : public OperatorImpl<IsScalar, Operator::GetAttrLike> { ...@@ -160,11 +165,13 @@ class IsScalar final : public OperatorImpl<IsScalar, Operator::GetAttrLike> {
private: private:
public: public:
std::string to_string() const override; std::string to_string() const override;
std::string raw_type() const { return "IsScalar"; }
}; };
class GetFormat final : public OperatorImpl<GetFormat, Operator::GetAttrLike> { class GetFormat final : public OperatorImpl<GetFormat, Operator::GetAttrLike> {
public: public:
std::string to_string() const override; std::string to_string() const override;
std::string raw_type() const { return "GetFromat"; }
}; };
class SetFormat final : public OperatorImpl<SetFormat, Operator::IdentityLike> { class SetFormat final : public OperatorImpl<SetFormat, Operator::IdentityLike> {
...@@ -175,6 +182,7 @@ public: ...@@ -175,6 +182,7 @@ public:
SetFormat(std::string format) : m_format(format) {} SetFormat(std::string format) : m_format(format) {}
Format format() const { return m_format; } Format format() const { return m_format; }
std::string raw_type() const { return "SetFromat"; }
std::string to_string() const override; std::string to_string() const override;
}; };
...@@ -182,6 +190,7 @@ public: ...@@ -182,6 +190,7 @@ public:
class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> { class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> {
public: public:
std::string to_string() const override; std::string to_string() const override;
std::string raw_type() const { return "GetVarVal"; }
}; };
class CreateNode final : public OperatorImpl<CreateNode> { class CreateNode final : public OperatorImpl<CreateNode> {
...@@ -192,6 +201,7 @@ public: ...@@ -192,6 +201,7 @@ public:
CreateNode(cg::VarNode* node) : m_node(node) {} CreateNode(cg::VarNode* node) : m_node(node) {}
cg::VarNode* node() const { return m_node; } cg::VarNode* node() const { return m_node; }
std::string raw_type() const { return "CreateNode"; }
std::string to_string() const override; std::string to_string() const override;
}; };
...@@ -199,6 +209,7 @@ public: ...@@ -199,6 +209,7 @@ public:
class DupTensor final : public OperatorImpl<DupTensor, Operator::IdentityLike> { class DupTensor final : public OperatorImpl<DupTensor, Operator::IdentityLike> {
public: public:
std::string to_string() const override { return "DupTensor"; } std::string to_string() const override { return "DupTensor"; }
std::string raw_type() const { return "DupTensor"; }
}; };
} // namespace imperative } // namespace imperative
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <any> #include <any>
#include <atomic> #include <atomic>
#include "./backtrace.h"
#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"
namespace mgb::imperative::interpreter { namespace mgb::imperative::interpreter {
...@@ -63,10 +64,13 @@ struct Interpreter { ...@@ -63,10 +64,13 @@ struct Interpreter {
virtual void push_scope(std::string name) = 0; virtual void push_scope(std::string name) = 0;
virtual void pop_scope(std::string name) = 0; virtual void pop_scope(std::string name) = 0;
virtual BackTraceInfoPtr& get_backtrace() = 0;
virtual void set_backtrace(BackTraceInfoPtr bt) = 0;
virtual void clear_backtrace() = 0;
}; };
virtual std::unique_ptr<Channel> create_channel() = 0; virtual std::unique_ptr<Channel> create_channel() = 0;
static Interpreter& inst(); static Interpreter& inst();
protected: protected:
......
...@@ -74,6 +74,8 @@ public: ...@@ -74,6 +74,8 @@ public:
std::string to_string() const; std::string to_string() const;
std::string name() const;
const std::string scope() const; const std::string scope() const;
const std::string make_name() const; const std::string make_name() const;
......
...@@ -61,7 +61,7 @@ public: ...@@ -61,7 +61,7 @@ public:
} }
virtual std::string to_string() const = 0; virtual std::string to_string() const = 0;
virtual std::string raw_type() const = 0;
/** /**
* \brief fallback implementation of this. Not all operators has fallback * \brief fallback implementation of this. Not all operators has fallback
* implementation. * implementation.
...@@ -86,6 +86,7 @@ public: ...@@ -86,6 +86,7 @@ public:
static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }();
std::string to_string() const override = 0; std::string to_string() const override = 0;
std::string raw_type() const override = 0;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <vector> #include <vector>
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/imperative/backtrace.h"
#include "megbrain/imperative/subgraph.h" #include "megbrain/imperative/subgraph.h"
#include "megbrain/imperative/utils/allocator.h" #include "megbrain/imperative/utils/allocator.h"
#include "megbrain/imperative/utils/local_ptr.h" #include "megbrain/imperative/utils/local_ptr.h"
...@@ -35,6 +36,9 @@ struct TransformationContext { ...@@ -35,6 +36,9 @@ struct TransformationContext {
size_t next_transformation = 0; size_t next_transformation = 0;
std::vector<TransformationFrame> frames; std::vector<TransformationFrame> frames;
ForwardAllocator<ValueRef> allocator; ForwardAllocator<ValueRef> allocator;
size_t record_bt_trans_id;
bool record_trans_bt = false;
BackTraceInfoPtr bt;
}; };
/** /**
......
...@@ -318,7 +318,7 @@ private: ...@@ -318,7 +318,7 @@ private:
// TODO: identified by GradKey // TODO: identified by GradKey
public: public:
std::string to_string() const override { return "DetachValue"; } std::string to_string() const override { return "DetachValue"; }
std::string raw_type() const override { return "DetachGrad"; }
ValueRefList fallback(Span<ValueRef> inputs) const override { ValueRefList fallback(Span<ValueRef> inputs) const override {
return {inputs.as_array<1>()[0]}; return {inputs.as_array<1>()[0]};
} }
...@@ -335,6 +335,8 @@ public: ...@@ -335,6 +335,8 @@ public:
std::string to_string() const override { std::string to_string() const override {
return ssprintf("AttachGradValue{key=%s}", m_key->name().c_str()); return ssprintf("AttachGradValue{key=%s}", m_key->name().c_str());
} }
std::string raw_type() const { return "AttachGrad"; }
}; };
class GradBackward : public OperatorImpl<GradBackward, Operator::GetAttrLike> { class GradBackward : public OperatorImpl<GradBackward, Operator::GetAttrLike> {
...@@ -349,6 +351,8 @@ public: ...@@ -349,6 +351,8 @@ public:
std::string to_string() const override { std::string to_string() const override {
return ssprintf("GradBackwardValue{key=%s}", m_key->name().c_str()); return ssprintf("GradBackwardValue{key=%s}", m_key->name().c_str());
} }
std::string raw_type() const { return "GradBackward"; }
}; };
class IsAttachedTo : public OperatorImpl<IsAttachedTo, Operator::GetAttrLike> { class IsAttachedTo : public OperatorImpl<IsAttachedTo, Operator::GetAttrLike> {
...@@ -363,6 +367,8 @@ public: ...@@ -363,6 +367,8 @@ public:
return ssprintf("IsAttachedToValue{key=%s}", m_key->name().c_str()); return ssprintf("IsAttachedToValue{key=%s}", m_key->name().c_str());
} }
std::string raw_type() const { return "IsAttachedTo"; }
ValueRefList fallback(Span<ValueRef> inputs) const override { ValueRefList fallback(Span<ValueRef> inputs) const override {
return {BoolValue::make(false)}; return {BoolValue::make(false)};
} }
...@@ -382,7 +388,7 @@ public: ...@@ -382,7 +388,7 @@ public:
size_t nr_inputs() const { return m_nr_inputs; } size_t nr_inputs() const { return m_nr_inputs; }
std::string to_string() const override { return ssprintf("SetGradValue{}"); } std::string to_string() const override { return ssprintf("SetGradValue{}"); }
std::string raw_type() const { return "SetGrad"; }
ValueRefList fallback(Span<ValueRef> inputs) const override { ValueRefList fallback(Span<ValueRef> inputs) const override {
auto outputs = inputs.sub(m_nr_inputs, inputs.size() - m_nr_inputs); auto outputs = inputs.sub(m_nr_inputs, inputs.size() - m_nr_inputs);
return {outputs.begin(), outputs.end()}; return {outputs.begin(), outputs.end()};
...@@ -394,7 +400,7 @@ public: ...@@ -394,7 +400,7 @@ public:
GetGradKey() = default; GetGradKey() = default;
std::string to_string() const override { return ssprintf("GetGradKeyValue{}"); } std::string to_string() const override { return ssprintf("GetGradKeyValue{}"); }
std::string raw_type() const { return "GetGradKey"; };
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; } ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; }
}; };
...@@ -411,6 +417,7 @@ public: ...@@ -411,6 +417,7 @@ public:
std::string to_string() const override { std::string to_string() const override {
return ssprintf("GetBackwardClosure{key=%s}", m_key->name().c_str()); return ssprintf("GetBackwardClosure{key=%s}", m_key->name().c_str());
} }
std::string raw_type() const { return "GetBackwardClosure"; }
}; };
} // namespace mgb::imperative } // namespace mgb::imperative
...@@ -87,6 +87,8 @@ public: ...@@ -87,6 +87,8 @@ public:
std::string to_string() const override { std::string to_string() const override {
return ssprintf("TraceMarkVar{mark=%s}", imperative::quoted(m_mark).c_str()); return ssprintf("TraceMarkVar{mark=%s}", imperative::quoted(m_mark).c_str());
} }
std::string raw_type() const { return "TraceMarkVar"; }
}; };
class TracingValue final : public ObjectValue<TracingValue> { class TracingValue final : public ObjectValue<TracingValue> {
......
...@@ -231,6 +231,7 @@ public: ...@@ -231,6 +231,7 @@ public:
ValueRef unwrap() const; ValueRef unwrap() const;
std::string to_string() const; std::string to_string() const;
std::string raw_type() const; std::string raw_type() const;
const IType* type() const;
uint64_t id() const { return m_id; } uint64_t id() const { return m_id; }
size_t hash() const { return id(); } size_t hash() const { return id(); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册