diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 845e331d55712afc54128a55c4d5323dc850188c..57aa21b05d24a87e2209be5cf63f83fe7842b174 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -16,6 +16,7 @@ from __future__ import print_function import paddle from .strategy_compiler import StrategyCompiler from .meta_optimizer_factory import MetaOptimizerFactory +from .distributed_strategy import DistributedStrategy from .runtime_factory import RuntimeFactory from .util_factory import UtilFactory @@ -209,7 +210,7 @@ class Fleet(object): assert self._runtime_handle is not None self._runtime_handle._stop_worker() - def distributed_optimizer(self, optimizer, strategy): + def distributed_optimizer(self, optimizer, strategy=None): """ distirbuted_optimizer Returns: @@ -225,6 +226,8 @@ class Fleet(object): optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) """ self.user_defined_optimizer = optimizer + if strategy == None: + strategy = DistributedStrategy() self.user_defined_strategy = strategy self.valid_strategy = None return self