From aca450f6fb94f73aa420057f63ac5668d616a858 Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Mon, 7 Sep 2020 20:06:54 +0800 Subject: [PATCH] fix the localsgd optimizer (#27094) * fix the localsgd --- .../fleet/meta_optimizers/localsgd_optimizer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py index 3ac2fd374a..3c1318301b 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py @@ -14,8 +14,8 @@ from __future__ import print_function +import paddle from paddle.fluid import program_guard, layers, default_main_program -from paddle.fluid.optimizer import Momentum, SGD from .meta_optimizer_base import MetaOptimizerBase from .common import OpRole, OP_ROLE_KEY, CollectiveHelper, is_update_op @@ -35,8 +35,10 @@ class LocalSGDOptimizer(MetaOptimizerBase): if self.role_maker.worker_num() <= 1: return False - return isinstance(self.inner_opt, Momentum) \ - or isinstance(self.inner_opt, SGD) + return isinstance(self.inner_opt, paddle.optimizer.momentum.Momentum) \ + or isinstance(self.inner_opt, paddle.fluid.optimizer.Momentum) \ + or isinstance(self.inner_opt, paddle.optimizer.sgd.SGD) \ + or isinstance(self.inner_opt, paddle.fluid.optimizer.SGD) def _disable_strategy(self, dist_strategy): dist_strategy.localsgd = False -- GitLab