提交 c54a731b 编写于 作者: M Megvii Engine Team

fix(utils/network): fix replace oprs

GitOrigin-RevId: eba27e3dfb2603b185279ac6120d1dea69673bf7
上级 604bb2a5
...@@ -399,6 +399,7 @@ class Network: ...@@ -399,6 +399,7 @@ class Network:
var.owner = repl_dict[opr] var.owner = repl_dict[opr]
var.__dict__.update(repl_dict[opr].outputs[ind].__dict__) var.__dict__.update(repl_dict[opr].outputs[ind].__dict__)
var.var = repl_dict[opr].outputs[ind].var var.var = repl_dict[opr].outputs[ind].var
repl_dict[opr].outputs = opr.outputs
self._compile() self._compile()
def get_opr_by_type(self, oprcls, unique=True): def get_opr_by_type(self, oprcls, unique=True):
......
...@@ -119,8 +119,16 @@ def test_replace_opr(): ...@@ -119,8 +119,16 @@ def test_replace_opr():
out1 = graph.add_dep_oprs(out1) out1 = graph.add_dep_oprs(out1)
orig_opr = graph.opr_filter.has_input(vara).as_unique() orig_opr = graph.opr_filter.has_input(vara).as_unique()
repl_dict = {orig_opr: out1[0].owner} new_opr = out1[0].owner
repl_dict = {orig_opr: new_opr}
graph.replace_oprs(repl_dict) graph.replace_oprs(repl_dict)
var_out = orig_opr.outputs
for idx, node in enumerate(var_out):
assert node.owner is new_opr
assert node.owner.outputs[idx] is node
modified_model1 = io.BytesIO() modified_model1 = io.BytesIO()
graph.dump(modified_model1) graph.dump(modified_model1)
modified_model1.seek(0) modified_model1.seek(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册