From 087ceb52ef20b7642bc33e298320e4a140ecfe7f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 6 Sep 2020 13:52:03 +0800 Subject: [PATCH] feat(mge/imperative): add more optimizer trace tests GitOrigin-RevId: 4127de1d22b97a4abbcb223575d7756200d163a1 --- .../python/test/integration/test_optimizer.py | 32 ++++++++++++++ .../test/integration/test_sgd_momentum.py | 43 +++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/imperative/python/test/integration/test_optimizer.py b/imperative/python/test/integration/test_optimizer.py index 388a8814..36ea5e95 100644 --- a/imperative/python/test/integration/test_optimizer.py +++ b/imperative/python/test/integration/test_optimizer.py @@ -10,6 +10,7 @@ import numpy as np import megengine.functional as F from megengine import Parameter, optimizer +from megengine.jit import trace from megengine.module import Linear, Module from megengine.tensor import TensorDict, tensor @@ -66,6 +67,37 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): step += 1 check_func(ori_params, net.parameters(), step) + # static graph + for symbolic in (False, True): + + @trace(symbolic=symbolic) + def train_func(data, *, opt=None): + opt.zero_grad() + with opt.record(): + pred = net(data) + loss = pred.sum() + opt.backward(loss) + opt.step() + + # reset net and opt + net = Simple() + opt = getattr(optimizer, opt_str)(net.parameters(), **test_case) + check_func = check_class(net, **test_case) + step = 0 + for i in range(iter_num): + if update_lr and i == 1: # change learning rate + for group in opt.param_groups: + group["lr"] += 0.01 + check_func.lr += 0.01 + + ori_params = TensorDict() + for param in net.parameters(): + ori_params[param] = np.copy(param.numpy()) + + train_func(np.random.random(data_shape).astype(np.float32), opt=opt) + step += 1 + check_func(ori_params, net.parameters(), step) + def test_sgd(): class CheckValue: diff --git a/imperative/python/test/integration/test_sgd_momentum.py b/imperative/python/test/integration/test_sgd_momentum.py index 33944150..da60e003 100644 --- a/imperative/python/test/integration/test_sgd_momentum.py +++ b/imperative/python/test/integration/test_sgd_momentum.py @@ -11,6 +11,7 @@ import numpy as np import megengine import megengine.optimizer as optimizer from megengine import Parameter, tensor +from megengine.jit import trace from megengine.module import Module @@ -57,3 +58,45 @@ def test_sgd_momentum(): np.testing.assert_almost_equal( optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 ) + + +def test_sgd_momentum_trace(): + + for symbolic in (True, False): + + @trace(symbolic=symbolic) + def train_func(data, *, model=None, optim=None): + optim.zero_grad() + with optim.record(): + loss = net(data) + optim.backward(loss) + optim.step() + return loss + + @trace(symbolic=symbolic) + def eval_func(data, *, model=None, optim=None): + loss = net(data) + return loss + + net = Simple() + optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) + data = tensor([2.34]) + train_func(data, model=net, optim=optim) + np.testing.assert_almost_equal( + optim._state[net.a]["momentum_buffer"].numpy(), 2.34 + ) + + # do 3 steps of infer + for _ in range(3): + loss = eval_func(data) + np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) + np.testing.assert_almost_equal( + optim._state[net.a]["momentum_buffer"].numpy(), 2.34 + ) + + # do a step of train + train_func(data, model=net, optim=optim) + np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) + np.testing.assert_almost_equal( + optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 + ) -- GitLab