From ba8bd01023bd3c08a172c943925698c1206d4fc2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 1 Nov 2021 11:59:56 +0800 Subject: [PATCH] fix(mge/traced_module): fix insert module GitOrigin-RevId: 755e1c68f60b0fc994eec56697d0515a8343e9f5 --- .../megengine/traced_module/traced_module.py | 22 ++++++------- .../unit/traced_module/test_modification.py | 31 ++++++++++++++++--- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 7e1a57f7a..591b60d3f 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -293,19 +293,10 @@ class _InsertExprs: module = self.graph.inputs[0].owner - for mod, parent in module.modules(with_parent=True): - name = mod._name - if isinstance(mod, TracedModuleBuilder): - mod = mod.build() - if hasattr(mod, "argdef_graph_map"): - for g in mod.argdef_graph_map.values(): - for n in g.nodes(False): - if isinstance(n, TensorNode): - n.value = None - setattr(parent, name, mod) - - for node in self.global_scope.nodes(False): - node.value = None + for k, v in module.__dict__.items(): + if isinstance(v, TracedModuleBuilder): + v = v.build() + setattr(module, k, v) extra_inp_nodes = set(self.global_scope.inputs) max_inp_expr_idx = -1 @@ -334,6 +325,9 @@ class _InsertExprs: self.graph._namespace.merge(self.global_scope._namespace) self.root_graph._total_ids = (Node._get_next_id(), Expr._get_next_id()) self.root_graph.inputs[0].owner._update_ref() + for node in self.root_graph.nodes(): + if isinstance(node, TensorNode): + node.value = None return True @@ -1519,6 +1513,7 @@ class TracedModuleBuilder(NodeMixin): return active_module_tracer().patcher.wrap_fn(attr) if isinstance(attr, (List, Dict)): + flag = _set_convert_node_flag(False) unset_module_tracing() has_module, m_container = replace_container_with_module_container(attr) if m_container: @@ -1529,6 +1524,7 @@ class TracedModuleBuilder(NodeMixin): " Module and Non-Module objects." ) set_module_tracing() + _set_convert_node_flag(flag) if isinstance(attr, Module): attr = TracedModuleBuilder(attr) diff --git a/imperative/python/test/unit/traced_module/test_modification.py b/imperative/python/test/unit/traced_module/test_modification.py index 9ac537306..0fed2217f 100644 --- a/imperative/python/test/unit/traced_module/test_modification.py +++ b/imperative/python/test/unit/traced_module/test_modification.py @@ -16,7 +16,7 @@ import megengine.module as M from megengine.module.identity import Identity from megengine.traced_module import trace_module from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input -from megengine.traced_module.node import ModuleNode, Node +from megengine.traced_module.node import ModuleNode, Node, TensorNode class IdentityMod(M.Module): @@ -159,21 +159,44 @@ def test_insert(): def test_insert_module(): class Neg(M.Module): + def __init__(self, name): + super().__init__(name) + self.identity = M.Identity() + self.identity_list = [M.Identity(), M.Identity()] + self.identity_dict = {"0": M.Identity(), "1": M.Identity()} + self.param = F.zeros((1,)) + def forward(self, x): - return F.neg(x) + x = self.identity(x) + for m in self.identity_dict: + x = self.identity_dict[m](x) + for m in self.identity_list: + x = m(x) + return F.neg(x) + self.param traced_module, x, expect = _init_block() graph = traced_module.graph relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0] self = graph.inputs[0] - setattr(traced_module, "neg", Neg()) + setattr(traced_module, "neg", Neg(name="neg")) + setattr(traced_module, "neg2", Neg(name="neg")) + setattr(traced_module, "param", F.zeros((1,))) + with graph.insert_exprs(): neg_out = self.neg(relu_out) + neg_out = self.neg2(relu_out) + neg_out = neg_out + self.param graph.replace_node({relu_out: neg_out}) graph.compile() + np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) assert traced_module.neg.graph is not None - assert len(traced_module.neg.graph._exprs) == 1 + assert traced_module.neg2.graph is not None + assert traced_module.neg2.param is not None + assert len(traced_module.neg.graph._exprs) == 13 + for n in traced_module.graph.nodes(): + if isinstance(n, TensorNode): + assert n.value is None def test_add_input_and_output(): -- GitLab