提交 526c82c8 编写于 作者: M Megvii Engine Team

feat(traced_module): delete value of node when it will not be used by any expr

GitOrigin-RevId: 3fb7350d018ca7e32e7552fc4644eb5f34c59913
上级 db9ca196
......@@ -756,24 +756,32 @@ class InternalGraph:
return not end_nodes_set
return False
ref_count = lambda n: len(n.users) + (1 if n in self._outputs else 0)
for n, v in zip(self._inputs, inputs):
node2value[n] = v
if ref_count(n) > 0:
node2value[n] = [v, ref_count(n)]
if n in self._watch_point:
self._rst[n].append(v)
if n in self._end_point and get_all_endnode_val(n, v):
return list(endnode2value[i] for i in self._end_point)
for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
values = expr.interpret(*list(node2value[i][0] for i in expr.inputs))
for n in expr.inputs:
node2value[n][1] -= 1
if node2value[n][1] == 0:
node2value.pop(n)
if values is not None:
for n, v in zip(expr.outputs, values):
node2value[n] = v
if ref_count(n) > 0:
node2value[n] = [v, ref_count(n)]
if n in self._watch_point:
self._rst[n] = v
if self._end_point and get_all_endnode_val(n, v):
return list(endnode2value[i] for i in self._end_point)
return list(node2value[i] for i in self._outputs)
return list(node2value[i][0] for i in self._outputs)
def eval(self, *inputs):
assert len(inputs) == len(self._inputs) - 1
......@@ -1575,6 +1583,7 @@ class TracedModule(Module):
for index, inp in enumerate(expr.inputs):
if inp is call_out:
expr.inputs[index] = repl_dict[out]
repl_dict[out].users.append(expr)
continue
repl_dict[out] = call.outputs[ind]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册