提交 31031804 编写于 作者: M Megvii Engine Team

fix(mge/optimizer): only disable convert inputs in build-in optimizers

GitOrigin-RevId: 1a48fe318dc29c80c8f244923d64de79d9abd6b8
上级 fd24dc8e
......@@ -48,6 +48,7 @@ class Adadelta(Optimizer):
defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
self._disable_type_convert = True
def _create_state(self, param_group):
for param in param_group["params"]:
......
......@@ -48,6 +48,7 @@ class Adagrad(Optimizer):
defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults)
self._disable_type_convert = True
def _create_state(self, param_group):
for param in param_group["params"]:
......
......@@ -47,6 +47,7 @@ class Adam(Optimizer):
defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps)
super().__init__(params, defaults)
self._disable_type_convert = True
def _create_state(self, param_group):
for param in param_group["params"]:
......
......@@ -47,6 +47,7 @@ class AdamW(Optimizer):
defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps)
super().__init__(params, defaults)
self._disable_type_convert = True
def _create_state(self, param_group):
for param in param_group["params"]:
......
......@@ -42,6 +42,7 @@ class Optimizer(metaclass=ABCMeta):
):
self._state = dict()
self._defaults = defaults
self._disable_type_convert = False
if isinstance(params, (Parameter, dict)):
params = [params]
......@@ -149,7 +150,8 @@ class Optimizer(metaclass=ABCMeta):
# set the globle state `_enable_convert_inputs` to `False` to disable
# the `convert_inputs` for param updates
set_option("record_computing_path", 0)
backup = set_convert_inputs(False)
if self._disable_type_convert:
backup = set_convert_inputs(False)
for group in self.param_groups:
if isinstance(group["params"], set):
raise TypeError(
......@@ -160,8 +162,9 @@ class Optimizer(metaclass=ABCMeta):
push_scope("step")
self._updates(group)
pop_scope("step")
# restore the globle state `_enable_convert_inputs`
set_convert_inputs(backup)
if self._disable_type_convert:
# restore the globle state `_enable_convert_inputs`
set_convert_inputs(backup)
set_option("record_computing_path", 1)
return self
......
......@@ -43,6 +43,7 @@ class SGD(Optimizer):
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
super().__init__(params, defaults)
self._disable_type_convert = True
def _create_state(self, param_group):
if param_group["momentum"] != 0.0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册