提交 e9d6361e 编写于 作者: M Megvii Engine Team

fix(mge/optimizer): fix optimizer update step inplace add grad

GitOrigin-RevId: d677d1ca6b37bf94b89305a6102d2f1a11d6c872
上级 a3e098c8
...@@ -84,7 +84,7 @@ class Adadelta(Optimizer): ...@@ -84,7 +84,7 @@ class Adadelta(Optimizer):
step += c1 step += c1
grad = param.grad grad = param.grad
if weight_decay != 0.0: if weight_decay != 0.0:
grad += param * _weight_decay grad = grad + param * _weight_decay
square_avg = states["square_avg"] square_avg = states["square_avg"]
acc_delta = states["acc_delta"] acc_delta = states["acc_delta"]
......
...@@ -82,7 +82,7 @@ class Adagrad(Optimizer): ...@@ -82,7 +82,7 @@ class Adagrad(Optimizer):
step += c1 step += c1
grad = param.grad grad = param.grad
if weight_decay != 0.0: if weight_decay != 0.0:
grad += param * _weight_decay grad = grad + param * _weight_decay
square_avg = states["square_avg"] square_avg = states["square_avg"]
square_avg += grad ** c2 square_avg += grad ** c2
......
...@@ -85,7 +85,7 @@ class Adam(Optimizer): ...@@ -85,7 +85,7 @@ class Adam(Optimizer):
grad = param.grad grad = param.grad
if weight_decay != 0.0: if weight_decay != 0.0:
grad += param * _weight_decay grad = grad + param * _weight_decay
states = self._state[param] states = self._state[param]
......
...@@ -72,7 +72,7 @@ class SGD(Optimizer): ...@@ -72,7 +72,7 @@ class SGD(Optimizer):
grad = param.grad grad = param.grad
if weight_decay != 0.0: if weight_decay != 0.0:
grad += param * _weight_decay grad = grad + param * _weight_decay
if inplace_mode: if inplace_mode:
if momentum: if momentum:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册