提交 1fed5929 编写于 作者: M Megvii Engine Team

perf(mge/optimizer): close conver_inputs for optimizer step

GitOrigin-RevId: c710530d934e1be29e611322b75837c9b72a610c
上级 1f75c7ad
......@@ -16,6 +16,25 @@ from ..ops.special import Const
from ..tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
from .dtype import is_equal, is_quantize
_enable_convert_inputs = True
def get_convert_inputs():
""" get the curerent state of `_enable_convert_inputs` """
return _enable_convert_inputs
def set_convert_inputs(flag):
""" This function is a temporary workaround for reducing the overhead of operator
invocations. The function `convert_inputs` is disabled if the global state
`_enable_convert_inputs` is set to `False`, otherwise enabled. This function is for
internal use only, and should be removed when the tensor-like system is refactored.
"""
global _enable_convert_inputs
backup = _enable_convert_inputs
_enable_convert_inputs = flag
return backup
def dtype_promotion(inputs):
"""
......@@ -129,6 +148,9 @@ def convert_single_value(v, inputs, *, dtype=None, device=None):
def convert_inputs(*args: TensorBase):
if not _enable_convert_inputs:
return args
dtype = dtype_promotion(args)
device = get_device(args)
......
......@@ -10,8 +10,8 @@ from typing import Iterable, Union
import numpy as np
from ..functional import sqrt
from ..tensor import Parameter
from ..core.tensor.tensor import Tensor
from ..tensor import Parameter, tensor
from .optimizer import Optimizer
......@@ -62,6 +62,16 @@ class Adadelta(Optimizer):
rho = param_group["rho"]
eps = param_group["eps"]
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr = tensor([lr])
_weight_decay = tensor([weight_decay])
_rho = tensor([rho])
_eps = tensor([eps])
c05 = tensor([0.5])
c1 = tensor([1.0])
c2 = tensor([2.0])
for param in param_group["params"]:
if param.grad is None:
......@@ -69,17 +79,17 @@ class Adadelta(Optimizer):
states = self._state[param]
step = states["step"]
step += 1.0
step += c1
grad = param.grad
if weight_decay != 0.0:
grad += param * weight_decay
grad += param * _weight_decay
square_avg = states["square_avg"]
acc_delta = states["acc_delta"]
square_avg = rho * square_avg + (1 - rho) * grad ** 2
std = sqrt(square_avg + eps)
delta = sqrt(acc_delta + eps) / std * grad
param -= lr * delta
acc_delta = rho * acc_delta + (1 - rho) * delta ** 2
square_avg = _rho * square_avg + (c1 - _rho) * grad ** c2
std = (square_avg + _eps) ** c05
delta = (acc_delta + _eps) ** c05 / std * grad
param -= _lr * delta
acc_delta = _rho * acc_delta + (c1 - _rho) * delta ** c2
states["square_avg"]._reset(square_avg)
states["acc_delta"]._reset(acc_delta)
......@@ -10,8 +10,8 @@ from typing import Iterable, Union
import numpy as np
from ..functional import sqrt
from ..tensor import Parameter
from ..core.tensor.tensor import Tensor
from ..tensor import Parameter, tensor
from .optimizer import Optimizer
......@@ -61,6 +61,16 @@ class Adagrad(Optimizer):
weight_decay = param_group["weight_decay"]
eps = param_group["eps"]
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr = tensor([lr])
_lr_decay = tensor([lr_decay])
_weight_decay = tensor([weight_decay])
_eps = tensor([eps])
c05 = tensor([0.5])
c1 = tensor([1.0])
c2 = tensor([2.0])
for param in param_group["params"]:
if param.grad is None:
......@@ -68,14 +78,14 @@ class Adagrad(Optimizer):
states = self._state[param]
step = states["step"]
step += 1.0
step += c1
grad = param.grad
if weight_decay != 0.0:
grad += param * weight_decay
grad += param * _weight_decay
square_avg = states["square_avg"]
square_avg += grad ** 2
delta = grad / sqrt(square_avg + eps)
clr = lr / (1 + (step - 1) * lr_decay)
square_avg += grad ** c2
delta = grad / (square_avg + _eps) ** c05
clr = _lr / (c1 + (step - c1) * _lr_decay)
param -= clr * delta
......@@ -8,7 +8,8 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, Tuple, Union
from ..tensor import Parameter
from ..core.tensor.tensor import Tensor
from ..tensor import Parameter, tensor
from .optimizer import Optimizer
......@@ -58,6 +59,15 @@ class Adam(Optimizer):
eps = param_group["eps"]
beta0, beta1 = param_group["betas"]
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr = tensor([lr])
_weight_decay = tensor([weight_decay])
_eps = tensor([eps])
_beta0, _beta1 = tensor([beta0]), tensor([beta1])
c1 = tensor([1.0])
c05 = tensor([0.5])
for param in param_group["params"]:
if param.grad is None:
......@@ -65,20 +75,20 @@ class Adam(Optimizer):
grad = param.grad
if weight_decay != 0.0:
grad += param * weight_decay
grad += param * _weight_decay
states = self._state[param]
step = states["step"]
step += 1.0
step += c1
exp_avg = states["exp_avg"]
exp_avg_sq = states["exp_avg_sq"]
exp_avg = beta0 * exp_avg + grad * (1 - beta0)
exp_avg_sq = beta1 * exp_avg_sq + (1 - beta1) * (grad * grad)
exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0)
exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad)
delta = (exp_avg / (1 - beta0 ** step)) / (
(exp_avg_sq / (1 - beta1 ** step)) ** 0.5 + eps
delta = (exp_avg / (c1 - _beta0 ** step)) / (
(exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps
)
param -= lr * delta
param -= _lr * delta
# not inplace change, need to update underlying tensor handler in state
states["exp_avg"]._reset(exp_avg)
......
......@@ -15,6 +15,7 @@ from typing import Union
import numpy as np
from ..core.tensor.utils import set_convert_inputs
from ..tensor import Parameter, Tensor
from ..utils.deprecation import deprecated
......@@ -143,6 +144,9 @@ class Optimizer(metaclass=ABCMeta):
Performs a single optimization step.
"""
# set the globle state `_enable_convert_inputs` to `False` to disable
# the `convert_inputs` for param updates
backup = set_convert_inputs(False)
for group in self.param_groups:
if isinstance(group["params"], set):
raise TypeError(
......@@ -151,6 +155,8 @@ class Optimizer(metaclass=ABCMeta):
"Please use a list instead."
)
self._updates(group)
# restore the globle state `_enable_convert_inputs`
set_convert_inputs(backup)
return self
@deprecated(version="1.0", reason="use clear_grad instead")
......
......@@ -8,7 +8,8 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, Union
from ..tensor import Parameter
from ..core.tensor.tensor import Tensor
from ..tensor import Parameter, tensor
from .optimizer import Optimizer
......@@ -52,18 +53,24 @@ class SGD(Optimizer):
weight_decay = param_group["weight_decay"]
momentum = param_group["momentum"]
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr = tensor([lr])
_weight_decay = tensor([weight_decay])
_momentum = tensor([momentum])
for param in param_group["params"]:
if param.grad is None:
continue
grad = param.grad
if weight_decay != 0.0:
grad += param * weight_decay
grad += param * _weight_decay
if momentum:
v = self._state[param]["momentum_buffer"]
v = momentum * v + grad
param -= lr * v
v = _momentum * v + grad
param -= _lr * v
self._state[param]["momentum_buffer"]._reset(v)
else:
param -= lr * grad
param -= _lr * grad
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册