diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index ae998761254780abf0c8ab2c7ea5f072f64bce50..f14080d80ad2555b00781ef37363743eb7f4f7b4 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -16,7 +16,7 @@ import numpy as np import megengine as mge -from ..ops.builtin import Elemwise, OpDef +from ..ops.builtin import Elemwise, OpDef, RemoteSend from ..ops.special import Const from ..tensor.core import TensorBase, TensorWrapperBase, apply from ..tensor.function import Function @@ -84,6 +84,9 @@ class Grad: # ops forms the computational graph self.ops = [] + # save remote_send output for backward + self.remote_send_cache = [] + self._attached_tensors = weakref.WeakSet() self._enabled = True @@ -144,6 +147,7 @@ class Grad: o.clear() for i in self._attached_tensors: i._extra_data.pop(self, None) + self.remote_send_cache = [] def __exit__(self, *_): self._exit() @@ -398,6 +402,8 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): return opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) + if isinstance(op, RemoteSend): + manager.remote_send_cache.append(opnode) opnode.backward = backward outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index a6450919bc19c3c05d32d84543907fc697aa3aea..7f56e1a345c0652bc295a8411521f47bb55655ca 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -588,7 +588,7 @@ class trace: graph.options.graph_opt_level = self._graph_opt_level else: graph.options.graph_opt_level = 2 - graph.compile(*readers) + graph.compile(*readers, *links) def _reset_exec_env(self): for opnode in self._need_reset_nodes: diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index 8e0cc901e77b7fb93bc2921296477f087be8fa32..e7c59fab99b47e0bd2166519717078d48f36f42e 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -111,7 +111,6 @@ def test_remote_grad(): gm = GradManager().attach(m.parameters()) opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) - @trace(symbolic=True) def train_func(x): with gm: if rank != 0: @@ -120,18 +119,22 @@ def test_remote_grad(): ) y = m(x) if rank != size - 1: - y = dist.functional.remote_send(y, dest_rank=rank + 1) - if rank == size - 1: + dist.functional.remote_send(y, dest_rank=rank + 1) + gm.backward() + else: y = y.mean() gm.backward(y) - else: - gm.backward() opt.step().clear_grad() - for i in range(3): - train_func(x) + train_funcs = [ + train_func, + trace(symbolic=False)(train_func), + trace(symbolic=True)(train_func), + ] - for param in m.parameters(): - param.numpy() + for func in train_funcs: + for i in range(3): + func(x) + sync() worker() diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index 81dd1a0088e82a62cdee959bb800c971e5068b3a..b0285ba177fc05607b681f34e17677d74bdf7907 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -266,11 +266,20 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv( const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const OperatorNodeConfig& config) { auto&& opr = opr_.cast_final_safe(); - return RemoteRecv::make(opr.key(), *opr.owner_graph(), - opr.group_client(), config, inputs[0]->shape(), - inputs[0]->dtype()) - .node() - ->owner_opr(); + if (inputs.size() == 1) { + return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(), + opr.group_client(), config, opr.shape(), + opr.dtype()) + .node() + ->owner_opr(); + } else { + mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input"); + return RemoteRecv::make(opr.key(), *opr.owner_graph(), + opr.group_client(), config, opr.shape(), + opr.dtype()) + .node() + ->owner_opr(); + } } MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv); diff --git a/src/opr-mm/include/megbrain/opr/io_remote.h b/src/opr-mm/include/megbrain/opr/io_remote.h index 335e26489a34b7fcc919ae64d936fe9fe079eef7..a5bb9ec6dbb6164f6e37b587ea5a10d15361b168 100644 --- a/src/opr-mm/include/megbrain/opr/io_remote.h +++ b/src/opr-mm/include/megbrain/opr/io_remote.h @@ -94,6 +94,9 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { const OperatorNodeConfig& config, const TensorShape& shape, DType dtype); + const TensorShape& shape() const { return m_shape; } + const DType& dtype() const { return m_dtype; } + private: const TensorShape m_shape; const DType m_dtype;