diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index c3328e1b90ba4ed65b8bc68b6476ebb5700a40d8..7e1a57f7a626c0679d6ebba9975a66b55751b362 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -2078,9 +2078,7 @@ class TracedModule(Module): for node, repl_node in repl_dict.items(): assert node in graph._inputs or node in graph._outputs - for i in node.users: - if i not in repl_node.users: - repl_node.users.append(i) + repl_node.users.extend(node.users) rename_blacklist = list(chain(call.inputs, call.outputs)) diff --git a/imperative/python/test/unit/traced_module/test_modification.py b/imperative/python/test/unit/traced_module/test_modification.py index 820987e3b3a278b5abe4c51707d99b883fb2c5ca..9ac53730674e99e3f16e318c56b7151c16f7b8fd 100644 --- a/imperative/python/test/unit/traced_module/test_modification.py +++ b/imperative/python/test/unit/traced_module/test_modification.py @@ -6,6 +6,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import pickle +from collections import defaultdict from itertools import chain import numpy as np @@ -52,6 +53,25 @@ class MyModule(M.Module): return x +class MyBlock1(M.Module): + def forward(self, a): + y = F.concat([a, a]) + return a, y + + +class MyModule1(M.Module): + def __init__(self): + super().__init__() + self.block0 = MyBlock1() + self.block1 = MyBlock1() + + def forward(self, a): + a, y1 = self.block0(a) + a = a + 1 + a, y2 = self.block1(a) + return a, y1 + y2 + + class NewModule(M.Module): def __init__(self, traced_module): super(NewModule, self).__init__() @@ -64,6 +84,17 @@ class NewModule(M.Module): return x +def _check_expr_users(traced_module): + node_user = defaultdict(list) + for expr in traced_module.graph._exprs: + for node in expr.inputs: + node_user[node].append(expr) + for node in traced_module.graph.nodes(): + node.users.sort(key=lambda m: m._id) + node_user[node].sort(key=lambda m: m._id) + assert node.users == node_user[node] + + def _init_cls(cls): module = cls() x = F.ones((1, 3, 3, 3)) @@ -201,6 +232,10 @@ def test_flatten(): assert len(traced_module.graph._exprs) == 12 np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) + traced_module, x, expect = _init_cls(MyModule1) + traced_module = traced_module.flatten() + _check_expr_users(traced_module) + def test_id_and_name(): def _check_id(traced_module):