diff --git a/imperative/python/test/integration/test_optimizer.py b/imperative/python/test/integration/test_optimizer.py index 388a881485e4e63e69a6c84b0410ac430277c50a..36ea5e95d3316431822bf77b8ce28d713b301da8 100644 --- a/imperative/python/test/integration/test_optimizer.py +++ b/imperative/python/test/integration/test_optimizer.py @@ -10,6 +10,7 @@ import numpy as np import megengine.functional as F from megengine import Parameter, optimizer +from megengine.jit import trace from megengine.module import Linear, Module from megengine.tensor import TensorDict, tensor @@ -66,6 +67,37 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): step += 1 check_func(ori_params, net.parameters(), step) + # static graph + for symbolic in (False, True): + + @trace(symbolic=symbolic) + def train_func(data, *, opt=None): + opt.zero_grad() + with opt.record(): + pred = net(data) + loss = pred.sum() + opt.backward(loss) + opt.step() + + # reset net and opt + net = Simple() + opt = getattr(optimizer, opt_str)(net.parameters(), **test_case) + check_func = check_class(net, **test_case) + step = 0 + for i in range(iter_num): + if update_lr and i == 1: # change learning rate + for group in opt.param_groups: + group["lr"] += 0.01 + check_func.lr += 0.01 + + ori_params = TensorDict() + for param in net.parameters(): + ori_params[param] = np.copy(param.numpy()) + + train_func(np.random.random(data_shape).astype(np.float32), opt=opt) + step += 1 + check_func(ori_params, net.parameters(), step) + def test_sgd(): class CheckValue: diff --git a/imperative/python/test/integration/test_sgd_momentum.py b/imperative/python/test/integration/test_sgd_momentum.py index 33944150e1de1cdebb37a3a0eb1e37a688f54fce..da60e003eb53c601e30c404ecfd2c4f95168af96 100644 --- a/imperative/python/test/integration/test_sgd_momentum.py +++ b/imperative/python/test/integration/test_sgd_momentum.py @@ -11,6 +11,7 @@ import numpy as np import megengine import megengine.optimizer as optimizer from megengine import Parameter, tensor +from megengine.jit import trace from megengine.module import Module @@ -57,3 +58,45 @@ def test_sgd_momentum(): np.testing.assert_almost_equal( optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 ) + + +def test_sgd_momentum_trace(): + + for symbolic in (True, False): + + @trace(symbolic=symbolic) + def train_func(data, *, model=None, optim=None): + optim.zero_grad() + with optim.record(): + loss = net(data) + optim.backward(loss) + optim.step() + return loss + + @trace(symbolic=symbolic) + def eval_func(data, *, model=None, optim=None): + loss = net(data) + return loss + + net = Simple() + optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) + data = tensor([2.34]) + train_func(data, model=net, optim=optim) + np.testing.assert_almost_equal( + optim._state[net.a]["momentum_buffer"].numpy(), 2.34 + ) + + # do 3 steps of infer + for _ in range(3): + loss = eval_func(data) + np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) + np.testing.assert_almost_equal( + optim._state[net.a]["momentum_buffer"].numpy(), 2.34 + ) + + # do a step of train + train_func(data, model=net, optim=optim) + np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) + np.testing.assert_almost_equal( + optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 + )