提交 4d75f691 编写于 作者: M Megvii Engine Team

feat(mge): restore remote send/recv

GitOrigin-RevId: 8b78fd55917e319cd765ee8c895af9eeb8e9f358
上级 9c92701f
......@@ -71,7 +71,7 @@ if sys.platform == "win32":
kernel32.SetErrorMode(old_error_mode)
from .core._imperative_rt.core2 import sync, release_trace_apply_func
from .core._imperative_rt.core2 import release_trace_apply_func, sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
......
......@@ -46,9 +46,31 @@ def get_grad_managers():
return [_grad_manager_dict[key] for key in _grad_manager_dict]
class GradKey(core2.GradKey):
def __init__(self, name=None):
if name:
self.name = name
def backward(self, ys, dys):
return core2.backward(self, ys, dys)
class Grad:
def __init__(self):
self._impl = core2.GradKey()
def __init__(self, name=None):
global _grad_count
if name is None:
name = "grad_%d" % _grad_count
_grad_count += 1
self._refkeeper = []
self._impl = GradKey(name)
_grad_manager_dict[self._name] = self
@property
def _name(self):
return self._impl.name
def _is_attached_to(self, tensor):
return self._impl.is_attached_to(tensor)
def wrt(self, *tensors, callback=None):
for x in tensors:
......@@ -62,12 +84,16 @@ class Grad:
ys = [ys]
if not isinstance(dys, Sequence):
dys = [dys]
core2.backward(self._impl, ys, dys)
self._impl.backward(ys, dys)
self._refkeeper = None
def __enter__(self):
return self
def __exit__(self, _1, _2, _3):
self._refkeeper = None
del self._impl
......
......@@ -9,8 +9,8 @@
from typing import Optional, Tuple
from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import get_grad_managers
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.autodiff.grad import _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend
from ..device import get_default_device
from ..tensor import Tensor
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
......@@ -193,6 +193,48 @@ def all_to_all(
return collective_comm(inp, mode, group, device)
class _RemoteSend(PyOpBase):
def __init__(self, op: RemoteSend):
self.op = op
def _default_rule(self, data):
return apply(self.op, data)
def _grad_rule(self, data):
self.dtype = data.dtype
self.shape = data.shape
self.device = data.device
(self.dummy,) = self._default_rule(data)
return self.dummy, self.backward
def backward(self, grad):
assert grad is None
if get_client().check_is_grad(self.op.key):
return remote_recv(
self.op.rank_to,
self.shape,
self.dtype,
device=str(self.device),
inp=self.dummy,
)
class _RemoteRecv(PyOpBase):
def __init__(self, op: RemoteRecv):
self.op = op
def _default_rule(self, dummy):
return apply(self.op, dummy)
def _grad_rule(self, dummy):
return self._default_rule(dummy), self.backward
def backward(self, grad):
get_client().set_is_grad(self.op.key, grad is not None)
if grad is not None:
remote_send(grad, self.op.rank_from)
def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
"""
Send a Tensor to a remote process.
......@@ -200,11 +242,21 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
:param inp: tensor to send.
:param dest_rank: destination process rank.
"""
key = "{}->{}".format(get_rank(), dest_rank)
grad_keys = {}
for n, g in _grad_manager_dict.items():
if g._is_attached_to(inp):
grad_keys[n] = g
get_client().set_remote_tracer(key, grad_keys)
op = RemoteSend()
op.key = "{}->{}".format(get_rank(), dest_rank)
op.key = key
op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank
return apply(op, inp)[0]
(dummy,) = apply(_RemoteSend(op), inp)
for g in grad_keys.values():
g._refkeeper.append(dummy)
def remote_recv(
......@@ -228,12 +280,14 @@ def remote_recv(
if device is None:
device = get_default_device()
# dummy input
if inp == None:
if inp is None:
inp = Tensor([0], device=device)
tracer_set = get_client().check_remote_tracer(key)
for grad_manager in get_grad_managers():
if grad_manager.name in tracer_set:
grad_manager.wrt(inp)
for n in tracer_set:
g = _grad_manager_dict.get(n)
if g is not None:
g.wrt(inp)
g._refkeeper.append(inp)
op = RemoteRecv()
op.key = key
......@@ -243,4 +297,5 @@ def remote_recv(
op.addr, op.port = get_mm_server_addr()
op.rank_from = src_rank
return apply(op, inp)[0]
(ret,) = apply(_RemoteRecv(op), inp)
return ret
......@@ -193,11 +193,15 @@ struct PythonBackward {
args[i] = g ? ctx.wrap_tensor(g) : py::none();
}
auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr));
if (!input_grads) throw py::error_already_set();
if (input_grads.is_none()) return;
if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) {
if (input_size != 1) {
throw py::value_error("custom grad rule returned wrong number of grads");
}
if (!ctx.pytype) {
ctx.pytype = Py_TYPE(input_grads.ptr());
}
receiver(0, tw->m_tensor);
return;
}
......@@ -210,6 +214,9 @@ struct PythonBackward {
if (!tw) {
throw py::type_error("custom grad rule returned non-tensor");
}
if (!ctx.pytype) {
ctx.pytype = Py_TYPE(g.ptr());
}
receiver(i, tw->m_tensor);
}
}
......@@ -321,6 +328,7 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
}
auto grad_rule = py::getattr(op->obj, "_grad_rule");
auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr));
if (!pyret) throw py::error_already_set();
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
......@@ -507,8 +515,12 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
~CleanupGuard() {owner->cleanup();}
} _cleanup_guard(this);
if (tape.empty() || grads.empty()) return;
PyTypeObject* pytype = Py_TYPE(grads[0]->self().ptr());
if (tape.empty()) return;
BackwardContext bctx;
if (!grads.empty()) {
bctx.pytype = Py_TYPE(grads[0]->self().ptr());
}
for (size_t i = 0; i < tensors.size(); ++i) {
auto& grad_info = tensors[i]->m_tensor->m_grad_info;
......@@ -517,7 +529,6 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
}
}
BackwardContext bctx{pytype};
std::vector<std::shared_ptr<GradFn>> ref_keeper;
ref_keeper.reserve(tape.size());
// back-propagation in reverse order
......@@ -548,7 +559,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
}
if (!dst.producer_record.next && dst->callback && dst->grad) {
// I'm the last grad producer, invoke callback
dst->callback(TensorWrapper::make(pytype, dst->grad));
dst->callback(bctx.wrap_tensor(dst->grad));
}
}
grad_fn->clear();
......@@ -568,6 +579,31 @@ void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<T
m_key->backward(std::move(tensors), std::move(grads));
}
PyObject* GradKeyWrapper::get_name() {
return py::cast(m_key->name).release().ptr();
}
void GradKeyWrapper::set_name(py::handle name) {
m_key->name = py::cast<std::string>(name);
}
PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
if (nargs != 1) {
PyErr_SetString(PyExc_TypeError, "expect 1 argument");
return nullptr;
}
auto* tw = TensorWrapper::try_cast(args[0]);
if (!tw) {
PyErr_SetString(PyExc_TypeError, "expect Tensor");
return nullptr;
}
auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn;
if (grad_fn && grad_fn->key.lock() == m_key) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
GradKey::~GradKey() {
cleanup();
}
......
......@@ -41,8 +41,11 @@ struct GradKeyWrapper {
inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {}
PyObject* get_name();
void set_name(pybind11::handle name);
void attach(PyObject*const* args, size_t nargs);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
PyObject* is_attached_to(PyObject*const* args, size_t nargs);
};
struct BackwardContext {
......
......@@ -738,10 +738,13 @@ void init_tensor(py::module m) {
py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach")
.def<&GradKeyWrapper::is_attached_to>("is_attached_to")
.def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name")
.finalize();
if (!grad_key_type) throw py::error_already_set();
py::setattr(m, "GradKey", grad_key_type);
py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward));
m.def("backward", &GradKeyWrapper::backward);
m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode);
......
......@@ -141,6 +141,7 @@ def test_regression_1762():
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
@pytest.mark.skip(reason="FIXME: remote_send/recv")
def test_remote_grad():
@dist.launcher
def worker():
......
......@@ -16,9 +16,8 @@ import pytest
import megengine as mge
import megengine.distributed as dist
import megengine.functional as F
from megengine.core._imperative_rt import TensorAttr, core2, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply
from megengine.core._imperative_rt.imperative import sync
from megengine.core._imperative_rt import CompNode, TensorAttr, core2, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync
from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import Elemwise
from megengine.distributed.helper import get_device_count_by_fork
......@@ -73,7 +72,7 @@ def test_dist_grad():
x = as_tensor(x_np)
grad.wrt(x, callback=save_to(x))
# need a placeholder to trace operator
send_x = remote_send(x, 1)
remote_send(x, 1)
recv_x = remote_recv(1, x_np.shape, x_np.dtype)
y = recv_x * recv_x
......@@ -83,13 +82,12 @@ def test_dist_grad():
grad = Grad()
recv_x = remote_recv(0, x_np.shape, x_np.dtype)
send_x = remote_send(recv_x, 0)
remote_send(recv_x, 0)
grad([], [])
worker()
def test_grad():
x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np)
......
......@@ -14,6 +14,7 @@ import pytest
import megengine as mge
import megengine.distributed as dist
from megengine import Parameter, Tensor, tensor
from megengine.core._imperative_rt.core2 import sync
from megengine.device import get_default_device, set_default_device
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.distributed import (
......@@ -333,8 +334,8 @@ def test_io_remote():
rank = dist.get_rank()
if rank == 0: # remote send
x = Tensor(val, device="gpu0")
y = remote_send(x, 1)
assert y.numpy()[0] == 0
remote_send(x, 1)
sync()
else: # remote recv
y = remote_recv(0, val.shape, val.dtype)
assert y.device == "gpu1"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册