test_sgd_momentum.py 3.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

import megengine
12
import megengine.autodiff as ad
13 14
import megengine.optimizer as optimizer
from megengine import Parameter, tensor
15
from megengine.jit import trace
16 17 18 19 20 21
from megengine.module import Module


class Simple(Module):
    def __init__(self):
        super().__init__()
22
        self.a = Parameter([1.23], dtype=np.float32)
23 24 25 26 27 28 29 30 31 32

    def forward(self, x):
        x = x * self.a
        return x


def test_sgd_momentum():
    net = Simple()

    optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
33
    optim.clear_grad()
M
Megvii Engine Team 已提交
34
    gm = ad.GradManager().attach(net.parameters())
35 36 37 38

    data = tensor([2.34])

    # do a step of train
M
Megvii Engine Team 已提交
39
    with gm:
40
        loss = net(data)
41
        gm.backward(loss)
42 43 44 45 46 47 48 49 50 51 52
    optim.step()

    np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34)

    # do a step of infer
    loss = net(data)
    np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)

    np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34)

    # do a step of train
53
    optim.clear_grad()
M
Megvii Engine Team 已提交
54
    with gm:
55
        loss = net(data)
56
        gm.backward(loss)
57 58 59 60 61 62
    optim.step()

    np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
    np.testing.assert_almost_equal(
        optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34
    )
63 64 65 66 67 68 69


def test_sgd_momentum_trace():

    for symbolic in (True, False):

        @trace(symbolic=symbolic)
70 71
        def train_func(data, *, model=None, optim=None, gm=None):
            optim.clear_grad()
M
Megvii Engine Team 已提交
72
            with gm:
73
                loss = net(data)
74
                gm.backward(loss)
75 76 77 78
            optim.step()
            return loss

        @trace(symbolic=symbolic)
79
        def eval_func(data, *, model=None, optim=None, gm=None):
80 81 82 83 84
            loss = net(data)
            return loss

        net = Simple()
        optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
M
Megvii Engine Team 已提交
85
        gm = ad.GradManager().attach(net.parameters())
86
        data = tensor([2.34])
87
        train_func(data, model=net, optim=optim, gm=gm)
88 89 90 91 92 93 94 95 96 97 98 99 100
        np.testing.assert_almost_equal(
            optim._state[net.a]["momentum_buffer"].numpy(), 2.34
        )

        # do 3 steps of infer
        for _ in range(3):
            loss = eval_func(data)
            np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
            np.testing.assert_almost_equal(
                optim._state[net.a]["momentum_buffer"].numpy(), 2.34
            )

        # do a step of train
101
        train_func(data, model=net, optim=optim, gm=gm)
102 103 104 105
        np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
        np.testing.assert_almost_equal(
            optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34
        )