提交 49547295 编写于 作者: M Megvii Engine Team

fix(trace): link io-op to avoid deadlock

GitOrigin-RevId: 872cb6b7153e7ade20cf5b905eeb0fb16f1e4a65
上级 064c774f
...@@ -300,7 +300,7 @@ def remote_recv( ...@@ -300,7 +300,7 @@ def remote_recv(
device = get_default_device() device = get_default_device()
# dummy input # dummy input
if inp == None: if inp == None:
inp = tensor([0]) inp = tensor([0], device=device)
tracer_set = get_client().check_remote_tracer(key) tracer_set = get_client().check_remote_tracer(key)
for grad_manager in get_grad_managers(): for grad_manager in get_grad_managers():
if grad_manager.name in tracer_set: if grad_manager.name in tracer_set:
......
...@@ -18,7 +18,13 @@ import weakref ...@@ -18,7 +18,13 @@ import weakref
import numpy as np import numpy as np
from ..core._imperative_rt import GraphProfiler from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import OprAttr from ..core._imperative_rt.ops import (
CollectiveComm,
OprAttr,
RemoteRecv,
RemoteSend,
VirtualDep,
)
from ..core._trace_option import set_symbolic_shape from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device from ..core._wrap import device as as_device
from ..core.ops.special import Const from ..core.ops.special import Const
...@@ -92,6 +98,9 @@ class TensorInfo: ...@@ -92,6 +98,9 @@ class TensorInfo:
self.data_reader = None self.data_reader = None
_io_op_types = {CollectiveComm, RemoteSend, RemoteRecv}
class trace: class trace:
""" """
Wraps a callable and provide: Wraps a callable and provide:
...@@ -143,8 +152,8 @@ class trace: ...@@ -143,8 +152,8 @@ class trace:
self._graph = None self._graph = None
self._need_reset_nodes = None self._need_reset_nodes = None
self._lazy_eval_graph = None self._lazy_eval_graph = None
self._lazy_eval_tensors = [] self._lazy_eval_tensors = weakref.WeakSet()
self._lazy_eval_tensor_count = 0 self._lazy_eval_links = None
self._active_tensors = weakref.WeakSet() self._active_tensors = weakref.WeakSet()
self._tensor_remaps = None self._tensor_remaps = None
self._inputs_to_restore = None self._inputs_to_restore = None
...@@ -286,27 +295,22 @@ class trace: ...@@ -286,27 +295,22 @@ class trace:
apply.enable(apply_const_symbolic_mode) apply.enable(apply_const_symbolic_mode)
self._lazy_eval_graph = G.Graph() self._lazy_eval_graph = G.Graph()
self._apply_graph_options(self._lazy_eval_graph) self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_links = ()
def _take_escaped_tensors(self): def _take_escaped_tensors(self):
escaped_tensors = tuple(self._active_tensors) escaped_tensors = tuple(self._active_tensors)
self._active_tensors.clear() self._active_tensors.clear()
return escaped_tensors return escaped_tensors
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors): def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
active_lazy_eval_tensors = [] readers = [
visited = set() G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers = [] for x in lazy_eval_tensors
for x in lazy_eval_tensors: ]
x = x() self._apply_graph_options(lazy_eval_graph)
if x is None or x in visited: lazy_eval_graph.compile(*lazy_eval_links, *readers)
continue
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers.append(reader)
active_lazy_eval_tensors.append(x)
visited.add(x)
lazy_eval_graph.compile(*readers)
lazy_eval_graph() lazy_eval_graph()
for r, x in zip(readers, active_lazy_eval_tensors): for r, x in zip(readers, lazy_eval_tensors):
assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
@contextlib.contextmanager @contextlib.contextmanager
...@@ -333,11 +337,18 @@ class trace: ...@@ -333,11 +337,18 @@ class trace:
if self._inputs_to_restore: if self._inputs_to_restore:
for x in self._inputs_to_restore: for x in self._inputs_to_restore:
x._TraceMixin__restore() x._TraceMixin__restore()
if self._symbolic and self._lazy_eval_tensors: if self._symbolic and (
self._lazy_eval_tensors or self._lazy_eval_links
):
# eval lazy eval tensors # eval lazy eval tensors
self._lazy_eval(self._lazy_eval_graph, self._lazy_eval_tensors) self._lazy_eval(
self._lazy_eval_graph,
tuple(self._lazy_eval_tensors),
self._lazy_eval_links,
)
self._lazy_eval_graph = None self._lazy_eval_graph = None
self._lazy_eval_tensors = None self._lazy_eval_tensors = None
self._lazy_eval_links = None
self._untraced = False self._untraced = False
else: else:
# compiled_tensor leaks # compiled_tensor leaks
...@@ -438,8 +449,10 @@ class trace: ...@@ -438,8 +449,10 @@ class trace:
links += opnode.outputs[1:] links += opnode.outputs[1:]
for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
require_links = type(op) in _io_op_types
ivars = [] ivars = []
for h in ihandles: for i, h in enumerate(ihandles):
info = self._tinfo[h] info = self._tinfo[h]
if not hasattr(info, "varnode"): if not hasattr(info, "varnode"):
assert info.external assert info.external
...@@ -455,9 +468,14 @@ class trace: ...@@ -455,9 +468,14 @@ class trace:
) )
need_reset_nodes.append(opnode) need_reset_nodes.append(opnode)
info.varnode, *links = opnode.outputs info.varnode, *links = opnode.outputs
if require_links and i == 0 and len(links) > 0:
info.varnode = apply(VirtualDep(), info.varnode, *links)[0]
links = (info.varnode,)
ivars.append(info.varnode) ivars.append(info.varnode)
ovars = apply(op, *ivars) ovars = apply(op, *ivars)
if require_links and len(ovars) > 0:
links = (ovars[0],)
assert len(ovars) == len(ohandles) assert len(ovars) == len(ohandles)
for h, v in zip(ohandles, ovars): for h, v in zip(ohandles, ovars):
info = self._tinfo[h] info = self._tinfo[h]
...@@ -502,6 +520,8 @@ class trace: ...@@ -502,6 +520,8 @@ class trace:
info.data_read = True info.data_read = True
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if is_tracing():
return self.__wrapped__(*args, **kwargs)
with self._setup(): with self._setup():
if self._capture_as_const: if self._capture_as_const:
self._process_inputs(*args, **kwargs) self._process_inputs(*args, **kwargs)
...@@ -938,9 +958,21 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): ...@@ -938,9 +958,21 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
or graph.make_const(x._dev_tensor()) or graph.make_const(x._dev_tensor())
for x in args for x in args
] ]
require_links = type(op) in _io_op_types
if require_links and active_trace._lazy_eval_links:
assert len(ivars) > 0, "op should has at least one input"
ivars[0] = apply(VirtualDep(), ivars[0], *active_trace._lazy_eval_links)[0]
active_trace._lazy_eval_links = (ivars[0],)
ovars = apply(op, *ivars) ovars = apply(op, *ivars)
if require_links:
active_trace._lazy_eval_links = (ovars[0],)
outputs = [LazyEvalTensor(v) for v in ovars] outputs = [LazyEvalTensor(v) for v in ovars]
active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs) active_trace._lazy_eval_tensors.update(outputs)
return outputs return outputs
...@@ -951,7 +983,7 @@ apply.disable(apply_symbolic_mode) ...@@ -951,7 +983,7 @@ apply.disable(apply_symbolic_mode)
def apply_const_symbolic_mode(op: Const, *args: RawTensor): def apply_const_symbolic_mode(op: Const, *args: RawTensor):
graph = active_trace._lazy_eval_graph graph = active_trace._lazy_eval_graph
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
active_trace._lazy_eval_tensors.append(weakref.ref(ret)) active_trace._lazy_eval_tensors.add(ret)
return (ret,) return (ret,)
......
...@@ -18,6 +18,7 @@ import megengine.optimizer as optim ...@@ -18,6 +18,7 @@ import megengine.optimizer as optim
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.core._imperative_rt.imperative import sync from megengine.core._imperative_rt.imperative import sync
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace
def test_basic(): def test_basic():
...@@ -75,27 +76,27 @@ def test_remote_grad(): ...@@ -75,27 +76,27 @@ def test_remote_grad():
gm = GradManager().attach(m.parameters()) gm = GradManager().attach(m.parameters())
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9)
@trace(symbolic=True)
def train_func(x): def train_func(x):
with gm:
if rank != 0: if rank != 0:
x = dist.functional.remote_recv( x = dist.functional.remote_recv(
rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32 rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32
) )
print(rank, "x", x)
y = m(x) y = m(x)
print(rank, "y", y)
if rank != size - 1: if rank != size - 1:
y = dist.functional.remote_send(y, dest_rank=rank + 1) y = dist.functional.remote_send(y, dest_rank=rank + 1)
return y
with gm:
y = train_func(x)
if rank == size - 1: if rank == size - 1:
y = y.mean() y = y.mean()
gm.backward(y) gm.backward(y)
else: else:
gm.backward() gm.backward()
opt.step().clear_grad() opt.step().clear_grad()
# sync because send is the last job
sync() for i in range(3):
train_func(x)
for param in m.parameters():
param.numpy()
worker() worker()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册