提交 30d6b4f6 编写于 作者: M Megvii Engine Team

fix(mge): fix scalar parameter change to 1-dim parameter after backward and optimize

GitOrigin-RevId: 1794369a71251475cbe8f839cbf35f91a3adee99
上级 cf27dd64
......@@ -279,6 +279,8 @@ class GradManager:
tensor.grad = grad
else:
tensor.grad += grad
if tensor.isscalar() and tensor.grad is not None:
tensor.grad.setscalar()
finally:
self.release()
backwarding_grad_manager = cache
......
......@@ -12,4 +12,8 @@ from ..core.ops.builtin import InplaceAdd
def _inplace_add_(dest, delta, alpha, beta):
return dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0])
isscalar = dest.isscalar()
dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0])
if isscalar:
dest.setscalar()
return dest
......@@ -61,16 +61,19 @@ class Adadelta(Optimizer):
rho = param_group["rho"]
eps = param_group["eps"]
def make_scalar(val):
return tensor(val)
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr = tensor([lr])
_weight_decay = tensor([weight_decay])
_rho = tensor([rho])
_eps = tensor([eps])
c05 = tensor([0.5])
c1 = tensor([1.0])
c2 = tensor([2.0])
_lr = make_scalar(lr)
_weight_decay = make_scalar(weight_decay)
_rho = make_scalar(rho)
_eps = make_scalar(eps)
c1, c2, c05 = map(make_scalar, (1.0, 2.0, 0.5))
for param in param_group["params"]:
if param.grad is None:
......
......@@ -60,16 +60,18 @@ class Adagrad(Optimizer):
weight_decay = param_group["weight_decay"]
eps = param_group["eps"]
def make_scalar(val):
return tensor(val)
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr = tensor([lr])
_lr_decay = tensor([lr_decay])
_weight_decay = tensor([weight_decay])
_eps = tensor([eps])
c05 = tensor([0.5])
c1 = tensor([1.0])
c2 = tensor([2.0])
_lr, _lr_decay = map(make_scalar, (lr, lr_decay))
_weight_decay = make_scalar(weight_decay)
_eps = make_scalar(eps)
c1, c2, c05 = map(make_scalar, (1.0, 2.0, 0.5))
for param in param_group["params"]:
if param.grad is None:
......
......@@ -61,7 +61,7 @@ class Adam(Optimizer):
beta0, beta1 = param_group["betas"]
def make_scalar(val):
return tensor([val])
return tensor(val)
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
......
......@@ -57,13 +57,13 @@ class SGD(Optimizer):
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr = tensor([lr])
_weight_decay = tensor([weight_decay])
_momentum = tensor([momentum])
_lr = tensor(lr)
_weight_decay = tensor(weight_decay)
_momentum = tensor(momentum)
inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0"))
if inplace_mode:
_neg_lr = tensor([-lr])
_neg_lr = tensor(-lr)
c1 = tensor([1.0])
for param in param_group["params"]:
......
......@@ -32,7 +32,7 @@ class MLP(Module):
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter([1.23], dtype=np.float32)
self.a = Parameter(1.23, dtype=np.float32)
def forward(self, x):
x = x * self.a
......@@ -64,6 +64,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
ori_params = {}
for param in net.parameters():
assert param._tuple_shape is ()
ori_params[param] = np.copy(param.numpy())
opt.step()
step += 1
......@@ -95,6 +96,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
ori_params = {}
for param in net.parameters():
assert param._tuple_shape is ()
ori_params[param] = np.copy(param.numpy())
train_func(
......@@ -121,7 +123,9 @@ def test_sgd():
delta = -self.lr * self.slots[param]
else:
delta = -self.lr * grad
np.testing.assert_almost_equal(param.numpy(), ori_params[param] + delta)
np.testing.assert_almost_equal(
param.numpy(), ori_params[param] + delta, decimal=6
)
cases = [
{"momentum": 0.9, "lr": 0.01}, # SGD with momentum
......@@ -157,7 +161,7 @@ def test_adam():
np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps
)
np.testing.assert_almost_equal(
param.numpy(), ori_params[param] - self.lr * delta
param.numpy(), ori_params[param] - self.lr * delta, decimal=6
)
cases = [
......@@ -189,7 +193,9 @@ def test_adagrad():
self.s_slots[param] += grad ** 2
delta = grad / (self.s_slots[param] + self.eps) ** 0.5
delta *= -(self.lr / (1 + (step - 1) * self.lr_decay))
np.testing.assert_almost_equal(param.numpy(), ori_params[param] + delta)
np.testing.assert_almost_equal(
param.numpy(), ori_params[param] + delta, decimal=6
)
cases = [
{"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01},
......@@ -232,7 +238,9 @@ def test_adadelta():
1 - self.rho
)
delta *= -self.lr
np.testing.assert_almost_equal(param.numpy(), ori_params[param] + delta)
np.testing.assert_almost_equal(
param.numpy(), ori_params[param] + delta, decimal=6
)
cases = [
{"lr": 1.0, "eps": 1e-06, "rho": 0.9},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册