提交 2ad8c5e1 编写于 作者: M Megvii Engine Team

fix(mge/io_remote): fix remote send/recv gradient at trace

GitOrigin-RevId: 7886efd0c124b1a6f60046c9f876e457eb683b1d
上级 f470df4f
......@@ -16,7 +16,7 @@ import numpy as np
import megengine as mge
from ..ops.builtin import Elemwise, OpDef
from ..ops.builtin import Elemwise, OpDef, RemoteSend
from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply
from ..tensor.function import Function
......@@ -84,6 +84,9 @@ class Grad:
# ops forms the computational graph
self.ops = []
# save remote_send output for backward
self.remote_send_cache = []
self._attached_tensors = weakref.WeakSet()
self._enabled = True
......@@ -144,6 +147,7 @@ class Grad:
o.clear()
for i in self._attached_tensors:
i._extra_data.pop(self, None)
self.remote_send_cache = []
def __exit__(self, *_):
self._exit()
......@@ -398,6 +402,8 @@ def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
return
opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
if isinstance(op, RemoteSend):
manager.remote_send_cache.append(opnode)
opnode.backward = backward
outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)]
......
......@@ -588,7 +588,7 @@ class trace:
graph.options.graph_opt_level = self._graph_opt_level
else:
graph.options.graph_opt_level = 2
graph.compile(*readers)
graph.compile(*readers, *links)
def _reset_exec_env(self):
for opnode in self._need_reset_nodes:
......
......@@ -111,7 +111,6 @@ def test_remote_grad():
gm = GradManager().attach(m.parameters())
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9)
@trace(symbolic=True)
def train_func(x):
with gm:
if rank != 0:
......@@ -120,18 +119,22 @@ def test_remote_grad():
)
y = m(x)
if rank != size - 1:
y = dist.functional.remote_send(y, dest_rank=rank + 1)
if rank == size - 1:
dist.functional.remote_send(y, dest_rank=rank + 1)
gm.backward()
else:
y = y.mean()
gm.backward(y)
else:
gm.backward()
opt.step().clear_grad()
for i in range(3):
train_func(x)
train_funcs = [
train_func,
trace(symbolic=False)(train_func),
trace(symbolic=True)(train_func),
]
for param in m.parameters():
param.numpy()
for func in train_funcs:
for i in range(3):
func(x)
sync()
worker()
......@@ -266,11 +266,20 @@ cg::OperatorNodeBase* opr_shallow_copy_remote_recv(
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto&& opr = opr_.cast_final_safe<RemoteRecv>();
return RemoteRecv::make(opr.key(), *opr.owner_graph(),
opr.group_client(), config, inputs[0]->shape(),
inputs[0]->dtype())
.node()
->owner_opr();
if (inputs.size() == 1) {
return RemoteRecv::make(opr.key(), inputs[0], *opr.owner_graph(),
opr.group_client(), config, opr.shape(),
opr.dtype())
.node()
->owner_opr();
} else {
mgb_assert(inputs.size() == 0, "recv should have 1 or 0 input");
return RemoteRecv::make(opr.key(), *opr.owner_graph(),
opr.group_client(), config, opr.shape(),
opr.dtype())
.node()
->owner_opr();
}
}
MGB_REG_OPR_SHALLOW_COPY(RemoteRecv, opr_shallow_copy_remote_recv);
......
......@@ -94,6 +94,9 @@ MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
const OperatorNodeConfig& config, const TensorShape& shape,
DType dtype);
const TensorShape& shape() const { return m_shape; }
const DType& dtype() const { return m_dtype; }
private:
const TensorShape m_shape;
const DType m_dtype;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册