未验证 提交 aca450f6 编写于 作者: S ShenLiang 提交者: GitHub

fix the localsgd optimizer (#27094)

* fix the localsgd
上级 c1a88687
......@@ -14,8 +14,8 @@
from __future__ import print_function
import paddle
from paddle.fluid import program_guard, layers, default_main_program
from paddle.fluid.optimizer import Momentum, SGD
from .meta_optimizer_base import MetaOptimizerBase
from .common import OpRole, OP_ROLE_KEY, CollectiveHelper, is_update_op
......@@ -35,8 +35,10 @@ class LocalSGDOptimizer(MetaOptimizerBase):
if self.role_maker.worker_num() <= 1:
return False
return isinstance(self.inner_opt, Momentum) \
or isinstance(self.inner_opt, SGD)
return isinstance(self.inner_opt, paddle.optimizer.momentum.Momentum) \
or isinstance(self.inner_opt, paddle.fluid.optimizer.Momentum) \
or isinstance(self.inner_opt, paddle.optimizer.sgd.SGD) \
or isinstance(self.inner_opt, paddle.fluid.optimizer.SGD)
def _disable_strategy(self, dist_strategy):
dist_strategy.localsgd = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册