test_lamb.py 2.8 KB
Newer Older
1
import numpy as np
2
import pytest
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65

import megengine as mge
import megengine.autodiff as ad
import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from megengine import tensor
from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops.builtin import LAMBUpdate


def lamb_update(
    param_group, step, exp_avg, exp_avg_sq, param, grad, bias_correction, always_adapt
):
    lr = param_group["lr"]
    weight_decay = param_group["weight_decay"]
    eps = param_group["eps"]
    beta0, beta1 = param_group["betas"]

    # since `conver_inputs` is disabled for param updates,
    # scalar should be explicitly tansforred to tensor

    _lr, _neg_lr = map(tensor, (lr, -lr))
    _weight_decay = tensor(weight_decay)
    _eps = tensor(eps)
    _beta0, _beta1 = map(tensor, (beta0, beta1))

    c1, c05, c0 = map(tensor, (1.0, 0.5, 0.0))

    def norm(vec):
        return sum(vec * vec) ** c05

    p_norm = norm(param.flatten())

    # step = step + c1
    step += c1

    # exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0)
    exp_avg *= _beta0
    exp_avg += grad * (c1 - _beta0)

    # exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad)
    exp_avg_sq *= _beta1
    exp_avg_sq += (c1 - _beta1) * (grad * grad)

    bias_correction1 = c1 - _beta0 ** step if bias_correction else c1
    bias_correction2 = c1 - _beta1 ** step if bias_correction else c1
    delta = (exp_avg / bias_correction1) / (
        (exp_avg_sq / bias_correction2) ** c05 + _eps
    )
    if weight_decay != 0.0:
        delta += param * _weight_decay

    d_norm = norm(delta.flatten())
    trust_ratio = (
        p_norm / d_norm
        if (always_adapt or weight_decay > 0) and p_norm > c0 and d_norm > c0
        else c1
    )
    new_param = param - _lr * trust_ratio * delta
    return exp_avg, exp_avg_sq, new_param


66
@pytest.mark.skip(reason="pytest aborted, the same as groupnorm")
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
def test_lamb():
    op = LAMBUpdate(0.9, 0.999, 1, 1e-3, 0.4, 1e-8, True, False)
    m_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
    v_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
    params = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
    grad = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float16)
    (new_m_t, new_v_t, new_param) = apply(op, m_t_1, v_t_1, params, grad)

    param_group = {
        "betas": (0.9, 0.999),
        "step": 1,
        "lr": 1e-3,
        "weight_decay": 0.4,
        "eps": 1e-8,
    }
    gt_m_t, gt_v_t, gt_new_param = lamb_update(
        param_group, 1, m_t_1, v_t_1, params, grad, True, False
    )
    np.testing.assert_allclose(new_m_t.numpy(), gt_m_t.numpy(), atol=1e-2)
    np.testing.assert_allclose(new_v_t.numpy(), gt_v_t.numpy(), atol=1e-2)
    np.testing.assert_allclose(new_param.numpy(), gt_new_param.numpy(), atol=1e-2)