提交 094601e8 编写于 作者: M Megvii Engine Team

feat(mge/distributed): allow remote grad by using grad manager

GitOrigin-RevId: a890c206a51be021735f9d2693d0b58b8d4fe7b9
上级 f7b2bdae
...@@ -127,7 +127,7 @@ class GradManager: ...@@ -127,7 +127,7 @@ class GradManager:
self._after_backward_callback.append(callback) self._after_backward_callback.append(callback)
return self return self
def backward(self, ys, dys=None): def backward(self, ys=None, dys=None):
r""" r"""
Performs back-propagation and computes gradients. Performs back-propagation and computes gradients.
...@@ -146,6 +146,8 @@ class GradManager: ...@@ -146,6 +146,8 @@ class GradManager:
"call a method that clears the history?" "call a method that clears the history?"
) )
assert self._grad is not None assert self._grad is not None
if ys is None:
ys = []
if not isinstance(ys, (tuple, list)): if not isinstance(ys, (tuple, list)):
ys = [ys] ys = [ys]
if dys is None: if dys is None:
......
...@@ -14,6 +14,8 @@ import weakref ...@@ -14,6 +14,8 @@ import weakref
import numpy as np import numpy as np
import megengine as mge
from ..ops.builtin import Elemwise, OpDef from ..ops.builtin import Elemwise, OpDef
from ..ops.special import Const from ..ops.special import Const
from ..tensor.core import TensorBase, TensorWrapperBase, apply from ..tensor.core import TensorBase, TensorWrapperBase, apply
...@@ -167,6 +169,8 @@ class Grad: ...@@ -167,6 +169,8 @@ class Grad:
for i in dys: for i in dys:
if isinstance(i, TensorWrapperBase): if isinstance(i, TensorWrapperBase):
return type(i) return type(i)
# use Tensor as defualt wrapper
return mge.Tensor
Wrapper = check_wrapper() Wrapper = check_wrapper()
......
...@@ -59,7 +59,11 @@ def _(op: RemoteSend, inputs, outputs, input_requires_grad): ...@@ -59,7 +59,11 @@ def _(op: RemoteSend, inputs, outputs, input_requires_grad):
def backward(*args): def backward(*args):
return [ return [
remote_recv( remote_recv(
op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) op.rank_to,
inputs[0].shape,
inputs[0].dtype,
device=str(inputs[0].device),
inp=inputs[0],
) )
] ]
...@@ -275,7 +279,11 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: ...@@ -275,7 +279,11 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
def remote_recv( def remote_recv(
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None src_rank: int,
shape: Tuple[int],
dtype: type,
device: Optional[str] = None,
inp=None,
) -> Tensor: ) -> Tensor:
""" """
Receive a Tensor from a remote process. Receive a Tensor from a remote process.
...@@ -284,13 +292,15 @@ def remote_recv( ...@@ -284,13 +292,15 @@ def remote_recv(
:param shape: the shape of the tensor to receive. :param shape: the shape of the tensor to receive.
:param dtype: the data type of the tensor to receive. :param dtype: the data type of the tensor to receive.
:param device: the device to place the received tensor. :param device: the device to place the received tensor.
:param inp: dummy input to determine recved tensor type
""" """
key = "{}->{}".format(src_rank, get_rank()) key = "{}->{}".format(src_rank, get_rank())
if device is None: if device is None:
device = get_default_device() device = get_default_device()
# dummpy input # dummy input
inp = tensor([0]) if inp == None:
inp = tensor([0])
tracer_set = get_client().check_remote_tracer(key) tracer_set = get_client().check_remote_tracer(key)
for grad_manager in get_grad_managers(): for grad_manager in get_grad_managers():
if grad_manager.name in tracer_set: if grad_manager.name in tracer_set:
......
...@@ -5,12 +5,19 @@ ...@@ -5,12 +5,19 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import platform
import numpy as np import numpy as np
import pytest import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist
import megengine.functional as F import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.core._imperative_rt.imperative import sync
from megengine.distributed.helper import get_device_count_by_fork
def test_basic(): def test_basic():
...@@ -48,3 +55,47 @@ def test_attach_in_with_block(): ...@@ -48,3 +55,47 @@ def test_attach_in_with_block():
c = b + 1 c = b + 1
gm.backward(c) gm.backward(c)
assert int(b.grad.numpy()) == 1 assert int(b.grad.numpy()) == 1
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_remote_grad():
@dist.launcher
def worker():
rank = dist.get_rank()
size = dist.get_world_size()
x = mge.tensor(np.random.randn(1, rank * 2 + 2), dtype=np.float32)
m = M.Linear(rank * 2 + 2, rank * 2 + 4)
gm = GradManager().attach(m.parameters())
opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9)
def train_func(x):
if rank != 0:
x = dist.functional.remote_recv(
rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32
)
print(rank, "x", x)
y = m(x)
print(rank, "y", y)
if rank != size - 1:
y = dist.functional.remote_send(y, dest_rank=rank + 1)
return y
with gm:
y = train_func(x)
if rank == size - 1:
y = y.mean()
gm.backward(y)
else:
gm.backward()
opt.step().clear_grad()
# sync because send is the last job
sync()
worker()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册