diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 845e331d55712afc54128a55c4d5323dc850188c..695fd01909c0e6eecb37b34a120f156fb5fed090 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -15,6 +15,7 @@ from __future__ import print_function import paddle from .strategy_compiler import StrategyCompiler +from .distributed_strategy import DistributedStrategy from .meta_optimizer_factory import MetaOptimizerFactory from .runtime_factory import RuntimeFactory from .util_factory import UtilFactory @@ -209,7 +210,7 @@ class Fleet(object): assert self._runtime_handle is not None self._runtime_handle._stop_worker() - def distributed_optimizer(self, optimizer, strategy): + def distributed_optimizer(self, optimizer, strategy=None): """ distirbuted_optimizer Returns: @@ -225,6 +226,8 @@ class Fleet(object): optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) """ self.user_defined_optimizer = optimizer + if strategy == None: + strategy = DistributedStrategy() self.user_defined_strategy = strategy self.valid_strategy = None return self diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base.py b/python/paddle/fluid/tests/unittests/test_fleet_base.py index 8019841f72eccd35cb76ece582f89179a880e892..3a79b694cad5b0cb3fe0a08b6a18506510eead5b 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base.py @@ -145,9 +145,9 @@ class TestFleetBase(unittest.TestCase): import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) - strategy = fleet.DistributedStrategy() + optimizer = paddle.optimizer.SGD(learning_rate=0.001) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer = fleet.distributed_optimizer(optimizer) def test_minimize(self): import paddle