From 31031804561b512e3fb1d98d993f37ab498a5216 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 12 Aug 2021 19:11:15 +0800 Subject: [PATCH] fix(mge/optimizer): only disable convert inputs in build-in optimizers GitOrigin-RevId: 1a48fe318dc29c80c8f244923d64de79d9abd6b8 --- imperative/python/megengine/optimizer/adadelta.py | 1 + imperative/python/megengine/optimizer/adagrad.py | 1 + imperative/python/megengine/optimizer/adam.py | 1 + imperative/python/megengine/optimizer/adamw.py | 1 + imperative/python/megengine/optimizer/optimizer.py | 9 ++++++--- imperative/python/megengine/optimizer/sgd.py | 1 + 6 files changed, 11 insertions(+), 3 deletions(-) diff --git a/imperative/python/megengine/optimizer/adadelta.py b/imperative/python/megengine/optimizer/adadelta.py index 1c321d211..81565a1c8 100644 --- a/imperative/python/megengine/optimizer/adadelta.py +++ b/imperative/python/megengine/optimizer/adadelta.py @@ -48,6 +48,7 @@ class Adadelta(Optimizer): defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/adagrad.py b/imperative/python/megengine/optimizer/adagrad.py index c983c7916..fadbf48f7 100644 --- a/imperative/python/megengine/optimizer/adagrad.py +++ b/imperative/python/megengine/optimizer/adagrad.py @@ -48,6 +48,7 @@ class Adagrad(Optimizer): defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/adam.py b/imperative/python/megengine/optimizer/adam.py index 40d5eec5b..9e51c90a2 100644 --- a/imperative/python/megengine/optimizer/adam.py +++ b/imperative/python/megengine/optimizer/adam.py @@ -47,6 +47,7 @@ class Adam(Optimizer): defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/adamw.py b/imperative/python/megengine/optimizer/adamw.py index aec655e0c..cd3f2d918 100644 --- a/imperative/python/megengine/optimizer/adamw.py +++ b/imperative/python/megengine/optimizer/adamw.py @@ -47,6 +47,7 @@ class AdamW(Optimizer): defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index 8b2d48586..b6f60cd7d 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -42,6 +42,7 @@ class Optimizer(metaclass=ABCMeta): ): self._state = dict() self._defaults = defaults + self._disable_type_convert = False if isinstance(params, (Parameter, dict)): params = [params] @@ -149,7 +150,8 @@ class Optimizer(metaclass=ABCMeta): # set the globle state `_enable_convert_inputs` to `False` to disable # the `convert_inputs` for param updates set_option("record_computing_path", 0) - backup = set_convert_inputs(False) + if self._disable_type_convert: + backup = set_convert_inputs(False) for group in self.param_groups: if isinstance(group["params"], set): raise TypeError( @@ -160,8 +162,9 @@ class Optimizer(metaclass=ABCMeta): push_scope("step") self._updates(group) pop_scope("step") - # restore the globle state `_enable_convert_inputs` - set_convert_inputs(backup) + if self._disable_type_convert: + # restore the globle state `_enable_convert_inputs` + set_convert_inputs(backup) set_option("record_computing_path", 1) return self diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index 5ed256d20..9c939eb36 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -43,6 +43,7 @@ class SGD(Optimizer): defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): if param_group["momentum"] != 0.0: -- GitLab