diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py index e30a84e12826ab6763bed911b33f5c2e643d6a8c..8e88a213b545447939447c4a0533583389004377 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py @@ -236,6 +236,10 @@ class AscendOptimizer(Optimizer): ret_list.append(var) return ret_list + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self.inner_opt._set_auxiliary_var(key, val) + def minimize( self, loss, diff --git a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py index 87085a322c30370f7e67868e07423d09980d7de8..9a7660ebd7dc1fd85ac8386bdaf17f95710d0f98 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py +++ b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py @@ -25,6 +25,10 @@ class MetaOptimizerBase(Optimizer): self.meta_optimizers_white_list = [] self.meta_optimizers_black_list = [] + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self.inner_opt._set_auxiliary_var(key, val) + def _set_basic_info( self, loss, role_maker, user_defined_optimizer, user_defined_strategy ): diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 00ec12a523f919a19c50c2ef357ae892a248469a..639bdf79ac9aa094ca634a20ad3eca00d1ec0eb4 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -203,6 +203,10 @@ class GroupShardedOptimizerStage2(Optimizer): # Update optimizer parameters and adjust parameter storage and use according to rank. self._update_opt_status() + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self._optim._set_auxiliary_var(key, val) + @paddle.autograd.no_grad() def _sync_params_and_buffers(self): """ diff --git a/python/paddle/incubate/optimizer/lookahead.py b/python/paddle/incubate/optimizer/lookahead.py index b1ad5f3ecb0b5cdee8b263ea121a8fa63ece31fc..bfa08c40556beca1111d250818a395956faa168e 100644 --- a/python/paddle/incubate/optimizer/lookahead.py +++ b/python/paddle/incubate/optimizer/lookahead.py @@ -144,6 +144,10 @@ class LookAhead(Optimizer): self._global_step_var = None self._k_var = None + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self.inner_optimizer._set_auxiliary_var(key, val) + @framework.dygraph_only @imperative_base.no_grad def step(self):