提交 49547295 编写于 作者: M Megvii Engine Team

fix(trace): link io-op to avoid deadlock

GitOrigin-RevId: 872cb6b7153e7ade20cf5b905eeb0fb16f1e4a65
上级 064c774f
......@@ -300,7 +300,7 @@ def remote_recv(
device = get_default_device()
# dummy input
if inp == None:
inp = tensor([0])
inp = tensor([0], device=device)
tracer_set = get_client().check_remote_tracer(key)
for grad_manager in get_grad_managers():
if grad_manager.name in tracer_set:
......
......@@ -18,7 +18,13 @@ import weakref
import numpy as np
from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt.ops import OprAttr
from ..core._imperative_rt.ops import (
CollectiveComm,
OprAttr,
RemoteRecv,
RemoteSend,
VirtualDep,
)
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core.ops.special import Const
......@@ -92,6 +98,9 @@ class TensorInfo:
self.data_reader = None
_io_op_types = {CollectiveComm, RemoteSend, RemoteRecv}
class trace:
"""
Wraps a callable and provide:
......@@ -143,8 +152,8 @@ class trace:
self._graph = None
self._need_reset_nodes = None
self._lazy_eval_graph = None
self._lazy_eval_tensors = []
self._lazy_eval_tensor_count = 0
self._lazy_eval_tensors = weakref.WeakSet()
self._lazy_eval_links = None
self._active_tensors = weakref.WeakSet()
self._tensor_remaps = None
self._inputs_to_restore = None
......@@ -286,27 +295,22 @@ class trace:
apply.enable(apply_const_symbolic_mode)
self._lazy_eval_graph = G.Graph()
self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_links = ()
def _take_escaped_tensors(self):
escaped_tensors = tuple(self._active_tensors)
self._active_tensors.clear()
return escaped_tensors
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors):
active_lazy_eval_tensors = []
visited = set()
readers = []
for x in lazy_eval_tensors:
x = x()
if x is None or x in visited:
continue
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers.append(reader)
active_lazy_eval_tensors.append(x)
visited.add(x)
lazy_eval_graph.compile(*readers)
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
readers = [
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
for x in lazy_eval_tensors
]
self._apply_graph_options(lazy_eval_graph)
lazy_eval_graph.compile(*lazy_eval_links, *readers)
lazy_eval_graph()
for r, x in zip(readers, active_lazy_eval_tensors):
for r, x in zip(readers, lazy_eval_tensors):
assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
@contextlib.contextmanager
......@@ -333,11 +337,18 @@ class trace:
if self._inputs_to_restore:
for x in self._inputs_to_restore:
x._TraceMixin__restore()
if self._symbolic and self._lazy_eval_tensors:
if self._symbolic and (
self._lazy_eval_tensors or self._lazy_eval_links
):
# eval lazy eval tensors
self._lazy_eval(self._lazy_eval_graph, self._lazy_eval_tensors)
self._lazy_eval(
self._lazy_eval_graph,
tuple(self._lazy_eval_tensors),
self._lazy_eval_links,
)
self._lazy_eval_graph = None
self._lazy_eval_tensors = None
self._lazy_eval_links = None
self._untraced = False
else:
# compiled_tensor leaks
......@@ -438,8 +449,10 @@ class trace:
links += opnode.outputs[1:]
for op, ihandles, ohandles in self._seq:
require_links = type(op) in _io_op_types
ivars = []
for h in ihandles:
for i, h in enumerate(ihandles):
info = self._tinfo[h]
if not hasattr(info, "varnode"):
assert info.external
......@@ -455,9 +468,14 @@ class trace:
)
need_reset_nodes.append(opnode)
info.varnode, *links = opnode.outputs
if require_links and i == 0 and len(links) > 0:
info.varnode = apply(VirtualDep(), info.varnode, *links)[0]
links = (info.varnode,)
ivars.append(info.varnode)
ovars = apply(op, *ivars)
if require_links and len(ovars) > 0:
links = (ovars[0],)
assert len(ovars) == len(ohandles)
for h, v in zip(ohandles, ovars):
info = self._tinfo[h]
......@@ -502,6 +520,8 @@ class trace:
info.data_read = True
def __call__(self, *args, **kwargs):
if is_tracing():
return self.__wrapped__(*args, **kwargs)
with self._setup():
if self._capture_as_const:
self._process_inputs(*args, **kwargs)
......@@ -938,9 +958,21 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
or graph.make_const(x._dev_tensor())
for x in args
]
require_links = type(op) in _io_op_types
if require_links and active_trace._lazy_eval_links:
assert len(ivars) > 0, "op should has at least one input"
ivars[0] = apply(VirtualDep(), ivars[0], *active_trace._lazy_eval_links)[0]
active_trace._lazy_eval_links = (ivars[0],)
ovars = apply(op, *ivars)
if require_links:
active_trace._lazy_eval_links = (ovars[0],)
outputs = [LazyEvalTensor(v) for v in ovars]
active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs)
active_trace._lazy_eval_tensors.update(outputs)
return outputs
......@@ -951,7 +983,7 @@ apply.disable(apply_symbolic_mode)
def apply_const_symbolic_mode(op: Const, *args: RawTensor):
graph = active_trace._lazy_eval_graph
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
active_trace._lazy_eval_tensors.append(weakref.ref(ret))
active_trace._lazy_eval_tensors.add(ret)
return (ret,)
......
......@@ -18,6 +18,7 @@ import megengine.optimizer as optim
from megengine.autodiff import GradManager
from megengine.core._imperative_rt.imperative import sync
from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace
def test_basic():
......@@ -75,27 +76,27 @@ def test_remote_grad():
gm = GradManager().attach(m.parameters())
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9)
@trace(symbolic=True)
def train_func(x):
if rank != 0:
x = dist.functional.remote_recv(
rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32
)
print(rank, "x", x)
y = m(x)
print(rank, "y", y)
if rank != size - 1:
y = dist.functional.remote_send(y, dest_rank=rank + 1)
return y
with gm:
y = train_func(x)
if rank == size - 1:
y = y.mean()
gm.backward(y)
else:
gm.backward()
opt.step().clear_grad()
# sync because send is the last job
sync()
with gm:
if rank != 0:
x = dist.functional.remote_recv(
rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32
)
y = m(x)
if rank != size - 1:
y = dist.functional.remote_send(y, dest_rank=rank + 1)
if rank == size - 1:
y = y.mean()
gm.backward(y)
else:
gm.backward()
opt.step().clear_grad()
for i in range(3):
train_func(x)
for param in m.parameters():
param.numpy()
worker()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册