未验证 提交 c44005f0 编写于 作者: W wanghuancoder 提交者: GitHub

fixoptminizer _set_auxiliary_var bug (#50335)

上级 0036316e
...@@ -236,6 +236,10 @@ class AscendOptimizer(Optimizer): ...@@ -236,6 +236,10 @@ class AscendOptimizer(Optimizer):
ret_list.append(var) ret_list.append(var)
return ret_list return ret_list
def _set_auxiliary_var(self, key, val):
super()._set_auxiliary_var(key, val)
self.inner_opt._set_auxiliary_var(key, val)
def minimize( def minimize(
self, self,
loss, loss,
......
...@@ -25,6 +25,10 @@ class MetaOptimizerBase(Optimizer): ...@@ -25,6 +25,10 @@ class MetaOptimizerBase(Optimizer):
self.meta_optimizers_white_list = [] self.meta_optimizers_white_list = []
self.meta_optimizers_black_list = [] self.meta_optimizers_black_list = []
def _set_auxiliary_var(self, key, val):
super()._set_auxiliary_var(key, val)
self.inner_opt._set_auxiliary_var(key, val)
def _set_basic_info( def _set_basic_info(
self, loss, role_maker, user_defined_optimizer, user_defined_strategy self, loss, role_maker, user_defined_optimizer, user_defined_strategy
): ):
......
...@@ -203,6 +203,10 @@ class GroupShardedOptimizerStage2(Optimizer): ...@@ -203,6 +203,10 @@ class GroupShardedOptimizerStage2(Optimizer):
# Update optimizer parameters and adjust parameter storage and use according to rank. # Update optimizer parameters and adjust parameter storage and use according to rank.
self._update_opt_status() self._update_opt_status()
def _set_auxiliary_var(self, key, val):
super()._set_auxiliary_var(key, val)
self._optim._set_auxiliary_var(key, val)
@paddle.autograd.no_grad() @paddle.autograd.no_grad()
def _sync_params_and_buffers(self): def _sync_params_and_buffers(self):
""" """
......
...@@ -144,6 +144,10 @@ class LookAhead(Optimizer): ...@@ -144,6 +144,10 @@ class LookAhead(Optimizer):
self._global_step_var = None self._global_step_var = None
self._k_var = None self._k_var = None
def _set_auxiliary_var(self, key, val):
super()._set_auxiliary_var(key, val)
self.inner_optimizer._set_auxiliary_var(key, val)
@framework.dygraph_only @framework.dygraph_only
@imperative_base.no_grad @imperative_base.no_grad
def step(self): def step(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册