Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
23d3929a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
23d3929a
编写于
3月 12, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize merge vars
上级
d3a14377
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
63 addition
and
22 deletion
+63
-22
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+63
-22
未找到文件。
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
23d3929a
...
@@ -18,12 +18,15 @@ limitations under the License. */
...
@@ -18,12 +18,15 @@ limitations under the License. */
#include <chrono> // NOLINT
#include <chrono> // NOLINT
#include <thread> // NOLINT
#include <thread> // NOLINT
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/device_context.h"
DEFINE_bool
(
communicator_independent_recv_thread
,
true
,
DEFINE_bool
(
communicator_independent_recv_thread
,
true
,
"use an independent to recv vars from parameter server"
);
"use an independent to recv vars from parameter server"
);
...
@@ -40,28 +43,54 @@ namespace paddle {
...
@@ -40,28 +43,54 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
distributed
{
namespace
distributed
{
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
inline
double
GetCurrentUS
()
{
struct
timeval
time
;
gettimeofday
(
&
time
,
NULL
);
return
1e+6
*
time
.
tv_sec
+
time
.
tv_usec
;
}
static
inline
void
MergeVars
(
const
std
::
string
&
var_name
,
static
inline
void
MergeVars
(
const
std
::
string
&
var_name
,
const
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
&
vars
,
const
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
&
vars
,
Scope
*
scope
)
{
Scope
*
scope
)
{
VLOG
(
3
)
<<
"merge "
<<
vars
.
size
()
<<
" vars "
<<
var_name
<<
" to 1"
;
PADDLE_ENFORCE
(
!
vars
.
empty
(),
"should have value to merge!"
);
PADDLE_ENFORCE
(
!
vars
.
empty
(),
"should have value to merge!"
);
auto
cpu_place
=
platform
::
CPUPlace
();
auto
cpu_place
=
platform
::
CPUPlace
();
auto
&
var0
=
vars
[
0
];
auto
&
var0
=
vars
[
0
];
auto
*
out_var
=
scope
->
Var
(
var_name
);
auto
*
out_var
=
scope
->
Var
(
var_name
);
if
(
var0
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
var0
->
IsType
<
framework
::
LoDTensor
>
())
{
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" LoDTensor"
<<
var0
->
Get
<
framework
::
LoDTensor
>
().
dims
();
// init output tensor
auto
*
out_t
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
out_t
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
out_ptr
=
out_t
->
mutable_data
<
float
>
(
auto
*
out_ptr
=
out_t
->
mutable_data
<
float
>
(
var0
->
Get
<
framework
::
LoDTensor
>
().
dims
(),
cpu_place
);
var0
->
Get
<
framework
::
LoDTensor
>
().
dims
(),
cpu_place
);
auto
numel
=
out_t
->
numel
();
auto
numel
=
out_t
->
numel
();
for
(
auto
i
=
0
;
i
<
numel
;
++
i
)
{
out_ptr
[
i
]
=
0
;
// check the input dims
for
(
auto
&
var
:
vars
)
{
for
(
auto
&
var
:
vars
)
{
auto
&
var_t
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
&
var_t
=
var
->
Get
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
var_t
.
numel
(),
numel
,
"should have the same dims"
);
PADDLE_ENFORCE_EQ
(
var_t
.
numel
(),
numel
,
"should have the same dims"
);
out_ptr
[
i
]
+=
var_t
.
data
<
float
>
()[
i
];
}
}
// set output tensor to 0.
auto
cpu_ctx
=
paddle
::
platform
::
CPUDeviceContext
();
math
::
SetConstant
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
constant_functor
;
constant_functor
(
cpu_ctx
,
out_t
,
static_cast
<
float
>
(
0
));
// sum all vars to out
auto
result
=
EigenVector
<
T
>::
Flatten
(
*
out_t
);
for
(
auto
&
var
:
vars
)
{
auto
&
in_t
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
in
=
EigenVector
<
float
>::
Flatten
(
in_t
);
result
.
device
(
*
cpu_ctx
.
eigen_device
())
=
result
+
in
;
}
}
}
else
if
(
var0
->
IsType
<
framework
::
SelectedRows
>
())
{
}
else
if
(
var0
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
&
slr0
=
var0
->
Get
<
framework
::
SelectedRows
>
();
auto
*
out_slr
=
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
out_slr
=
out_var
->
GetMutable
<
framework
::
SelectedRows
>
();
out_slr
->
mutable_rows
()
->
clear
();
out_slr
->
mutable_rows
()
->
clear
();
out_slr
->
mutable_value
()
->
mutable_data
<
float
>
({{}},
cpu_place
);
out_slr
->
mutable_value
()
->
mutable_data
<
float
>
({{}},
cpu_place
);
...
@@ -74,6 +103,8 @@ static inline void MergeVars(const std::string &var_name,
...
@@ -74,6 +103,8 @@ static inline void MergeVars(const std::string &var_name,
merge_add
;
merge_add
;
auto
dev_ctx
=
paddle
::
platform
::
CPUDeviceContext
();
auto
dev_ctx
=
paddle
::
platform
::
CPUDeviceContext
();
merge_add
(
dev_ctx
,
inputs
,
out_slr
,
false
);
merge_add
(
dev_ctx
,
inputs
,
out_slr
,
false
);
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" SelectedRows height: "
<<
slr0
.
height
()
<<
" dims: "
<<
slr0
.
value
().
dims
();
}
else
{
}
else
{
PADDLE_THROW
(
"unsupported var type!"
);
PADDLE_THROW
(
"unsupported var type!"
);
}
}
...
@@ -123,12 +154,13 @@ void Communicator::SendThread() {
...
@@ -123,12 +154,13 @@ void Communicator::SendThread() {
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
task_futures
.
reserve
(
send_varname_to_ctx_
.
size
());
task_futures
.
reserve
(
send_varname_to_ctx_
.
size
());
VLOG
(
3
)
<<
"run send graph"
;
VLOG
(
3
)
<<
"run send graph"
;
auto
before_run_send_graph
=
GetCurrentUS
();
for
(
auto
&
iter
:
send_varname_to_queue_
)
{
for
(
auto
&
iter
:
send_varname_to_queue_
)
{
auto
&
var_name
=
iter
.
first
;
auto
&
var_name
=
iter
.
first
;
auto
&
var_queue
=
iter
.
second
;
auto
&
var_queue
=
iter
.
second
;
if
(
var_queue
->
Size
()
>
0
)
{
if
(
var_queue
->
Size
()
>
0
)
{
auto
send_task
=
[
this
,
&
var_name
,
&
var_queue
]
{
auto
send_task
=
[
this
,
&
var_name
,
&
var_queue
]
{
VLOG
(
3
)
<<
"merge var "
<<
var_name
<<
"
and send"
;
VLOG
(
3
)
<<
var_name
<<
" merge
and send"
;
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
vars
;
std
::
vector
<
std
::
shared_ptr
<
Variable
>>
vars
;
size_t
merged_var_num
=
0
;
size_t
merged_var_num
=
0
;
while
(
var_queue
->
Size
()
>
0
&&
while
(
var_queue
->
Size
()
>
0
&&
...
@@ -136,12 +168,19 @@ void Communicator::SendThread() {
...
@@ -136,12 +168,19 @@ void Communicator::SendThread() {
vars
.
push_back
(
var_queue
->
Pop
());
vars
.
push_back
(
var_queue
->
Pop
());
merged_var_num
++
;
merged_var_num
++
;
}
}
auto
before_merge
=
GetCurrentUS
();
MergeVars
(
var_name
,
vars
,
send_scope_
.
get
());
MergeVars
(
var_name
,
vars
,
send_scope_
.
get
());
auto
after_merge
=
GetCurrentUS
();
VLOG
(
3
)
<<
"merge "
<<
var_name
<<
" use time "
<<
after_merge
-
before_merge
;
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
send_functor
=
distributed
::
ParameterSend
<
float
>
();
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
auto
&
ctx
=
send_varname_to_ctx_
.
at
(
var_name
);
if
(
!
FLAGS_communicator_fake_rpc
)
{
if
(
!
FLAGS_communicator_fake_rpc
)
{
send_functor
(
ctx
,
*
send_scope_
,
true
);
send_functor
(
ctx
,
*
send_scope_
,
true
);
}
}
auto
after_send
=
GetCurrentUS
();
VLOG
(
3
)
<<
"send "
<<
var_name
<<
" use time "
<<
after_send
-
after_merge
;
};
};
task_futures
.
emplace_back
(
task_futures
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_task
)));
send_threadpool_
->
enqueue
(
std
::
move
(
send_task
)));
...
@@ -152,7 +191,9 @@ void Communicator::SendThread() {
...
@@ -152,7 +191,9 @@ void Communicator::SendThread() {
for
(
auto
&
task_f
:
task_futures
)
{
for
(
auto
&
task_f
:
task_futures
)
{
task_f
.
wait
();
task_f
.
wait
();
}
}
VLOG
(
3
)
<<
"run send graph done"
;
auto
after_run_send_graph
=
GetCurrentUS
();
VLOG
(
3
)
<<
"run send graph use time "
<<
after_run_send_graph
-
before_run_send_graph
;
if
(
!
FLAGS_communicator_independent_recv_thread
)
{
if
(
!
FLAGS_communicator_independent_recv_thread
)
{
RecvAll
();
RecvAll
();
}
}
...
@@ -161,6 +202,7 @@ void Communicator::SendThread() {
...
@@ -161,6 +202,7 @@ void Communicator::SendThread() {
void
Communicator
::
RecvAll
()
{
void
Communicator
::
RecvAll
()
{
VLOG
(
3
)
<<
"parallel run recv graph"
;
VLOG
(
3
)
<<
"parallel run recv graph"
;
auto
before_send
=
GetCurrentUS
();
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
std
::
vector
<
std
::
future
<
void
>>
task_futures
;
task_futures
.
reserve
(
recv_varname_to_ctx_
.
size
());
task_futures
.
reserve
(
recv_varname_to_ctx_
.
size
());
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
for
(
auto
&
iter
:
recv_varname_to_ctx_
)
{
...
@@ -177,7 +219,8 @@ void Communicator::RecvAll() {
...
@@ -177,7 +219,8 @@ void Communicator::RecvAll() {
for
(
auto
&
task
:
task_futures
)
{
for
(
auto
&
task
:
task_futures
)
{
task
.
wait
();
task
.
wait
();
}
}
VLOG
(
3
)
<<
"run recv graph done"
;
auto
after_recv
=
GetCurrentUS
();
VLOG
(
3
)
<<
"run recv graph use time "
<<
after_recv
-
before_send
;
}
}
void
Communicator
::
RecvThread
()
{
void
Communicator
::
RecvThread
()
{
...
@@ -191,17 +234,15 @@ void Communicator::RecvThread() {
...
@@ -191,17 +234,15 @@ void Communicator::RecvThread() {
void
Communicator
::
Send
(
const
std
::
string
&
var_name
,
void
Communicator
::
Send
(
const
std
::
string
&
var_name
,
const
framework
::
Scope
&
scope
)
{
const
framework
::
Scope
&
scope
)
{
if
(
!
FLAGS_communicator_fake_rpc
)
{
VLOG
(
3
)
<<
"communicator send "
<<
var_name
;
VLOG
(
3
)
<<
"communicator send "
<<
var_name
;
// push var into send queue by var_name
// push var into send queue by var_name
auto
*
grad_var
=
scope
.
FindVar
(
var_name
);
auto
*
grad_var
=
scope
.
FindVar
(
var_name
);
PADDLE_ENFORCE
(
grad_var
->
IsInitialized
(),
"grad var should be inited"
);
PADDLE_ENFORCE
(
grad_var
->
IsInitialized
(),
"grad var should be inited"
);
auto
tmp_grad_var
=
std
::
make_shared
<
Variable
>
();
auto
tmp_grad_var
=
std
::
make_shared
<
Variable
>
();
framework
::
CopyVariable
(
*
grad_var
,
tmp_grad_var
.
get
());
framework
::
CopyVariable
(
*
grad_var
,
tmp_grad_var
.
get
());
auto
&
queue
=
send_varname_to_queue_
.
at
(
var_name
);
auto
&
queue
=
send_varname_to_queue_
.
at
(
var_name
);
VLOG
(
3
)
<<
"send "
<<
var_name
<<
" queue size "
<<
queue
->
Size
();
VLOG
(
3
)
<<
"send "
<<
var_name
<<
" queue size "
<<
queue
->
Size
();
queue
->
Push
(
tmp_grad_var
);
queue
->
Push
(
tmp_grad_var
);
}
}
}
Communicator
*
Communicator
::
GetInstance
()
{
return
communicator_
.
get
();
}
Communicator
*
Communicator
::
GetInstance
()
{
return
communicator_
.
get
();
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录