From c54a731b9ba867e7a50d54d1ed38d7bde26d0f97 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 10 Aug 2021 17:41:03 +0800 Subject: [PATCH] fix(utils/network): fix replace oprs GitOrigin-RevId: eba27e3dfb2603b185279ac6120d1dea69673bf7 --- imperative/python/megengine/utils/network.py | 1 + imperative/python/test/unit/utils/test_network.py | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py index 91fc8ec1..7209ca06 100644 --- a/imperative/python/megengine/utils/network.py +++ b/imperative/python/megengine/utils/network.py @@ -399,6 +399,7 @@ class Network: var.owner = repl_dict[opr] var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) var.var = repl_dict[opr].outputs[ind].var + repl_dict[opr].outputs = opr.outputs self._compile() def get_opr_by_type(self, oprcls, unique=True): diff --git a/imperative/python/test/unit/utils/test_network.py b/imperative/python/test/unit/utils/test_network.py index 7c55d914..ebf62fa0 100644 --- a/imperative/python/test/unit/utils/test_network.py +++ b/imperative/python/test/unit/utils/test_network.py @@ -119,8 +119,16 @@ def test_replace_opr(): out1 = graph.add_dep_oprs(out1) orig_opr = graph.opr_filter.has_input(vara).as_unique() - repl_dict = {orig_opr: out1[0].owner} + new_opr = out1[0].owner + repl_dict = {orig_opr: new_opr} graph.replace_oprs(repl_dict) + + var_out = orig_opr.outputs + + for idx, node in enumerate(var_out): + assert node.owner is new_opr + assert node.owner.outputs[idx] is node + modified_model1 = io.BytesIO() graph.dump(modified_model1) modified_model1.seek(0) -- GitLab