From 2ad8c5e1e918f8fa9fad0910e3f82c78b78463b4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 25 Nov 2020 14:49:52 +0800 Subject: [PATCH] fix(mge/io_remote): fix remote send/recv gradient at trace GitOrigin-RevId: 7886efd0c124b1a6f60046c9f876e457eb683b1d --- .../python/megengine/core/autodiff/grad.py | 8 ++++++- imperative/python/megengine/jit/tracing.py | 2 +- .../test/unit/autodiff/test_grad_manger.py | 21 +++++++++++-------- src/opr-mm/impl/io_remote.cpp | 19 ++++++++++++----- src/opr-mm/include/megbrain/opr/io_remote.h | 3 +++ 5 files changed, 37 insertions(+), 16 deletions(-) diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index ae998761..f14080d8 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -16,7 +16,7 @@ import numpy as np import megengine as mge -from ..ops.builtin import Elemwise, OpDef +from ..ops.builtin import Elemwise, OpDef, RemoteSend from ..ops.special import Const from ..tensor.core import TensorBase, TensorWrapperBase, apply from ..tensor.function import Function @@ -84,6 +84,9 @@ class Grad: # ops forms the computational graph self.ops = [] + # save remote_send output for backward + self.remote_send_cache = [] + self._attached_tensors = weakref.WeakSet() self._enabled = True @@ -144,6 +147,7 @@ class Grad: o.clear() for i in self._attached_tensors: i._extra_data.pop(self, None) + self.remote_send_cache = [] def __exit__(self, *_): self._exit() @@ -398,6 +402,8 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]): return opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs) + if isinstance(op, RemoteSend): + manager.remote_send_cache.append(opnode) opnode.backward = backward outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)] diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index a6450919..7f56e1a3 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -588,7 +588,7 @@ class trace: graph.options.graph_opt_level = self._graph_opt_level else: graph.options.graph_opt_level = 2 - graph.compile(*readers) + graph.compile(*readers, *links) def _reset_exec_env(self): for opnode in self._need_reset_nodes: diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index 8e0cc901..e7c59fab 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -111,7 +111,6 @@ def test_remote_grad(): gm = GradManager().attach(m.parameters()) opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) - @trace(symbolic=True) def train_func(x): with gm: if rank != 0: @@ -120,18 +119,22 @@ def test_remote_grad(): ) y = m(x) if rank != size - 1: - y = dist.functional.remote_send(y, dest_rank=rank + 1) - if rank == size - 1: + dist.functional.remote_send(y, dest_rank=rank + 1) + gm.backward() + else: y = y.mean() gm.backward(y) - else: - gm.backward() opt.step().clear_grad() - for i in range(3): - train_func(x) + train_funcs = [ + train_func, + trace(symbolic=False)(train_func), + trace(symbolic=True)(train_func), + ] - for param in m.parameters(): - param.numpy() + for func in train_funcs: + for i in range(3): + func(x) + sync() worker() diff --git a/src/opr-mm/impl/io_remote.cpp b/src/opr-mm/impl/io_remote.cpp index 81dd1a00..b0285ba1 100644 --- a/src/opr-mm/impl/io_remote.cpp +++ b/src/opr-mm/impl/io_remote.cpp @@ -266,11 +266,20 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv( const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const OperatorNodeConfig& config) { auto&& opr = opr_.cast_final_safe(); - return RemoteRecv::make(opr.key(), *opr.owner_graph(), - opr.group_client(), config, inputs[0]->shape(), - inputs[0]->dtype()) - .node() - ->owner_opr(); + if (inputs.size() == 1) { + return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(), + opr.group_client(), config, opr.shape(), + opr.dtype()) + .node() + ->owner_opr(); + } else { + mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input"); + return RemoteRecv::make(opr.key(), *opr.owner_graph(), + opr.group_client(), config, opr.shape(), + opr.dtype()) + .node() + ->owner_opr(); + } } MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv); diff --git a/src/opr-mm/include/megbrain/opr/io_remote.h b/src/opr-mm/include/megbrain/opr/io_remote.h index 335e2648..a5bb9ec6 100644 --- a/src/opr-mm/include/megbrain/opr/io_remote.h +++ b/src/opr-mm/include/megbrain/opr/io_remote.h @@ -94,6 +94,9 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // { const OperatorNodeConfig& config, const TensorShape& shape, DType dtype); + const TensorShape& shape() const { return m_shape; } + const DType& dtype() const { return m_dtype; } + private: const TensorShape m_shape; const DType m_dtype; -- GitLab