提交 094601e8 编写于 作者: M Megvii Engine Team

feat(mge/distributed): allow remote grad by using grad manager

GitOrigin-RevId: a890c206a51be021735f9d2693d0b58b8d4fe7b9
上级 f7b2bdae
......@@ -127,7 +127,7 @@ class GradManager:
self._after_backward_callback.append(callback)
return self
def backward(self, ys, dys=None):
def backward(self, ys=None, dys=None):
r"""
Performs back-propagation and computes gradients.
......@@ -146,6 +146,8 @@ class GradManager:
"call a method that clears the history?"
)
assert self._grad is not None
if ys is None:
ys = []
if not isinstance(ys, (tuple, list)):
ys = [ys]
if dys is None:
......
......@@ -14,6 +14,8 @@ import weakref
import numpy as np
import megengine as mge
from ..ops.builtin import Elemwise, OpDef
from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply
......@@ -167,6 +169,8 @@ class Grad:
for i in dys:
if isinstance(i, TensorWrapperBase):
return type(i)
# use Tensor as defualt wrapper
return mge.Tensor
Wrapper = check_wrapper()
......
......@@ -59,7 +59,11 @@ def _(op: RemoteSend, inputs, outputs, input_requires_grad):
def backward(*args):
return [
remote_recv(
op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device)
op.rank_to,
inputs[0].shape,
inputs[0].dtype,
device=str(inputs[0].device),
inp=inputs[0],
)
]
......@@ -275,7 +279,11 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
def remote_recv(
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None
src_rank: int,
shape: Tuple[int],
dtype: type,
device: Optional[str] = None,
inp=None,
) -> Tensor:
"""
Receive a Tensor from a remote process.
......@@ -284,13 +292,15 @@ def remote_recv(
:param shape: the shape of the tensor to receive.
:param dtype: the data type of the tensor to receive.
:param device: the device to place the received tensor.
:param inp: dummy input to determine recved tensor type
"""
key = "{}->{}".format(src_rank, get_rank())
if device is None:
device = get_default_device()
# dummpy input
inp = tensor([0])
# dummy input
if inp == None:
inp = tensor([0])
tracer_set = get_client().check_remote_tracer(key)
for grad_manager in get_grad_managers():
if grad_manager.name in tracer_set:
......
......@@ -5,12 +5,19 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import platform
import numpy as np
import pytest
import megengine as mge
import megengine.distributed as dist
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from megengine.autodiff import GradManager
from megengine.core._imperative_rt.imperative import sync
from megengine.distributed.helper import get_device_count_by_fork
def test_basic():
......@@ -48,3 +55,47 @@ def test_attach_in_with_block():
c = b + 1
gm.backward(c)
assert int(b.grad.numpy()) == 1
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_remote_grad():
@dist.launcher
def worker():
rank = dist.get_rank()
size = dist.get_world_size()
x = mge.tensor(np.random.randn(1, rank * 2 + 2), dtype=np.float32)
m = M.Linear(rank * 2 + 2, rank * 2 + 4)
gm = GradManager().attach(m.parameters())
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9)
def train_func(x):
if rank != 0:
x = dist.functional.remote_recv(
rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32
)
print(rank, "x", x)
y = m(x)
print(rank, "y", y)
if rank != size - 1:
y = dist.functional.remote_send(y, dest_rank=rank + 1)
return y
with gm:
y = train_func(x)
if rank == size - 1:
y = y.mean()
gm.backward(y)
else:
gm.backward()
opt.step().clear_grad()
# sync because send is the last job
sync()
worker()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册