提交 087ceb52 编写于 作者: M Megvii Engine Team

feat(mge/imperative): add more optimizer trace tests

GitOrigin-RevId: 4127de1d22b97a4abbcb223575d7756200d163a1
上级 38a5c1c9
...@@ -10,6 +10,7 @@ import numpy as np ...@@ -10,6 +10,7 @@ import numpy as np
import megengine.functional as F import megengine.functional as F
from megengine import Parameter, optimizer from megengine import Parameter, optimizer
from megengine.jit import trace
from megengine.module import Linear, Module from megengine.module import Linear, Module
from megengine.tensor import TensorDict, tensor from megengine.tensor import TensorDict, tensor
...@@ -66,6 +67,37 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -66,6 +67,37 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
step += 1 step += 1
check_func(ori_params, net.parameters(), step) check_func(ori_params, net.parameters(), step)
# static graph
for symbolic in (False, True):
@trace(symbolic=symbolic)
def train_func(data, *, opt=None):
opt.zero_grad()
with opt.record():
pred = net(data)
loss = pred.sum()
opt.backward(loss)
opt.step()
# reset net and opt
net = Simple()
opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
check_func = check_class(net, **test_case)
step = 0
for i in range(iter_num):
if update_lr and i == 1: # change learning rate
for group in opt.param_groups:
group["lr"] += 0.01
check_func.lr += 0.01
ori_params = TensorDict()
for param in net.parameters():
ori_params[param] = np.copy(param.numpy())
train_func(np.random.random(data_shape).astype(np.float32), opt=opt)
step += 1
check_func(ori_params, net.parameters(), step)
def test_sgd(): def test_sgd():
class CheckValue: class CheckValue:
......
...@@ -11,6 +11,7 @@ import numpy as np ...@@ -11,6 +11,7 @@ import numpy as np
import megengine import megengine
import megengine.optimizer as optimizer import megengine.optimizer as optimizer
from megengine import Parameter, tensor from megengine import Parameter, tensor
from megengine.jit import trace
from megengine.module import Module from megengine.module import Module
...@@ -57,3 +58,45 @@ def test_sgd_momentum(): ...@@ -57,3 +58,45 @@ def test_sgd_momentum():
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34
) )
def test_sgd_momentum_trace():
for symbolic in (True, False):
@trace(symbolic=symbolic)
def train_func(data, *, model=None, optim=None):
optim.zero_grad()
with optim.record():
loss = net(data)
optim.backward(loss)
optim.step()
return loss
@trace(symbolic=symbolic)
def eval_func(data, *, model=None, optim=None):
loss = net(data)
return loss
net = Simple()
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
data = tensor([2.34])
train_func(data, model=net, optim=optim)
np.testing.assert_almost_equal(
optim._state[net.a]["momentum_buffer"].numpy(), 2.34
)
# do 3 steps of infer
for _ in range(3):
loss = eval_func(data)
np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
np.testing.assert_almost_equal(
optim._state[net.a]["momentum_buffer"].numpy(), 2.34
)
# do a step of train
train_func(data, model=net, optim=optim)
np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
np.testing.assert_almost_equal(
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册