提交 cf5e9488 编写于 作者: M Megvii Engine Team

fix(traced_module): fix module trace transformation

GitOrigin-RevId: ce11fe5e093d89cd444595f673a849c16dadbe4a
上级 97c90d91
...@@ -606,7 +606,8 @@ class Apply(Expr): ...@@ -606,7 +606,8 @@ class Apply(Expr):
def apply_module_trace_hook(cls, opdef, *inputs): def apply_module_trace_hook(cls, opdef, *inputs):
for i in inputs: for i in inputs:
node = NodeMixin.get(i, None) node = NodeMixin.get(i, None)
assert node is not None if node is None: # capture as constant
NodeMixin.wrap_safe(i, Constant.make(i))
if isinstance(opdef, FakeQuant): if isinstance(opdef, FakeQuant):
inp_nodes = [NodeMixin.get(inputs[0])] inp_nodes = [NodeMixin.get(inputs[0])]
...@@ -627,7 +628,6 @@ class Apply(Expr): ...@@ -627,7 +628,6 @@ class Apply(Expr):
unset_module_tracing() unset_module_tracing()
outputs = apply(opdef, *inputs) outputs = apply(opdef, *inputs)
outputs = list(map(Tensor, outputs))
set_module_tracing() set_module_tracing()
apply_node.add_outputs(outputs) apply_node.add_outputs(outputs)
...@@ -741,12 +741,8 @@ class Constant(Expr): ...@@ -741,12 +741,8 @@ class Constant(Expr):
assert isinstance(c, (RawTensor, Module)) assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module): if isinstance(c, Module):
assert module_tracer.is_builtin(c) or c.is_qat assert module_tracer.is_builtin(c) or c.is_qat
if isinstance(c, RawTensor): if type(c) is RawTensor:
if is_tracing_module(): with _exclude_from_trace():
unset_module_tracing()
c = Tensor(c)
set_module_tracing()
else:
c = Tensor(c) c = Tensor(c)
self.value = c self.value = c
self.name = name self.name = name
......
...@@ -52,6 +52,12 @@ public: ...@@ -52,6 +52,12 @@ public:
} }
} }
void enable() { m_enabled = 1; }
void disable() { m_enabled = 0; }
bool enabled() const { return m_enabled; }
ValueRef unwrap(ValueRef value) override { return value; } ValueRef unwrap(ValueRef value) override { return value; }
std::string name() const override { return "ModuleTraceTransformation"; } std::string name() const override { return "ModuleTraceTransformation"; }
......
...@@ -219,17 +219,19 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { ...@@ -219,17 +219,19 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
PyObject* TensorWrapper::module_trace_info() { PyObject* TensorWrapper::module_trace_info() {
if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) { if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) {
if (module_trace_info->ptr()) {
return module_trace_info->inc_ref().ptr(); return module_trace_info->inc_ref().ptr();
} else { }
}
PyErr_SetString( PyErr_SetString(
PyExc_AttributeError, PyExc_AttributeError,
"Has no attribute named \'_NodeMixin__node\', please " "Has no attribute named \'_NodeMixin__node\', please "
"set it first"); "set it first");
return nullptr; return nullptr;
}
} }
void TensorWrapper::set_module_trace_info(PyObject* obj) { void TensorWrapper::set_module_trace_info(PyObject* obj) {
// TODO: erase when obj == nullptr
module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj); module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj);
} }
...@@ -1031,29 +1033,23 @@ void init_tensor(py::module m) { ...@@ -1031,29 +1033,23 @@ void init_tensor(py::module m) {
static py::function module_trace_hook; static py::function module_trace_hook;
static auto get_module_trace = [] {
static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation; static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation;
static int module_tracing = 0;
m.def("set_module_tracing", [=] {
if (!module_trace_transformation) { if (!module_trace_transformation) {
mgb_assert(module_trace_hook); mgb_assert(module_trace_hook);
module_trace_transformation = module_trace_transformation =
std::make_shared<ModuleTraceTransformation>(module_trace_hook); std::make_shared<ModuleTraceTransformation>(module_trace_hook);
} transformations.register_at<Segment::ModuleTrace>(
if (++module_tracing == 1) {
transformations.register_at<TransformationManager::ModuleTrace>(
module_trace_transformation); module_trace_transformation);
} }
}); return module_trace_transformation;
};
m.def("unset_module_tracing", [=] { m.def("set_module_tracing", [=] { get_module_trace()->enable(); });
if (--module_tracing == 0) {
transformations.unregister<TransformationManager::ModuleTrace>( m.def("unset_module_tracing", [=] { get_module_trace()->disable(); });
module_trace_transformation);
}
});
m.def("is_tracing_module", [=] { return module_tracing > 0; }); m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); });
m.def("set_module_trace_hook", m.def("set_module_trace_hook",
[](py::function function) { module_trace_hook = function; }); [](py::function function) { module_trace_hook = function; });
......
...@@ -5,9 +5,11 @@ import numpy as np ...@@ -5,9 +5,11 @@ import numpy as np
import megengine.functional as F import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine import Tensor from megengine import Tensor
from megengine.module.module import Module from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
from megengine.module import Module
from megengine.traced_module import TracedModule, enable_expr_checker, trace_module from megengine.traced_module import TracedModule, enable_expr_checker, trace_module
from megengine.traced_module.expr import CallFunction from megengine.traced_module.expr import Apply, CallFunction, Constant
class MyModule1(M.Module): class MyModule1(M.Module):
...@@ -133,3 +135,25 @@ def test_trace_module(): ...@@ -133,3 +135,25 @@ def test_trace_module():
tm6 = trace_module(MyModule5(), a, b) tm6 = trace_module(MyModule5(), a, b)
assert tm6.m1.argspec is None assert tm6.m1.argspec is None
assert tm6.m1._is_top is False assert tm6.m1._is_top is False
def test_trace_module_2():
class Model(M.Module):
def __init__(self):
super().__init__()
def forward(self, x):
out = x.shape
out = apply(builtin.Elemwise(mode="ADD"), out, Tensor(1))
return out
traced_model = trace_module(Model(), Tensor(([1,])))
assert isinstance(traced_model.graph._exprs[0], Apply) and isinstance(
traced_model.graph._exprs[0].opdef, builtin.GetVarShape
)
assert isinstance(traced_model.graph._exprs[1], Constant)
assert isinstance(traced_model.graph._exprs[2], Apply) and isinstance(
traced_model.graph._exprs[2].opdef, builtin.Elemwise
)
assert int(traced_model(Tensor([1, 2]))[0]) == 3
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册