Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0f9d4081
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0f9d4081
编写于
3月 03, 2020
作者:
1
123malin
提交者:
GitHub
3月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test=develop, optimize distributedstrategy (#22677)
* test=develop, optimize distributedstrategy
上级
5ee29c67
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
200 addition
and
66 deletion
+200
-66
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
...eter_server/distribute_transpiler/distributed_strategy.py
+173
-66
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
...paddle/fluid/tests/unittests/test_distributed_strategy.py
+27
-0
未找到文件。
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
浏览文件 @
0f9d4081
...
...
@@ -19,12 +19,32 @@ __all__ = [
import
os
import
paddle.fluid
as
fluid
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
,
DistributedMode
class
TrainerRuntimeConfig
(
object
):
def
__init__
(
self
):
self
.
mode
=
None
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
self
.
runtime_configs
=
{}
self
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
num_threads
)
self
.
runtime_configs
[
'communicator_send_queue_size'
]
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
num_threads
)
self
.
runtime_configs
[
'communicator_independent_recv_thread'
]
=
os
.
getenv
(
"FLAGS_communicator_independent_recv_thread"
,
"1"
)
self
.
runtime_configs
[
'communicator_min_send_grad_num_before_recv'
]
=
os
.
getenv
(
"FLAGS_communicator_min_send_grad_num_before_recv"
,
num_threads
)
self
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"5"
)
self
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
runtime_configs
[
'communicator_is_sgd_optimizer'
]
=
os
.
getenv
(
"FLAGS_communicator_is_sgd_optimizer"
,
"1"
)
# not used
self
.
runtime_configs
[
'rpc_deadline'
]
=
os
.
getenv
(
"FLAGS_rpc_deadline"
,
"180000"
)
...
...
@@ -32,9 +52,54 @@ class TrainerRuntimeConfig(object):
"FLAGS_rpc_retry_times"
,
"3"
)
def
get_communicator_flags
(
self
):
return
self
.
runtime_configs
def
__repr__
(
self
):
need_keys
=
[]
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
mode_str
=
""
if
self
.
mode
is
None
or
self
.
mode
==
DistributedMode
.
ASYNC
:
need_keys
=
self
.
runtime_configs
.
keys
()
mode_str
=
"async"
elif
self
.
mode
==
DistributedMode
.
SYNC
or
self
.
mode
==
DistributedMode
.
HALF_ASYNC
:
mode_str
=
"sync or half_async"
need_keys
=
[
'communicator_max_merge_var_num'
,
'communicator_send_wait_times'
,
'communicator_thread_pool_size'
,
'communicator_send_queue_size'
]
elif
self
.
mode
==
DistributedMode
.
GEO
:
mode_str
=
"GEO"
need_keys
=
[
'communicator_thread_pool_size'
,
'communicator_send_wait_times'
]
else
:
raise
ValueError
(
"Unsupported Mode"
)
if
self
.
mode
==
DistributedMode
.
SYNC
or
self
.
mode
==
DistributedMode
.
HALF_ASYNC
:
max_merge_var_num
=
self
.
runtime_configs
[
'communicator_max_merge_var_num'
]
send_queue_size
=
self
.
runtime_configs
[
'communicator_send_queue_size'
]
if
max_merge_var_num
!=
num_threads
:
print
(
'WARNING: In {} mode, communicator_max_merge_var_num '
'must be equal to CPU_NUM. But received, '
'communicator_max_merge_var_num = {}, CPU_NUM = '
'{}. communicator_max_merge_var_num will be fored to {}.'
.
format
(
mode_str
,
max_merge_var_num
,
num_threads
,
num_threads
))
self
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
num_threads
if
send_queue_size
!=
num_threads
:
print
(
'WARNING: In {} mode, communicator_send_queue_size '
'must be equal to CPU_NUM. But received, '
'communicator_send_queue_size = {}, CPU_NUM = '
'{}. communicator_send_queue_size will be fored to {}.'
.
format
(
mode_str
,
send_queue_size
,
num_threads
,
num_threads
))
self
.
runtime_configs
[
'communicator_send_queue_size'
]
=
num_threads
return
dict
((
key
,
str
(
self
.
runtime_configs
[
key
]))
for
key
in
need_keys
)
def
display
(
self
,
configs
):
raw0
,
raw1
,
length
=
45
,
5
,
50
h_format
=
"{:^45s}{:<5s}
\n
"
l_format
=
"{:<45s}{:<5s}
\n
"
...
...
@@ -47,7 +112,7 @@ class TrainerRuntimeConfig(object):
draws
+=
h_format
.
format
(
"TrainerRuntimeConfig Overview"
,
"Value"
)
draws
+=
line
+
"
\n
"
for
k
,
v
in
self
.
get_communicator_flags
()
.
items
():
for
k
,
v
in
configs
.
items
():
draws
+=
l_format
.
format
(
k
,
v
)
draws
+=
border
...
...
@@ -55,6 +120,9 @@ class TrainerRuntimeConfig(object):
_str
=
"
\n
{}
\n
"
.
format
(
draws
)
return
_str
def
__repr__
(
self
):
return
self
.
display
(
self
.
get_communicator_flags
())
class
DistributedStrategy
(
object
):
def
__init__
(
self
):
...
...
@@ -105,6 +173,12 @@ class DistributedStrategy(object):
raise
TypeError
(
"program_config only accept input type: dict or DistributeTranspilerConfig"
)
self
.
check_program_config
()
def
check_program_config
(
self
):
raise
NotImplementedError
(
"check_program_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
def
get_trainer_runtime_config
(
self
):
return
self
.
_trainer_runtime_config
...
...
@@ -123,6 +197,12 @@ class DistributedStrategy(object):
raise
TypeError
(
"trainer_runtime_config only accept input type: dict or TrainerRuntimeConfig"
)
self
.
check_trainer_runtime_config
()
def
check_trainer_runtime_config
(
self
):
raise
NotImplementedError
(
"check_trainer_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
def
get_server_runtime_config
(
self
):
return
self
.
_server_runtime_config
...
...
@@ -141,6 +221,12 @@ class DistributedStrategy(object):
raise
TypeError
(
"server_runtime_config only accept input type: dict or ServerRuntimeConfig"
)
self
.
check_server_runtime_config
()
def
check_server_runtime_config
(
self
):
raise
NotImplementedError
(
"check_server_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
def
get_execute_strategy
(
self
):
return
self
.
_execute_strategy
...
...
@@ -159,6 +245,12 @@ class DistributedStrategy(object):
raise
TypeError
(
"execute_strategy only accept input type: dict or ExecutionStrategy"
)
self
.
check_execute_strategy
()
def
check_execute_strategy
(
self
):
raise
NotImplementedError
(
"check_execute_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
def
get_build_strategy
(
self
):
return
self
.
_build_strategy
...
...
@@ -176,106 +268,121 @@ class DistributedStrategy(object):
else
:
raise
TypeError
(
"build_strategy only accept input type: dict or BuildStrategy"
)
self
.
check_build_strategy
()
def
check_build_strategy
(
self
):
raise
NotImplementedError
(
"check_build_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
class
SyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
SyncStrategy
,
self
).
__init__
()
self
.
check_program_config
()
self
.
check_trainer_runtime_config
()
self
.
check_server_runtime_config
()
self
.
check_build_strategy
()
self
.
check_execute_strategy
()
def
check_trainer_runtime_config
(
self
):
self
.
_trainer_runtime_config
.
mode
=
DistributedMode
.
SYNC
def
check_program_config
(
self
):
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_build_strategy
.
async_mode
=
True
self
.
_program_config
.
half_async
=
True
self
.
_program_config
.
completely_not_async
=
True
self
.
_execute_strategy
.
use_thread_barrier
=
True
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
def
check_server_runtime_config
(
self
):
pass
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
num_threads
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"10"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
num_threads
)
def
check_execute_strategy
(
self
):
self
.
_execute_strategy
.
use_thread_barrier
=
True
def
check_build_strategy
(
self
):
self
.
_build_strategy
.
async_mode
=
True
class
AsyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
AsyncStrategy
,
self
).
__init__
()
self
.
check_program_config
()
self
.
check_trainer_runtime_config
()
self
.
check_server_runtime_config
()
self
.
check_build_strategy
()
self
.
check_execute_strategy
()
def
check_trainer_runtime_config
(
self
):
self
.
_trainer_runtime_config
.
mode
=
DistributedMode
.
ASYNC
def
check_program_config
(
self
):
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_build_strategy
.
async_mode
=
True
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
def
check_server_runtime_config
(
self
):
pass
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
num_threads
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_independent_recv_thread'
]
=
os
.
getenv
(
"FLAGS_communicator_independent_recv_thread"
,
"0"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_min_send_grad_num_before_recv'
]
=
os
.
getenv
(
"FLAGS_communicator_min_send_grad_num_before_recv"
,
num_threads
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"10"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_is_sgd_optimizer'
]
=
os
.
getenv
(
"FLAGS_communicator_is_sgd_optimizer"
,
"1"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
num_threads
)
def
check_execute_strategy
(
self
):
pass
def
check_build_strategy
(
self
):
self
.
_build_strategy
.
async_mode
=
True
class
HalfAsyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
HalfAsyncStrategy
,
self
).
__init__
()
self
.
check_program_config
()
self
.
check_trainer_runtime_config
()
self
.
check_server_runtime_config
()
self
.
check_build_strategy
()
self
.
check_execute_strategy
()
def
check_trainer_runtime_config
(
self
):
self
.
_trainer_runtime_config
.
mode
=
DistributedMode
.
HALF_ASYNC
def
check_program_config
(
self
):
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_program_config
.
half_async
=
True
self
.
_build_strategy
.
async_mode
=
True
self
.
_execute_strategy
.
use_thread_barrier
=
True
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
def
check_server_runtime_config
(
self
):
pass
def
check_execute_strategy
(
self
):
self
.
_execute_strategy
.
use_thread_barrier
=
True
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
num_threads
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"10"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
num_threads
)
def
check_build_strategy
(
self
):
self
.
_build_strategy
.
async_mode
=
True
class
GeoStrategy
(
DistributedStrategy
):
def
__init__
(
self
,
update_frequency
=
100
):
super
(
GeoStrategy
,
self
).
__init__
()
self
.
_program_config
.
geo_sgd_need_push_nums
=
update_frequency
self
.
check_program_config
()
self
.
check_trainer_runtime_config
()
self
.
check_server_runtime_config
()
self
.
check_build_strategy
()
self
.
check_execute_strategy
()
def
check_program_config
(
self
):
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_program_config
.
geo_sgd_mode
=
True
self
.
_program_config
.
geo_sgd_need_push_nums
=
update_frequency
self
.
_build_strategy
.
async_mode
=
True
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"10"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
def
check_trainer_runtime_config
(
self
):
self
.
_trainer_runtime_config
.
mode
=
DistributedMode
.
GEO
def
check_server_runtime_config
(
self
):
pass
def
check_execute_strategy
(
self
):
pass
def
check_build_strategy
(
self
):
self
.
_build_strategy
.
async_mode
=
True
class
StrategyFactory
(
object
):
...
...
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
浏览文件 @
0f9d4081
...
...
@@ -52,6 +52,15 @@ class TestStrategyFactor(unittest.TestCase):
self
.
assertRaises
(
Exception
,
strategy
.
set_program_config
,
program_config_illegal
)
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
'50'
runtime_configs
=
trainer_runtime_config
.
get_communicator_flags
()
self
.
assertIn
(
'communicator_send_queue_size'
,
runtime_configs
)
self
.
assertNotIn
(
'communicator_independent_recv_thread'
,
runtime_configs
)
self
.
assertEqual
(
runtime_configs
[
'communicator_send_queue_size'
],
'2'
)
def
test_geo_strategy
(
self
):
strategy
=
StrategyFactory
.
create_geo_strategy
(
5
)
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
False
)
...
...
@@ -82,6 +91,14 @@ class TestStrategyFactor(unittest.TestCase):
self
.
assertRaises
(
Exception
,
strategy
.
set_build_strategy
,
build_strategy_illegal
)
os
.
environ
[
"CPU_NUM"
]
=
'100'
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
runtime_configs
=
trainer_runtime_config
.
get_communicator_flags
()
self
.
assertIn
(
'communicator_thread_pool_size'
,
runtime_configs
)
self
.
assertIn
(
'communicator_send_wait_times'
,
runtime_configs
)
self
.
assertNotIn
(
'communicator_independent_recv_thread'
,
runtime_configs
)
def
test_async_strategy
(
self
):
os
.
environ
[
"CPU_NUM"
]
=
'100'
...
...
@@ -164,6 +181,16 @@ class TestStrategyFactor(unittest.TestCase):
self
.
assertRaises
(
Exception
,
strategy
.
set_server_runtime_config
,
server_runtime_config_illegal
)
os
.
environ
[
"CPU_NUM"
]
=
'100'
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
'50'
runtime_configs
=
trainer_runtime_config
.
get_communicator_flags
()
self
.
assertIn
(
'communicator_send_queue_size'
,
runtime_configs
)
self
.
assertNotIn
(
'communicator_independent_recv_thread'
,
runtime_configs
)
self
.
assertEqual
(
runtime_configs
[
'communicator_send_queue_size'
],
'100'
)
class
TestCreateDefaultStrategy
(
unittest
.
TestCase
):
def
test_default_strategy
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录