未验证 提交 d549a9b1 编写于 作者: Q Qinghe JING 提交者: GitHub

【paddle.fleet】Set default value to strategy in distributed_optimizer (#26246)

* set default value to strategy in distributed_optimizer test=develop
上级 672578a7
......@@ -15,6 +15,7 @@
from __future__ import print_function
import paddle
from .strategy_compiler import StrategyCompiler
from .distributed_strategy import DistributedStrategy
from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory
from .util_factory import UtilFactory
......@@ -209,7 +210,7 @@ class Fleet(object):
assert self._runtime_handle is not None
self._runtime_handle._stop_worker()
def distributed_optimizer(self, optimizer, strategy):
def distributed_optimizer(self, optimizer, strategy=None):
"""
distirbuted_optimizer
Returns:
......@@ -225,6 +226,8 @@ class Fleet(object):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
"""
self.user_defined_optimizer = optimizer
if strategy == None:
strategy = DistributedStrategy()
self.user_defined_strategy = strategy
self.valid_strategy = None
return self
......
......@@ -145,9 +145,9 @@ class TestFleetBase(unittest.TestCase):
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
strategy = fleet.DistributedStrategy()
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer = fleet.distributed_optimizer(optimizer)
def test_minimize(self):
import paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册