From c000f8a29637d75abede3af9d7afecdc5727c3cc Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 17 Feb 2020 09:43:33 +0800 Subject: [PATCH] add texttable for pretty flag output (#22584) (#22626) pretty print for communicator flag --- .../distributed_strategy.py | 137 ++++++++++++------ .../unittests/test_distributed_strategy.py | 14 +- 2 files changed, 98 insertions(+), 53 deletions(-) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py index 2eb69d76e4..b796e6ad1c 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py @@ -24,51 +24,35 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo class TrainerRuntimeConfig(object): def __init__(self): - self.max_merge_var_num = os.getenv( - "FLAGS_communicator_max_merge_var_num", "20") - self.send_queue_size = os.getenv("FLAGS_communicator_send_queue_size", - "20") - self.independent_recv_thread = os.getenv( - "FLAGS_communicator_independent_recv_thread", "1") - self.min_send_grad_num_before_recv = os.getenv( - "FLAGS_communicator_min_send_grad_num_before_recv", "20") - self.thread_pool_size = os.getenv("FLAGS_communicator_thread_pool_size", - "5") - self.send_wait_times = os.getenv("FLAGS_communicator_send_wait_times", - "5") - self.fake_rpc = os.getenv("FLAGS_communicator_fake_rpc", "0") - self.merge_sparse_grad = os.getenv( - "FLAGS_communicator_merge_sparse_grad", "1") - self.is_sgd_optimizer = os.getenv("FLAGS_communicator_is_sgd_optimizer", - "1") - + self.runtime_configs = {} # not used - self._rpc_deadline = os.getenv("FLAGS_rpc_deadline", "180000") - self._rpc_retry_times = os.getenv("FLAGS_rpc_retry_times", "3") + self.runtime_configs['rpc_deadline'] = os.getenv("FLAGS_rpc_deadline", + "180000") + self.runtime_configs['rpc_retry_times'] = os.getenv( + "FLAGS_rpc_retry_times", "3") def get_communicator_flags(self): - _communicator_flags = dict() - _communicator_flags["communicator_max_merge_var_num"] = str( - self.max_merge_var_num) - _communicator_flags["communicator_send_queue_size"] = str( - self.send_queue_size) - _communicator_flags["communicator_independent_recv_thread"] = str( - self.independent_recv_thread) - _communicator_flags["communicator_min_send_grad_num_before_recv"] = str( - self.min_send_grad_num_before_recv) - _communicator_flags["communicator_thread_pool_size"] = str( - self.thread_pool_size) - _communicator_flags["communicator_send_wait_times"] = str( - self.send_wait_times) - _communicator_flags["communicator_is_sgd_optimizer"] = str( - self.is_sgd_optimizer) - return _communicator_flags + return self.runtime_configs def __repr__(self): - _str = "please check that TrainerRuntimeConfig is as expected:\n" - _communicator_flags = self.get_communicator_flags() - for key in _communicator_flags: - _str += "{}: {}\n".format(key, _communicator_flags[key]) + raw0, raw1, length = 45, 5, 50 + h_format = "{:^45s}{:<5s}\n" + l_format = "{:<45s}{:<5s}\n" + + border = "".join(["="] * length) + line = "".join(["-"] * length) + + draws = "" + draws += border + "\n" + draws += h_format.format("TrainerRuntimeConfig Overview", "Value") + draws += line + "\n" + + for k, v in self.get_communicator_flags().items(): + draws += l_format.format(k, v) + + draws += border + + _str = "\n{}\n".format(draws) return _str @@ -77,9 +61,11 @@ class DistributedStrategy(object): self._program_config = DistributeTranspilerConfig() self._trainer_runtime_config = TrainerRuntimeConfig() self._server_runtime_config = ServerRuntimeConfig() + num_threads = int(os.getenv("CPU_NUM", "1")) + self._execute_strategy = fluid.ExecutionStrategy() self._build_strategy = fluid.BuildStrategy() - num_threads = int(os.getenv("CPU_NUM", "1")) + self._execute_strategy.num_threads = num_threads if num_threads > 1: self._build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce @@ -110,9 +96,9 @@ class DistributedStrategy(object): if isinstance(config, TrainerRuntimeConfig): self._trainer_runtime_config = config elif isinstance(config, dict): - for key in config: - if hasattr(self._trainer_runtime_config, key): - setattr(self._trainer_runtime_config, key, config[key]) + for key, Value in config.items(): + if key in self._trainer_runtime_config.runtime_configs: + self._trainer_runtime_config.runtime_configs[key] = Value else: raise ValueError( "TrainerRuntimeConfig doesn't have key: {}".format(key)) @@ -182,6 +168,21 @@ class SyncStrategy(DistributedStrategy): self._program_config.runtime_split_send_recv = False self._build_strategy.async_mode = False + num_threads = os.getenv("CPU_NUM", "1") + + self._trainer_runtime_config.runtime_configs[ + 'communicator_max_merge_var_num'] = os.getenv( + "FLAGS_communicator_max_merge_var_num", num_threads) + self._trainer_runtime_config.runtime_configs[ + 'communicator_send_wait_times'] = os.getenv( + "FLAGS_communicator_send_wait_times", "5") + self._trainer_runtime_config.runtime_configs[ + 'communicator_thread_pool_size'] = os.getenv( + "FLAGS_communicator_thread_pool_size", "10") + self._trainer_runtime_config.runtime_configs[ + 'communicator_send_queue_size'] = os.getenv( + "FLAGS_communicator_send_queue_size", num_threads) + class AsyncStrategy(DistributedStrategy): def __init__(self): @@ -190,6 +191,30 @@ class AsyncStrategy(DistributedStrategy): self._program_config.runtime_split_send_recv = True self._build_strategy.async_mode = True + num_threads = os.getenv("CPU_NUM", "1") + + self._trainer_runtime_config.runtime_configs[ + 'communicator_max_merge_var_num'] = os.getenv( + "FLAGS_communicator_max_merge_var_num", num_threads) + self._trainer_runtime_config.runtime_configs[ + 'communicator_independent_recv_thread'] = os.getenv( + "FLAGS_communicator_independent_recv_thread", "0") + self._trainer_runtime_config.runtime_configs[ + 'communicator_min_send_grad_num_before_recv'] = os.getenv( + "FLAGS_communicator_min_send_grad_num_before_recv", num_threads) + self._trainer_runtime_config.runtime_configs[ + 'communicator_thread_pool_size'] = os.getenv( + "FLAGS_communicator_thread_pool_size", "10") + self._trainer_runtime_config.runtime_configs[ + 'communicator_send_wait_times'] = os.getenv( + "FLAGS_communicator_send_wait_times", "5") + self._trainer_runtime_config.runtime_configs[ + 'communicator_is_sgd_optimizer'] = os.getenv( + "FLAGS_communicator_is_sgd_optimizer", "1") + self._trainer_runtime_config.runtime_configs[ + 'communicator_send_queue_size'] = os.getenv( + "FLAGS_communicator_send_queue_size", num_threads) + class HalfAsyncStrategy(DistributedStrategy): def __init__(self): @@ -200,15 +225,37 @@ class HalfAsyncStrategy(DistributedStrategy): self._build_strategy.async_mode = True self._execute_strategy.use_thread_barrier = True + num_threads = os.getenv("CPU_NUM", "1") + + self._trainer_runtime_config.runtime_configs[ + 'communicator_max_merge_var_num'] = os.getenv( + "FLAGS_communicator_max_merge_var_num", num_threads) + self._trainer_runtime_config.runtime_configs[ + 'communicator_send_wait_times'] = os.getenv( + "FLAGS_communicator_send_wait_times", "5") + self._trainer_runtime_config.runtime_configs[ + 'communicator_thread_pool_size'] = os.getenv( + "FLAGS_communicator_thread_pool_size", "10") + self._trainer_runtime_config.runtime_configs[ + 'communicator_send_queue_size'] = os.getenv( + "FLAGS_communicator_send_queue_size", num_threads) + class GeoStrategy(DistributedStrategy): def __init__(self, update_frequency=100): super(GeoStrategy, self).__init__() self._program_config.sync_mode = False self._program_config.runtime_split_send_recv = True - self._build_strategy.async_mode = True self._program_config.geo_sgd_mode = True self._program_config.geo_sgd_need_push_nums = update_frequency + self._build_strategy.async_mode = True + + self._trainer_runtime_config.runtime_configs[ + 'communicator_thread_pool_size'] = os.getenv( + "FLAGS_communicator_thread_pool_size", "10") + self._trainer_runtime_config.runtime_configs[ + 'communicator_send_wait_times'] = os.getenv( + "FLAGS_communicator_send_wait_times", "5") class StrategyFactory(object): diff --git a/python/paddle/fluid/tests/unittests/test_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_distributed_strategy.py index 797387a7f5..0267413663 100644 --- a/python/paddle/fluid/tests/unittests/test_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_distributed_strategy.py @@ -84,22 +84,20 @@ class TestStrategyFactor(unittest.TestCase): build_strategy_illegal) def test_async_strategy(self): + os.environ["CPU_NUM"] = '100' + strategy = StrategyFactory.create_async_strategy() self.assertEqual(strategy._program_config.sync_mode, False) self.assertEqual(strategy._program_config.runtime_split_send_recv, True) self.assertEqual(strategy._build_strategy.async_mode, True) - # test set_trainer_runtime_config using TrainerRuntimeConfig - trainer_runtime_config_class = TrainerRuntimeConfig() - trainer_runtime_config_class.send_queue_size = 50 - print(trainer_runtime_config_class) - strategy.set_trainer_runtime_config(trainer_runtime_config_class) trainer_runtime_config = strategy.get_trainer_runtime_config() - self.assertEqual(trainer_runtime_config.send_queue_size, 50) + self.assertEqual(trainer_runtime_config.runtime_configs[ + 'communicator_send_queue_size'], '100') # test set_trainer_runtime_config using dict trainer_runtime_config_dict = dict() - trainer_runtime_config_dict['send_queue_size'] = 100 + trainer_runtime_config_dict['communicator_send_queue_size'] = '20' strategy.set_trainer_runtime_config(trainer_runtime_config_dict) trainer_runtime_config = strategy.get_trainer_runtime_config() trainer_communicator_flags = trainer_runtime_config.get_communicator_flags( @@ -107,7 +105,7 @@ class TestStrategyFactor(unittest.TestCase): self.assertIn('communicator_send_queue_size', trainer_communicator_flags) self.assertEqual( - trainer_communicator_flags['communicator_send_queue_size'], '100') + trainer_communicator_flags['communicator_send_queue_size'], '20') # test set_trainer_runtime_config exception trainer_runtime_config_dict['unknown'] = None -- GitLab