diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index b37b37fc3cf3462941210db955d315cd8e75b35c..db03f23544c50a1abd9de2628c27b49f56a84e58 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -127,7 +127,7 @@ class GradManager: self._after_backward_callback.append(callback) return self - def backward(self, ys, dys=None): + def backward(self, ys=None, dys=None): r""" Performs back-propagation and computes gradients. @@ -146,6 +146,8 @@ class GradManager: "call a method that clears the history?" ) assert self._grad is not None + if ys is None: + ys = [] if not isinstance(ys, (tuple, list)): ys = [ys] if dys is None: diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index 12a46b97690b42033b954f8a3cf2108b314b0ba8..81ea827ea803a6dfaea7f804006f4f3e14dcbde8 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -14,6 +14,8 @@ import weakref import numpy as np +import megengine as mge + from ..ops.builtin import Elemwise, OpDef from ..ops.special import Const from ..tensor.core import TensorBase, TensorWrapperBase, apply @@ -167,6 +169,8 @@ class Grad: for i in dys: if isinstance(i, TensorWrapperBase): return type(i) + # use Tensor as defualt wrapper + return mge.Tensor Wrapper = check_wrapper() diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index b4d2fe30bd28a6515136e52553407f34e56b73ba..5a9ce213c5f1a2429bd86935aaf541777fb7426b 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -59,7 +59,11 @@ def _(op: RemoteSend, inputs, outputs, input_requires_grad): def backward(*args): return [ remote_recv( - op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) + op.rank_to, + inputs[0].shape, + inputs[0].dtype, + device=str(inputs[0].device), + inp=inputs[0], ) ] @@ -275,7 +279,11 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: def remote_recv( - src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None + src_rank: int, + shape: Tuple[int], + dtype: type, + device: Optional[str] = None, + inp=None, ) -> Tensor: """ Receive a Tensor from a remote process. @@ -284,13 +292,15 @@ def remote_recv( :param shape: the shape of the tensor to receive. :param dtype: the data type of the tensor to receive. :param device: the device to place the received tensor. + :param inp: dummy input to determine recved tensor type """ key = "{}->{}".format(src_rank, get_rank()) if device is None: device = get_default_device() - # dummpy input - inp = tensor([0]) + # dummy input + if inp == None: + inp = tensor([0]) tracer_set = get_client().check_remote_tracer(key) for grad_manager in get_grad_managers(): if grad_manager.name in tracer_set: diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index f54bd02a1f48d2e66ed7ce70b2e4d5e6b1e19150..d285e573a1b29b96e18bb8d36c071e308e95b10c 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -5,12 +5,19 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import platform + import numpy as np import pytest import megengine as mge +import megengine.distributed as dist import megengine.functional as F +import megengine.module as M +import megengine.optimizer as optim from megengine.autodiff import GradManager +from megengine.core._imperative_rt.imperative import sync +from megengine.distributed.helper import get_device_count_by_fork def test_basic(): @@ -48,3 +55,47 @@ def test_attach_in_with_block(): c = b + 1 gm.backward(c) assert int(b.grad.numpy()) == 1 + + +@pytest.mark.skipif( + platform.system() == "Darwin", reason="do not imp GPU mode at macos now" +) +@pytest.mark.skipif( + platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" +) +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") +@pytest.mark.isolated_distributed +def test_remote_grad(): + @dist.launcher + def worker(): + rank = dist.get_rank() + size = dist.get_world_size() + x = mge.tensor(np.random.randn(1, rank * 2 + 2), dtype=np.float32) + m = M.Linear(rank * 2 + 2, rank * 2 + 4) + gm = GradManager().attach(m.parameters()) + opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9) + + def train_func(x): + if rank != 0: + x = dist.functional.remote_recv( + rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32 + ) + print(rank, "x", x) + y = m(x) + print(rank, "y", y) + if rank != size - 1: + y = dist.functional.remote_send(y, dest_rank=rank + 1) + return y + + with gm: + y = train_func(x) + if rank == size - 1: + y = y.mean() + gm.backward(y) + else: + gm.backward() + opt.step().clear_grad() + # sync because send is the last job + sync() + + worker()