diff --git a/imperative/python/src/module_trace.h b/imperative/python/src/module_trace.h index 4102dd3a3002d7bdf22628a5e3fd4cce3d594b03..c9f195a8d682562162474c2cf26e8fbe881ede89 100644 --- a/imperative/python/src/module_trace.h +++ b/imperative/python/src/module_trace.h @@ -30,12 +30,22 @@ private: } public: + inline static WeakKeyMap module_trace_info_map; ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} ValueRefList apply_transformation( const Operator& op, Span inputs) override { if (op.is() && m_enabled > 0) { auto outputs = apply_module_trace_hook(op.cast().op(), inputs); return outputs; + } else if (op.is()) { + auto outputs = imperative::apply(op, inputs); + if (auto module_trace_info = module_trace_info_map.try_get(inputs[0])) { + if (module_trace_info->ptr()) { + auto node = module_trace_info.value(); + module_trace_info_map[outputs[0]] = module_trace_info.value(); + } + } + return outputs; } else { return imperative::apply(op, inputs); } diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index e76ca5a3208107b0bc12d07cfa74cc5b75fac5e5..e24bd598c99f9742fb4cc6f90509ef474259460b 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -47,10 +47,6 @@ namespace views = ranges::views; namespace mgb::imperative::python { -namespace { -WeakKeyMap module_trace_info_map; -} // namespace - interpreter::Interpreter::Channel* interpreter_for_py = nullptr; PyTypeObject* py_tensor_type = nullptr; PyTypeObject* py_varnode_type = nullptr; @@ -594,7 +590,9 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { } PyObject* TensorWrapper::module_trace_info() { - if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) { + if (auto module_trace_info = + ModuleTraceTransformation::module_trace_info_map.try_get( + m_tensor->data())) { if (module_trace_info->ptr()) { return module_trace_info->inc_ref().ptr(); } @@ -608,7 +606,8 @@ PyObject* TensorWrapper::module_trace_info() { void TensorWrapper::set_module_trace_info(PyObject* obj) { // TODO: erase when obj == nullptr - module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow(obj); + ModuleTraceTransformation::module_trace_info_map[m_tensor->data()] = + py::reinterpret_borrow(obj); } void TensorWrapper::_set_format(PyObject* dest) { @@ -620,6 +619,7 @@ void TensorWrapper::_set_format(PyObject* dest) { void TensorWrapper::_set_name(PyObject* dest) { auto py_dest = py::reinterpret_borrow(dest); auto name = py_dest.cast(); + m_tensor->set_name(name); } diff --git a/imperative/python/test/unit/traced_module/test_trace_module.py b/imperative/python/test/unit/traced_module/test_trace_module.py index 43d3f492eb4782364ae0a8cdf347342b5543c20b..5bf42c876c117e61797046c4fdb1b022cf0d9a61 100644 --- a/imperative/python/test/unit/traced_module/test_trace_module.py +++ b/imperative/python/test/unit/traced_module/test_trace_module.py @@ -9,7 +9,7 @@ from megengine.core._imperative_rt.core2 import apply from megengine.core.ops import builtin from megengine.module import Module from megengine.traced_module import TracedModule, enable_expr_checker, trace_module -from megengine.traced_module.expr import Apply, CallFunction, Constant +from megengine.traced_module.expr import Apply, CallFunction, CallMethod, Constant class MyModule1(M.Module): @@ -59,6 +59,14 @@ class MyModule4(M.Module): return self.add(x, y) +class MyModule5(M.Module): + def forward(self, x): + a = x + x + b = x * a + b.name = "result" + return b + + def test_trace_module(): enable_expr_checker() x = Tensor(1) @@ -157,3 +165,9 @@ def test_trace_module_2(): traced_model.graph._exprs[2].opdef, builtin.Elemwise ) assert int(traced_model(Tensor([1, 2]))[0]) == 3 + + +def test_rename(): + model = MyModule5() + tm_model = trace_module(model, Tensor(1)) + assert isinstance(tm_model.graph.outputs[0].expr, CallMethod)