Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0f9d4081
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0f9d4081
编写于
3月 03, 2020
作者:
1
123malin
提交者:
GitHub
3月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test=develop, optimize distributedstrategy (#22677)
* test=develop, optimize distributedstrategy
上级
5ee29c67
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
200 addition
and
66 deletion
+200
-66
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
...eter_server/distribute_transpiler/distributed_strategy.py
+173
-66
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
...paddle/fluid/tests/unittests/test_distributed_strategy.py
+27
-0
未找到文件。
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
浏览文件 @
0f9d4081
...
...
@@ -19,12 +19,32 @@ __all__ = [
import
os
import
paddle.fluid
as
fluid
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
,
ServerRuntimeConfig
,
DistributedMode
class
TrainerRuntimeConfig
(
object
):
def
__init__
(
self
):
self
.
mode
=
None
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
self
.
runtime_configs
=
{}
self
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
num_threads
)
self
.
runtime_configs
[
'communicator_send_queue_size'
]
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
num_threads
)
self
.
runtime_configs
[
'communicator_independent_recv_thread'
]
=
os
.
getenv
(
"FLAGS_communicator_independent_recv_thread"
,
"1"
)
self
.
runtime_configs
[
'communicator_min_send_grad_num_before_recv'
]
=
os
.
getenv
(
"FLAGS_communicator_min_send_grad_num_before_recv"
,
num_threads
)
self
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"5"
)
self
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
runtime_configs
[
'communicator_is_sgd_optimizer'
]
=
os
.
getenv
(
"FLAGS_communicator_is_sgd_optimizer"
,
"1"
)
# not used
self
.
runtime_configs
[
'rpc_deadline'
]
=
os
.
getenv
(
"FLAGS_rpc_deadline"
,
"180000"
)
...
...
@@ -32,9 +52,54 @@ class TrainerRuntimeConfig(object):
"FLAGS_rpc_retry_times"
,
"3"
)
def
get_communicator_flags
(
self
):
return
self
.
runtime_configs
def
__repr__
(
self
):
need_keys
=
[]
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
mode_str
=
""
if
self
.
mode
is
None
or
self
.
mode
==
DistributedMode
.
ASYNC
:
need_keys
=
self
.
runtime_configs
.
keys
()
mode_str
=
"async"
elif
self
.
mode
==
DistributedMode
.
SYNC
or
self
.
mode
==
DistributedMode
.
HALF_ASYNC
:
mode_str
=
"sync or half_async"
need_keys
=
[
'communicator_max_merge_var_num'
,
'communicator_send_wait_times'
,
'communicator_thread_pool_size'
,
'communicator_send_queue_size'
]
elif
self
.
mode
==
DistributedMode
.
GEO
:
mode_str
=
"GEO"
need_keys
=
[
'communicator_thread_pool_size'
,
'communicator_send_wait_times'
]
else
:
raise
ValueError
(
"Unsupported Mode"
)
if
self
.
mode
==
DistributedMode
.
SYNC
or
self
.
mode
==
DistributedMode
.
HALF_ASYNC
:
max_merge_var_num
=
self
.
runtime_configs
[
'communicator_max_merge_var_num'
]
send_queue_size
=
self
.
runtime_configs
[
'communicator_send_queue_size'
]
if
max_merge_var_num
!=
num_threads
:
print
(
'WARNING: In {} mode, communicator_max_merge_var_num '
'must be equal to CPU_NUM. But received, '
'communicator_max_merge_var_num = {}, CPU_NUM = '
'{}. communicator_max_merge_var_num will be fored to {}.'
.
format
(
mode_str
,
max_merge_var_num
,
num_threads
,
num_threads
))
self
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
num_threads
if
send_queue_size
!=
num_threads
:
print
(
'WARNING: In {} mode, communicator_send_queue_size '
'must be equal to CPU_NUM. But received, '
'communicator_send_queue_size = {}, CPU_NUM = '
'{}. communicator_send_queue_size will be fored to {}.'
.
format
(
mode_str
,
send_queue_size
,
num_threads
,
num_threads
))
self
.
runtime_configs
[
'communicator_send_queue_size'
]
=
num_threads
return
dict
((
key
,
str
(
self
.
runtime_configs
[
key
]))
for
key
in
need_keys
)
def
display
(
self
,
configs
):
raw0
,
raw1
,
length
=
45
,
5
,
50
h_format
=
"{:^45s}{:<5s}
\n
"
l_format
=
"{:<45s}{:<5s}
\n
"
...
...
@@ -47,7 +112,7 @@ class TrainerRuntimeConfig(object):
draws
+=
h_format
.
format
(
"TrainerRuntimeConfig Overview"
,
"Value"
)
draws
+=
line
+
"
\n
"
for
k
,
v
in
self
.
get_communicator_flags
()
.
items
():
for
k
,
v
in
configs
.
items
():
draws
+=
l_format
.
format
(
k
,
v
)
draws
+=
border
...
...
@@ -55,6 +120,9 @@ class TrainerRuntimeConfig(object):
_str
=
"
\n
{}
\n
"
.
format
(
draws
)
return
_str
def
__repr__
(
self
):
return
self
.
display
(
self
.
get_communicator_flags
())
class
DistributedStrategy
(
object
):
def
__init__
(
self
):
...
...
@@ -105,6 +173,12 @@ class DistributedStrategy(object):
raise
TypeError
(
"program_config only accept input type: dict or DistributeTranspilerConfig"
)
self
.
check_program_config
()
def
check_program_config
(
self
):
raise
NotImplementedError
(
"check_program_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
def
get_trainer_runtime_config
(
self
):
return
self
.
_trainer_runtime_config
...
...
@@ -123,6 +197,12 @@ class DistributedStrategy(object):
raise
TypeError
(
"trainer_runtime_config only accept input type: dict or TrainerRuntimeConfig"
)
self
.
check_trainer_runtime_config
()
def
check_trainer_runtime_config
(
self
):
raise
NotImplementedError
(
"check_trainer_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
def
get_server_runtime_config
(
self
):
return
self
.
_server_runtime_config
...
...
@@ -141,6 +221,12 @@ class DistributedStrategy(object):
raise
TypeError
(
"server_runtime_config only accept input type: dict or ServerRuntimeConfig"
)
self
.
check_server_runtime_config
()
def
check_server_runtime_config
(
self
):
raise
NotImplementedError
(
"check_server_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
def
get_execute_strategy
(
self
):
return
self
.
_execute_strategy
...
...
@@ -159,6 +245,12 @@ class DistributedStrategy(object):
raise
TypeError
(
"execute_strategy only accept input type: dict or ExecutionStrategy"
)
self
.
check_execute_strategy
()
def
check_execute_strategy
(
self
):
raise
NotImplementedError
(
"check_execute_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
def
get_build_strategy
(
self
):
return
self
.
_build_strategy
...
...
@@ -176,106 +268,121 @@ class DistributedStrategy(object):
else
:
raise
TypeError
(
"build_strategy only accept input type: dict or BuildStrategy"
)
self
.
check_build_strategy
()
def
check_build_strategy
(
self
):
raise
NotImplementedError
(
"check_build_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
)
class
SyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
SyncStrategy
,
self
).
__init__
()
self
.
check_program_config
()
self
.
check_trainer_runtime_config
()
self
.
check_server_runtime_config
()
self
.
check_build_strategy
()
self
.
check_execute_strategy
()
def
check_trainer_runtime_config
(
self
):
self
.
_trainer_runtime_config
.
mode
=
DistributedMode
.
SYNC
def
check_program_config
(
self
):
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_build_strategy
.
async_mode
=
True
self
.
_program_config
.
half_async
=
True
self
.
_program_config
.
completely_not_async
=
True
self
.
_execute_strategy
.
use_thread_barrier
=
True
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
def
check_server_runtime_config
(
self
):
pass
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
num_threads
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"10"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
num_threads
)
def
check_execute_strategy
(
self
):
self
.
_execute_strategy
.
use_thread_barrier
=
True
def
check_build_strategy
(
self
):
self
.
_build_strategy
.
async_mode
=
True
class
AsyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
AsyncStrategy
,
self
).
__init__
()
self
.
check_program_config
()
self
.
check_trainer_runtime_config
()
self
.
check_server_runtime_config
()
self
.
check_build_strategy
()
self
.
check_execute_strategy
()
def
check_trainer_runtime_config
(
self
):
self
.
_trainer_runtime_config
.
mode
=
DistributedMode
.
ASYNC
def
check_program_config
(
self
):
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_build_strategy
.
async_mode
=
True
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
def
check_server_runtime_config
(
self
):
pass
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
num_threads
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_independent_recv_thread'
]
=
os
.
getenv
(
"FLAGS_communicator_independent_recv_thread"
,
"0"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_min_send_grad_num_before_recv'
]
=
os
.
getenv
(
"FLAGS_communicator_min_send_grad_num_before_recv"
,
num_threads
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"10"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_is_sgd_optimizer'
]
=
os
.
getenv
(
"FLAGS_communicator_is_sgd_optimizer"
,
"1"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
num_threads
)
def
check_execute_strategy
(
self
):
pass
def
check_build_strategy
(
self
):
self
.
_build_strategy
.
async_mode
=
True
class
HalfAsyncStrategy
(
DistributedStrategy
):
def
__init__
(
self
):
super
(
HalfAsyncStrategy
,
self
).
__init__
()
self
.
check_program_config
()
self
.
check_trainer_runtime_config
()
self
.
check_server_runtime_config
()
self
.
check_build_strategy
()
self
.
check_execute_strategy
()
def
check_trainer_runtime_config
(
self
):
self
.
_trainer_runtime_config
.
mode
=
DistributedMode
.
HALF_ASYNC
def
check_program_config
(
self
):
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_program_config
.
half_async
=
True
self
.
_build_strategy
.
async_mode
=
True
self
.
_execute_strategy
.
use_thread_barrier
=
True
num_threads
=
os
.
getenv
(
"CPU_NUM"
,
"1"
)
def
check_server_runtime_config
(
self
):
pass
def
check_execute_strategy
(
self
):
self
.
_execute_strategy
.
use_thread_barrier
=
True
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_max_merge_var_num'
]
=
os
.
getenv
(
"FLAGS_communicator_max_merge_var_num"
,
num_threads
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"10"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
os
.
getenv
(
"FLAGS_communicator_send_queue_size"
,
num_threads
)
def
check_build_strategy
(
self
):
self
.
_build_strategy
.
async_mode
=
True
class
GeoStrategy
(
DistributedStrategy
):
def
__init__
(
self
,
update_frequency
=
100
):
super
(
GeoStrategy
,
self
).
__init__
()
self
.
_program_config
.
geo_sgd_need_push_nums
=
update_frequency
self
.
check_program_config
()
self
.
check_trainer_runtime_config
()
self
.
check_server_runtime_config
()
self
.
check_build_strategy
()
self
.
check_execute_strategy
()
def
check_program_config
(
self
):
self
.
_program_config
.
sync_mode
=
False
self
.
_program_config
.
runtime_split_send_recv
=
True
self
.
_program_config
.
geo_sgd_mode
=
True
self
.
_program_config
.
geo_sgd_need_push_nums
=
update_frequency
self
.
_build_strategy
.
async_mode
=
True
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_thread_pool_size'
]
=
os
.
getenv
(
"FLAGS_communicator_thread_pool_size"
,
"10"
)
self
.
_trainer_runtime_config
.
runtime_configs
[
'communicator_send_wait_times'
]
=
os
.
getenv
(
"FLAGS_communicator_send_wait_times"
,
"5"
)
def
check_trainer_runtime_config
(
self
):
self
.
_trainer_runtime_config
.
mode
=
DistributedMode
.
GEO
def
check_server_runtime_config
(
self
):
pass
def
check_execute_strategy
(
self
):
pass
def
check_build_strategy
(
self
):
self
.
_build_strategy
.
async_mode
=
True
class
StrategyFactory
(
object
):
...
...
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
浏览文件 @
0f9d4081
...
...
@@ -52,6 +52,15 @@ class TestStrategyFactor(unittest.TestCase):
self
.
assertRaises
(
Exception
,
strategy
.
set_program_config
,
program_config_illegal
)
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
'50'
runtime_configs
=
trainer_runtime_config
.
get_communicator_flags
()
self
.
assertIn
(
'communicator_send_queue_size'
,
runtime_configs
)
self
.
assertNotIn
(
'communicator_independent_recv_thread'
,
runtime_configs
)
self
.
assertEqual
(
runtime_configs
[
'communicator_send_queue_size'
],
'2'
)
def
test_geo_strategy
(
self
):
strategy
=
StrategyFactory
.
create_geo_strategy
(
5
)
self
.
assertEqual
(
strategy
.
_program_config
.
sync_mode
,
False
)
...
...
@@ -82,6 +91,14 @@ class TestStrategyFactor(unittest.TestCase):
self
.
assertRaises
(
Exception
,
strategy
.
set_build_strategy
,
build_strategy_illegal
)
os
.
environ
[
"CPU_NUM"
]
=
'100'
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
runtime_configs
=
trainer_runtime_config
.
get_communicator_flags
()
self
.
assertIn
(
'communicator_thread_pool_size'
,
runtime_configs
)
self
.
assertIn
(
'communicator_send_wait_times'
,
runtime_configs
)
self
.
assertNotIn
(
'communicator_independent_recv_thread'
,
runtime_configs
)
def
test_async_strategy
(
self
):
os
.
environ
[
"CPU_NUM"
]
=
'100'
...
...
@@ -164,6 +181,16 @@ class TestStrategyFactor(unittest.TestCase):
self
.
assertRaises
(
Exception
,
strategy
.
set_server_runtime_config
,
server_runtime_config_illegal
)
os
.
environ
[
"CPU_NUM"
]
=
'100'
trainer_runtime_config
=
strategy
.
get_trainer_runtime_config
()
trainer_runtime_config
.
runtime_configs
[
'communicator_send_queue_size'
]
=
'50'
runtime_configs
=
trainer_runtime_config
.
get_communicator_flags
()
self
.
assertIn
(
'communicator_send_queue_size'
,
runtime_configs
)
self
.
assertNotIn
(
'communicator_independent_recv_thread'
,
runtime_configs
)
self
.
assertEqual
(
runtime_configs
[
'communicator_send_queue_size'
],
'100'
)
class
TestCreateDefaultStrategy
(
unittest
.
TestCase
):
def
test_default_strategy
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录