提交 8be78b11 编写于 作者: M Megvii Engine Team

test(mge/optimizer): refactor the unittest of optimizer

GitOrigin-RevId: 4754285713d6a8697a31056331b443ad6b1302af
上级 01ac8bbd
...@@ -12,249 +12,178 @@ import numpy as np ...@@ -12,249 +12,178 @@ import numpy as np
from helpers import MLP, graph_mode from helpers import MLP, graph_mode
import megengine.functional as F import megengine.functional as F
from megengine import load, save from megengine import load, optimizer, save
from megengine.core import TensorDict, tensor from megengine.core import TensorDict, tensor
from megengine.jit import trace from megengine.jit import trace
from megengine.optimizer import SGD, Adam
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
def get_input(): def get_input():
batch_size = 2 batch_size, input_dim = 2, 28
input_dim = 28 data_shape, label_shape = (batch_size, input_dim), (batch_size,)
data_shape = (batch_size, input_dim) data, label = tensor(dtype=np.float32), tensor(dtype=np.int32)
label_shape = (batch_size,)
data = tensor()
label = tensor(dtype=np.int32)
data.set_value(np.random.random(data_shape).astype(np.float32)) data.set_value(np.random.random(data_shape).astype(np.float32))
label.set_value(np.random.randint(0, 10, label_shape)) label.set_value(np.random.randint(0, 10, label_shape))
return data, data_shape, label, label_shape return data, data_shape, label, label_shape
def test_sgd_simple(): @graph_mode("eager", "static")
data, data_shape, label, label_shape = get_input() def test_optimizer_serialization():
mlp = MLP()
opt = SGD(mlp.parameters(), lr=0.01, weight_decay=0.1)
for idx in range(3):
data.set_value(np.random.random(data_shape).astype(np.float32))
label.set_value(np.random.randint(0, 10, label_shape))
pred = mlp(data)
loss = F.square_loss(pred, label.reshape(-1, 1))
if idx % 2:
opt.zero_grad()
else:
mlp.zero_grad()
opt.backward(loss)
grads = TensorDict()
orig_params = TensorDict()
for param in mlp.parameters():
grad = F.grad(loss, param, use_virtual_grad=False)
assertTensorClose(grad.numpy(), param.grad.numpy())
grads[param] = np.copy(grad.numpy())
orig_params[param] = np.copy(param.numpy())
opt.step()
for param in mlp.parameters():
assertTensorClose(
param.numpy(), orig_params[param] * 0.999 - grads[param] * 0.01
)
def test_sgd_momentum():
data, data_shape, label, label_shape = get_input() data, data_shape, label, label_shape = get_input()
mlp = MLP() mlp = MLP()
opt = SGD(mlp.parameters(), lr=0.01, momentum=0.9) opt = optimizer.SGD(mlp.parameters(), lr=0.01, momentum=0.9)
slots = TensorDict()
for param in mlp.parameters():
slots[param] = np.zeros(param.shape).astype(np.float32)
for _ in range(3):
data.set_value(np.random.random(data_shape).astype(np.float32))
label.set_value(np.random.randint(0, 10, label_shape))
pred = mlp(data)
loss = F.square_loss(pred, label.reshape(-1, 1))
opt.zero_grad()
opt.backward(loss)
orig_params = TensorDict()
grads = TensorDict()
for param in mlp.parameters():
orig_params[param] = np.copy(param.numpy())
grads[param] = np.copy(param.grad.numpy())
opt.step()
for param in mlp.parameters():
slot = slots[param]
orig_param = orig_params[param]
slot *= 0.9
slot -= param.grad.numpy() * 0.01
assertTensorClose(param.numpy(), orig_param + slot)
# TODO: put opt.step() inside trace
def test_sgd_momentum_static():
_, data_shape, _, label_shape = get_input()
mlp = MLP()
opt = SGD(mlp.parameters(), lr=0.01, momentum=0.9)
@trace
def f(data, label):
pred = mlp(data)
loss = F.square_loss(pred, label.reshape(-1, 1))
opt.zero_grad()
opt.backward(loss)
slots = TensorDict() slots = TensorDict()
for param in mlp.parameters(): for param in mlp.parameters():
slots[param] = np.zeros(param.shape).astype(np.float32) slots[param] = np.zeros(param.shape).astype(np.float32)
for _ in range(3):
f(
np.random.random(data_shape).astype(np.float32),
np.random.randint(0, 10, label_shape).astype(np.int32),
)
orig_params = TensorDict()
grads = TensorDict()
for param in mlp.parameters():
orig_params[param] = np.copy(param.numpy())
grads[param] = np.copy(param.grad.numpy())
opt.step()
for param in mlp.parameters():
slot = slots[param]
orig_param = orig_params[param]
slot *= 0.9
slot -= param.grad.numpy() * 0.01
assertTensorClose(param.numpy(), orig_param + slot)
def test_update_lr():
data, data_shape, label, label_shape = get_input()
mlp = MLP()
opt = SGD(mlp.parameters(), lr=0.01)
pred = mlp(data) pred = mlp(data)
loss = F.square_loss(pred, label.reshape(-1, 1)) loss = F.square_loss(pred, label.reshape(-1, 1))
opt.zero_grad() opt.zero_grad()
opt.backward(loss) opt.backward(loss)
opt.step() opt.step()
for group in opt.param_groups: for param in mlp.parameters():
group["lr"] += 0.02 slots[param] = slots[param] * 0.9 + param.grad.numpy()
for _ in range(3):
with BytesIO() as fout:
save(opt.state_dict(), fout)
fout.seek(0)
state_dict = load(fout)
opt1 = optimizer.SGD(mlp.parameters(), lr=0.02, momentum=0.8)
opt1.load_state_dict(state_dict)
data.set_value(np.random.random(data_shape).astype(np.float32)) data.set_value(np.random.random(data_shape).astype(np.float32))
label.set_value(np.random.randint(0, 10, label_shape)) label.set_value(np.random.randint(0, 10, label_shape))
pred = mlp(data) pred = mlp(data)
loss = F.square_loss(pred, label.reshape(-1, 1)) loss = F.square_loss(pred, label.reshape(-1, 1))
opt.zero_grad() opt1.zero_grad()
opt.backward(loss) opt1.backward(loss)
orig_params = TensorDict()
for param in mlp.parameters(): for param in mlp.parameters():
grad = F.grad(loss, param, use_virtual_grad=False) orig_params[param] = np.copy(param.numpy())
assertTensorClose(grad.numpy(), param.grad.numpy()) opt1.step()
orig_params = []
for param in mlp.parameters(): for param in mlp.parameters():
orig_params.append(np.copy(param.numpy())) orig_param = orig_params[param]
opt.step() slots[param] = slots[param] * 0.9 + param.grad.numpy()
for param, orig_param in zip(mlp.parameters(), orig_params): assertTensorClose(param.numpy(), orig_param - 0.01 * slots[param])
assertTensorClose(param.numpy(), orig_param - param.grad.numpy() * 0.03)
def test_adam(): def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
iter_num = 3
data, data_shape, label, label_shape = get_input() data, data_shape, label, label_shape = get_input()
mlp = MLP()
beta0 = 0.8
beta1 = 0.9
eps = 1e-4
opt = Adam(mlp.parameters(), lr=0.01, betas=(beta0, beta1), eps=eps)
m_slots = TensorDict()
v_slots = TensorDict()
for param in mlp.parameters():
m_slots[param] = np.zeros(param.shape).astype(np.float32)
v_slots[param] = np.zeros(param.shape).astype(np.float32)
step_size = 0
def check_value(): net = MLP()
for param in mlp.parameters(): opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
grad = param.grad.numpy() check_func = check_class(net, **test_case)
orig_param = orig_params[param]
m = m_slots[param]
v = v_slots[param]
m *= beta0
m += (1 - beta0) * grad
v *= beta1
v += (1 - beta1) * grad * grad
update = (m / (1 - beta0 ** step_size)) / (
np.sqrt(v / (1 - beta1 ** step_size)) + eps
)
assertTensorClose(param.numpy(), orig_param - 0.01 * update)
# eager step = 0
for _ in range(3):
# eager graph
for i in range(iter_num):
if update_lr and i == 1: # change learning rate
for group in opt.param_groups:
group["lr"] += 0.01
check_func.lr += 0.01
data.set_value(np.random.random(data_shape).astype(np.float32)) data.set_value(np.random.random(data_shape).astype(np.float32))
label.set_value(np.random.randint(0, 10, label_shape)) label.set_value(np.random.randint(0, 10, label_shape))
pred = mlp(data) pred = net(data)
loss = F.square_loss(pred, label.reshape(-1, 1)) loss = F.square_loss(pred, label.reshape(-1, 1))
opt.zero_grad() opt.zero_grad()
grads = opt.backward(loss) opt.backward(loss)
orig_params = TensorDict() ori_params = TensorDict()
for param in mlp.parameters(): for param in net.parameters():
orig_params[param] = np.copy(param.numpy()) ori_params[param] = np.copy(param.numpy())
opt.step() opt.step()
step_size += 1 step += 1
check_value() check_func(ori_params, net.parameters(), step)
# static # static graph
@trace @trace
def f(data, label): def train_func(data, label):
pred = mlp(data) pred = net(data)
loss = F.square_loss(pred, label.reshape(-1, 1)) loss = F.square_loss(pred, label.reshape(-1, 1))
opt.backward(loss) opt.backward(loss)
for _ in range(3): for i in range(iter_num):
if update_lr and i == 1: # change learning rate
for group in opt.param_groups:
group["lr"] += 0.01
check_func.lr += 0.01
opt.zero_grad() opt.zero_grad()
orig_params = TensorDict() ori_params = TensorDict()
for param in mlp.parameters(): for param in net.parameters():
orig_params[param] = np.copy(param.numpy()) ori_params[param] = np.copy(param.numpy())
f( train_func(
np.random.random(data_shape).astype(np.float32), np.random.random(data_shape).astype(np.float32),
np.random.randint(0, 10, label_shape).astype(np.int32), np.random.randint(0, 10, label_shape).astype(np.int32),
) )
opt.step() opt.step()
step_size += 1 step += 1
check_value() check_func(ori_params, net.parameters(), step)
def test_sgd():
class CheckValue:
def __init__(self, net, **kwarg):
self.slots = TensorDict()
for param in net.parameters():
self.slots[param] = np.zeros(param.shape).astype(np.float32)
for k, v in kwarg.items():
setattr(self, k, v)
def __call__(self, ori_params, new_params, step):
for param in new_params:
grad = param.grad.numpy()
if hasattr(self, "momentum"):
self.slots[param] = grad + self.slots[param] * self.momentum
delta = -self.lr * self.slots[param]
else:
delta = -self.lr * grad
assertTensorClose(param.numpy(), ori_params[param] + delta)
cases = [
{"momentum": 0.9, "lr": 0.01}, # SGD with momentum
{"lr": 0.01}, # simple SGD
{"weight_decay": 0.1, "lr": 0.01}, # with weight_decay
]
for case in cases:
_test_optimizer("SGD", case, CheckValue)
_test_optimizer("SGD", case, CheckValue, update_lr=True)
@graph_mode("eager", "static") def test_adam():
def test_optimizer_serialization(): class CheckValue:
data, data_shape, label, label_shape = get_input() def __init__(self, net, **kwarg):
mlp = MLP() self.m_slots = TensorDict()
opt = SGD(mlp.parameters(), lr=0.01, momentum=0.9) self.v_slots = TensorDict()
slots = TensorDict() for param in net.parameters():
for param in mlp.parameters(): self.m_slots[param] = np.zeros(param.shape).astype(np.float32)
slots[param] = np.zeros(param.shape).astype(np.float32) self.v_slots[param] = np.zeros(param.shape).astype(np.float32)
for k, v in kwarg.items():
pred = mlp(data) setattr(self, k, v)
loss = F.square_loss(pred, label.reshape(-1, 1))
opt.zero_grad() def __call__(self, ori_params, new_params, step):
opt.backward(loss) for param in new_params:
opt.step() grad = param.grad.numpy()
for param in mlp.parameters(): m = self.m_slots[param]
slot = slots[param] v = self.v_slots[param]
slot *= 0.9 m *= self.betas[0]
slot -= param.grad.numpy() * 0.01 m += (1 - self.betas[0]) * grad
v *= self.betas[1]
with BytesIO() as fout: v += (1 - self.betas[1]) * grad * grad
save(opt.state_dict(), fout) delta = (m / (1 - self.betas[0] ** step)) / (
fout.seek(0) np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps
state_dict = load(fout) )
opt1 = SGD(mlp.parameters(), lr=0.02, momentum=0.8) assertTensorClose(param.numpy(), ori_params[param] - self.lr * delta)
opt1.load_state_dict(state_dict)
cases = [
data.set_value(np.random.random(data_shape).astype(np.float32)) {"betas": (0.8, 0.9), "eps": 1e-04, "lr": 0.01},
label.set_value(np.random.randint(0, 10, label_shape)) {
pred = mlp(data) "betas": (0.8, 0.9),
loss = F.square_loss(pred, label.reshape(-1, 1)) "eps": 1e-04,
opt1.zero_grad() "lr": 0.01,
opt1.backward(loss) "weight_decay": 0.1,
orig_params = TensorDict() }, # with weight_decay
for param in mlp.parameters(): ]
orig_params[param] = np.copy(param.numpy()) for case in cases:
opt1.step() _test_optimizer("Adam", case, CheckValue)
for param in mlp.parameters(): _test_optimizer("Adam", case, CheckValue, update_lr=True)
orig_param = orig_params[param]
slot = slots[param]
slot *= 0.9
slot -= param.grad.numpy() * 0.01
assertTensorClose(param.numpy(), orig_param + slot)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册