未验证 提交 d549a9b1 编写于 作者: Q Qinghe JING 提交者: GitHub

【paddle.fleet】Set default value to strategy in distributed_optimizer (#26246)

* set default value to strategy in distributed_optimizer test=develop
上级 672578a7
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import paddle import paddle
from .strategy_compiler import StrategyCompiler from .strategy_compiler import StrategyCompiler
from .distributed_strategy import DistributedStrategy
from .meta_optimizer_factory import MetaOptimizerFactory from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory from .runtime_factory import RuntimeFactory
from .util_factory import UtilFactory from .util_factory import UtilFactory
...@@ -209,7 +210,7 @@ class Fleet(object): ...@@ -209,7 +210,7 @@ class Fleet(object):
assert self._runtime_handle is not None assert self._runtime_handle is not None
self._runtime_handle._stop_worker() self._runtime_handle._stop_worker()
def distributed_optimizer(self, optimizer, strategy): def distributed_optimizer(self, optimizer, strategy=None):
""" """
distirbuted_optimizer distirbuted_optimizer
Returns: Returns:
...@@ -225,6 +226,8 @@ class Fleet(object): ...@@ -225,6 +226,8 @@ class Fleet(object):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
""" """
self.user_defined_optimizer = optimizer self.user_defined_optimizer = optimizer
if strategy == None:
strategy = DistributedStrategy()
self.user_defined_strategy = strategy self.user_defined_strategy = strategy
self.valid_strategy = None self.valid_strategy = None
return self return self
......
...@@ -145,9 +145,9 @@ class TestFleetBase(unittest.TestCase): ...@@ -145,9 +145,9 @@ class TestFleetBase(unittest.TestCase):
import paddle.fluid.incubate.fleet.base.role_maker as role_maker import paddle.fluid.incubate.fleet.base.role_maker as role_maker
role = role_maker.PaddleCloudRoleMaker(is_collective=True) role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role) fleet.init(role)
strategy = fleet.DistributedStrategy()
optimizer = paddle.optimizer.SGD(learning_rate=0.001) optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer = fleet.distributed_optimizer(optimizer)
def test_minimize(self): def test_minimize(self):
import paddle import paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册