diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py index 91fc8ec1cd3be7f523b5ec3656354309c6125e15..7209ca06df599ba15336a8425a9c995b77f64ec4 100644 --- a/imperative/python/megengine/utils/network.py +++ b/imperative/python/megengine/utils/network.py @@ -399,6 +399,7 @@ class Network: var.owner = repl_dict[opr] var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) var.var = repl_dict[opr].outputs[ind].var + repl_dict[opr].outputs = opr.outputs self._compile() def get_opr_by_type(self, oprcls, unique=True): diff --git a/imperative/python/test/unit/utils/test_network.py b/imperative/python/test/unit/utils/test_network.py index 7c55d91492f280021bbe97f8c6dd4c2b09d40a86..ebf62fa029a2043057d175695736e22ffc8ce1d9 100644 --- a/imperative/python/test/unit/utils/test_network.py +++ b/imperative/python/test/unit/utils/test_network.py @@ -119,8 +119,16 @@ def test_replace_opr(): out1 = graph.add_dep_oprs(out1) orig_opr = graph.opr_filter.has_input(vara).as_unique() - repl_dict = {orig_opr: out1[0].owner} + new_opr = out1[0].owner + repl_dict = {orig_opr: new_opr} graph.replace_oprs(repl_dict) + + var_out = orig_opr.outputs + + for idx, node in enumerate(var_out): + assert node.owner is new_opr + assert node.owner.outputs[idx] is node + modified_model1 = io.BytesIO() graph.dump(modified_model1) modified_model1.seek(0)