diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py index 0d868f6109e48d108b3e718720e5535348f95428..92d07c97da46568f31d86a99f20f0b8fe071b031 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py @@ -19,12 +19,32 @@ __all__ = [ import os import paddle.fluid as fluid -from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig +from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig, DistributedMode class TrainerRuntimeConfig(object): def __init__(self): + self.mode = None + num_threads = os.getenv("CPU_NUM", "1") + self.runtime_configs = {} + self.runtime_configs['communicator_max_merge_var_num'] = os.getenv( + "FLAGS_communicator_max_merge_var_num", num_threads) + self.runtime_configs['communicator_send_queue_size'] = os.getenv( + "FLAGS_communicator_send_queue_size", num_threads) + self.runtime_configs[ + 'communicator_independent_recv_thread'] = os.getenv( + "FLAGS_communicator_independent_recv_thread", "1") + self.runtime_configs[ + 'communicator_min_send_grad_num_before_recv'] = os.getenv( + "FLAGS_communicator_min_send_grad_num_before_recv", num_threads) + self.runtime_configs['communicator_thread_pool_size'] = os.getenv( + "FLAGS_communicator_thread_pool_size", "5") + self.runtime_configs['communicator_send_wait_times'] = os.getenv( + "FLAGS_communicator_send_wait_times", "5") + self.runtime_configs['communicator_is_sgd_optimizer'] = os.getenv( + "FLAGS_communicator_is_sgd_optimizer", "1") + # not used self.runtime_configs['rpc_deadline'] = os.getenv("FLAGS_rpc_deadline", "180000") @@ -32,9 +52,54 @@ class TrainerRuntimeConfig(object): "FLAGS_rpc_retry_times", "3") def get_communicator_flags(self): - return self.runtime_configs - - def __repr__(self): + need_keys = [] + num_threads = os.getenv("CPU_NUM", "1") + mode_str = "" + if self.mode is None or self.mode == DistributedMode.ASYNC: + need_keys = self.runtime_configs.keys() + mode_str = "async" + elif self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC: + mode_str = "sync or half_async" + need_keys = [ + 'communicator_max_merge_var_num', + 'communicator_send_wait_times', 'communicator_thread_pool_size', + 'communicator_send_queue_size' + ] + elif self.mode == DistributedMode.GEO: + mode_str = "GEO" + need_keys = [ + 'communicator_thread_pool_size', 'communicator_send_wait_times' + ] + else: + raise ValueError("Unsupported Mode") + + if self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC: + max_merge_var_num = self.runtime_configs[ + 'communicator_max_merge_var_num'] + send_queue_size = self.runtime_configs[ + 'communicator_send_queue_size'] + if max_merge_var_num != num_threads: + print('WARNING: In {} mode, communicator_max_merge_var_num ' + 'must be equal to CPU_NUM. But received, ' + 'communicator_max_merge_var_num = {}, CPU_NUM = ' + '{}. communicator_max_merge_var_num will be fored to {}.' + .format(mode_str, max_merge_var_num, num_threads, + num_threads)) + self.runtime_configs[ + 'communicator_max_merge_var_num'] = num_threads + if send_queue_size != num_threads: + print('WARNING: In {} mode, communicator_send_queue_size ' + 'must be equal to CPU_NUM. But received, ' + 'communicator_send_queue_size = {}, CPU_NUM = ' + '{}. communicator_send_queue_size will be fored to {}.' + .format(mode_str, send_queue_size, num_threads, + num_threads)) + self.runtime_configs[ + 'communicator_send_queue_size'] = num_threads + + return dict((key, str(self.runtime_configs[key])) for key in need_keys) + + def display(self, configs): raw0, raw1, length = 45, 5, 50 h_format = "{:^45s}{:<5s}\n" l_format = "{:<45s}{:<5s}\n" @@ -47,7 +112,7 @@ class TrainerRuntimeConfig(object): draws += h_format.format("TrainerRuntimeConfig Overview", "Value") draws += line + "\n" - for k, v in self.get_communicator_flags().items(): + for k, v in configs.items(): draws += l_format.format(k, v) draws += border @@ -55,6 +120,9 @@ class TrainerRuntimeConfig(object): _str = "\n{}\n".format(draws) return _str + def __repr__(self): + return self.display(self.get_communicator_flags()) + class DistributedStrategy(object): def __init__(self): @@ -105,6 +173,12 @@ class DistributedStrategy(object): raise TypeError( "program_config only accept input type: dict or DistributeTranspilerConfig" ) + self.check_program_config() + + def check_program_config(self): + raise NotImplementedError( + "check_program_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy." + ) def get_trainer_runtime_config(self): return self._trainer_runtime_config @@ -123,6 +197,12 @@ class DistributedStrategy(object): raise TypeError( "trainer_runtime_config only accept input type: dict or TrainerRuntimeConfig" ) + self.check_trainer_runtime_config() + + def check_trainer_runtime_config(self): + raise NotImplementedError( + "check_trainer_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy." + ) def get_server_runtime_config(self): return self._server_runtime_config @@ -141,6 +221,12 @@ class DistributedStrategy(object): raise TypeError( "server_runtime_config only accept input type: dict or ServerRuntimeConfig" ) + self.check_server_runtime_config() + + def check_server_runtime_config(self): + raise NotImplementedError( + "check_server_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy." + ) def get_execute_strategy(self): return self._execute_strategy @@ -159,6 +245,12 @@ class DistributedStrategy(object): raise TypeError( "execute_strategy only accept input type: dict or ExecutionStrategy" ) + self.check_execute_strategy() + + def check_execute_strategy(self): + raise NotImplementedError( + "check_execute_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy." + ) def get_build_strategy(self): return self._build_strategy @@ -176,106 +268,121 @@ class DistributedStrategy(object): else: raise TypeError( "build_strategy only accept input type: dict or BuildStrategy") + self.check_build_strategy() + + def check_build_strategy(self): + raise NotImplementedError( + "check_build_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy." + ) class SyncStrategy(DistributedStrategy): def __init__(self): super(SyncStrategy, self).__init__() + self.check_program_config() + self.check_trainer_runtime_config() + self.check_server_runtime_config() + self.check_build_strategy() + self.check_execute_strategy() + + def check_trainer_runtime_config(self): + self._trainer_runtime_config.mode = DistributedMode.SYNC + + def check_program_config(self): self._program_config.sync_mode = False self._program_config.runtime_split_send_recv = True - self._build_strategy.async_mode = True self._program_config.half_async = True self._program_config.completely_not_async = True - self._execute_strategy.use_thread_barrier = True - num_threads = os.getenv("CPU_NUM", "1") + def check_server_runtime_config(self): + pass - self._trainer_runtime_config.runtime_configs[ - 'communicator_max_merge_var_num'] = os.getenv( - "FLAGS_communicator_max_merge_var_num", num_threads) - self._trainer_runtime_config.runtime_configs[ - 'communicator_send_wait_times'] = os.getenv( - "FLAGS_communicator_send_wait_times", "5") - self._trainer_runtime_config.runtime_configs[ - 'communicator_thread_pool_size'] = os.getenv( - "FLAGS_communicator_thread_pool_size", "10") - self._trainer_runtime_config.runtime_configs[ - 'communicator_send_queue_size'] = os.getenv( - "FLAGS_communicator_send_queue_size", num_threads) + def check_execute_strategy(self): + self._execute_strategy.use_thread_barrier = True + + def check_build_strategy(self): + self._build_strategy.async_mode = True class AsyncStrategy(DistributedStrategy): def __init__(self): super(AsyncStrategy, self).__init__() + self.check_program_config() + self.check_trainer_runtime_config() + self.check_server_runtime_config() + self.check_build_strategy() + self.check_execute_strategy() + + def check_trainer_runtime_config(self): + self._trainer_runtime_config.mode = DistributedMode.ASYNC + + def check_program_config(self): self._program_config.sync_mode = False self._program_config.runtime_split_send_recv = True - self._build_strategy.async_mode = True - num_threads = os.getenv("CPU_NUM", "1") + def check_server_runtime_config(self): + pass - self._trainer_runtime_config.runtime_configs[ - 'communicator_max_merge_var_num'] = os.getenv( - "FLAGS_communicator_max_merge_var_num", num_threads) - self._trainer_runtime_config.runtime_configs[ - 'communicator_independent_recv_thread'] = os.getenv( - "FLAGS_communicator_independent_recv_thread", "0") - self._trainer_runtime_config.runtime_configs[ - 'communicator_min_send_grad_num_before_recv'] = os.getenv( - "FLAGS_communicator_min_send_grad_num_before_recv", num_threads) - self._trainer_runtime_config.runtime_configs[ - 'communicator_thread_pool_size'] = os.getenv( - "FLAGS_communicator_thread_pool_size", "10") - self._trainer_runtime_config.runtime_configs[ - 'communicator_send_wait_times'] = os.getenv( - "FLAGS_communicator_send_wait_times", "5") - self._trainer_runtime_config.runtime_configs[ - 'communicator_is_sgd_optimizer'] = os.getenv( - "FLAGS_communicator_is_sgd_optimizer", "1") - self._trainer_runtime_config.runtime_configs[ - 'communicator_send_queue_size'] = os.getenv( - "FLAGS_communicator_send_queue_size", num_threads) + def check_execute_strategy(self): + pass + + def check_build_strategy(self): + self._build_strategy.async_mode = True class HalfAsyncStrategy(DistributedStrategy): def __init__(self): super(HalfAsyncStrategy, self).__init__() + self.check_program_config() + self.check_trainer_runtime_config() + self.check_server_runtime_config() + self.check_build_strategy() + self.check_execute_strategy() + + def check_trainer_runtime_config(self): + self._trainer_runtime_config.mode = DistributedMode.HALF_ASYNC + + def check_program_config(self): self._program_config.sync_mode = False self._program_config.runtime_split_send_recv = True self._program_config.half_async = True - self._build_strategy.async_mode = True - self._execute_strategy.use_thread_barrier = True - num_threads = os.getenv("CPU_NUM", "1") + def check_server_runtime_config(self): + pass + + def check_execute_strategy(self): + self._execute_strategy.use_thread_barrier = True - self._trainer_runtime_config.runtime_configs[ - 'communicator_max_merge_var_num'] = os.getenv( - "FLAGS_communicator_max_merge_var_num", num_threads) - self._trainer_runtime_config.runtime_configs[ - 'communicator_send_wait_times'] = os.getenv( - "FLAGS_communicator_send_wait_times", "5") - self._trainer_runtime_config.runtime_configs[ - 'communicator_thread_pool_size'] = os.getenv( - "FLAGS_communicator_thread_pool_size", "10") - self._trainer_runtime_config.runtime_configs[ - 'communicator_send_queue_size'] = os.getenv( - "FLAGS_communicator_send_queue_size", num_threads) + def check_build_strategy(self): + self._build_strategy.async_mode = True class GeoStrategy(DistributedStrategy): def __init__(self, update_frequency=100): super(GeoStrategy, self).__init__() + self._program_config.geo_sgd_need_push_nums = update_frequency + self.check_program_config() + self.check_trainer_runtime_config() + self.check_server_runtime_config() + self.check_build_strategy() + self.check_execute_strategy() + + def check_program_config(self): self._program_config.sync_mode = False self._program_config.runtime_split_send_recv = True self._program_config.geo_sgd_mode = True - self._program_config.geo_sgd_need_push_nums = update_frequency - self._build_strategy.async_mode = True - self._trainer_runtime_config.runtime_configs[ - 'communicator_thread_pool_size'] = os.getenv( - "FLAGS_communicator_thread_pool_size", "10") - self._trainer_runtime_config.runtime_configs[ - 'communicator_send_wait_times'] = os.getenv( - "FLAGS_communicator_send_wait_times", "5") + def check_trainer_runtime_config(self): + self._trainer_runtime_config.mode = DistributedMode.GEO + + def check_server_runtime_config(self): + pass + + def check_execute_strategy(self): + pass + + def check_build_strategy(self): + self._build_strategy.async_mode = True class StrategyFactory(object): diff --git a/python/paddle/fluid/tests/unittests/test_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_distributed_strategy.py index 25940022b8b35d875b24fe4486e46e629433c46e..8dbe2f398f210b43454ae6a984650bd9f7c5dc43 100644 --- a/python/paddle/fluid/tests/unittests/test_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_distributed_strategy.py @@ -52,6 +52,15 @@ class TestStrategyFactor(unittest.TestCase): self.assertRaises(Exception, strategy.set_program_config, program_config_illegal) + trainer_runtime_config = strategy.get_trainer_runtime_config() + trainer_runtime_config.runtime_configs[ + 'communicator_send_queue_size'] = '50' + runtime_configs = trainer_runtime_config.get_communicator_flags() + self.assertIn('communicator_send_queue_size', runtime_configs) + self.assertNotIn('communicator_independent_recv_thread', + runtime_configs) + self.assertEqual(runtime_configs['communicator_send_queue_size'], '2') + def test_geo_strategy(self): strategy = StrategyFactory.create_geo_strategy(5) self.assertEqual(strategy._program_config.sync_mode, False) @@ -82,6 +91,14 @@ class TestStrategyFactor(unittest.TestCase): self.assertRaises(Exception, strategy.set_build_strategy, build_strategy_illegal) + os.environ["CPU_NUM"] = '100' + trainer_runtime_config = strategy.get_trainer_runtime_config() + runtime_configs = trainer_runtime_config.get_communicator_flags() + self.assertIn('communicator_thread_pool_size', runtime_configs) + self.assertIn('communicator_send_wait_times', runtime_configs) + self.assertNotIn('communicator_independent_recv_thread', + runtime_configs) + def test_async_strategy(self): os.environ["CPU_NUM"] = '100' @@ -164,6 +181,16 @@ class TestStrategyFactor(unittest.TestCase): self.assertRaises(Exception, strategy.set_server_runtime_config, server_runtime_config_illegal) + os.environ["CPU_NUM"] = '100' + trainer_runtime_config = strategy.get_trainer_runtime_config() + trainer_runtime_config.runtime_configs[ + 'communicator_send_queue_size'] = '50' + runtime_configs = trainer_runtime_config.get_communicator_flags() + self.assertIn('communicator_send_queue_size', runtime_configs) + self.assertNotIn('communicator_independent_recv_thread', + runtime_configs) + self.assertEqual(runtime_configs['communicator_send_queue_size'], '100') + class TestCreateDefaultStrategy(unittest.TestCase): def test_default_strategy(self):