Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fbf9564f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fbf9564f
编写于
11月 24, 2020
作者:
1
123malin
提交者:
GitHub
11月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【paddle.distributed.fleet】Optimize ParameterServer's Async Mode (#28442)
* test=develop, optimize global_step
上级
98adc8f0
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
138 addition
and
59 deletion
+138
-59
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+130
-54
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+7
-5
python/paddle/distributed/fleet/runtime/parameter_server_runtime.py
...dle/distributed/fleet/runtime/parameter_server_runtime.py
+1
-0
未找到文件。
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
fbf9564f
...
...
@@ -65,6 +65,7 @@ void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
}
else
{
send_scope_
.
reset
(
new
Scope
());
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
if
(
iter
.
first
==
STEP_COUNTER
&&
!
need_global_step_
)
continue
;
send_varname_to_queue_
[
iter
.
first
]
=
std
::
make_shared
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>
(
send_queue_size_
);
...
...
@@ -108,21 +109,87 @@ void AsyncCommunicator::SendGlobalStep(int batches) {
send_functor
(
ctx
,
*
send_scope_
,
true
,
1
);
}
void
AsyncCommunicator
::
SendByCommunicator
(
int
batches
)
{
void
AsyncCommunicator
::
SendByCommunicator
()
{
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
task_futures
.
reserve
(
send_varname_to_ctx_
.
size
());
VLOG
(
3
)
<<
"run send graph"
;
auto
before_run_send_graph
=
GetCurrentUS
();
for
(
auto
&
iter
:
send_varname_to_queue_
)
{
auto
&
var_name
=
iter
.
first
;
auto
&
var_queue
=
iter
.
second
;
auto
send_task
=
[
this
,
batches
,
&
var_name
,
&
var_queue
]
{
auto
send_task
=
[
this
,
&
var_name
,
&
var_queue
]
{
VLOG
(
3
)
<<
var_name
<<
" merge and send; "
;
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
vars
;
int
merged_var_num
=
0
;
int
wait_times
=
0
;
while
(
merged_var_num
<
max_merge_var_num_
)
{
if
(
var_queue
->
Size
()
==
0
)
{
VLOG
(
4
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
send_wait_times_
)
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
wait_times
++
;
continue
;
}
else
{
wait_times
=
0
;
vars
.
push_back
(
var_queue
->
Pop
());
merged_var_num
++
;
}
}
auto
before_merge
=
GetCurrentUS
();
if
(
var_name
==
STEP_COUNTER
)
{
SendGlobalStep
(
merged_var_num
);
auto
after_merge
=
GetCurrentUS
();
VLOG
(
3
)
<<
"merge and send "
<<
merged_var_num
<<
" "
<<
var_name
<<
" use time "
<<
after_merge
-
before_merge
;
return
;
}
VLOG
(
3
)
<<
var_name
<<
" merge and send"
;
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
MergeVars
<
float
>
(
var_name
,
vars
,
send_scope_
.
get
(),
ctx
.
merge_add
);
auto
after_merge
=
GetCurrentUS
();
VLOG
(
3
)
<<
"merge "
<<
merged_var_num
<<
" "
<<
var_name
<<
" use time "
<<
after_merge
-
before_merge
;
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
send_functor
(
ctx
,
*
send_scope_
,
true
,
1
);
auto
after_send
=
GetCurrentUS
();
VLOG
(
3
)
<<
"send "
<<
var_name
<<
" use time "
<<
after_send
-
after_merge
;
};
task_futures
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_task
)));
}
for
(
auto
&
task_f
:
task_futures
)
{
task_f
.
wait
();
}
auto
after_run_send_graph
=
GetCurrentUS
();
VLOG
(
3
)
<<
"run send graph use time "
<<
(
after_run_send_graph
-
before_run_send_graph
);
}
void
HalfAsyncCommunicator
::
SendByCommunicator
()
{
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
task_futures
.
reserve
(
send_varname_to_ctx_
.
size
());
VLOG
(
3
)
<<
"run send graph"
;
int
batches
=
BatchesCounter
();
if
(
batches
<=
0
)
return
;
auto
before_run_send_graph
=
GetCurrentUS
();
for
(
auto
&
iter
:
send_varname_to_queue_
)
{
auto
&
var_name
=
iter
.
first
;
auto
&
var_queue
=
iter
.
second
;
auto
send_task
=
[
this
,
batches
,
&
var_name
,
&
var_queue
]
{
VLOG
(
3
)
<<
var_name
<<
" merge and send; "
;
auto
before_task
=
GetCurrentUS
();
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
vars
;
vars
.
reserve
(
batches
);
...
...
@@ -130,6 +197,14 @@ void AsyncCommunicator::SendByCommunicator(int batches) {
vars
.
push_back
(
var_queue
->
Pop
());
}
if
(
var_name
==
STEP_COUNTER
)
{
SendGlobalStep
(
batches
);
auto
end_task
=
GetCurrentUS
();
VLOG
(
3
)
<<
"merge "
<<
batches
<<
" "
<<
var_name
<<
" use time "
<<
end_task
-
before_task
;
return
;
}
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
auto
before_merge
=
GetCurrentUS
();
...
...
@@ -142,7 +217,20 @@ void AsyncCommunicator::SendByCommunicator(int batches) {
send_functor
(
ctx
,
*
send_scope_
,
true
,
1
);
auto
after_send
=
GetCurrentUS
();
VLOG
(
3
)
<<
"send "
<<
var_name
<<
" use time "
<<
after_send
-
after_merge
;
<<
after_send
-
before_task
;
if
(
var_name
.
rfind
(
"@GRAD"
)
!=
var_name
.
size
()
-
5
)
return
;
auto
recv_param
=
var_name
.
substr
(
0
,
var_name
.
size
()
-
5
);
if
(
recv_varname_to_ctx_
.
find
(
recv_param
)
==
recv_varname_to_ctx_
.
end
())
return
;
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
recv_functor
(
recv_varname_to_ctx_
.
at
(
recv_param
),
*
recv_scope_
);
auto
after_recv
=
GetCurrentUS
();
VLOG
(
3
)
<<
"recv "
<<
recv_param
<<
" use time "
<<
after_recv
-
after_send
;
return
;
};
task_futures
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_task
)));
}
...
...
@@ -152,7 +240,7 @@ void AsyncCommunicator::SendByCommunicator(int batches) {
auto
after_run_send_graph
=
GetCurrentUS
();
VLOG
(
3
)
<<
"run send graph use time "
<<
after_run_send_graph
-
before_run_send_graph
;
<<
(
after_run_send_graph
-
before_run_send_graph
)
;
}
void
AsyncCommunicator
::
MainThread
()
{
...
...
@@ -164,20 +252,28 @@ void AsyncCommunicator::MainThread() {
}
while
(
running_
)
{
int
batches
=
BatchesCounter
();
SendByCommunicator
();
BarrierSend
();
}
VLOG
(
3
)
<<
"communicator stopped, send thread exit"
;
}
if
(
batches
>
0
)
{
SendGlobalStep
(
batches
);
SendByCommunicator
(
batches
);
void
HalfAsyncCommunicator
::
MainThread
()
{
VLOG
(
3
)
<<
"MainThread start and wait"
;
while
(
waiting_
&&
running_
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
VLOG
(
3
)
<<
"wait for running"
;
}
while
(
running_
)
{
SendByCommunicator
();
BarrierSend
();
RecvByCommunicator
();
BarrierRecv
();
BarrierWeakUp
();
}
else
{
VLOG
(
1
)
<<
"get nothing from sending queue, will skip send/recv"
;
}
}
VLOG
(
1
)
<<
"communicator stopped, send thread exit"
;
VLOG
(
3
)
<<
"communicator stopped, send thread exit"
;
}
void
AsyncCommunicator
::
RecvByCommunicator
()
{
...
...
@@ -193,10 +289,13 @@ void AsyncCommunicator::RecvNoBarrier() {
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
auto
recv_task
=
[
this
,
&
iter
]
{
auto
before_task
=
GetCurrentUS
();
auto
&
var_name
=
iter
.
first
;
VLOG
(
4
)
<<
"recv var "
<<
var_name
;
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
recv_functor
(
iter
.
second
,
*
recv_scope_
);
auto
end_task
=
GetCurrentUS
();
VLOG
(
1
)
<<
"recv var "
<<
var_name
<<
" use time "
<<
(
end_task
-
before_task
);
};
task_futures
.
emplace_back
(
recv_threadpool_
->
enqueue
(
std
::
move
(
recv_task
)));
}
...
...
@@ -206,37 +305,12 @@ void AsyncCommunicator::RecvNoBarrier() {
}
}
int
AsyncCommunicator
::
BatchesCounter
()
{
auto
&
step_queue
=
send_varname_to_queue_
.
at
(
STEP_COUNTER
);
size_t
merged_var_num
=
0
;
size_t
wait_times
=
0
;
while
(
merged_var_num
<
static_cast
<
size_t
>
(
max_merge_var_num_
))
{
if
(
step_queue
->
Size
()
==
0
)
{
VLOG
(
3
)
<<
"wait_times -> "
<<
wait_times
;
if
(
wait_times
>=
static_cast
<
size_t
>
(
send_wait_times_
))
{
break
;
}
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
10
));
wait_times
++
;
continue
;
}
else
{
step_queue
->
Pop
();
wait_times
=
0
;
merged_var_num
++
;
}
}
return
merged_var_num
;
}
void
AsyncCommunicator
::
Start
()
{
VLOG
(
1
)
<<
"Communicator start"
;
VLOG
(
3
)
<<
"Communicator start"
;
if
(
!
communicator_
)
{
VLOG
(
0
)
<<
"Communicator is not inited, do nothing"
;
}
else
{
VLOG
(
1
)
<<
"start send thread and recv thread"
;
VLOG
(
3
)
<<
"start send thread and recv thread"
;
waiting_
=
true
;
running_
=
true
;
BarrierTriggerReset
(
max_merge_var_num_
);
...
...
@@ -247,18 +321,18 @@ void AsyncCommunicator::Start() {
}
void
AsyncCommunicator
::
Stop
()
{
VLOG
(
1
)
<<
"Communicator stop"
;
VLOG
(
3
)
<<
"Communicator stop"
;
running_
=
false
;
if
(
!
communicator_
)
{
VLOG
(
0
)
<<
"Communicator is not inited, do nothing"
;
}
else
{
if
(
main_thread_
)
{
VLOG
(
1
)
<<
"stop send thread"
;
VLOG
(
3
)
<<
"stop send thread"
;
main_thread_
->
join
();
main_thread_
.
reset
(
nullptr
);
}
}
VLOG
(
1
)
<<
"Communicator stop done"
;
VLOG
(
3
)
<<
"Communicator stop done"
;
}
void
AsyncCommunicator
::
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
...
...
@@ -271,6 +345,10 @@ void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
platform
::
errors
::
InvalidArgument
(
"var_tables.size() == 1 is permitted"
));
auto
table_name
=
var_tables
[
0
];
if
(
table_name
==
STEP_COUNTER
&&
!
need_global_step_
)
return
;
auto
before_send_op
=
GetCurrentUS
();
auto
&
queue
=
send_varname_to_queue_
.
at
(
table_name
);
if
(
table_name
==
STEP_COUNTER
)
{
...
...
@@ -279,7 +357,6 @@ void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
tensor
->
Resize
(
framework
::
make_ddim
({
1
}));
auto
*
out_d
=
tensor
->
mutable_data
<
int64_t
>
(
platform
::
CPUPlace
());
out_d
[
0
]
=
1
;
VLOG
(
3
)
<<
"send to "
<<
table_name
<<
" with queue size "
<<
queue
->
Size
();
queue
->
Push
(
tmp_var
);
}
else
{
PADDLE_ENFORCE_GE
(
var_names
.
size
(),
1
,
...
...
@@ -295,21 +372,20 @@ void AsyncCommunicator::Send(const std::vector<std::string> &var_names,
auto
tmp_var
=
std
::
make_shared
<
Variable
>
();
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
framework
::
CopyVariable
(
*
var
,
tmp_var
.
get
());
VLOG
(
3
)
<<
"send to "
<<
table_name
<<
" with queue size "
<<
queue
->
Size
();
queue
->
Push
(
tmp_var
);
}
else
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
// push var into send queue by var_name
auto
var_name
=
var_names
[
0
];
framework
::
CopyVariable
(
*
var
,
tmp_var
.
get
());
VLOG
(
3
)
<<
"send to "
<<
table_name
<<
" with queue size "
<<
queue
->
Size
();
queue
->
Push
(
tmp_var
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"unknown var type to copy, only support LoDTensor/SelectedRows"
));
}
}
auto
after_send_op
=
GetCurrentUS
();
VLOG
(
3
)
<<
"send to "
<<
table_name
<<
" with queue size "
<<
queue
->
Size
()
<<
", use time "
<<
(
after_send_op
-
before_send_op
);
}
void
HalfAsyncCommunicator
::
Clean
()
{
...
...
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
fbf9564f
...
...
@@ -302,16 +302,13 @@ class AsyncCommunicator : public Communicator {
const
std
::
vector
<
std
::
string
>
&
var_tables
,
const
framework
::
Scope
&
scope
)
override
;
virtual
void
SendByCommunicator
(
int
batches
);
virtual
void
SendByCommunicator
();
virtual
void
SendGlobalStep
(
int
batches
);
virtual
void
RecvByCommunicator
();
virtual
void
RecvNoBarrier
();
virtual
int
BatchesCounter
();
virtual
void
BarrierSend
()
{}
virtual
void
BarrierRecv
()
{}
...
...
@@ -359,6 +356,10 @@ class HalfAsyncCommunicator : public AsyncCommunicator {
VLOG
(
0
)
<<
"HalfAsyncCommunicator Initialized"
;
}
void
MainThread
()
override
;
void
SendByCommunicator
()
override
;
void
Clean
()
override
;
void
Barrier
()
override
;
...
...
@@ -438,7 +439,7 @@ class GeoCommunicator : public AsyncCommunicator {
const
std
::
vector
<
std
::
string
>
&
var_tables
,
const
framework
::
Scope
&
scope
)
override
;
void
SendByCommunicator
(
int
batches
)
{
return
;
}
void
SendByCommunicator
()
{
return
;
}
std
::
vector
<
int64_t
>
MergeSparseIds
(
const
std
::
string
&
send_varname
);
...
...
@@ -475,6 +476,7 @@ class GeoCommunicator : public AsyncCommunicator {
std
::
shared_ptr
<
Scope
>
pserver_scope_
;
int
send_var_nums_
=
0
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
SparseValue
>>
old_sparses_
;
std
::
unordered_map
<
...
...
python/paddle/distributed/fleet/runtime/parameter_server_runtime.py
浏览文件 @
fbf9564f
...
...
@@ -207,6 +207,7 @@ class ParameterServerRuntime(RuntimeBase):
SyncStrategy
,
GeoStrategy
trainer_config
=
self
.
async_strategy
.
get_trainer_runtime_config
()
print
(
trainer_config
)
dist_strategy
=
self
.
context
[
"valid_strategy"
]
launch_barrier
=
dist_strategy
.
a_sync_configs
[
"launch_barrier"
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录