Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c000f8a2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c000f8a2
编写于
2月 17, 2020
作者:
T
tangwei12
提交者:
GitHub
2月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add texttable for pretty flag output (#22584) (#22626)
pretty print for communicator flag
上级
f517fb60
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
98 addition
and
53 deletion
+98
-53
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
...eter_server/distribute_transpiler/distributed_strategy.py
+92
-45
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
...paddle/fluid/tests/unittests/test_distributed_strategy.py
+6
-8
未找到文件。
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
浏览文件 @
c000f8a2
...
...
@@ -24,51 +24,35 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo
class
TrainerRuntimeConfig
(
object
):
def
__init__
(
self
):
self
.
max_merge_var_num
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
"20"
)
self
.
send_queue_size
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
"20"
)
self
.
independent_recv_thread
=
os
.
getenv
(
"FLAGS_communicator_independent_recv_thread"
,
"1"
)
self
.
min_send_grad_num_before_recv
=
os
.
getenv
(
"FLAGS_communicator_min_send_grad_num_before_recv"
,
"20"
)
self
.
thread_pool_size
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"5"
)
self
.
send_wait_times
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
fake_rpc
=
os
.
getenv
(
"FLAGS_communicator_fake_rpc"
,
"0"
)
self
.
merge_sparse_grad
=
os
.
getenv
(
"FLAGS_communicator_merge_sparse_grad"
,
"1"
)
self
.
is_sgd_optimizer
=
os
.
getenv
(
"FLAGS_communicator_is_sgd_optimizer"
,
"1"
)
self
.
runtime_configs
=
{}
# not used
self
.
_rpc_deadline
=
os
.
getenv
(
"FLAGS_rpc_deadline"
,
"180000"
)
self
.
_rpc_retry_times
=
os
.
getenv
(
"FLAGS_rpc_retry_times"
,
"3"
)
self
.
runtime_configs
[
'rpc_deadline'
]
=
os
.
getenv
(
"FLAGS_rpc_deadline"
,
"180000"
)
self
.
runtime_configs
[
'rpc_retry_times'
]
=
os
.
getenv
(
"FLAGS_rpc_retry_times"
,
"3"
)
def
get_communicator_flags
(
self
):
_communicator_flags
=
dict
()
_communicator_flags
[
"communicator_max_merge_var_num"
]
=
str
(
self
.
max_merge_var_num
)
_communicator_flags
[
"communicator_send_queue_size"
]
=
str
(
self
.
send_queue_size
)
_communicator_flags
[
"communicator_independent_recv_thread"
]
=
str
(
self
.
independent_recv_thread
)
_communicator_flags
[
"communicator_min_send_grad_num_before_recv"
]
=
str
(
self
.
min_send_grad_num_before_recv
)
_communicator_flags
[
"communicator_thread_pool_size"
]
=
str
(
self
.
thread_pool_size
)
_communicator_flags
[
"communicator_send_wait_times"
]
=
str
(
self
.
send_wait_times
)
_communicator_flags
[
"communicator_is_sgd_optimizer"
]
=
str
(
self
.
is_sgd_optimizer
)
return
_communicator_flags
return
self
.
runtime_configs
def
__repr__
(
self
):
_str
=
"please check that TrainerRuntimeConfig is as expected:
\n
"
_communicator_flags
=
self
.
get_communicator_flags
()
for
key
in
_communicator_flags
:
_str
+=
"{}: {}
\n
"
.
format
(
key
,
_communicator_flags
[
key
])
raw0
,
raw1
,
length
=
45
,
5
,
50
h_format
=
"{:^45s}{:<5s}
\n
"
l_format
=
"{:<45s}{:<5s}
\n
"
border
=
""
.
join
([
"="
]
*
length
)
line
=
""
.
join
([
"-"
]
*
length
)
draws
=
""
draws
+=
border
+
"
\n
"
draws
+=
h_format
.
format
(
"TrainerRuntimeConfig Overview"
,
"Value"
)
draws
+=
line
+
"
\n
"
for
k
,
v
in
self
.
get_communicator_flags
().
items
():
draws
+=
l_format
.
format
(
k
,
v
)
draws
+=
border
_str
=
"
\n
{}
\n
"
.
format
(
draws
)
return
_str
...
...
@@ -77,9 +61,11 @@ class DistributedStrategy(object):
self
.
_program_config
=
DistributeTranspilerConfig
()
self
.
_trainer_runtime_config
=
TrainerRuntimeConfig
()
self
.
_server_runtime_config
=
ServerRuntimeConfig
()
num_threads
=
int
(
os
.
getenv
(
"CPU_NUM"
,
"1"
))
self
.
_execute_strategy
=
fluid
.
ExecutionStrategy
()
self
.
_build_strategy
=
fluid
.
BuildStrategy
()
num_threads
=
int
(
os
.
getenv
(
"CPU_NUM"
,
"1"
))
self
.
_execute_strategy
.
num_threads
=
num_threads
if
num_threads
>
1
:
self
.
_build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
...
...
@@ -110,9 +96,9 @@ class DistributedStrategy(object):
if
isinstance
(
config
,
TrainerRuntimeConfig
):
self
.
_trainer_runtime_config
=
config
elif
isinstance
(
config
,
dict
):
for
key
in
config
:
if
hasattr
(
self
.
_trainer_runtime_config
,
key
)
:
se
tattr
(
self
.
_trainer_runtime_config
,
key
,
config
[
key
])
for
key
,
Value
in
config
.
items
()
:
if
key
in
self
.
_trainer_runtime_config
.
runtime_configs
:
se
lf
.
_trainer_runtime_config
.
runtime_configs
[
key
]
=
Value
else
:
raise
ValueError
(
"TrainerRuntimeConfig doesn't have key: {}"
.
format
(
key
))
...
...
@@ -182,6 +168,21 @@ class SyncStrategy(DistributedStrategy):
self
.
_program_config
.
runtime_split_send_recv
=
False
self
.
_build_strategy
.
async_mode
=
False
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
num_threads
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"10"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
num_threads
)
class
AsyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
...
...
@@ -190,6 +191,30 @@ class AsyncStrategy(DistributedStrategy):
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_build_strategy
.
async_mode
=
True
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
num_threads
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_independent_recv_thread'
]
=
os
.
getenv
(
"FLAGS_communicator_independent_recv_thread"
,
"0"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_min_send_grad_num_before_recv'
]
=
os
.
getenv
(
"FLAGS_communicator_min_send_grad_num_before_recv"
,
num_threads
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"10"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_is_sgd_optimizer'
]
=
os
.
getenv
(
"FLAGS_communicator_is_sgd_optimizer"
,
"1"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
num_threads
)
class
HalfAsyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
...
...
@@ -200,15 +225,37 @@ class HalfAsyncStrategy(DistributedStrategy):
self
.
_build_strategy
.
async_mode
=
True
self
.
_execute_strategy
.
use_thread_barrier
=
True
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
num_threads
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"10"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
num_threads
)
class
GeoStrategy
(
DistributedStrategy
):
def
__init__
(
self
,
update_frequency
=
100
):
super
(
GeoStrategy
,
self
).
__init__
()
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_build_strategy
.
async_mode
=
True
self
.
_program_config
.
geo_sgd_mode
=
True
self
.
_program_config
.
geo_sgd_need_push_nums
=
update_frequency
self
.
_build_strategy
.
async_mode
=
True
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"10"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
class
StrategyFactory
(
object
):
...
...
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
浏览文件 @
c000f8a2
...
...
@@ -84,22 +84,20 @@ class TestStrategyFactor(unittest.TestCase):
build_strategy_illegal
)
def
test_async_strategy
(
self
):
os
.
environ
[
"CPU_NUM"
]
=
'100'
strategy
=
StrategyFactory
.
create_async_strategy
()
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
False
)
self
.
assertEqual
(
strategy
.
_program_config
.
runtime_split_send_recv
,
True
)
self
.
assertEqual
(
strategy
.
_build_strategy
.
async_mode
,
True
)
# test set_trainer_runtime_config using TrainerRuntimeConfig
trainer_runtime_config_class
=
TrainerRuntimeConfig
()
trainer_runtime_config_class
.
send_queue_size
=
50
print
(
trainer_runtime_config_class
)
strategy
.
set_trainer_runtime_config
(
trainer_runtime_config_class
)
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
self
.
assertEqual
(
trainer_runtime_config
.
send_queue_size
,
50
)
self
.
assertEqual
(
trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
],
'100'
)
# test set_trainer_runtime_config using dict
trainer_runtime_config_dict
=
dict
()
trainer_runtime_config_dict
[
'
send_queue_size'
]
=
100
trainer_runtime_config_dict
[
'
communicator_send_queue_size'
]
=
'20'
strategy
.
set_trainer_runtime_config
(
trainer_runtime_config_dict
)
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
trainer_communicator_flags
=
trainer_runtime_config
.
get_communicator_flags
(
...
...
@@ -107,7 +105,7 @@ class TestStrategyFactor(unittest.TestCase):
self
.
assertIn
(
'communicator_send_queue_size'
,
trainer_communicator_flags
)
self
.
assertEqual
(
trainer_communicator_flags
[
'communicator_send_queue_size'
],
'
10
0'
)
trainer_communicator_flags
[
'communicator_send_queue_size'
],
'
2
0'
)
# test set_trainer_runtime_config exception
trainer_runtime_config_dict
[
'unknown'
]
=
None
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录