提交 4d75f691 编写于 作者: M Megvii Engine Team

feat(mge): restore remote send/recv

GitOrigin-RevId: 8b78fd55917e319cd765ee8c895af9eeb8e9f358
上级 9c92701f
...@@ -71,7 +71,7 @@ if sys.platform == "win32": ...@@ -71,7 +71,7 @@ if sys.platform == "win32":
kernel32.SetErrorMode(old_error_mode) kernel32.SetErrorMode(old_error_mode)
from .core._imperative_rt.core2 import sync, release_trace_apply_func from .core._imperative_rt.core2 import release_trace_apply_func, sync
from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func from .core._imperative_rt.utils import _set_fork_exec_path_for_timed_func
from .device import * from .device import *
from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
......
...@@ -46,9 +46,31 @@ def get_grad_managers(): ...@@ -46,9 +46,31 @@ def get_grad_managers():
return [_grad_manager_dict[key] for key in _grad_manager_dict] return [_grad_manager_dict[key] for key in _grad_manager_dict]
class GradKey(core2.GradKey):
def __init__(self, name=None):
if name:
self.name = name
def backward(self, ys, dys):
return core2.backward(self, ys, dys)
class Grad: class Grad:
def __init__(self): def __init__(self, name=None):
self._impl = core2.GradKey() global _grad_count
if name is None:
name = "grad_%d" % _grad_count
_grad_count += 1
self._refkeeper = []
self._impl = GradKey(name)
_grad_manager_dict[self._name] = self
@property
def _name(self):
return self._impl.name
def _is_attached_to(self, tensor):
return self._impl.is_attached_to(tensor)
def wrt(self, *tensors, callback=None): def wrt(self, *tensors, callback=None):
for x in tensors: for x in tensors:
...@@ -62,12 +84,16 @@ class Grad: ...@@ -62,12 +84,16 @@ class Grad:
ys = [ys] ys = [ys]
if not isinstance(dys, Sequence): if not isinstance(dys, Sequence):
dys = [dys] dys = [dys]
core2.backward(self._impl, ys, dys)
self._impl.backward(ys, dys)
self._refkeeper = None
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, _1, _2, _3): def __exit__(self, _1, _2, _3):
self._refkeeper = None
del self._impl del self._impl
......
...@@ -9,8 +9,8 @@ ...@@ -9,8 +9,8 @@
from typing import Optional, Tuple from typing import Optional, Tuple
from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import get_grad_managers from ..core.autodiff.grad import _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend
from ..device import get_default_device from ..device import get_default_device
from ..tensor import Tensor from ..tensor import Tensor
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
...@@ -193,6 +193,48 @@ def all_to_all( ...@@ -193,6 +193,48 @@ def all_to_all(
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)
class _RemoteSend(PyOpBase):
def __init__(self, op: RemoteSend):
self.op = op
def _default_rule(self, data):
return apply(self.op, data)
def _grad_rule(self, data):
self.dtype = data.dtype
self.shape = data.shape
self.device = data.device
(self.dummy,) = self._default_rule(data)
return self.dummy, self.backward
def backward(self, grad):
assert grad is None
if get_client().check_is_grad(self.op.key):
return remote_recv(
self.op.rank_to,
self.shape,
self.dtype,
device=str(self.device),
inp=self.dummy,
)
class _RemoteRecv(PyOpBase):
def __init__(self, op: RemoteRecv):
self.op = op
def _default_rule(self, dummy):
return apply(self.op, dummy)
def _grad_rule(self, dummy):
return self._default_rule(dummy), self.backward
def backward(self, grad):
get_client().set_is_grad(self.op.key, grad is not None)
if grad is not None:
remote_send(grad, self.op.rank_from)
def remote_send(inp: Tensor, dest_rank: int) -> Tensor: def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
""" """
Send a Tensor to a remote process. Send a Tensor to a remote process.
...@@ -200,11 +242,21 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: ...@@ -200,11 +242,21 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
:param inp: tensor to send. :param inp: tensor to send.
:param dest_rank: destination process rank. :param dest_rank: destination process rank.
""" """
key = "{}->{}".format(get_rank(), dest_rank)
grad_keys = {}
for n, g in _grad_manager_dict.items():
if g._is_attached_to(inp):
grad_keys[n] = g
get_client().set_remote_tracer(key, grad_keys)
op = RemoteSend() op = RemoteSend()
op.key = "{}->{}".format(get_rank(), dest_rank) op.key = key
op.addr, op.port = get_mm_server_addr() op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank op.rank_to = dest_rank
return apply(op, inp)[0] (dummy,) = apply(_RemoteSend(op), inp)
for g in grad_keys.values():
g._refkeeper.append(dummy)
def remote_recv( def remote_recv(
...@@ -228,12 +280,14 @@ def remote_recv( ...@@ -228,12 +280,14 @@ def remote_recv(
if device is None: if device is None:
device = get_default_device() device = get_default_device()
# dummy input # dummy input
if inp == None: if inp is None:
inp = Tensor([0], device=device) inp = Tensor([0], device=device)
tracer_set = get_client().check_remote_tracer(key) tracer_set = get_client().check_remote_tracer(key)
for grad_manager in get_grad_managers(): for n in tracer_set:
if grad_manager.name in tracer_set: g = _grad_manager_dict.get(n)
grad_manager.wrt(inp) if g is not None:
g.wrt(inp)
g._refkeeper.append(inp)
op = RemoteRecv() op = RemoteRecv()
op.key = key op.key = key
...@@ -243,4 +297,5 @@ def remote_recv( ...@@ -243,4 +297,5 @@ def remote_recv(
op.addr, op.port = get_mm_server_addr() op.addr, op.port = get_mm_server_addr()
op.rank_from = src_rank op.rank_from = src_rank
return apply(op, inp)[0] (ret,) = apply(_RemoteRecv(op), inp)
return ret
...@@ -193,11 +193,15 @@ struct PythonBackward { ...@@ -193,11 +193,15 @@ struct PythonBackward {
args[i] = g ? ctx.wrap_tensor(g) : py::none(); args[i] = g ? ctx.wrap_tensor(g) : py::none();
} }
auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr)); auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr));
if (!input_grads) throw py::error_already_set();
if (input_grads.is_none()) return; if (input_grads.is_none()) return;
if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) { if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) {
if (input_size != 1) { if (input_size != 1) {
throw py::value_error("custom grad rule returned wrong number of grads"); throw py::value_error("custom grad rule returned wrong number of grads");
} }
if (!ctx.pytype) {
ctx.pytype = Py_TYPE(input_grads.ptr());
}
receiver(0, tw->m_tensor); receiver(0, tw->m_tensor);
return; return;
} }
...@@ -210,6 +214,9 @@ struct PythonBackward { ...@@ -210,6 +214,9 @@ struct PythonBackward {
if (!tw) { if (!tw) {
throw py::type_error("custom grad rule returned non-tensor"); throw py::type_error("custom grad rule returned non-tensor");
} }
if (!ctx.pytype) {
ctx.pytype = Py_TYPE(g.ptr());
}
receiver(i, tw->m_tensor); receiver(i, tw->m_tensor);
} }
} }
...@@ -321,6 +328,7 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { ...@@ -321,6 +328,7 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
} }
auto grad_rule = py::getattr(op->obj, "_grad_rule"); auto grad_rule = py::getattr(op->obj, "_grad_rule");
auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr)); auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr));
if (!pyret) throw py::error_already_set();
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret); auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret);
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs); ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs);
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) {
...@@ -507,8 +515,12 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr ...@@ -507,8 +515,12 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
~CleanupGuard() {owner->cleanup();} ~CleanupGuard() {owner->cleanup();}
} _cleanup_guard(this); } _cleanup_guard(this);
if (tape.empty() || grads.empty()) return; if (tape.empty()) return;
PyTypeObject* pytype = Py_TYPE(grads[0]->self().ptr());
BackwardContext bctx;
if (!grads.empty()) {
bctx.pytype = Py_TYPE(grads[0]->self().ptr());
}
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
auto& grad_info = tensors[i]->m_tensor->m_grad_info; auto& grad_info = tensors[i]->m_tensor->m_grad_info;
...@@ -517,7 +529,6 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr ...@@ -517,7 +529,6 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
} }
} }
BackwardContext bctx{pytype};
std::vector<std::shared_ptr<GradFn>> ref_keeper; std::vector<std::shared_ptr<GradFn>> ref_keeper;
ref_keeper.reserve(tape.size()); ref_keeper.reserve(tape.size());
// back-propagation in reverse order // back-propagation in reverse order
...@@ -548,7 +559,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr ...@@ -548,7 +559,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
} }
if (!dst.producer_record.next && dst->callback && dst->grad) { if (!dst.producer_record.next && dst->callback && dst->grad) {
// I'm the last grad producer, invoke callback // I'm the last grad producer, invoke callback
dst->callback(TensorWrapper::make(pytype, dst->grad)); dst->callback(bctx.wrap_tensor(dst->grad));
} }
} }
grad_fn->clear(); grad_fn->clear();
...@@ -568,6 +579,31 @@ void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<T ...@@ -568,6 +579,31 @@ void GradKeyWrapper::backward(std::vector<TensorWrapper*> tensors, std::vector<T
m_key->backward(std::move(tensors), std::move(grads)); m_key->backward(std::move(tensors), std::move(grads));
} }
PyObject* GradKeyWrapper::get_name() {
return py::cast(m_key->name).release().ptr();
}
void GradKeyWrapper::set_name(py::handle name) {
m_key->name = py::cast<std::string>(name);
}
PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
if (nargs != 1) {
PyErr_SetString(PyExc_TypeError, "expect 1 argument");
return nullptr;
}
auto* tw = TensorWrapper::try_cast(args[0]);
if (!tw) {
PyErr_SetString(PyExc_TypeError, "expect Tensor");
return nullptr;
}
auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn;
if (grad_fn && grad_fn->key.lock() == m_key) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
GradKey::~GradKey() { GradKey::~GradKey() {
cleanup(); cleanup();
} }
......
...@@ -41,8 +41,11 @@ struct GradKeyWrapper { ...@@ -41,8 +41,11 @@ struct GradKeyWrapper {
inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {} inline GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {}
PyObject* get_name();
void set_name(pybind11::handle name);
void attach(PyObject*const* args, size_t nargs); void attach(PyObject*const* args, size_t nargs);
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>);
PyObject* is_attached_to(PyObject*const* args, size_t nargs);
}; };
struct BackwardContext { struct BackwardContext {
......
...@@ -733,15 +733,18 @@ void init_tensor(py::module m) { ...@@ -733,15 +733,18 @@ void init_tensor(py::module m) {
py_task_q.wait_all_task_finish(); py_task_q.wait_all_task_finish();
}, },
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("release_trace_apply_func", &release_trace_apply_func); m.def("release_trace_apply_func", &release_trace_apply_func);
py::handle grad_key_type = GradKeyWrapper::wrap_t::type() py::handle grad_key_type = GradKeyWrapper::wrap_t::type()
.def<&GradKeyWrapper::attach>("attach") .def<&GradKeyWrapper::attach>("attach")
.def<&GradKeyWrapper::is_attached_to>("is_attached_to")
.def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>("name")
.finalize(); .finalize();
if (!grad_key_type) throw py::error_already_set(); if (!grad_key_type) throw py::error_already_set();
py::setattr(m, "GradKey", grad_key_type); py::setattr(m, "GradKey", grad_key_type);
py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward)); m.def("backward", &GradKeyWrapper::backward);
m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing); m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing); m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode); m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode);
......
...@@ -141,6 +141,7 @@ def test_regression_1762(): ...@@ -141,6 +141,7 @@ def test_regression_1762():
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
@pytest.mark.skip(reason="FIXME: remote_send/recv")
def test_remote_grad(): def test_remote_grad():
@dist.launcher @dist.launcher
def worker(): def worker():
......
...@@ -16,9 +16,8 @@ import pytest ...@@ -16,9 +16,8 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
import megengine.functional as F import megengine.functional as F
from megengine.core._imperative_rt import TensorAttr, core2, imperative from megengine.core._imperative_rt import CompNode, TensorAttr, core2, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync
from megengine.core._imperative_rt.imperative import sync
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import Elemwise from megengine.core.ops.builtin import Elemwise
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
...@@ -73,7 +72,7 @@ def test_dist_grad(): ...@@ -73,7 +72,7 @@ def test_dist_grad():
x = as_tensor(x_np) x = as_tensor(x_np)
grad.wrt(x, callback=save_to(x)) grad.wrt(x, callback=save_to(x))
# need a placeholder to trace operator # need a placeholder to trace operator
send_x = remote_send(x, 1) remote_send(x, 1)
recv_x = remote_recv(1, x_np.shape, x_np.dtype) recv_x = remote_recv(1, x_np.shape, x_np.dtype)
y = recv_x * recv_x y = recv_x * recv_x
...@@ -83,13 +82,12 @@ def test_dist_grad(): ...@@ -83,13 +82,12 @@ def test_dist_grad():
grad = Grad() grad = Grad()
recv_x = remote_recv(0, x_np.shape, x_np.dtype) recv_x = remote_recv(0, x_np.shape, x_np.dtype)
send_x = remote_send(recv_x, 0) remote_send(recv_x, 0)
grad([], []) grad([], [])
worker() worker()
def test_grad(): def test_grad():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np) x = as_tensor(x_np)
......
...@@ -14,6 +14,7 @@ import pytest ...@@ -14,6 +14,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine import Parameter, Tensor, tensor from megengine import Parameter, Tensor, tensor
from megengine.core._imperative_rt.core2 import sync
from megengine.device import get_default_device, set_default_device from megengine.device import get_default_device, set_default_device
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.distributed import ( from megengine.functional.distributed import (
...@@ -333,8 +334,8 @@ def test_io_remote(): ...@@ -333,8 +334,8 @@ def test_io_remote():
rank = dist.get_rank() rank = dist.get_rank()
if rank == 0: # remote send if rank == 0: # remote send
x = Tensor(val, device="gpu0") x = Tensor(val, device="gpu0")
y = remote_send(x, 1) remote_send(x, 1)
assert y.numpy()[0] == 0 sync()
else: # remote recv else: # remote recv
y = remote_recv(0, val.shape, val.dtype) y = remote_recv(0, val.shape, val.dtype)
assert y.device == "gpu1" assert y.device == "gpu1"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册