diff --git a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py index 3ac2fd374a80fcfa4115187cc946069e02f633ea..3c1318301bb37bea71b896c220eb4a2090b334bf 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py @@ -14,8 +14,8 @@ from __future__ import print_function +import paddle from paddle.fluid import program_guard, layers, default_main_program -from paddle.fluid.optimizer import Momentum, SGD from .meta_optimizer_base import MetaOptimizerBase from .common import OpRole, OP_ROLE_KEY, CollectiveHelper, is_update_op @@ -35,8 +35,10 @@ class LocalSGDOptimizer(MetaOptimizerBase): if self.role_maker.worker_num() <= 1: return False - return isinstance(self.inner_opt, Momentum) \ - or isinstance(self.inner_opt, SGD) + return isinstance(self.inner_opt, paddle.optimizer.momentum.Momentum) \ + or isinstance(self.inner_opt, paddle.fluid.optimizer.Momentum) \ + or isinstance(self.inner_opt, paddle.optimizer.sgd.SGD) \ + or isinstance(self.inner_opt, paddle.fluid.optimizer.SGD) def _disable_strategy(self, dist_strategy): dist_strategy.localsgd = False