diff --git a/imperative/python/megengine/optimizer/adadelta.py b/imperative/python/megengine/optimizer/adadelta.py index 1c321d2118d7f632e4969aca96520b34012f982a..81565a1c8b2cc0816463f632f780e8fc10988078 100644 --- a/imperative/python/megengine/optimizer/adadelta.py +++ b/imperative/python/megengine/optimizer/adadelta.py @@ -48,6 +48,7 @@ class Adadelta(Optimizer): defaults = dict(lr=lr, rho=rho, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/adagrad.py b/imperative/python/megengine/optimizer/adagrad.py index c983c7916bfde2cb57ef4f4b8c95de4828690b1c..fadbf48f763535c139f458c59d91fedfa0d1fd78 100644 --- a/imperative/python/megengine/optimizer/adagrad.py +++ b/imperative/python/megengine/optimizer/adagrad.py @@ -48,6 +48,7 @@ class Adagrad(Optimizer): defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/adam.py b/imperative/python/megengine/optimizer/adam.py index 40d5eec5bf24e28ffb59ed0a4215e71272fd16e1..9e51c90a2f7459c83a0854a7fdeb8e2aa5f5e91b 100644 --- a/imperative/python/megengine/optimizer/adam.py +++ b/imperative/python/megengine/optimizer/adam.py @@ -47,6 +47,7 @@ class Adam(Optimizer): defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/adamw.py b/imperative/python/megengine/optimizer/adamw.py index aec655e0c706a30db4838304eb6fc7a23e2cdcac..cd3f2d918f6f8e3f525fbd0f2dfe8323c8087695 100644 --- a/imperative/python/megengine/optimizer/adamw.py +++ b/imperative/python/megengine/optimizer/adamw.py @@ -47,6 +47,7 @@ class AdamW(Optimizer): defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): for param in param_group["params"]: diff --git a/imperative/python/megengine/optimizer/optimizer.py b/imperative/python/megengine/optimizer/optimizer.py index 8b2d485869fb2df23c62031db2425ae10aa96cfd..b6f60cd7d8dc4267e76b29440eae48bb25777766 100644 --- a/imperative/python/megengine/optimizer/optimizer.py +++ b/imperative/python/megengine/optimizer/optimizer.py @@ -42,6 +42,7 @@ class Optimizer(metaclass=ABCMeta): ): self._state = dict() self._defaults = defaults + self._disable_type_convert = False if isinstance(params, (Parameter, dict)): params = [params] @@ -149,7 +150,8 @@ class Optimizer(metaclass=ABCMeta): # set the globle state `_enable_convert_inputs` to `False` to disable # the `convert_inputs` for param updates set_option("record_computing_path", 0) - backup = set_convert_inputs(False) + if self._disable_type_convert: + backup = set_convert_inputs(False) for group in self.param_groups: if isinstance(group["params"], set): raise TypeError( @@ -160,8 +162,9 @@ class Optimizer(metaclass=ABCMeta): push_scope("step") self._updates(group) pop_scope("step") - # restore the globle state `_enable_convert_inputs` - set_convert_inputs(backup) + if self._disable_type_convert: + # restore the globle state `_enable_convert_inputs` + set_convert_inputs(backup) set_option("record_computing_path", 1) return self diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index 5ed256d209d6d4890525101c86b3a7cad6076220..9c939eb36c07a2966921a4484a13767575d84e92 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -43,6 +43,7 @@ class SGD(Optimizer): defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) super().__init__(params, defaults) + self._disable_type_convert = True def _create_state(self, param_group): if param_group["momentum"] != 0.0: