提交 66b6daf7 编写于 作者: M Megvii Engine Team

test(mge/optimizer): fix test for new optimizer api

GitOrigin-RevId: 482ee6265224f7fcc4c21300a597d36b4333cca3
上级 e9104ef1
......@@ -40,7 +40,7 @@ __all__ = [
]
@apply.add
@apply.register()
def _(op: RemoteSend, *args: Tensor):
ret = tensor_apply(op, *args)
......
......@@ -133,7 +133,7 @@ def update_model(model_path):
data = Tensor(checkpoint["data"], dtype=np.float32)
label = Tensor(checkpoint["label"], dtype=np.int32)
opt.zero_grad()
opt.clear_grad()
loss = train(data, label, net=net, opt=opt)
opt.step()
......
......@@ -73,17 +73,18 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
for symbolic in (False, True):
@trace(symbolic=symbolic)
def train_func(data, *, opt=None):
opt.zero_grad()
with opt.record():
def train_func(data, *, opt=None, gm=None):
opt.clear_grad()
with gm.record():
pred = net(data)
loss = pred.sum()
opt.backward(loss)
gm.backward(loss)
opt.step()
# reset net and opt
net = Simple()
opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
gm = ad.GradManager().register(net.parameters())
check_func = check_class(net, **test_case)
step = 0
for i in range(iter_num):
......@@ -96,7 +97,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
for param in net.parameters():
ori_params[param] = np.copy(param.numpy())
train_func(np.random.random(data_shape).astype(np.float32), opt=opt)
train_func(np.random.random(data_shape).astype(np.float32), opt=opt, gm=gm)
step += 1
check_func(ori_params, net.parameters(), step)
......
......@@ -67,23 +67,24 @@ def test_sgd_momentum_trace():
for symbolic in (True, False):
@trace(symbolic=symbolic)
def train_func(data, *, model=None, optim=None):
optim.zero_grad()
with optim.record():
def train_func(data, *, model=None, optim=None, gm=None):
optim.clear_grad()
with gm.record():
loss = net(data)
optim.backward(loss)
gm.backward(loss)
optim.step()
return loss
@trace(symbolic=symbolic)
def eval_func(data, *, model=None, optim=None):
def eval_func(data, *, model=None, optim=None, gm=None):
loss = net(data)
return loss
net = Simple()
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
gm = ad.GradManager().register(net.parameters())
data = tensor([2.34])
train_func(data, model=net, optim=optim)
train_func(data, model=net, optim=optim, gm=gm)
np.testing.assert_almost_equal(
optim._state[net.a]["momentum_buffer"].numpy(), 2.34
)
......@@ -97,7 +98,7 @@ def test_sgd_momentum_trace():
)
# do a step of train
train_func(data, model=net, optim=optim)
train_func(data, model=net, optim=optim, gm=gm)
np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
np.testing.assert_almost_equal(
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34
......
......@@ -17,6 +17,7 @@ import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from megengine import tensor
from megengine.autodiff import GradManager
from megengine.jit import trace
......@@ -61,17 +62,18 @@ class XORNet(M.Module):
def test_xornet_trace_dump():
net = XORNet()
opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9)
gm = GradManager().register(net.parameters(requires_grad=True))
batch_size = 64
train_dataset = minibatch_generator(batch_size)
val_dataset = minibatch_generator(batch_size)
@trace
def train_fun(data, label):
with opt.record():
with gm.record():
net.train()
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
opt.backward(loss)
gm.backward(loss)
return pred, loss
@trace
......@@ -95,7 +97,7 @@ def test_xornet_trace_dump():
break
data = tensor(minibatch["data"])
label = tensor(minibatch["label"])
opt.zero_grad()
opt.clear_grad()
_, loss = train_fun(data, label)
train_loss.append((step, loss.numpy()))
if step % 50 == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册