提交 6ce1483b 编写于 作者: M Megvii Engine Team

fix(utils/network): fix replace var in different networks

GitOrigin-RevId: 8d1356ddb13c5cae1190e2db2814b7cb1d0a72fd
上级 73d25779
......@@ -371,7 +371,9 @@ class Network:
if repl_var is var:
continue
for opnode in var.users:
assert var in opnode.inputs
# use method 'is' instead of 'in' to avoid
# compare VarNode use elemwise equal
assert any([var is _ for _ in opnode.inputs])
opnode.inputs = [repl_var if var is i else i for i in opnode.inputs]
if opnode not in repl_var.users:
repl_var.users.append(opnode)
......
......@@ -511,3 +511,50 @@ def test_set_symbolic_shape():
set_symbolic_shape(False)
assert var_a.shape == var_a.partial_shape
set_symbolic_shape(saved_symbolic_shape)
def test_replace_var_in_different_network():
a = Tensor([1, 2])
b = Tensor([3, 4])
@trace(symbolic=True, capture_as_const=True)
def fwd(a, b):
return (a + b) * 2
@trace(symbolic=True, capture_as_const=True)
def fwd1(c, d):
return c + d
fwd(a, b)
orig_model = io.BytesIO()
fwd.dump(
orig_model, arg_names=["a", "b"], output_names="o", optimize_for_inference=False
)
orig_model.seek(0)
fwd1(a, b)
orig_model1 = io.BytesIO()
fwd1.dump(
orig_model1,
arg_names=["c", "d"],
output_names="o",
optimize_for_inference=False,
)
orig_model1.seek(0)
graph = Net.load(orig_model)
graph1 = Net.load(orig_model1)
vara = graph.var_filter.name("a").as_unique()
varb = graph.var_filter.name("b").as_unique()
varo = graph1.var_filter.name("o").as_unique()
graph.replace_vars({vara: varo, varb: varo})
modified_model = io.BytesIO()
graph.dump(modified_model)
modified_model.seek(0)
load_graph = GraphInference(modified_model)
out = load_graph.run(a, b)
np.testing.assert_equal(out["o"], [16, 24])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册