提交 ba8bd010 编写于 作者: M Megvii Engine Team

fix(mge/traced_module): fix insert module

GitOrigin-RevId: 755e1c68f60b0fc994eec56697d0515a8343e9f5
上级 b8316de5
......@@ -293,19 +293,10 @@ class _InsertExprs:
module = self.graph.inputs[0].owner
for mod, parent in module.modules(with_parent=True):
name = mod._name
if isinstance(mod, TracedModuleBuilder):
mod = mod.build()
if hasattr(mod, "argdef_graph_map"):
for g in mod.argdef_graph_map.values():
for n in g.nodes(False):
if isinstance(n, TensorNode):
n.value = None
setattr(parent, name, mod)
for node in self.global_scope.nodes(False):
node.value = None
for k, v in module.__dict__.items():
if isinstance(v, TracedModuleBuilder):
v = v.build()
setattr(module, k, v)
extra_inp_nodes = set(self.global_scope.inputs)
max_inp_expr_idx = -1
......@@ -334,6 +325,9 @@ class _InsertExprs:
self.graph._namespace.merge(self.global_scope._namespace)
self.root_graph._total_ids = (Node._get_next_id(), Expr._get_next_id())
self.root_graph.inputs[0].owner._update_ref()
for node in self.root_graph.nodes():
if isinstance(node, TensorNode):
node.value = None
return True
......@@ -1519,6 +1513,7 @@ class TracedModuleBuilder(NodeMixin):
return active_module_tracer().patcher.wrap_fn(attr)
if isinstance(attr, (List, Dict)):
flag = _set_convert_node_flag(False)
unset_module_tracing()
has_module, m_container = replace_container_with_module_container(attr)
if m_container:
......@@ -1529,6 +1524,7 @@ class TracedModuleBuilder(NodeMixin):
" Module and Non-Module objects."
)
set_module_tracing()
_set_convert_node_flag(flag)
if isinstance(attr, Module):
attr = TracedModuleBuilder(attr)
......
......@@ -16,7 +16,7 @@ import megengine.module as M
from megengine.module.identity import Identity
from megengine.traced_module import trace_module
from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input
from megengine.traced_module.node import ModuleNode, Node
from megengine.traced_module.node import ModuleNode, Node, TensorNode
class IdentityMod(M.Module):
......@@ -159,21 +159,44 @@ def test_insert():
def test_insert_module():
class Neg(M.Module):
def __init__(self, name):
super().__init__(name)
self.identity = M.Identity()
self.identity_list = [M.Identity(), M.Identity()]
self.identity_dict = {"0": M.Identity(), "1": M.Identity()}
self.param = F.zeros((1,))
def forward(self, x):
return F.neg(x)
x = self.identity(x)
for m in self.identity_dict:
x = self.identity_dict[m](x)
for m in self.identity_list:
x = m(x)
return F.neg(x) + self.param
traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0]
self = graph.inputs[0]
setattr(traced_module, "neg", Neg())
setattr(traced_module, "neg", Neg(name="neg"))
setattr(traced_module, "neg2", Neg(name="neg"))
setattr(traced_module, "param", F.zeros((1,)))
with graph.insert_exprs():
neg_out = self.neg(relu_out)
neg_out = self.neg2(relu_out)
neg_out = neg_out + self.param
graph.replace_node({relu_out: neg_out})
graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
assert traced_module.neg.graph is not None
assert len(traced_module.neg.graph._exprs) == 1
assert traced_module.neg2.graph is not None
assert traced_module.neg2.param is not None
assert len(traced_module.neg.graph._exprs) == 13
for n in traced_module.graph.nodes():
if isinstance(n, TensorNode):
assert n.value is None
def test_add_input_and_output():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册