提交 9d313db0 编写于 作者: T tangwei12 提交者: tangwei

add texttable for pretty flag output (#22584)

pretty print for communicator flag
上级 926227c8
......@@ -24,51 +24,35 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo
class TrainerRuntimeConfig(object):
def __init__(self):
self.max_merge_var_num = os.getenv(
"FLAGS_communicator_max_merge_var_num", "20")
self.send_queue_size = os.getenv("FLAGS_communicator_send_queue_size",
"20")
self.independent_recv_thread = os.getenv(
"FLAGS_communicator_independent_recv_thread", "1")
self.min_send_grad_num_before_recv = os.getenv(
"FLAGS_communicator_min_send_grad_num_before_recv", "20")
self.thread_pool_size = os.getenv("FLAGS_communicator_thread_pool_size",
"5")
self.send_wait_times = os.getenv("FLAGS_communicator_send_wait_times",
"5")
self.fake_rpc = os.getenv("FLAGS_communicator_fake_rpc", "0")
self.merge_sparse_grad = os.getenv(
"FLAGS_communicator_merge_sparse_grad", "1")
self.is_sgd_optimizer = os.getenv("FLAGS_communicator_is_sgd_optimizer",
"1")
self.runtime_configs = {}
# not used
self._rpc_deadline = os.getenv("FLAGS_rpc_deadline", "180000")
self._rpc_retry_times = os.getenv("FLAGS_rpc_retry_times", "3")
self.runtime_configs['rpc_deadline'] = os.getenv("FLAGS_rpc_deadline",
"180000")
self.runtime_configs['rpc_retry_times'] = os.getenv(
"FLAGS_rpc_retry_times", "3")
def get_communicator_flags(self):
_communicator_flags = dict()
_communicator_flags["communicator_max_merge_var_num"] = str(
self.max_merge_var_num)
_communicator_flags["communicator_send_queue_size"] = str(
self.send_queue_size)
_communicator_flags["communicator_independent_recv_thread"] = str(
self.independent_recv_thread)
_communicator_flags["communicator_min_send_grad_num_before_recv"] = str(
self.min_send_grad_num_before_recv)
_communicator_flags["communicator_thread_pool_size"] = str(
self.thread_pool_size)
_communicator_flags["communicator_send_wait_times"] = str(
self.send_wait_times)
_communicator_flags["communicator_is_sgd_optimizer"] = str(
self.is_sgd_optimizer)
return _communicator_flags
return self.runtime_configs
def __repr__(self):
_str = "please check that TrainerRuntimeConfig is as expected:\n"
_communicator_flags = self.get_communicator_flags()
for key in _communicator_flags:
_str += "{}: {}\n".format(key, _communicator_flags[key])
raw0, raw1, length = 45, 5, 50
h_format = "{:^45s}{:<5s}\n"
l_format = "{:<45s}{:<5s}\n"
border = "".join(["="] * length)
line = "".join(["-"] * length)
draws = ""
draws += border + "\n"
draws += h_format.format("TrainerRuntimeConfig Overview", "Value")
draws += line + "\n"
for k, v in self.get_communicator_flags().items():
draws += l_format.format(k, v)
draws += border
_str = "\n{}\n".format(draws)
return _str
......@@ -77,9 +61,11 @@ class DistributedStrategy(object):
self._program_config = DistributeTranspilerConfig()
self._trainer_runtime_config = TrainerRuntimeConfig()
self._server_runtime_config = ServerRuntimeConfig()
num_threads = int(os.getenv("CPU_NUM", "1"))
self._execute_strategy = fluid.ExecutionStrategy()
self._build_strategy = fluid.BuildStrategy()
num_threads = int(os.getenv("CPU_NUM", "1"))
self._execute_strategy.num_threads = num_threads
if num_threads > 1:
self._build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
......@@ -110,9 +96,9 @@ class DistributedStrategy(object):
if isinstance(config, TrainerRuntimeConfig):
self._trainer_runtime_config = config
elif isinstance(config, dict):
for key in config:
if hasattr(self._trainer_runtime_config, key):
setattr(self._trainer_runtime_config, key, config[key])
for key, Value in config.items():
if key in self._trainer_runtime_config.runtime_configs:
self._trainer_runtime_config.runtime_configs[key] = Value
else:
raise ValueError(
"TrainerRuntimeConfig doesn't have key: {}".format(key))
......@@ -182,6 +168,21 @@ class SyncStrategy(DistributedStrategy):
self._program_config.runtime_split_send_recv = False
self._build_strategy.async_mode = False
num_threads = os.getenv("CPU_NUM", "1")
self._trainer_runtime_config.runtime_configs[
'communicator_max_merge_var_num'] = os.getenv(
"FLAGS_communicator_max_merge_var_num", num_threads)
self._trainer_runtime_config.runtime_configs[
'communicator_send_wait_times'] = os.getenv(
"FLAGS_communicator_send_wait_times", "5")
self._trainer_runtime_config.runtime_configs[
'communicator_thread_pool_size'] = os.getenv(
"FLAGS_communicator_thread_pool_size", "10")
self._trainer_runtime_config.runtime_configs[
'communicator_send_queue_size'] = os.getenv(
"FLAGS_communicator_send_queue_size", num_threads)
class AsyncStrategy(DistributedStrategy):
def __init__(self):
......@@ -190,6 +191,30 @@ class AsyncStrategy(DistributedStrategy):
self._program_config.runtime_split_send_recv = True
self._build_strategy.async_mode = True
num_threads = os.getenv("CPU_NUM", "1")
self._trainer_runtime_config.runtime_configs[
'communicator_max_merge_var_num'] = os.getenv(
"FLAGS_communicator_max_merge_var_num", num_threads)
self._trainer_runtime_config.runtime_configs[
'communicator_independent_recv_thread'] = os.getenv(
"FLAGS_communicator_independent_recv_thread", "0")
self._trainer_runtime_config.runtime_configs[
'communicator_min_send_grad_num_before_recv'] = os.getenv(
"FLAGS_communicator_min_send_grad_num_before_recv", num_threads)
self._trainer_runtime_config.runtime_configs[
'communicator_thread_pool_size'] = os.getenv(
"FLAGS_communicator_thread_pool_size", "10")
self._trainer_runtime_config.runtime_configs[
'communicator_send_wait_times'] = os.getenv(
"FLAGS_communicator_send_wait_times", "5")
self._trainer_runtime_config.runtime_configs[
'communicator_is_sgd_optimizer'] = os.getenv(
"FLAGS_communicator_is_sgd_optimizer", "1")
self._trainer_runtime_config.runtime_configs[
'communicator_send_queue_size'] = os.getenv(
"FLAGS_communicator_send_queue_size", num_threads)
class HalfAsyncStrategy(DistributedStrategy):
def __init__(self):
......@@ -200,15 +225,37 @@ class HalfAsyncStrategy(DistributedStrategy):
self._build_strategy.async_mode = True
self._execute_strategy.use_thread_barrier = True
num_threads = os.getenv("CPU_NUM", "1")
self._trainer_runtime_config.runtime_configs[
'communicator_max_merge_var_num'] = os.getenv(
"FLAGS_communicator_max_merge_var_num", num_threads)
self._trainer_runtime_config.runtime_configs[
'communicator_send_wait_times'] = os.getenv(
"FLAGS_communicator_send_wait_times", "5")
self._trainer_runtime_config.runtime_configs[
'communicator_thread_pool_size'] = os.getenv(
"FLAGS_communicator_thread_pool_size", "10")
self._trainer_runtime_config.runtime_configs[
'communicator_send_queue_size'] = os.getenv(
"FLAGS_communicator_send_queue_size", num_threads)
class GeoStrategy(DistributedStrategy):
def __init__(self, update_frequency=100):
super(GeoStrategy, self).__init__()
self._program_config.sync_mode = False
self._program_config.runtime_split_send_recv = True
self._build_strategy.async_mode = True
self._program_config.geo_sgd_mode = True
self._program_config.geo_sgd_need_push_nums = update_frequency
self._build_strategy.async_mode = True
self._trainer_runtime_config.runtime_configs[
'communicator_thread_pool_size'] = os.getenv(
"FLAGS_communicator_thread_pool_size", "10")
self._trainer_runtime_config.runtime_configs[
'communicator_send_wait_times'] = os.getenv(
"FLAGS_communicator_send_wait_times", "5")
class StrategyFactory(object):
......
......@@ -84,22 +84,20 @@ class TestStrategyFactor(unittest.TestCase):
build_strategy_illegal)
def test_async_strategy(self):
os.environ["CPU_NUM"] = '100'
strategy = StrategyFactory.create_async_strategy()
self.assertEqual(strategy._program_config.sync_mode, False)
self.assertEqual(strategy._program_config.runtime_split_send_recv, True)
self.assertEqual(strategy._build_strategy.async_mode, True)
# test set_trainer_runtime_config using TrainerRuntimeConfig
trainer_runtime_config_class = TrainerRuntimeConfig()
trainer_runtime_config_class.send_queue_size = 50
print(trainer_runtime_config_class)
strategy.set_trainer_runtime_config(trainer_runtime_config_class)
trainer_runtime_config = strategy.get_trainer_runtime_config()
self.assertEqual(trainer_runtime_config.send_queue_size, 50)
self.assertEqual(trainer_runtime_config.runtime_configs[
'communicator_send_queue_size'], '100')
# test set_trainer_runtime_config using dict
trainer_runtime_config_dict = dict()
trainer_runtime_config_dict['send_queue_size'] = 100
trainer_runtime_config_dict['communicator_send_queue_size'] = '20'
strategy.set_trainer_runtime_config(trainer_runtime_config_dict)
trainer_runtime_config = strategy.get_trainer_runtime_config()
trainer_communicator_flags = trainer_runtime_config.get_communicator_flags(
......@@ -107,7 +105,7 @@ class TestStrategyFactor(unittest.TestCase):
self.assertIn('communicator_send_queue_size',
trainer_communicator_flags)
self.assertEqual(
trainer_communicator_flags['communicator_send_queue_size'], '100')
trainer_communicator_flags['communicator_send_queue_size'], '20')
# test set_trainer_runtime_config exception
trainer_runtime_config_dict['unknown'] = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册