test_optimizer.py 11.2 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
M
Megvii Engine Team 已提交
9 10
import os

11
import numpy as np
M
Megvii Engine Team 已提交
12
import pytest
13

14
import megengine.autodiff as ad
15 16
import megengine.functional as F
from megengine import Parameter, optimizer
17
from megengine.jit import trace
18
from megengine.module import Linear, Module
19
from megengine.tensor import Tensor
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37


class MLP(Module):
    def __init__(self):
        super().__init__()
        self.dense0 = Linear(28, 50)
        self.dense1 = Linear(50, 20)

    def forward(self, x):
        x = self.dense0(x)
        x = F.relu(x)
        x = self.dense1(x)
        return x


class Simple(Module):
    def __init__(self):
        super().__init__()
38
        self.a = Parameter(1.23, dtype=np.float32)
39 40 41 42 43 44 45 46 47 48 49

    def forward(self, x):
        x = x * self.a
        return x


def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
    iter_num = 3
    net = Simple()
    opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
    check_func = check_class(net, **test_case)
M
Megvii Engine Team 已提交
50
    gm = ad.GradManager().attach(net.parameters())
51 52 53 54 55 56 57 58 59

    step = 0
    data_shape = (2, 28)

    for i in range(iter_num):
        if update_lr and i == 1:  # change learning rate
            for group in opt.param_groups:
                group["lr"] += 0.01
            check_func.lr += 0.01
60
        data = Tensor(np.random.random(data_shape).astype(np.float32))
61

62
        opt.clear_grad()
M
Megvii Engine Team 已提交
63
        with gm:
64 65
            pred = net(data)
            loss = pred.sum()
66
            gm.backward(loss)
67

M
Megvii Engine Team 已提交
68
        ori_params = {}
69
        for param in net.parameters():
70
            assert param._tuple_shape is ()
71 72 73 74 75
            ori_params[param] = np.copy(param.numpy())
        opt.step()
        step += 1
        check_func(ori_params, net.parameters(), step)

76 77 78 79
    # static graph
    for symbolic in (False, True):

        @trace(symbolic=symbolic)
80 81
        def train_func(data, *, opt=None, gm=None):
            opt.clear_grad()
M
Megvii Engine Team 已提交
82
            with gm:
83 84
                pred = net(data)
                loss = pred.sum()
85
                gm.backward(loss)
86 87 88 89 90
            opt.step()

        # reset net and opt
        net = Simple()
        opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
M
Megvii Engine Team 已提交
91
        gm = ad.GradManager().attach(net.parameters())
92 93 94 95 96 97 98 99
        check_func = check_class(net, **test_case)
        step = 0
        for i in range(iter_num):
            if update_lr and i == 1:  # change learning rate
                for group in opt.param_groups:
                    group["lr"] += 0.01
                check_func.lr += 0.01

M
Megvii Engine Team 已提交
100
            ori_params = {}
101
            for param in net.parameters():
102
                assert param._tuple_shape is ()
103 104
                ori_params[param] = np.copy(param.numpy())

105
            train_func(
106
                Tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm
107
            )
108 109
            step += 1
            check_func(ori_params, net.parameters(), step)
110 111 112 113
            try_state_dict = {
                "net": net.state_dict(),
                "opt": opt.state_dict(),
            }
114

115

M
Megvii Engine Team 已提交
116 117 118 119 120 121 122 123 124 125 126
@pytest.mark.parametrize(
    "case",
    [
        {"momentum": 0.9, "lr": 0.01},  # SGD with momentum
        {"lr": 0.01},  # simple SGD
        {"weight_decay": 0.1, "lr": 0.01},  # with weight_decay
    ],
)
@pytest.mark.parametrize("update_lr", [False, True])
@pytest.mark.parametrize("inplace_mode", [False, True])
def test_sgd(monkeypatch, case, update_lr, inplace_mode):
127 128
    class CheckValue:
        def __init__(self, net, **kwarg):
M
Megvii Engine Team 已提交
129
            self.slots = {}
130 131 132 133 134 135 136 137 138 139 140 141 142
            for param in net.parameters():
                self.slots[param] = np.zeros(param.shape).astype(np.float32)
            for k, v in kwarg.items():
                setattr(self, k, v)

        def __call__(self, ori_params, new_params, step):
            for param in new_params:
                grad = param.grad.numpy()
                if hasattr(self, "momentum"):
                    self.slots[param] = grad + self.slots[param] * self.momentum
                    delta = -self.lr * self.slots[param]
                else:
                    delta = -self.lr * grad
143 144 145
                np.testing.assert_almost_equal(
                    param.numpy(), ori_params[param] + delta, decimal=6
                )
146

M
Megvii Engine Team 已提交
147 148 149
    with monkeypatch.context() as mk:
        mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
        _test_optimizer("SGD", case, CheckValue, update_lr=update_lr)
150 151


M
Megvii Engine Team 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
@pytest.mark.parametrize(
    "case",
    [
        {"betas": (0.8, 0.9), "eps": 1e-04, "lr": 0.01},
        {
            "betas": (0.8, 0.9),
            "eps": 1e-04,
            "lr": 0.01,
            "weight_decay": 0.1,
        },  # with weight_decay
    ],
)
@pytest.mark.parametrize("update_lr", [False, True])
@pytest.mark.parametrize("inplace_mode", [False, True])
def test_adam(monkeypatch, case, update_lr, inplace_mode):
167 168
    class CheckValue:
        def __init__(self, net, **kwarg):
M
Megvii Engine Team 已提交
169 170
            self.m_slots = {}
            self.v_slots = {}
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
            for param in net.parameters():
                self.m_slots[param] = np.zeros(param.shape).astype(np.float32)
                self.v_slots[param] = np.zeros(param.shape).astype(np.float32)
            for k, v in kwarg.items():
                setattr(self, k, v)

        def __call__(self, ori_params, new_params, step):
            for param in new_params:
                grad = param.grad.numpy()
                m = self.m_slots[param]
                v = self.v_slots[param]
                m *= self.betas[0]
                m += (1 - self.betas[0]) * grad
                v *= self.betas[1]
                v += (1 - self.betas[1]) * grad * grad
                delta = (m / (1 - self.betas[0] ** step)) / (
                    np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps
                )
                np.testing.assert_almost_equal(
190
                    param.numpy(), ori_params[param] - self.lr * delta, decimal=6
191 192
                )

M
Megvii Engine Team 已提交
193 194 195 196 197 198 199 200 201 202
    with monkeypatch.context() as mk:
        mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
        _test_optimizer("Adam", case, CheckValue, update_lr=update_lr)


@pytest.mark.parametrize(
    "case",
    [
        {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01},
        {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.0},  # without lr_decay
203 204
        {
            "lr": 0.01,
M
Megvii Engine Team 已提交
205 206
            "eps": 1e-06,
            "lr_decay": 0.01,
207 208
            "weight_decay": 0.1,
        },  # with weight_decay
M
Megvii Engine Team 已提交
209 210 211 212 213
    ],
)
@pytest.mark.parametrize("update_lr", [False, True])
@pytest.mark.parametrize("inplace_mode", [False, True])
def test_adagrad(monkeypatch, case, update_lr, inplace_mode):
214 215
    class CheckValue:
        def __init__(self, net, **kwarg):
M
Megvii Engine Team 已提交
216
            self.s_slots = {}
217 218 219 220 221 222 223 224 225 226 227
            for param in net.parameters():
                self.s_slots[param] = np.zeros(param.shape).astype(np.float32)
            for k, v in kwarg.items():
                setattr(self, k, v)

        def __call__(self, ori_params, new_params, step):
            for param in new_params:
                grad = param.grad.numpy()
                self.s_slots[param] += grad ** 2
                delta = grad / (self.s_slots[param] + self.eps) ** 0.5
                delta *= -(self.lr / (1 + (step - 1) * self.lr_decay))
228 229 230
                np.testing.assert_almost_equal(
                    param.numpy(), ori_params[param] + delta, decimal=6
                )
231

M
Megvii Engine Team 已提交
232 233 234
    with monkeypatch.context() as mk:
        mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
        _test_optimizer("Adagrad", case, CheckValue, update_lr=update_lr)
235 236


M
Megvii Engine Team 已提交
237 238 239 240 241 242 243 244 245 246
@pytest.mark.parametrize(
    "case",
    [
        {"lr": 1.0, "eps": 1e-06, "rho": 0.9},
        {"lr": 1.0, "eps": 1e-06, "rho": 0.9, "weight_decay": 0.9},  # with weight_decay
    ],
)
@pytest.mark.parametrize("update_lr", [False, True])
@pytest.mark.parametrize("inplace_mode", [False, True])
def test_adadelta(monkeypatch, case, update_lr, inplace_mode):
247 248
    class CheckValue:
        def __init__(self, net, **kwarg):
M
Megvii Engine Team 已提交
249 250
            self.s_slots = {}
            self.a_slots = {}
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
            for param in net.parameters():
                self.s_slots[param] = np.zeros(param.shape).astype(np.float32)
                self.a_slots[param] = np.zeros(param.shape).astype(np.float32)
            for k, v in kwarg.items():
                setattr(self, k, v)

        def __call__(self, ori_params, new_params, step):
            for param in new_params:
                grad = param.grad.numpy()
                self.s_slots[param] = self.s_slots[param] * self.rho + grad ** 2 * (
                    1 - self.rho
                )
                delta = (
                    grad
                    * ((self.a_slots[param] + self.eps) ** 0.5)
                    / (self.s_slots[param] + self.eps) ** 0.5
                )
                self.a_slots[param] = self.a_slots[param] * self.rho + delta ** 2 * (
                    1 - self.rho
                )
                delta *= -self.lr
272 273 274
                np.testing.assert_almost_equal(
                    param.numpy(), ori_params[param] + delta, decimal=6
                )
275

M
Megvii Engine Team 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
    with monkeypatch.context() as mk:
        mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
        _test_optimizer("Adadelta", case, CheckValue, update_lr=update_lr)


@pytest.mark.parametrize(
    "case",
    [
        {"betas": (0.8, 0.9), "eps": 1e-08, "lr": 0.01},
        {
            "betas": (0.8, 0.9),
            "eps": 1e-08,
            "lr": 0.01,
            "weight_decay": 0.1,
        },  # with weight_decay
    ],
)
@pytest.mark.parametrize("update_lr", [False, True])
@pytest.mark.parametrize("inplace_mode", [False, True])
def test_adamw(monkeypatch, case, update_lr, inplace_mode):
    class CheckValue:
        def __init__(self, net, **kwarg):
            self.m_slots = {}
            self.v_slots = {}
            for param in net.parameters():
                self.m_slots[param] = np.zeros(param.shape).astype(np.float32)
                self.v_slots[param] = np.zeros(param.shape).astype(np.float32)
            self.weight_decay = 0.01
            for k, v in kwarg.items():
                setattr(self, k, v)

        def __call__(self, ori_params, new_params, step):
            step = np.array(step).astype(np.float32)
            for param in new_params:
                grad = param.grad.numpy()
                m = self.m_slots[param]
                v = self.v_slots[param]
                m *= self.betas[0]
                m += (1 - self.betas[0]) * grad
                v *= self.betas[1]
                v += (1 - self.betas[1]) * grad * grad
                delta = (m / (1 - self.betas[0] ** step)) / (
                    np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps
                )
                delta += ori_params[param] * self.weight_decay
                np.testing.assert_almost_equal(
                    param.numpy(), ori_params[param] - self.lr * delta, decimal=6
                )

    with monkeypatch.context() as mk:
        mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
        _test_optimizer("AdamW", case, CheckValue, update_lr=update_lr)