提交 30d6b4f6 编写于 作者: M Megvii Engine Team

fix(mge): fix scalar parameter change to 1-dim parameter after backward and optimize

GitOrigin-RevId: 1794369a71251475cbe8f839cbf35f91a3adee99
上级 cf27dd64
...@@ -279,6 +279,8 @@ class GradManager: ...@@ -279,6 +279,8 @@ class GradManager:
tensor.grad = grad tensor.grad = grad
else: else:
tensor.grad += grad tensor.grad += grad
if tensor.isscalar() and tensor.grad is not None:
tensor.grad.setscalar()
finally: finally:
self.release() self.release()
backwarding_grad_manager = cache backwarding_grad_manager = cache
......
...@@ -12,4 +12,8 @@ from ..core.ops.builtin import InplaceAdd ...@@ -12,4 +12,8 @@ from ..core.ops.builtin import InplaceAdd
def _inplace_add_(dest, delta, alpha, beta): def _inplace_add_(dest, delta, alpha, beta):
return dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) isscalar = dest.isscalar()
dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0])
if isscalar:
dest.setscalar()
return dest
...@@ -61,16 +61,19 @@ class Adadelta(Optimizer): ...@@ -61,16 +61,19 @@ class Adadelta(Optimizer):
rho = param_group["rho"] rho = param_group["rho"]
eps = param_group["eps"] eps = param_group["eps"]
def make_scalar(val):
return tensor(val)
# since `conver_inputs` is disabled for param updates, # since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor # scalar should be explicitly tansforred to tensor
_lr = tensor([lr])
_weight_decay = tensor([weight_decay]) _lr = make_scalar(lr)
_rho = tensor([rho]) _weight_decay = make_scalar(weight_decay)
_eps = tensor([eps]) _rho = make_scalar(rho)
_eps = make_scalar(eps)
c05 = tensor([0.5])
c1 = tensor([1.0]) c1, c2, c05 = map(make_scalar, (1.0, 2.0, 0.5))
c2 = tensor([2.0])
for param in param_group["params"]: for param in param_group["params"]:
if param.grad is None: if param.grad is None:
......
...@@ -60,16 +60,18 @@ class Adagrad(Optimizer): ...@@ -60,16 +60,18 @@ class Adagrad(Optimizer):
weight_decay = param_group["weight_decay"] weight_decay = param_group["weight_decay"]
eps = param_group["eps"] eps = param_group["eps"]
def make_scalar(val):
return tensor(val)
# since `conver_inputs` is disabled for param updates, # since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor # scalar should be explicitly tansforred to tensor
_lr = tensor([lr])
_lr_decay = tensor([lr_decay]) _lr, _lr_decay = map(make_scalar, (lr, lr_decay))
_weight_decay = tensor([weight_decay]) _weight_decay = make_scalar(weight_decay)
_eps = tensor([eps]) _eps = make_scalar(eps)
c05 = tensor([0.5]) c1, c2, c05 = map(make_scalar, (1.0, 2.0, 0.5))
c1 = tensor([1.0])
c2 = tensor([2.0])
for param in param_group["params"]: for param in param_group["params"]:
if param.grad is None: if param.grad is None:
......
...@@ -61,7 +61,7 @@ class Adam(Optimizer): ...@@ -61,7 +61,7 @@ class Adam(Optimizer):
beta0, beta1 = param_group["betas"] beta0, beta1 = param_group["betas"]
def make_scalar(val): def make_scalar(val):
return tensor([val]) return tensor(val)
# since `conver_inputs` is disabled for param updates, # since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor # scalar should be explicitly tansforred to tensor
......
...@@ -57,13 +57,13 @@ class SGD(Optimizer): ...@@ -57,13 +57,13 @@ class SGD(Optimizer):
# since `conver_inputs` is disabled for param updates, # since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor # scalar should be explicitly tansforred to tensor
_lr = tensor([lr]) _lr = tensor(lr)
_weight_decay = tensor([weight_decay]) _weight_decay = tensor(weight_decay)
_momentum = tensor([momentum]) _momentum = tensor(momentum)
inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0"))
if inplace_mode: if inplace_mode:
_neg_lr = tensor([-lr]) _neg_lr = tensor(-lr)
c1 = tensor([1.0]) c1 = tensor([1.0])
for param in param_group["params"]: for param in param_group["params"]:
......
...@@ -32,7 +32,7 @@ class MLP(Module): ...@@ -32,7 +32,7 @@ class MLP(Module):
class Simple(Module): class Simple(Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.a = Parameter([1.23], dtype=np.float32) self.a = Parameter(1.23, dtype=np.float32)
def forward(self, x): def forward(self, x):
x = x * self.a x = x * self.a
...@@ -64,6 +64,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -64,6 +64,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
ori_params = {} ori_params = {}
for param in net.parameters(): for param in net.parameters():
assert param._tuple_shape is ()
ori_params[param] = np.copy(param.numpy()) ori_params[param] = np.copy(param.numpy())
opt.step() opt.step()
step += 1 step += 1
...@@ -95,6 +96,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): ...@@ -95,6 +96,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
ori_params = {} ori_params = {}
for param in net.parameters(): for param in net.parameters():
assert param._tuple_shape is ()
ori_params[param] = np.copy(param.numpy()) ori_params[param] = np.copy(param.numpy())
train_func( train_func(
...@@ -121,7 +123,9 @@ def test_sgd(): ...@@ -121,7 +123,9 @@ def test_sgd():
delta = -self.lr * self.slots[param] delta = -self.lr * self.slots[param]
else: else:
delta = -self.lr * grad delta = -self.lr * grad
np.testing.assert_almost_equal(param.numpy(), ori_params[param] + delta) np.testing.assert_almost_equal(
param.numpy(), ori_params[param] + delta, decimal=6
)
cases = [ cases = [
{"momentum": 0.9, "lr": 0.01}, # SGD with momentum {"momentum": 0.9, "lr": 0.01}, # SGD with momentum
...@@ -157,7 +161,7 @@ def test_adam(): ...@@ -157,7 +161,7 @@ def test_adam():
np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps
) )
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
param.numpy(), ori_params[param] - self.lr * delta param.numpy(), ori_params[param] - self.lr * delta, decimal=6
) )
cases = [ cases = [
...@@ -189,7 +193,9 @@ def test_adagrad(): ...@@ -189,7 +193,9 @@ def test_adagrad():
self.s_slots[param] += grad ** 2 self.s_slots[param] += grad ** 2
delta = grad / (self.s_slots[param] + self.eps) ** 0.5 delta = grad / (self.s_slots[param] + self.eps) ** 0.5
delta *= -(self.lr / (1 + (step - 1) * self.lr_decay)) delta *= -(self.lr / (1 + (step - 1) * self.lr_decay))
np.testing.assert_almost_equal(param.numpy(), ori_params[param] + delta) np.testing.assert_almost_equal(
param.numpy(), ori_params[param] + delta, decimal=6
)
cases = [ cases = [
{"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01}, {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01},
...@@ -232,7 +238,9 @@ def test_adadelta(): ...@@ -232,7 +238,9 @@ def test_adadelta():
1 - self.rho 1 - self.rho
) )
delta *= -self.lr delta *= -self.lr
np.testing.assert_almost_equal(param.numpy(), ori_params[param] + delta) np.testing.assert_almost_equal(
param.numpy(), ori_params[param] + delta, decimal=6
)
cases = [ cases = [
{"lr": 1.0, "eps": 1e-06, "rho": 0.9}, {"lr": 1.0, "eps": 1e-06, "rho": 0.9},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册