From 495472954d791bd6bde37c32dca032d434b58b6b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 18 Nov 2020 11:39:59 +0800 Subject: [PATCH] fix(trace): link io-op to avoid deadlock GitOrigin-RevId: 872cb6b7153e7ade20cf5b905eeb0fb16f1e4a65 --- .../megengine/distributed/functional.py | 2 +- imperative/python/megengine/jit/tracing.py | 76 +++++++++++++------ .../test/unit/autodiff/test_grad_manger.py | 43 ++++++----- 3 files changed, 77 insertions(+), 44 deletions(-) diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index 5a9ce213..0c0b8d2b 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -300,7 +300,7 @@ def remote_recv( device = get_default_device() # dummy input if inp == None: - inp = tensor([0]) + inp = tensor([0], device=device) tracer_set = get_client().check_remote_tracer(key) for grad_manager in get_grad_managers(): if grad_manager.name in tracer_set: diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index 927b954c..b0c32d77 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -18,7 +18,13 @@ import weakref import numpy as np from ..core._imperative_rt import GraphProfiler -from ..core._imperative_rt.ops import OprAttr +from ..core._imperative_rt.ops import ( + CollectiveComm, + OprAttr, + RemoteRecv, + RemoteSend, + VirtualDep, +) from ..core._trace_option import set_symbolic_shape from ..core._wrap import device as as_device from ..core.ops.special import Const @@ -92,6 +98,9 @@ class TensorInfo: self.data_reader = None +_io_op_types = {CollectiveComm, RemoteSend, RemoteRecv} + + class trace: """ Wraps a callable and provide: @@ -143,8 +152,8 @@ class trace: self._graph = None self._need_reset_nodes = None self._lazy_eval_graph = None - self._lazy_eval_tensors = [] - self._lazy_eval_tensor_count = 0 + self._lazy_eval_tensors = weakref.WeakSet() + self._lazy_eval_links = None self._active_tensors = weakref.WeakSet() self._tensor_remaps = None self._inputs_to_restore = None @@ -286,27 +295,22 @@ class trace: apply.enable(apply_const_symbolic_mode) self._lazy_eval_graph = G.Graph() self._apply_graph_options(self._lazy_eval_graph) + self._lazy_eval_links = () def _take_escaped_tensors(self): escaped_tensors = tuple(self._active_tensors) self._active_tensors.clear() return escaped_tensors - def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors): - active_lazy_eval_tensors = [] - visited = set() - readers = [] - for x in lazy_eval_tensors: - x = x() - if x is None or x in visited: - continue - reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] - readers.append(reader) - active_lazy_eval_tensors.append(x) - visited.add(x) - lazy_eval_graph.compile(*readers) + def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): + readers = [ + G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] + for x in lazy_eval_tensors + ] + self._apply_graph_options(lazy_eval_graph) + lazy_eval_graph.compile(*lazy_eval_links, *readers) lazy_eval_graph() - for r, x in zip(readers, active_lazy_eval_tensors): + for r, x in zip(readers, lazy_eval_tensors): assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) @contextlib.contextmanager @@ -333,11 +337,18 @@ class trace: if self._inputs_to_restore: for x in self._inputs_to_restore: x._TraceMixin__restore() - if self._symbolic and self._lazy_eval_tensors: + if self._symbolic and ( + self._lazy_eval_tensors or self._lazy_eval_links + ): # eval lazy eval tensors - self._lazy_eval(self._lazy_eval_graph, self._lazy_eval_tensors) + self._lazy_eval( + self._lazy_eval_graph, + tuple(self._lazy_eval_tensors), + self._lazy_eval_links, + ) self._lazy_eval_graph = None self._lazy_eval_tensors = None + self._lazy_eval_links = None self._untraced = False else: # compiled_tensor leaks @@ -438,8 +449,10 @@ class trace: links += opnode.outputs[1:] for op, ihandles, ohandles in self._seq: + require_links = type(op) in _io_op_types + ivars = [] - for h in ihandles: + for i, h in enumerate(ihandles): info = self._tinfo[h] if not hasattr(info, "varnode"): assert info.external @@ -455,9 +468,14 @@ class trace: ) need_reset_nodes.append(opnode) info.varnode, *links = opnode.outputs + if require_links and i == 0 and len(links) > 0: + info.varnode = apply(VirtualDep(), info.varnode, *links)[0] + links = (info.varnode,) ivars.append(info.varnode) ovars = apply(op, *ivars) + if require_links and len(ovars) > 0: + links = (ovars[0],) assert len(ovars) == len(ohandles) for h, v in zip(ohandles, ovars): info = self._tinfo[h] @@ -502,6 +520,8 @@ class trace: info.data_read = True def __call__(self, *args, **kwargs): + if is_tracing(): + return self.__wrapped__(*args, **kwargs) with self._setup(): if self._capture_as_const: self._process_inputs(*args, **kwargs) @@ -938,9 +958,21 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): or graph.make_const(x._dev_tensor()) for x in args ] + + require_links = type(op) in _io_op_types + + if require_links and active_trace._lazy_eval_links: + assert len(ivars) > 0, "op should has at least one input" + ivars[0] = apply(VirtualDep(), ivars[0], *active_trace._lazy_eval_links)[0] + active_trace._lazy_eval_links = (ivars[0],) + ovars = apply(op, *ivars) + + if require_links: + active_trace._lazy_eval_links = (ovars[0],) + outputs = [LazyEvalTensor(v) for v in ovars] - active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs) + active_trace._lazy_eval_tensors.update(outputs) return outputs @@ -951,7 +983,7 @@ apply.disable(apply_symbolic_mode) def apply_const_symbolic_mode(op: Const, *args: RawTensor): graph = active_trace._lazy_eval_graph ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) - active_trace._lazy_eval_tensors.append(weakref.ref(ret)) + active_trace._lazy_eval_tensors.add(ret) return (ret,) diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index d285e573..f47e618b 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -18,6 +18,7 @@ import megengine.optimizer as optim from megengine.autodiff import GradManager from megengine.core._imperative_rt.imperative import sync from megengine.distributed.helper import get_device_count_by_fork +from megengine.jit import trace def test_basic(): @@ -75,27 +76,27 @@ def test_remote_grad(): gm = GradManager().attach(m.parameters()) opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) + @trace(symbolic=True) def train_func(x): - if rank != 0: - x = dist.functional.remote_recv( - rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32 - ) - print(rank, "x", x) - y = m(x) - print(rank, "y", y) - if rank != size - 1: - y = dist.functional.remote_send(y, dest_rank=rank + 1) - return y - - with gm: - y = train_func(x) - if rank == size - 1: - y = y.mean() - gm.backward(y) - else: - gm.backward() - opt.step().clear_grad() - # sync because send is the last job - sync() + with gm: + if rank != 0: + x = dist.functional.remote_recv( + rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32 + ) + y = m(x) + if rank != size - 1: + y = dist.functional.remote_send(y, dest_rank=rank + 1) + if rank == size - 1: + y = y.mean() + gm.backward(y) + else: + gm.backward() + opt.step().clear_grad() + + for i in range(3): + train_func(x) + + for param in m.parameters(): + param.numpy() worker() -- GitLab