From d549a9b1fe9528967ae019915601b02db7a9f7ad Mon Sep 17 00:00:00 2001 From: Qinghe JING Date: Mon, 17 Aug 2020 17:53:51 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90paddle.fleet=E3=80=91Set=20default=20v?= =?UTF-8?q?alue=20to=20strategy=20in=20distributed=5Foptimizer=20(#26246)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * set default value to strategy in distributed_optimizer test=develop --- python/paddle/distributed/fleet/base/fleet_base.py | 5 ++++- python/paddle/fluid/tests/unittests/test_fleet_base.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 845e331d557..695fd01909c 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -15,6 +15,7 @@ from __future__ import print_function import paddle from .strategy_compiler import StrategyCompiler +from .distributed_strategy import DistributedStrategy from .meta_optimizer_factory import MetaOptimizerFactory from .runtime_factory import RuntimeFactory from .util_factory import UtilFactory @@ -209,7 +210,7 @@ class Fleet(object): assert self._runtime_handle is not None self._runtime_handle._stop_worker() - def distributed_optimizer(self, optimizer, strategy): + def distributed_optimizer(self, optimizer, strategy=None): """ distirbuted_optimizer Returns: @@ -225,6 +226,8 @@ class Fleet(object): optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) """ self.user_defined_optimizer = optimizer + if strategy == None: + strategy = DistributedStrategy() self.user_defined_strategy = strategy self.valid_strategy = None return self diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base.py b/python/paddle/fluid/tests/unittests/test_fleet_base.py index 8019841f72e..3a79b694cad 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base.py @@ -145,9 +145,9 @@ class TestFleetBase(unittest.TestCase): import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) - strategy = fleet.DistributedStrategy() + optimizer = paddle.optimizer.SGD(learning_rate=0.001) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer = fleet.distributed_optimizer(optimizer) def test_minimize(self): import paddle -- GitLab