提交 31031804 编写于 作者: M Megvii Engine Team

fix(mge/optimizer): only disable convert inputs in build-in optimizers

GitOrigin-RevId: 1a48fe318dc29c80c8f244923d64de79d9abd6b8
上级 fd24dc8e
...@@ -48,6 +48,7 @@ class Adadelta(Optimizer): ...@@ -48,6 +48,7 @@ class Adadelta(Optimizer):
defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults) super().__init__(params, defaults)
self._disable_type_convert = True
def _create_state(self, param_group): def _create_state(self, param_group):
for param in param_group["params"]: for param in param_group["params"]:
......
...@@ -48,6 +48,7 @@ class Adagrad(Optimizer): ...@@ -48,6 +48,7 @@ class Adagrad(Optimizer):
defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults) super().__init__(params, defaults)
self._disable_type_convert = True
def _create_state(self, param_group): def _create_state(self, param_group):
for param in param_group["params"]: for param in param_group["params"]:
......
...@@ -47,6 +47,7 @@ class Adam(Optimizer): ...@@ -47,6 +47,7 @@ class Adam(Optimizer):
defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps)
super().__init__(params, defaults) super().__init__(params, defaults)
self._disable_type_convert = True
def _create_state(self, param_group): def _create_state(self, param_group):
for param in param_group["params"]: for param in param_group["params"]:
......
...@@ -47,6 +47,7 @@ class AdamW(Optimizer): ...@@ -47,6 +47,7 @@ class AdamW(Optimizer):
defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps)
super().__init__(params, defaults) super().__init__(params, defaults)
self._disable_type_convert = True
def _create_state(self, param_group): def _create_state(self, param_group):
for param in param_group["params"]: for param in param_group["params"]:
......
...@@ -42,6 +42,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -42,6 +42,7 @@ class Optimizer(metaclass=ABCMeta):
): ):
self._state = dict() self._state = dict()
self._defaults = defaults self._defaults = defaults
self._disable_type_convert = False
if isinstance(params, (Parameter, dict)): if isinstance(params, (Parameter, dict)):
params = [params] params = [params]
...@@ -149,6 +150,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -149,6 +150,7 @@ class Optimizer(metaclass=ABCMeta):
# set the globle state `_enable_convert_inputs` to `False` to disable # set the globle state `_enable_convert_inputs` to `False` to disable
# the `convert_inputs` for param updates # the `convert_inputs` for param updates
set_option("record_computing_path", 0) set_option("record_computing_path", 0)
if self._disable_type_convert:
backup = set_convert_inputs(False) backup = set_convert_inputs(False)
for group in self.param_groups: for group in self.param_groups:
if isinstance(group["params"], set): if isinstance(group["params"], set):
...@@ -160,6 +162,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -160,6 +162,7 @@ class Optimizer(metaclass=ABCMeta):
push_scope("step") push_scope("step")
self._updates(group) self._updates(group)
pop_scope("step") pop_scope("step")
if self._disable_type_convert:
# restore the globle state `_enable_convert_inputs` # restore the globle state `_enable_convert_inputs`
set_convert_inputs(backup) set_convert_inputs(backup)
set_option("record_computing_path", 1) set_option("record_computing_path", 1)
......
...@@ -43,6 +43,7 @@ class SGD(Optimizer): ...@@ -43,6 +43,7 @@ class SGD(Optimizer):
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
super().__init__(params, defaults) super().__init__(params, defaults)
self._disable_type_convert = True
def _create_state(self, param_group): def _create_state(self, param_group):
if param_group["momentum"] != 0.0: if param_group["momentum"] != 0.0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册