提交 76f36796 编写于 作者: M Megvii Engine Team

fix(mge/trace): fix op order in symbolic

GitOrigin-RevId: fbf081a1999dec7b9401d8898b938eec19021e98
上级 10e942d9
......@@ -124,7 +124,8 @@ class trace:
self._graph = None
self._need_reset_nodes = None
self._lazy_eval_graph = None
self._lazy_eval_tensors = weakref.WeakSet()
self._lazy_eval_tensors = []
self._lazy_eval_tensor_count = 0
self._active_tensors = weakref.WeakSet()
self._tensor_remaps = None
self._inputs_to_restore = None
......@@ -283,12 +284,18 @@ class trace:
x._TraceMixin__restore()
if self._symbolic:
# eval lazy eval tensors
lazy_eval_tensors = tuple(self._lazy_eval_tensors)
if lazy_eval_tensors:
readers = [
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
for x in lazy_eval_tensors
]
if self._lazy_eval_tensors:
lazy_eval_tensors = []
visited = set()
readers = []
for x in self._lazy_eval_tensors:
x = x()
if x is None or x in visited:
continue
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers.append(reader)
lazy_eval_tensors.append(x)
visited.add(x)
self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_graph.compile(*readers)
self._lazy_eval_graph()
......@@ -844,7 +851,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
]
ovars = apply(op, *ivars)
outputs = [LazyEvalTensor(v) for v in ovars]
active_trace._lazy_eval_tensors.update(outputs)
active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs)
return outputs
......@@ -855,7 +862,7 @@ apply.disable(apply_symbolic_mode)
def apply_const_symbolic_mode(op: Const, *args: RawTensor):
graph = active_trace._lazy_eval_graph
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
active_trace._lazy_eval_tensors.add(ret)
active_trace._lazy_eval_tensors.append(weakref.ref(ret))
return (ret,)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册