From c44005f0f83ca8ba4dda60f15424ac01a80a449f Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Thu, 9 Feb 2023 11:02:49 +0800 Subject: [PATCH] fixoptminizer _set_auxiliary_var bug (#50335) --- .../fleet/meta_optimizers/ascend/ascend_optimizer.py | 4 ++++ .../distributed/fleet/meta_optimizers/meta_optimizer_base.py | 4 ++++ .../meta_parallel/sharding/group_sharded_optimizer_stage2.py | 4 ++++ python/paddle/incubate/optimizer/lookahead.py | 4 ++++ 4 files changed, 16 insertions(+) diff --git a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py index e30a84e128..8e88a213b5 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/ascend/ascend_optimizer.py @@ -236,6 +236,10 @@ class AscendOptimizer(Optimizer): ret_list.append(var) return ret_list + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self.inner_opt._set_auxiliary_var(key, val) + def minimize( self, loss, diff --git a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py index 87085a322c..9a7660ebd7 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py +++ b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py @@ -25,6 +25,10 @@ class MetaOptimizerBase(Optimizer): self.meta_optimizers_white_list = [] self.meta_optimizers_black_list = [] + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self.inner_opt._set_auxiliary_var(key, val) + def _set_basic_info( self, loss, role_maker, user_defined_optimizer, user_defined_strategy ): diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 00ec12a523..639bdf79ac 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -203,6 +203,10 @@ class GroupShardedOptimizerStage2(Optimizer): # Update optimizer parameters and adjust parameter storage and use according to rank. self._update_opt_status() + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self._optim._set_auxiliary_var(key, val) + @paddle.autograd.no_grad() def _sync_params_and_buffers(self): """ diff --git a/python/paddle/incubate/optimizer/lookahead.py b/python/paddle/incubate/optimizer/lookahead.py index b1ad5f3ecb..bfa08c4055 100644 --- a/python/paddle/incubate/optimizer/lookahead.py +++ b/python/paddle/incubate/optimizer/lookahead.py @@ -144,6 +144,10 @@ class LookAhead(Optimizer): self._global_step_var = None self._k_var = None + def _set_auxiliary_var(self, key, val): + super()._set_auxiliary_var(key, val) + self.inner_optimizer._set_auxiliary_var(key, val) + @framework.dygraph_only @imperative_base.no_grad def step(self): -- GitLab