未验证 提交 0f9d4081 编写于 作者: 1 123malin 提交者: GitHub

test=develop, optimize distributedstrategy (#22677)

* test=develop, optimize distributedstrategy
上级 5ee29c67
...@@ -19,12 +19,32 @@ __all__ = [ ...@@ -19,12 +19,32 @@ __all__ = [
import os import os
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig, DistributedMode
class TrainerRuntimeConfig(object): class TrainerRuntimeConfig(object):
def __init__(self): def __init__(self):
self.mode = None
num_threads = os.getenv("CPU_NUM", "1")
self.runtime_configs = {} self.runtime_configs = {}
self.runtime_configs['communicator_max_merge_var_num'] = os.getenv(
"FLAGS_communicator_max_merge_var_num", num_threads)
self.runtime_configs['communicator_send_queue_size'] = os.getenv(
"FLAGS_communicator_send_queue_size", num_threads)
self.runtime_configs[
'communicator_independent_recv_thread'] = os.getenv(
"FLAGS_communicator_independent_recv_thread", "1")
self.runtime_configs[
'communicator_min_send_grad_num_before_recv'] = os.getenv(
"FLAGS_communicator_min_send_grad_num_before_recv", num_threads)
self.runtime_configs['communicator_thread_pool_size'] = os.getenv(
"FLAGS_communicator_thread_pool_size", "5")
self.runtime_configs['communicator_send_wait_times'] = os.getenv(
"FLAGS_communicator_send_wait_times", "5")
self.runtime_configs['communicator_is_sgd_optimizer'] = os.getenv(
"FLAGS_communicator_is_sgd_optimizer", "1")
# not used # not used
self.runtime_configs['rpc_deadline'] = os.getenv("FLAGS_rpc_deadline", self.runtime_configs['rpc_deadline'] = os.getenv("FLAGS_rpc_deadline",
"180000") "180000")
...@@ -32,9 +52,54 @@ class TrainerRuntimeConfig(object): ...@@ -32,9 +52,54 @@ class TrainerRuntimeConfig(object):
"FLAGS_rpc_retry_times", "3") "FLAGS_rpc_retry_times", "3")
def get_communicator_flags(self): def get_communicator_flags(self):
return self.runtime_configs need_keys = []
num_threads = os.getenv("CPU_NUM", "1")
def __repr__(self): mode_str = ""
if self.mode is None or self.mode == DistributedMode.ASYNC:
need_keys = self.runtime_configs.keys()
mode_str = "async"
elif self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC:
mode_str = "sync or half_async"
need_keys = [
'communicator_max_merge_var_num',
'communicator_send_wait_times', 'communicator_thread_pool_size',
'communicator_send_queue_size'
]
elif self.mode == DistributedMode.GEO:
mode_str = "GEO"
need_keys = [
'communicator_thread_pool_size', 'communicator_send_wait_times'
]
else:
raise ValueError("Unsupported Mode")
if self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC:
max_merge_var_num = self.runtime_configs[
'communicator_max_merge_var_num']
send_queue_size = self.runtime_configs[
'communicator_send_queue_size']
if max_merge_var_num != num_threads:
print('WARNING: In {} mode, communicator_max_merge_var_num '
'must be equal to CPU_NUM. But received, '
'communicator_max_merge_var_num = {}, CPU_NUM = '
'{}. communicator_max_merge_var_num will be fored to {}.'
.format(mode_str, max_merge_var_num, num_threads,
num_threads))
self.runtime_configs[
'communicator_max_merge_var_num'] = num_threads
if send_queue_size != num_threads:
print('WARNING: In {} mode, communicator_send_queue_size '
'must be equal to CPU_NUM. But received, '
'communicator_send_queue_size = {}, CPU_NUM = '
'{}. communicator_send_queue_size will be fored to {}.'
.format(mode_str, send_queue_size, num_threads,
num_threads))
self.runtime_configs[
'communicator_send_queue_size'] = num_threads
return dict((key, str(self.runtime_configs[key])) for key in need_keys)
def display(self, configs):
raw0, raw1, length = 45, 5, 50 raw0, raw1, length = 45, 5, 50
h_format = "{:^45s}{:<5s}\n" h_format = "{:^45s}{:<5s}\n"
l_format = "{:<45s}{:<5s}\n" l_format = "{:<45s}{:<5s}\n"
...@@ -47,7 +112,7 @@ class TrainerRuntimeConfig(object): ...@@ -47,7 +112,7 @@ class TrainerRuntimeConfig(object):
draws += h_format.format("TrainerRuntimeConfig Overview", "Value") draws += h_format.format("TrainerRuntimeConfig Overview", "Value")
draws += line + "\n" draws += line + "\n"
for k, v in self.get_communicator_flags().items(): for k, v in configs.items():
draws += l_format.format(k, v) draws += l_format.format(k, v)
draws += border draws += border
...@@ -55,6 +120,9 @@ class TrainerRuntimeConfig(object): ...@@ -55,6 +120,9 @@ class TrainerRuntimeConfig(object):
_str = "\n{}\n".format(draws) _str = "\n{}\n".format(draws)
return _str return _str
def __repr__(self):
return self.display(self.get_communicator_flags())
class DistributedStrategy(object): class DistributedStrategy(object):
def __init__(self): def __init__(self):
...@@ -105,6 +173,12 @@ class DistributedStrategy(object): ...@@ -105,6 +173,12 @@ class DistributedStrategy(object):
raise TypeError( raise TypeError(
"program_config only accept input type: dict or DistributeTranspilerConfig" "program_config only accept input type: dict or DistributeTranspilerConfig"
) )
self.check_program_config()
def check_program_config(self):
raise NotImplementedError(
"check_program_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
def get_trainer_runtime_config(self): def get_trainer_runtime_config(self):
return self._trainer_runtime_config return self._trainer_runtime_config
...@@ -123,6 +197,12 @@ class DistributedStrategy(object): ...@@ -123,6 +197,12 @@ class DistributedStrategy(object):
raise TypeError( raise TypeError(
"trainer_runtime_config only accept input type: dict or TrainerRuntimeConfig" "trainer_runtime_config only accept input type: dict or TrainerRuntimeConfig"
) )
self.check_trainer_runtime_config()
def check_trainer_runtime_config(self):
raise NotImplementedError(
"check_trainer_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
def get_server_runtime_config(self): def get_server_runtime_config(self):
return self._server_runtime_config return self._server_runtime_config
...@@ -141,6 +221,12 @@ class DistributedStrategy(object): ...@@ -141,6 +221,12 @@ class DistributedStrategy(object):
raise TypeError( raise TypeError(
"server_runtime_config only accept input type: dict or ServerRuntimeConfig" "server_runtime_config only accept input type: dict or ServerRuntimeConfig"
) )
self.check_server_runtime_config()
def check_server_runtime_config(self):
raise NotImplementedError(
"check_server_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
def get_execute_strategy(self): def get_execute_strategy(self):
return self._execute_strategy return self._execute_strategy
...@@ -159,6 +245,12 @@ class DistributedStrategy(object): ...@@ -159,6 +245,12 @@ class DistributedStrategy(object):
raise TypeError( raise TypeError(
"execute_strategy only accept input type: dict or ExecutionStrategy" "execute_strategy only accept input type: dict or ExecutionStrategy"
) )
self.check_execute_strategy()
def check_execute_strategy(self):
raise NotImplementedError(
"check_execute_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
def get_build_strategy(self): def get_build_strategy(self):
return self._build_strategy return self._build_strategy
...@@ -176,106 +268,121 @@ class DistributedStrategy(object): ...@@ -176,106 +268,121 @@ class DistributedStrategy(object):
else: else:
raise TypeError( raise TypeError(
"build_strategy only accept input type: dict or BuildStrategy") "build_strategy only accept input type: dict or BuildStrategy")
self.check_build_strategy()
def check_build_strategy(self):
raise NotImplementedError(
"check_build_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
class SyncStrategy(DistributedStrategy): class SyncStrategy(DistributedStrategy):
def __init__(self): def __init__(self):
super(SyncStrategy, self).__init__() super(SyncStrategy, self).__init__()
self.check_program_config()
self.check_trainer_runtime_config()
self.check_server_runtime_config()
self.check_build_strategy()
self.check_execute_strategy()
def check_trainer_runtime_config(self):
self._trainer_runtime_config.mode = DistributedMode.SYNC
def check_program_config(self):
self._program_config.sync_mode = False self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = True self._program_config.runtime_split_send_recv = True
self._build_strategy.async_mode = True
self._program_config.half_async = True self._program_config.half_async = True
self._program_config.completely_not_async = True self._program_config.completely_not_async = True
self._execute_strategy.use_thread_barrier = True
num_threads = os.getenv("CPU_NUM", "1") def check_server_runtime_config(self):
pass
self._trainer_runtime_config.runtime_configs[ def check_execute_strategy(self):
'communicator_max_merge_var_num'] = os.getenv( self._execute_strategy.use_thread_barrier = True
"FLAGS_communicator_max_merge_var_num", num_threads)
self._trainer_runtime_config.runtime_configs[ def check_build_strategy(self):
'communicator_send_wait_times'] = os.getenv( self._build_strategy.async_mode = True
"FLAGS_communicator_send_wait_times", "5")
self._trainer_runtime_config.runtime_configs[
'communicator_thread_pool_size'] = os.getenv(
"FLAGS_communicator_thread_pool_size", "10")
self._trainer_runtime_config.runtime_configs[
'communicator_send_queue_size'] = os.getenv(
"FLAGS_communicator_send_queue_size", num_threads)
class AsyncStrategy(DistributedStrategy): class AsyncStrategy(DistributedStrategy):
def __init__(self): def __init__(self):
super(AsyncStrategy, self).__init__() super(AsyncStrategy, self).__init__()
self.check_program_config()
self.check_trainer_runtime_config()
self.check_server_runtime_config()
self.check_build_strategy()
self.check_execute_strategy()
def check_trainer_runtime_config(self):
self._trainer_runtime_config.mode = DistributedMode.ASYNC
def check_program_config(self):
self._program_config.sync_mode = False self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = True self._program_config.runtime_split_send_recv = True
self._build_strategy.async_mode = True
num_threads = os.getenv("CPU_NUM", "1") def check_server_runtime_config(self):
pass
self._trainer_runtime_config.runtime_configs[ def check_execute_strategy(self):
'communicator_max_merge_var_num'] = os.getenv( pass
"FLAGS_communicator_max_merge_var_num", num_threads)
self._trainer_runtime_config.runtime_configs[ def check_build_strategy(self):
'communicator_independent_recv_thread'] = os.getenv( self._build_strategy.async_mode = True
"FLAGS_communicator_independent_recv_thread", "0")
self._trainer_runtime_config.runtime_configs[
'communicator_min_send_grad_num_before_recv'] = os.getenv(
"FLAGS_communicator_min_send_grad_num_before_recv", num_threads)
self._trainer_runtime_config.runtime_configs[
'communicator_thread_pool_size'] = os.getenv(
"FLAGS_communicator_thread_pool_size", "10")
self._trainer_runtime_config.runtime_configs[
'communicator_send_wait_times'] = os.getenv(
"FLAGS_communicator_send_wait_times", "5")
self._trainer_runtime_config.runtime_configs[
'communicator_is_sgd_optimizer'] = os.getenv(
"FLAGS_communicator_is_sgd_optimizer", "1")
self._trainer_runtime_config.runtime_configs[
'communicator_send_queue_size'] = os.getenv(
"FLAGS_communicator_send_queue_size", num_threads)
class HalfAsyncStrategy(DistributedStrategy): class HalfAsyncStrategy(DistributedStrategy):
def __init__(self): def __init__(self):
super(HalfAsyncStrategy, self).__init__() super(HalfAsyncStrategy, self).__init__()
self.check_program_config()
self.check_trainer_runtime_config()
self.check_server_runtime_config()
self.check_build_strategy()
self.check_execute_strategy()
def check_trainer_runtime_config(self):
self._trainer_runtime_config.mode = DistributedMode.HALF_ASYNC
def check_program_config(self):
self._program_config.sync_mode = False self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = True self._program_config.runtime_split_send_recv = True
self._program_config.half_async = True self._program_config.half_async = True
self._build_strategy.async_mode = True
self._execute_strategy.use_thread_barrier = True
num_threads = os.getenv("CPU_NUM", "1") def check_server_runtime_config(self):
pass
def check_execute_strategy(self):
self._execute_strategy.use_thread_barrier = True
self._trainer_runtime_config.runtime_configs[ def check_build_strategy(self):
'communicator_max_merge_var_num'] = os.getenv( self._build_strategy.async_mode = True
"FLAGS_communicator_max_merge_var_num", num_threads)
self._trainer_runtime_config.runtime_configs[
'communicator_send_wait_times'] = os.getenv(
"FLAGS_communicator_send_wait_times", "5")
self._trainer_runtime_config.runtime_configs[
'communicator_thread_pool_size'] = os.getenv(
"FLAGS_communicator_thread_pool_size", "10")
self._trainer_runtime_config.runtime_configs[
'communicator_send_queue_size'] = os.getenv(
"FLAGS_communicator_send_queue_size", num_threads)
class GeoStrategy(DistributedStrategy): class GeoStrategy(DistributedStrategy):
def __init__(self, update_frequency=100): def __init__(self, update_frequency=100):
super(GeoStrategy, self).__init__() super(GeoStrategy, self).__init__()
self._program_config.geo_sgd_need_push_nums = update_frequency
self.check_program_config()
self.check_trainer_runtime_config()
self.check_server_runtime_config()
self.check_build_strategy()
self.check_execute_strategy()
def check_program_config(self):
self._program_config.sync_mode = False self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = True self._program_config.runtime_split_send_recv = True
self._program_config.geo_sgd_mode = True self._program_config.geo_sgd_mode = True
self._program_config.geo_sgd_need_push_nums = update_frequency
self._build_strategy.async_mode = True
self._trainer_runtime_config.runtime_configs[ def check_trainer_runtime_config(self):
'communicator_thread_pool_size'] = os.getenv( self._trainer_runtime_config.mode = DistributedMode.GEO
"FLAGS_communicator_thread_pool_size", "10")
self._trainer_runtime_config.runtime_configs[ def check_server_runtime_config(self):
'communicator_send_wait_times'] = os.getenv( pass
"FLAGS_communicator_send_wait_times", "5")
def check_execute_strategy(self):
pass
def check_build_strategy(self):
self._build_strategy.async_mode = True
class StrategyFactory(object): class StrategyFactory(object):
......
...@@ -52,6 +52,15 @@ class TestStrategyFactor(unittest.TestCase): ...@@ -52,6 +52,15 @@ class TestStrategyFactor(unittest.TestCase):
self.assertRaises(Exception, strategy.set_program_config, self.assertRaises(Exception, strategy.set_program_config,
program_config_illegal) program_config_illegal)
trainer_runtime_config = strategy.get_trainer_runtime_config()
trainer_runtime_config.runtime_configs[
'communicator_send_queue_size'] = '50'
runtime_configs = trainer_runtime_config.get_communicator_flags()
self.assertIn('communicator_send_queue_size', runtime_configs)
self.assertNotIn('communicator_independent_recv_thread',
runtime_configs)
self.assertEqual(runtime_configs['communicator_send_queue_size'], '2')
def test_geo_strategy(self): def test_geo_strategy(self):
strategy = StrategyFactory.create_geo_strategy(5) strategy = StrategyFactory.create_geo_strategy(5)
self.assertEqual(strategy._program_config.sync_mode, False) self.assertEqual(strategy._program_config.sync_mode, False)
...@@ -82,6 +91,14 @@ class TestStrategyFactor(unittest.TestCase): ...@@ -82,6 +91,14 @@ class TestStrategyFactor(unittest.TestCase):
self.assertRaises(Exception, strategy.set_build_strategy, self.assertRaises(Exception, strategy.set_build_strategy,
build_strategy_illegal) build_strategy_illegal)
os.environ["CPU_NUM"] = '100'
trainer_runtime_config = strategy.get_trainer_runtime_config()
runtime_configs = trainer_runtime_config.get_communicator_flags()
self.assertIn('communicator_thread_pool_size', runtime_configs)
self.assertIn('communicator_send_wait_times', runtime_configs)
self.assertNotIn('communicator_independent_recv_thread',
runtime_configs)
def test_async_strategy(self): def test_async_strategy(self):
os.environ["CPU_NUM"] = '100' os.environ["CPU_NUM"] = '100'
...@@ -164,6 +181,16 @@ class TestStrategyFactor(unittest.TestCase): ...@@ -164,6 +181,16 @@ class TestStrategyFactor(unittest.TestCase):
self.assertRaises(Exception, strategy.set_server_runtime_config, self.assertRaises(Exception, strategy.set_server_runtime_config,
server_runtime_config_illegal) server_runtime_config_illegal)
os.environ["CPU_NUM"] = '100'
trainer_runtime_config = strategy.get_trainer_runtime_config()
trainer_runtime_config.runtime_configs[
'communicator_send_queue_size'] = '50'
runtime_configs = trainer_runtime_config.get_communicator_flags()
self.assertIn('communicator_send_queue_size', runtime_configs)
self.assertNotIn('communicator_independent_recv_thread',
runtime_configs)
self.assertEqual(runtime_configs['communicator_send_queue_size'], '100')
class TestCreateDefaultStrategy(unittest.TestCase): class TestCreateDefaultStrategy(unittest.TestCase):
def test_default_strategy(self): def test_default_strategy(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册