提交 9d397727 编写于 作者: M Megvii Engine Team

fix(traced_module): fix bug of renaming tesnor

GitOrigin-RevId: 468a996bddf5af9820626e207cf18b6420262814
上级 31f31cef
......@@ -30,12 +30,22 @@ private:
}
public:
inline static WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {}
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override {
if (op.is<ApplyOp>() && m_enabled > 0) {
auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs);
return outputs;
} else if (op.is<RenameValue>()) {
auto outputs = imperative::apply(op, inputs);
if (auto module_trace_info = module_trace_info_map.try_get(inputs[0])) {
if (module_trace_info->ptr()) {
auto node = module_trace_info.value();
module_trace_info_map[outputs[0]] = module_trace_info.value();
}
}
return outputs;
} else {
return imperative::apply(op, inputs);
}
......
......@@ -47,10 +47,6 @@ namespace views = ranges::views;
namespace mgb::imperative::python {
namespace {
WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
} // namespace
interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
PyTypeObject* py_tensor_type = nullptr;
PyTypeObject* py_varnode_type = nullptr;
......@@ -594,7 +590,9 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
}
PyObject* TensorWrapper::module_trace_info() {
if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) {
if (auto module_trace_info =
ModuleTraceTransformation::module_trace_info_map.try_get(
m_tensor->data())) {
if (module_trace_info->ptr()) {
return module_trace_info->inc_ref().ptr();
}
......@@ -608,7 +606,8 @@ PyObject* TensorWrapper::module_trace_info() {
void TensorWrapper::set_module_trace_info(PyObject* obj) {
// TODO: erase when obj == nullptr
module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj);
ModuleTraceTransformation::module_trace_info_map[m_tensor->data()] =
py::reinterpret_borrow<py::object>(obj);
}
void TensorWrapper::_set_format(PyObject* dest) {
......@@ -620,6 +619,7 @@ void TensorWrapper::_set_format(PyObject* dest) {
void TensorWrapper::_set_name(PyObject* dest) {
auto py_dest = py::reinterpret_borrow<py::object>(dest);
auto name = py_dest.cast<std::string>();
m_tensor->set_name(name);
}
......
......@@ -9,7 +9,7 @@ from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import builtin
from megengine.module import Module
from megengine.traced_module import TracedModule, enable_expr_checker, trace_module
from megengine.traced_module.expr import Apply, CallFunction, Constant
from megengine.traced_module.expr import Apply, CallFunction, CallMethod, Constant
class MyModule1(M.Module):
......@@ -59,6 +59,14 @@ class MyModule4(M.Module):
return self.add(x, y)
class MyModule5(M.Module):
def forward(self, x):
a = x + x
b = x * a
b.name = "result"
return b
def test_trace_module():
enable_expr_checker()
x = Tensor(1)
......@@ -157,3 +165,9 @@ def test_trace_module_2():
traced_model.graph._exprs[2].opdef, builtin.Elemwise
)
assert int(traced_model(Tensor([1, 2]))[0]) == 3
def test_rename():
model = MyModule5()
tm_model = trace_module(model, Tensor(1))
assert isinstance(tm_model.graph.outputs[0].expr, CallMethod)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册