提交 76f36796 编写于 作者: M Megvii Engine Team

fix(mge/trace): fix op order in symbolic

GitOrigin-RevId: fbf081a1999dec7b9401d8898b938eec19021e98
上级 10e942d9
...@@ -124,7 +124,8 @@ class trace: ...@@ -124,7 +124,8 @@ class trace:
self._graph = None self._graph = None
self._need_reset_nodes = None self._need_reset_nodes = None
self._lazy_eval_graph = None self._lazy_eval_graph = None
self._lazy_eval_tensors = weakref.WeakSet() self._lazy_eval_tensors = []
self._lazy_eval_tensor_count = 0
self._active_tensors = weakref.WeakSet() self._active_tensors = weakref.WeakSet()
self._tensor_remaps = None self._tensor_remaps = None
self._inputs_to_restore = None self._inputs_to_restore = None
...@@ -283,12 +284,18 @@ class trace: ...@@ -283,12 +284,18 @@ class trace:
x._TraceMixin__restore() x._TraceMixin__restore()
if self._symbolic: if self._symbolic:
# eval lazy eval tensors # eval lazy eval tensors
lazy_eval_tensors = tuple(self._lazy_eval_tensors) if self._lazy_eval_tensors:
if lazy_eval_tensors: lazy_eval_tensors = []
readers = [ visited = set()
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] readers = []
for x in lazy_eval_tensors for x in self._lazy_eval_tensors:
] x = x()
if x is None or x in visited:
continue
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers.append(reader)
lazy_eval_tensors.append(x)
visited.add(x)
self._apply_graph_options(self._lazy_eval_graph) self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_graph.compile(*readers) self._lazy_eval_graph.compile(*readers)
self._lazy_eval_graph() self._lazy_eval_graph()
...@@ -844,7 +851,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): ...@@ -844,7 +851,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
] ]
ovars = apply(op, *ivars) ovars = apply(op, *ivars)
outputs = [LazyEvalTensor(v) for v in ovars] outputs = [LazyEvalTensor(v) for v in ovars]
active_trace._lazy_eval_tensors.update(outputs) active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs)
return outputs return outputs
...@@ -855,7 +862,7 @@ apply.disable(apply_symbolic_mode) ...@@ -855,7 +862,7 @@ apply.disable(apply_symbolic_mode)
def apply_const_symbolic_mode(op: Const, *args: RawTensor): def apply_const_symbolic_mode(op: Const, *args: RawTensor):
graph = active_trace._lazy_eval_graph graph = active_trace._lazy_eval_graph
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
active_trace._lazy_eval_tensors.add(ret) active_trace._lazy_eval_tensors.append(weakref.ref(ret))
return (ret,) return (ret,)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册