test_grad_manger.py 2.9 KB
Newer Older
1 2 3 4 5 6 7
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8 9
import platform

10
import numpy as np
11
import pytest
12 13

import megengine as mge
14
import megengine.distributed as dist
15
import megengine.functional as F
16 17
import megengine.module as M
import megengine.optimizer as optim
18
from megengine.autodiff import GradManager
19 20
from megengine.core._imperative_rt.imperative import sync
from megengine.distributed.helper import get_device_count_by_fork
21
from megengine.jit import trace
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39


def test_basic():
    x = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3)
    w = mge.tensor([2.0, 4.0, 6.0]).reshape(3, 1)
    b = mge.tensor(-1.0)

    gm = GradManager().attach([w, b])
    gm.record()

    p = F.matmul(x, w)
    y = p + b

    gm.backward(y)
    gm.release()  # is not necessary
    np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
    np.testing.assert_equal(b.grad.numpy(), [1])

40
    gm.clear_grad()
41 42 43 44 45 46 47
    with gm:
        p = F.matmul(x, w)
        y = p + b
        gm.backward(y)

    np.testing.assert_equal(w.grad.numpy(), [[1], [3], [5]])
    np.testing.assert_equal(b.grad.numpy(), [1])
48 49 50 51


def test_attach_in_with_block():
    a = mge.Parameter([1.0])
52 53
    gm = GradManager()
    with gm:
54
        b = a * 3
55
        gm.attach(b)
56
        c = b + 1
57
        gm.backward(c)
58
    assert int(b.grad.numpy()) == 1
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78


@pytest.mark.skipif(
    platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
    platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_remote_grad():
    @dist.launcher
    def worker():
        rank = dist.get_rank()
        size = dist.get_world_size()
        x = mge.tensor(np.random.randn(1, rank * 2 + 2), dtype=np.float32)
        m = M.Linear(rank * 2 + 2, rank * 2 + 4)
        gm = GradManager().attach(m.parameters())
        opt = optim.SGD(m.parameters(), 1e-3, momentum=0.9)

79
        @trace(symbolic=True)
80
        def train_func(x):
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
            with gm:
                if rank != 0:
                    x = dist.functional.remote_recv(
                        rank - 1, shape=(1, rank * 2 + 2), dtype=np.float32
                    )
                y = m(x)
                if rank != size - 1:
                    y = dist.functional.remote_send(y, dest_rank=rank + 1)
                if rank == size - 1:
                    y = y.mean()
                    gm.backward(y)
                else:
                    gm.backward()
                opt.step().clear_grad()

        for i in range(3):
            train_func(x)

        for param in m.parameters():
            param.numpy()
101 102

    worker()