test_sgd_momentum.py 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

import megengine
12
import megengine.autodiff as ad
13 14
import megengine.optimizer as optimizer
from megengine import Parameter, tensor
15
from megengine.jit import trace
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
from megengine.module import Module


class Simple(Module):
    def __init__(self):
        super().__init__()
        self.a = Parameter(1.23, dtype=np.float32)

    def forward(self, x):
        x = x * self.a
        return x


def test_sgd_momentum():
    net = Simple()

    optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
33 34
    optim.clear_grad()
    gm = ad.GradManager().register(net.parameters())
35 36 37 38

    data = tensor([2.34])

    # do a step of train
39
    with gm.record():
40
        loss = net(data)
41
        gm.backward(loss)
42 43 44 45 46 47 48 49 50 51 52
    optim.step()

    np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34)

    # do a step of infer
    loss = net(data)
    np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)

    np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34)

    # do a step of train
53 54
    optim.clear_grad()
    with gm.record():
55
        loss = net(data)
56
        gm.backward(loss)
57 58 59 60 61 62
    optim.step()

    np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
    np.testing.assert_almost_equal(
        optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34
    )
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104


def test_sgd_momentum_trace():

    for symbolic in (True, False):

        @trace(symbolic=symbolic)
        def train_func(data, *, model=None, optim=None):
            optim.zero_grad()
            with optim.record():
                loss = net(data)
                optim.backward(loss)
            optim.step()
            return loss

        @trace(symbolic=symbolic)
        def eval_func(data, *, model=None, optim=None):
            loss = net(data)
            return loss

        net = Simple()
        optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
        data = tensor([2.34])
        train_func(data, model=net, optim=optim)
        np.testing.assert_almost_equal(
            optim._state[net.a]["momentum_buffer"].numpy(), 2.34
        )

        # do 3 steps of infer
        for _ in range(3):
            loss = eval_func(data)
            np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
            np.testing.assert_almost_equal(
                optim._state[net.a]["momentum_buffer"].numpy(), 2.34
            )

        # do a step of train
        train_func(data, model=net, optim=optim)
        np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
        np.testing.assert_almost_equal(
            optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34
        )