From 7e2665c58b95c7ec98527ab64911355a38e0587d Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 3 Feb 2020 11:00:21 +0800 Subject: [PATCH] fix bug with half (#22378) * fix bug with half communicator --- .../distribute_transpiler/__init__.py | 22 ++++++++---- .../distributed_strategy.py | 28 +++++++-------- .../unittests/test_distributed_strategy.py | 36 ++++++++++++++++++- 3 files changed, 64 insertions(+), 22 deletions(-) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index 73c127cbb3c..2a0fa372bc7 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -271,13 +271,21 @@ class DistributedTranspiler(Fleet): elif isinstance(config, DistributeTranspilerConfig): if config.sync_mode: self._transpile_config = SyncStrategy() - elif config.geo_sgd_mode: - self._transpile_config = GeoStrategy( - config.geo_sgd_need_push_nums) - elif config.runtime_split_send_recv and config.half_async: - self._transpile_config = HalfAsyncStrategy() else: - self._transpile_config = AsyncStrategy() + if config.runtime_split_send_recv: + if config.geo_sgd_mode: + self._transpile_config = GeoStrategy( + config.geo_sgd_need_push_nums) + elif config.half_async: + self._transpile_config = HalfAsyncStrategy() + else: + self._transpile_config = AsyncStrategy() + + else: + self._transpile_config = HalfAsyncStrategy() + # for half_async compatibility + config.half_async = True + config.runtime_split_send_recv = True self._transpile_config.set_program_config(config) else: raise TypeError( @@ -359,7 +367,7 @@ class TranspilerOptimizer(DistributedOptimizer): "In {} mode, strategy must be an instance of DistributeTranspilerConfig, SyncStrategy, HalfAsyncStrategy, AsyncStrategy, or GeoStrategy". format(fleet._mode)) else: - self._strategy = DistributedStrategy() + self._strategy = StrategyFactory.create_sync_strategy() def backward(self, loss, diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py index d8acb7430dd..05f920f426f 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py @@ -48,20 +48,20 @@ class TrainerRuntimeConfig(object): def get_communicator_flags(self): _communicator_flags = dict() - _communicator_flags[ - "communicator_max_merge_var_num"] = self.max_merge_var_num - _communicator_flags[ - "communicator_send_queue_size"] = self.send_queue_size - _communicator_flags[ - "communicator_independent_recv_thread"] = self.independent_recv_thread - _communicator_flags[ - "communicator_min_send_grad_num_before_recv"] = self.min_send_grad_num_before_recv - _communicator_flags[ - "communicator_thread_pool_size"] = self.thread_pool_size - _communicator_flags[ - "communicator_send_wait_times"] = self.send_wait_times - _communicator_flags[ - "communicator_is_sgd_optimizer"] = self.is_sgd_optimizer + _communicator_flags["communicator_max_merge_var_num"] = str( + self.max_merge_var_num) + _communicator_flags["communicator_send_queue_size"] = str( + self.send_queue_size) + _communicator_flags["communicator_independent_recv_thread"] = str( + self.independent_recv_thread) + _communicator_flags["communicator_min_send_grad_num_before_recv"] = str( + self.min_send_grad_num_before_recv) + _communicator_flags["communicator_thread_pool_size"] = str( + self.thread_pool_size) + _communicator_flags["communicator_send_wait_times"] = str( + self.send_wait_times) + _communicator_flags["communicator_is_sgd_optimizer"] = str( + self.is_sgd_optimizer) return _communicator_flags def __repr__(self): diff --git a/python/paddle/fluid/tests/unittests/test_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_distributed_strategy.py index 01b1b6e04cf..797387a7f5d 100644 --- a/python/paddle/fluid/tests/unittests/test_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_distributed_strategy.py @@ -16,6 +16,8 @@ import unittest import paddle.fluid as fluid from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import TrainerRuntimeConfig, StrategyFactory +from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet +import paddle.fluid.incubate.fleet.base.role_maker as role_maker import os @@ -105,7 +107,7 @@ class TestStrategyFactor(unittest.TestCase): self.assertIn('communicator_send_queue_size', trainer_communicator_flags) self.assertEqual( - trainer_communicator_flags['communicator_send_queue_size'], 100) + trainer_communicator_flags['communicator_send_queue_size'], '100') # test set_trainer_runtime_config exception trainer_runtime_config_dict['unknown'] = None @@ -166,5 +168,37 @@ class TestStrategyFactor(unittest.TestCase): server_runtime_config_illegal) +class TestCreateDefaultStrategy(unittest.TestCase): + def test_default_strategy(self): + role = role_maker.UserDefinedRoleMaker( + current_id=0, + role=role_maker.Role.WORKER, + worker_num=2, + server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"]) + fleet.init(role) + + optimizer = fluid.optimizer.SGD(0.0001) + optimizer = fleet.distributed_optimizer(optimizer) + + +class TestHalfAsyncStrategy(unittest.TestCase): + def test_half_async_strategy(self): + role = role_maker.UserDefinedRoleMaker( + current_id=0, + role=role_maker.Role.WORKER, + worker_num=2, + server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"]) + fleet.init(role) + + half_async_config = DistributeTranspilerConfig() + + half_async_config.sync_mode = False + half_async_config.geo_sgd_mode = False + half_async_config.runtime_split_send_recv = False + + optimizer = fluid.optimizer.SGD(0.0001) + optimizer = fleet.distributed_optimizer(optimizer, half_async_config) + + if __name__ == '__main__': unittest.main() -- GitLab