Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
50601501
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
50601501
编写于
3月 04, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
improve communicator
上级
c2cce6ba
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
70 addition
and
25 deletion
+70
-25
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+1
-1
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+46
-23
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+15
-1
paddle/fluid/operators/distributed/rpc_common.h
paddle/fluid/operators/distributed/rpc_common.h
+8
-0
未找到文件。
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
50601501
...
...
@@ -54,7 +54,7 @@ cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library
(
parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory
)
cc_library
(
parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory
)
cc_library
(
parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory
)
cc_library
(
communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor
)
cc_library
(
communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor
simple_threadpool
)
if
(
WITH_GPU
)
cc_test
(
collective_server_test SRCS collective_server_test.cc
DEPS sendrecvop_rpc executor
${
RPC_DEPS
}
...
...
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
50601501
...
...
@@ -25,9 +25,9 @@ namespace paddle {
namespace
operators
{
namespace
distributed
{
static
void
MergeVars
(
const
std
::
string
&
var_name
,
const
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
&
vars
,
Scope
*
scope
)
{
static
inline
void
MergeVars
(
const
std
::
string
&
var_name
,
const
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
&
vars
,
Scope
*
scope
)
{
PADDLE_ENFORCE
(
!
vars
.
empty
(),
"should have value to merge!"
);
auto
cpu_place
=
platform
::
CPUPlace
();
auto
&
var0
=
vars
[
0
];
...
...
@@ -62,31 +62,53 @@ static void MergeVars(const std::string &var_name,
}
void
Communicator
::
SendThread
()
{
for
(
auto
&
iter
:
send_varname_to_queue_
)
{
auto
&
var_name
=
iter
.
first
;
VLOG
(
3
)
<<
"merge var "
<<
var_name
<<
" and send"
;
auto
&
var_queue
=
iter
.
second
;
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
vars
;
const
size_t
max_merge_var_num
=
20
;
size_t
merged_var_num
=
0
;
while
(
var_queue
->
Size
()
>
0
&&
merged_var_num
<
max_merge_var_num
)
{
vars
.
push_back
(
var_queue
->
Pop
());
merged_var_num
++
;
while
(
running_
)
{
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
task_futures
.
reserve
(
send_varname_to_ctx_
.
size
());
for
(
auto
&
iter
:
send_varname_to_queue_
)
{
auto
send_task
=
[
this
,
&
iter
]
{
auto
&
var_name
=
iter
.
first
;
VLOG
(
3
)
<<
"merge var "
<<
var_name
<<
" and send"
;
auto
&
var_queue
=
iter
.
second
;
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
vars
;
const
size_t
max_merge_var_num
=
20
;
size_t
merged_var_num
=
0
;
while
(
var_queue
->
Size
()
>
0
&&
merged_var_num
<
max_merge_var_num
)
{
vars
.
push_back
(
var_queue
->
Pop
());
merged_var_num
++
;
}
MergeVars
(
var_name
,
vars
,
send_scope_
.
get
());
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
send_functor
(
ctx
,
*
send_scope_
,
true
);
};
task_futures
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_task
)));
}
for
(
auto
&
task_f
:
task_futures
)
{
task_f
.
wait
();
}
MergeVars
(
var_name
,
vars
,
send_scope_
.
get
());
// auto send_functor = distributed::ParameterSend<float>();
// send_functor(var_name, send_varname_to_ctx_[var_name], exe_ctx,
// send_scope_, true);
}
}
void
Communicator
::
RecvThread
()
{
// parallel run recv graph
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
auto
&
var_name
=
iter
.
first
;
VLOG
(
3
)
<<
"recv var "
<<
iter
.
first
;
// auto recv_functor = distributed::ParameterRecv<float>();
// recv_functor(var_name, iter.second, exe_ctx, recv_scope_);
while
(
running_
)
{
// parallel run recv graph
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
task_futures
.
reserve
(
recv_varname_to_ctx_
.
size
());
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
auto
recv_task
=
[
this
,
&
iter
]
{
auto
&
var_name
=
iter
.
first
;
VLOG
(
3
)
<<
"recv var "
<<
var_name
;
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
recv_functor
(
iter
.
second
,
*
recv_scope_
);
};
task_futures
.
emplace_back
(
recv_threadpool_
->
enqueue
(
std
::
move
(
recv_task
)));
}
for
(
auto
&
task
:
task_futures
)
{
task
.
wait
();
}
}
}
...
...
@@ -101,6 +123,7 @@ void Communicator::Send(const std::string &var_name,
}
void
Communicator
::
Start
()
{
running_
=
true
;
// start send and recv thread
send_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
Communicator
::
SendThread
,
this
)));
...
...
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
50601501
...
...
@@ -19,6 +19,8 @@ limitations under the License. */
#include <string>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
...
...
@@ -100,9 +102,18 @@ class Communicator {
send_varname_to_queue_
[
iter
.
first
]
=
std
::
make_shared
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>
(
10
);
}
// TODO(qiao): default 5, need to config
send_threadpool_
.
reset
(
new
::
ThreadPool
(
5
));
recv_threadpool_
.
reset
(
new
::
ThreadPool
(
5
));
}
~
Communicator
()
{}
~
Communicator
()
{
VLOG
(
3
)
<<
"~Communicator"
;
running_
=
false
;
send_thread_
->
join
();
recv_thread_
->
join
();
VLOG
(
3
)
<<
"~Communicator done"
;
}
void
Start
();
...
...
@@ -113,6 +124,7 @@ class Communicator {
void
SendThread
();
void
RecvThread
();
bool
running_
=
false
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
BlockingQueue
<
std
::
shared_ptr
<
Variable
>>>>
send_varname_to_queue_
;
...
...
@@ -122,6 +134,8 @@ class Communicator {
std
::
unique_ptr
<
std
::
thread
>
recv_thread_
;
Scope
*
recv_scope_
;
// should be global scope
std
::
unique_ptr
<
Scope
>
send_scope_
;
// an independent scope
std
::
unique_ptr
<::
ThreadPool
>
send_threadpool_
{
nullptr
};
std
::
unique_ptr
<::
ThreadPool
>
recv_threadpool_
{
nullptr
};
};
}
// namespace distributed
...
...
paddle/fluid/operators/distributed/rpc_common.h
浏览文件 @
50601501
...
...
@@ -29,6 +29,14 @@ struct RpcContext {
splited_var_names
(
names
),
epmap
(
emap
),
height_sections
(
sections
)
{}
RpcContext
(
const
RpcContext
&
ctx
)
{
var_name
=
ctx
.
var_name
;
splited_var_names
=
ctx
.
splited_var_names
;
epmap
=
ctx
.
epmap
;
height_sections
=
ctx
.
height_sections
;
}
std
::
string
var_name
;
std
::
vector
<
std
::
string
>
splited_var_names
;
std
::
vector
<
std
::
string
>
epmap
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录