Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b20f528b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b20f528b
编写于
9月 22, 2020
作者:
M
malin10
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test=develop, bug fix
上级
b3526fb4
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
99 addition
and
77 deletion
+99
-77
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+76
-56
paddle/fluid/operators/distributed/communicator.h
paddle/fluid/operators/distributed/communicator.h
+2
-1
python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py
.../fluid/incubate/fleet/parameter_server/ir/trainer_pass.py
+21
-20
未找到文件。
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
b20f528b
...
@@ -452,6 +452,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
...
@@ -452,6 +452,7 @@ void GeoCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
void
GeoCommunicator
::
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
void
GeoCommunicator
::
Send
(
const
std
::
vector
<
std
::
string
>
&
var_names
,
const
std
::
vector
<
std
::
string
>
&
var_tables
,
const
std
::
vector
<
std
::
string
>
&
var_tables
,
const
framework
::
Scope
&
scope
)
{
const
framework
::
Scope
&
scope
)
{
return
;
waiting_
=
false
;
waiting_
=
false
;
// PADDLE_ENFORCE_EQ(
// PADDLE_ENFORCE_EQ(
...
@@ -475,6 +476,7 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
...
@@ -475,6 +476,7 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
// << queue->Size();
// << queue->Size();
// queue->Push(tmp_var);
// queue->Push(tmp_var);
}
else
{
}
else
{
auto
p1
=
GetCurrentUS
();
auto
splited_var_nums
=
auto
splited_var_nums
=
recv_varname_to_ctx_
[
table_name
].
splited_varnames
.
size
();
recv_varname_to_ctx_
[
table_name
].
splited_varnames
.
size
();
if
(
ids_table
->
find
(
table_name
)
==
ids_table
->
end
())
{
if
(
ids_table
->
find
(
table_name
)
==
ids_table
->
end
())
{
...
@@ -484,18 +486,28 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
...
@@ -484,18 +486,28 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
table_name
,
table_name
,
std
::
vector
<
std
::
unordered_set
<
int64_t
>>
{
splited_var_nums
}));
std
::
vector
<
std
::
unordered_set
<
int64_t
>>
{
splited_var_nums
}));
}
}
auto
p2
=
GetCurrentUS
();
auto
*
var
=
scope
.
FindVar
(
var_names
[
i
]);
auto
*
var
=
scope
.
FindVar
(
var_names
[
i
]);
auto
&
rows
=
var
->
Get
<
framework
::
SelectedRows
>
().
rows
();
auto
var_tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
// split rows index into output sparse vars
int
element_number
=
var_tensor
.
numel
();
for
(
size_t
i
=
0
;
i
<
rows
.
size
();
++
i
)
{
int
*
var_mutable_data
=
var_tensor
.
mutable_data
<
int
>
(
var_tensor
.
place
());
auto
ep_idx
=
rows
[
i
]
%
splited_var_nums
;
auto
p3
=
GetCurrentUS
();
ids_table
->
at
(
table_name
)[
ep_idx
].
insert
(
rows
[
i
]);
// insert ids which has not been record
for
(
int
j
=
0
;
j
<
element_number
;
j
++
)
{
auto
ep_idx
=
var_mutable_data
[
j
]
%
splited_var_nums
;
ids_table
->
at
(
table_name
)[
ep_idx
].
insert
(
var_mutable_data
[
j
]);
}
}
auto
p4
=
GetCurrentUS
();
VLOG
(
1
)
<<
"table_name: "
<<
table_name
<<
"; p1-2: "
<<
(
p2
-
p1
)
<<
"; p2-3: "
<<
(
p3
-
p2
)
<<
"; p3-4: "
<<
(
p4
-
p3
);
}
}
}
}
auto
before_push
=
GetCurrentUS
();
need_push_queue_
->
Push
(
ids_table
);
need_push_queue_
->
Push
(
ids_table
);
auto
after_send
=
GetCurrentUS
();
auto
after_send
=
GetCurrentUS
();
VLOG
(
0
)
<<
"run send_op finish. using "
<<
(
after_send
-
before_send
);
VLOG
(
1
)
<<
"run send_op finish. using "
<<
(
before_push
-
before_send
)
<<
"; "
<<
(
after_send
-
before_push
);
}
}
void
GeoCommunicator
::
MainThread
()
{
void
GeoCommunicator
::
MainThread
()
{
...
@@ -532,15 +544,15 @@ void GeoCommunicator::MainThread() {
...
@@ -532,15 +544,15 @@ void GeoCommunicator::MainThread() {
if
(
ids_send_vec_
.
size
()
>=
static_cast
<
size_t
>
(
max_merge_var_num_
))
{
if
(
ids_send_vec_
.
size
()
>=
static_cast
<
size_t
>
(
max_merge_var_num_
))
{
auto
before_send_global_step
=
GetCurrentUS
();
auto
before_send_global_step
=
GetCurrentUS
();
VLOG
(
0
)
<<
"finish ins_send_vec using time "
VLOG
(
1
)
<<
"finish ins_send_vec using time "
<<
(
before_send_global_step
-
before_send_by_communicator
)
<<
(
before_send_global_step
-
before_send_by_communicator
)
<<
"; send_var_nums_ = "
<<
send_var_nums_
;
<<
"; send_var_nums_ = "
<<
send_var_nums_
;
SendGlobalStep
(
max_merge_var_num_
);
SendGlobalStep
(
max_merge_var_num_
);
auto
after_send_global_step
=
GetCurrentUS
();
auto
after_send_global_step
=
GetCurrentUS
();
VLOG
(
0
)
<<
"finish send global_step using "
VLOG
(
1
)
<<
"finish send global_step using "
<<
(
after_send_global_step
-
before_send_global_step
);
<<
(
after_send_global_step
-
before_send_global_step
);
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
for
(
auto
&
iter
:
send_varname_to_ctx_
)
{
VLOG
(
1
)
<<
"debug "
<<
iter
.
first
;
VLOG
(
2
)
<<
"debug "
<<
iter
.
first
;
auto
&
var_name
=
iter
.
first
;
auto
&
var_name
=
iter
.
first
;
auto
&
send_ctx
=
iter
.
second
;
auto
&
send_ctx
=
iter
.
second
;
int
pserver_num
=
static_cast
<
int
>
(
send_ctx
.
epmap
.
size
());
int
pserver_num
=
static_cast
<
int
>
(
send_ctx
.
epmap
.
size
());
...
@@ -556,11 +568,11 @@ void GeoCommunicator::MainThread() {
...
@@ -556,11 +568,11 @@ void GeoCommunicator::MainThread() {
if
(
var_name
==
STEP_COUNTER
)
{
if
(
var_name
==
STEP_COUNTER
)
{
return
;
return
;
}
}
SendSparse
(
var_name
,
ep_idx
);
// SendSparse(var_name, ep_idx, ids_send_vec_
);
auto
after_send_sparse
=
GetCurrentUS
();
auto
after_send_sparse
=
GetCurrentUS
();
RecvSparse
(
var_name
,
ep_idx
);
RecvSparse
(
var_name
,
ep_idx
);
auto
after_recv_sparse
=
GetCurrentUS
();
auto
after_recv_sparse
=
GetCurrentUS
();
VLOG
(
0
)
VLOG
(
1
)
<<
"send recv "
<<
"send recv "
<<
send_varname_to_ctx_
.
at
(
var_name
).
splited_varnames
[
ep_idx
]
<<
send_varname_to_ctx_
.
at
(
var_name
).
splited_varnames
[
ep_idx
]
<<
" finish, using "
<<
" finish, using "
...
@@ -596,57 +608,60 @@ void GeoCommunicator::MainThread() {
...
@@ -596,57 +608,60 @@ void GeoCommunicator::MainThread() {
ids_send_vec_
.
clear
();
ids_send_vec_
.
clear
();
auto
finish_one_comm
=
GetCurrentUS
();
auto
finish_one_comm
=
GetCurrentUS
();
VLOG
(
0
)
<<
"Finish SendByCommunicator "
VLOG
(
1
)
<<
"Finish SendByCommunicator "
<<
(
finish_one_comm
-
after_send_global_step
);
<<
(
finish_one_comm
-
after_send_global_step
);
}
}
}
}
}
}
void
GeoCommunicator
::
SendSparse
(
const
std
::
string
&
varname
,
int
ep_idx
)
{
void
GeoCommunicator
::
SendSparse
(
std
::
vector
<
int64_t
>
ids
;
const
std
::
string
&
varname
,
int
ep_idx
,
const
std
::
vector
<
SparseIdsMap
>
&
ids_send_vec
)
{
std
::
unordered_set
<
int64_t
>
ids_set
;
auto
debug1
=
GetCurrentUS
();
auto
&
rpc_ctx
=
send_varname_to_ctx_
.
at
(
varname
);
auto
&
rpc_ctx
=
send_varname_to_ctx_
.
at
(
varname
);
VLOG
(
1
)
<<
rpc_ctx
.
print
();
VLOG
(
2
)
<<
rpc_ctx
.
print
();
auto
send_varname
=
rpc_ctx
.
splited_varnames
[
ep_idx
];
auto
send_varname
=
rpc_ctx
.
splited_varnames
[
ep_idx
];
auto
trainer_id
=
rpc_ctx
.
trainer_id
;
auto
trainer_id
=
rpc_ctx
.
trainer_id
;
auto
endpoint
=
rpc_ctx
.
epmap
[
ep_idx
];
auto
endpoint
=
rpc_ctx
.
epmap
[
ep_idx
];
auto
pserver_num
=
rpc_ctx
.
epmap
.
size
();
auto
pserver_num
=
rpc_ctx
.
epmap
.
size
();
for
(
auto
ids_map
:
ids_send_vec_
)
{
int64_t
vector_size
=
0
;
std
::
copy
(
ids_map
[
varname
][
ep_idx
].
begin
(),
ids_map
[
varname
][
ep_idx
].
end
(),
for
(
auto
ids_map
:
ids_send_vec
)
{
back_inserter
(
ids
));
for
(
auto
id
:
ids_map
[
varname
][
ep_idx
])
{
ids_set
.
insert
(
id
);
vector_size
+=
1
;
if
(
vector_size
>
10
)
{
break
;
}
}
if
(
vector_size
>
10
)
{
break
;
}
}
}
VLOG
(
1
)
<<
"ids_vector_size: "
<<
ids
.
size
();
auto
size
=
ids
.
size
();
std
::
set
<
int64_t
>
st
(
ids
.
begin
(),
ids
.
end
());
auto
debug2
=
GetCurrentUS
();
ids
.
assign
(
st
.
begin
(),
st
.
end
());
VLOG
(
1
)
<<
"vector_size: "
<<
vector_size
<<
"; ids_set_size: "
<<
ids_set
.
size
()
<<
"; using time "
<<
(
debug2
-
debug1
);
std
::
stringstream
list_str
;
auto
size
=
ids_set
.
size
();
for
(
uint64_t
i
=
0
;
i
<
ids
.
size
();
i
++
)
{
list_str
<<
ids
[
i
]
<<
","
;
}
VLOG
(
1
)
<<
"SendSparse receive var: "
<<
send_varname
<<
" unset: "
<<
size
<<
" set: "
<<
ids
.
size
()
<<
": "
<<
list_str
.
str
();
if
(
ids
.
empty
()
)
{
if
(
size
==
0
)
{
LOG
(
WARNING
)
<<
"WARNING: GEO has nothing to send, return directly "
;
LOG
(
WARNING
)
<<
"WARNING: GEO has nothing to send, return directly "
;
return
;
return
;
}
}
std
::
vector
<
size_t
>
outs_rows_idx
;
std
::
vector
<
int64_t
>
new_rows
;
new_rows
.
insert
(
new_rows
.
begin
(),
ids_set
.
begin
(),
ids_set
.
end
());
if
(
!
rpc_ctx
.
is_distributed
)
{
// std::stringstream list_str;
for
(
size_t
i
=
0
;
i
<
ids
.
size
();
++
i
)
{
// for (uint64_t i = 0; i < ids.size(); i++) {
auto
id
=
ids
[
i
]
/
pserver_num
;
// list_str << ids[i] << ",";
outs_rows_idx
.
push_back
(
id
);
// }
}
auto
debug3
=
GetCurrentUS
();
}
else
{
VLOG
(
1
)
<<
"SendSparse receive var: "
<<
send_varname
for
(
size_t
i
=
0
;
i
<
ids
.
size
();
++
i
)
{
<<
" set: "
<<
ids_set
.
size
()
<<
", using time "
<<
(
debug3
-
debug1
);
outs_rows_idx
.
push_back
(
ids
[
i
]);
}
}
auto
*
var_latest
=
recv_scope_
->
FindVar
(
varname
);
auto
*
var_latest
=
recv_scope_
->
FindVar
(
varname
);
...
@@ -661,30 +676,35 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
...
@@ -661,30 +676,35 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
auto
*
var_delta
=
delta_scope_
->
Var
(
send_varname
);
auto
*
var_delta
=
delta_scope_
->
Var
(
send_varname
);
auto
*
t_delta
=
var_delta
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
*
t_delta
=
var_delta
->
GetMutable
<
framework
::
SelectedRows
>
();
t_delta
->
set_height
(
rpc_ctx
.
height_sections
[
ep_idx
]);
t_delta
->
mutable_rows
()
->
assign
(
outs_rows_idx
.
begin
(),
outs_rows_idx
.
end
());
auto
*
t_value
=
t_delta
->
mutable_value
();
auto
*
t_value
=
t_delta
->
mutable_value
();
t_value
->
mutable_data
<
float
>
(
t_value
->
mutable_data
<
float
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
id
s
.
size
()),
dims1
}),
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
new_row
s
.
size
()),
dims1
}),
cpu_ctx
.
GetPlace
());
cpu_ctx
.
GetPlace
());
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>
*>>
values
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
float
>
*>>
values
;
auto
*
ins
=
distributed
::
LargeScaleKV
::
GetInstance
();
auto
*
ins
=
distributed
::
LargeScaleKV
::
GetInstance
();
ins
->
Get
(
varname
)
->
Get
(
id
s
,
{
"Param"
},
&
values
);
ins
->
Get
(
varname
)
->
Get
(
new_row
s
,
{
"Param"
},
&
values
);
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
float
>
(
cpu_ctx
);
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
float
>
(
cpu_ctx
);
float
coefficient
=
1.0
/
static_cast
<
float
>
(
trainers_
);
float
coefficient
=
1.0
/
static_cast
<
float
>
(
trainers_
);
for
(
auto
j
=
0
;
j
<
static_cast
<
int
>
(
id
s
.
size
());
++
j
)
{
for
(
auto
j
=
0
;
j
<
static_cast
<
int
>
(
new_row
s
.
size
());
++
j
)
{
blas
.
VSUB
(
dims1
,
t_latest
.
data
<
float
>
()
+
id
s
[
j
]
*
dims1
,
blas
.
VSUB
(
dims1
,
t_latest
.
data
<
float
>
()
+
new_row
s
[
j
]
*
dims1
,
values
[
j
][
0
]
->
data
(),
t_value
->
data
<
float
>
()
+
j
*
dims1
);
values
[
j
][
0
]
->
data
(),
t_value
->
data
<
float
>
()
+
j
*
dims1
);
blas
.
SCAL
(
dims1
,
coefficient
,
t_value
->
data
<
float
>
()
+
j
*
dims1
);
blas
.
SCAL
(
dims1
,
coefficient
,
t_value
->
data
<
float
>
()
+
j
*
dims1
);
blas
.
VADD
(
dims1
,
values
[
j
][
0
]
->
data
(),
t_value
->
data
<
float
>
()
+
j
*
dims1
,
blas
.
VADD
(
dims1
,
values
[
j
][
0
]
->
data
(),
t_value
->
data
<
float
>
()
+
j
*
dims1
,
values
[
j
][
0
]
->
data
());
values
[
j
][
0
]
->
data
());
}
}
VLOG
(
1
)
<<
"begin to real send "
<<
send_varname
;
std
::
vector
<
int64_t
>
send_rows
;
send_rows
.
reserve
(
new_rows
.
size
());
for
(
auto
idx
:
new_rows
)
{
send_rows
.
push_back
(
idx
/
pserver_num
);
}
t_delta
->
set_height
(
rpc_ctx
.
height_sections
[
ep_idx
]);
t_delta
->
set_rows
(
send_rows
);
VLOG
(
2
)
<<
"begin to real send "
<<
send_varname
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
cpu_ctx_send
=
*
pool
.
Get
(
platform
::
CPUPlace
());
auto
&
cpu_ctx_send
=
*
pool
.
Get
(
platform
::
CPUPlace
());
distributed
::
RPCClient
*
rpc_client
=
distributed
::
RPCClient
*
rpc_client
=
...
@@ -692,9 +712,9 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
...
@@ -692,9 +712,9 @@ void GeoCommunicator::SendSparse(const std::string &varname, int ep_idx) {
auto
ret
=
rpc_client
->
AsyncSendVar
(
endpoint
,
cpu_ctx_send
,
auto
ret
=
rpc_client
->
AsyncSendVar
(
endpoint
,
cpu_ctx_send
,
*
delta_scope_
.
get
(),
send_varname
);
*
delta_scope_
.
get
(),
send_varname
);
VLOG
(
1
)
<<
"need to wait for send "
<<
send_varname
;
VLOG
(
2
)
<<
"need to wait for send "
<<
send_varname
;
ret
->
Wait
();
ret
->
Wait
();
VLOG
(
1
)
<<
"finish to send "
<<
send_varname
;
VLOG
(
2
)
<<
"finish to send "
<<
send_varname
;
}
}
void
GeoCommunicator
::
SendDense
(
const
std
::
string
&
varname
)
{
void
GeoCommunicator
::
SendDense
(
const
std
::
string
&
varname
)
{
...
@@ -740,7 +760,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
...
@@ -740,7 +760,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
recv_varname_to_ctx_
.
at
(
varname
).
splited_varnames
[
ep_idx
];
recv_varname_to_ctx_
.
at
(
varname
).
splited_varnames
[
ep_idx
];
auto
pserver_num
=
recv_varname_to_ctx_
.
at
(
varname
).
epmap
.
size
();
auto
pserver_num
=
recv_varname_to_ctx_
.
at
(
varname
).
epmap
.
size
();
VLOG
(
1
)
<<
"Begin to RecvSparse receive var: "
<<
splited_var_name
;
VLOG
(
2
)
<<
"Begin to RecvSparse receive var: "
<<
splited_var_name
;
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
cpu_ctx_recv
=
*
pool
.
Get
(
platform
::
CPUPlace
());
auto
&
cpu_ctx_recv
=
*
pool
.
Get
(
platform
::
CPUPlace
());
...
@@ -753,7 +773,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
...
@@ -753,7 +773,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
splited_var_name
,
splited_var_name
);
splited_var_name
,
splited_var_name
);
handle
->
Wait
();
handle
->
Wait
();
VLOG
(
1
)
<<
"Finish to RecvSparse receive var: "
<<
splited_var_name
;
VLOG
(
2
)
<<
"Finish to RecvSparse receive var: "
<<
splited_var_name
;
auto
*
var_latest
=
recv_scope_
->
FindVar
(
varname
);
auto
*
var_latest
=
recv_scope_
->
FindVar
(
varname
);
...
@@ -766,7 +786,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
...
@@ -766,7 +786,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
ids
.
assign
(
var_psrever
->
Get
<
framework
::
SelectedRows
>
().
rows
().
begin
(),
ids
.
assign
(
var_psrever
->
Get
<
framework
::
SelectedRows
>
().
rows
().
begin
(),
var_psrever
->
Get
<
framework
::
SelectedRows
>
().
rows
().
end
());
var_psrever
->
Get
<
framework
::
SelectedRows
>
().
rows
().
end
());
VLOG
(
1
)
<<
"RecvSparse receive var: "
<<
splited_var_name
VLOG
(
2
)
<<
"RecvSparse receive var: "
<<
splited_var_name
<<
" ids Size: "
<<
ids
.
size
();
<<
" ids Size: "
<<
ids
.
size
();
auto
t_psrever
=
var_psrever
->
Get
<
framework
::
SelectedRows
>
().
value
();
auto
t_psrever
=
var_psrever
->
Get
<
framework
::
SelectedRows
>
().
value
();
...
@@ -796,7 +816,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
...
@@ -796,7 +816,7 @@ void GeoCommunicator::RecvSparse(const std::string &varname, int ep_idx) {
blas
.
VCOPY
(
dims1
,
t_psrever
.
data
<
float
>
()
+
j
*
dims1
,
blas
.
VCOPY
(
dims1
,
t_psrever
.
data
<
float
>
()
+
j
*
dims1
,
old_values
[
j
][
0
]
->
data
());
old_values
[
j
][
0
]
->
data
());
}
}
VLOG
(
1
)
<<
"receive finish"
;
VLOG
(
2
)
<<
"receive finish"
;
}
}
void
GeoCommunicator
::
RecvDense
(
const
std
::
string
&
varname
)
{
void
GeoCommunicator
::
RecvDense
(
const
std
::
string
&
varname
)
{
...
...
paddle/fluid/operators/distributed/communicator.h
浏览文件 @
b20f528b
...
@@ -426,7 +426,8 @@ class GeoCommunicator : public AsyncCommunicator {
...
@@ -426,7 +426,8 @@ class GeoCommunicator : public AsyncCommunicator {
// void SendByCommunicator(int batches) override;
// void SendByCommunicator(int batches) override;
void
SendSparse
(
const
std
::
string
&
varname
,
int
ep_idx
);
void
SendSparse
(
const
std
::
string
&
varname
,
int
ep_idx
,
const
std
::
vector
<
SparseIdsMap
>
&
ids_send_vec
);
void
SendDense
(
const
std
::
string
&
varname
);
void
SendDense
(
const
std
::
string
&
varname
);
...
...
python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py
浏览文件 @
b20f528b
...
@@ -169,16 +169,24 @@ def append_send_ops_pass(program, config, merge=False):
...
@@ -169,16 +169,24 @@ def append_send_ops_pass(program, config, merge=False):
trainer_id
=
config
.
get_role_id
()
trainer_id
=
config
.
get_role_id
()
pserver_endpoints
=
config
.
get_ps_endpoints
()
pserver_endpoints
=
config
.
get_ps_endpoints
()
def
_append_send_op
(
union_vars
,
queue
):
def
_append_send_op
():
send_input_vars
=
[]
sparse_var
=
[]
assert
(
len
(
queue
)
==
len
(
union_vars
))
sparse_tables
=
[]
for
i
in
range
(
len
(
queue
)):
unique_sparse_var
=
{}
if
queue
[
i
]
==
STEP_COUNTER
:
for
op
in
program
.
global_block
().
ops
:
send_input_vars
.
append
(
""
)
if
"is_sparse"
in
op
.
all_attrs
():
else
:
if
op
.
type
==
"lookup_table"
:
send_input_vars
.
append
(
program
.
global_block
().
vars
[
union_vars
[
op
.
_set_attr
(
'remote_prefetch'
,
False
)
i
]])
for
input_var_name
,
sparse_var_name
in
zip
(
op
.
input
(
"Ids"
),
op
.
input
(
"W"
)):
if
input_var_name
in
unique_sparse_var
:
if
unique_sparse_var
[
input_var_name
]
==
sparse_var_name
:
continue
input_var
=
program
.
global_block
().
var
(
input_var_name
)
sparse_var
.
append
(
input_var
)
sparse_tables
.
append
(
sparse_var_name
)
unique_sparse_var
[
input_var_name
]
=
sparse_var_name
dummy_output
=
[]
dummy_output
=
[]
if
mode
in
[
DistributedMode
.
SYNC
,
DistributedMode
.
HALF_ASYNC
]:
if
mode
in
[
DistributedMode
.
SYNC
,
DistributedMode
.
HALF_ASYNC
]:
...
@@ -187,10 +195,10 @@ def append_send_ops_pass(program, config, merge=False):
...
@@ -187,10 +195,10 @@ def append_send_ops_pass(program, config, merge=False):
program
.
global_block
().
append_op
(
program
.
global_block
().
append_op
(
type
=
"send"
,
type
=
"send"
,
inputs
=
{
"X"
:
s
end_input_vars
},
inputs
=
{
"X"
:
s
parse_var
},
outputs
=
{
"Out"
:
dummy_output
},
outputs
=
{
"Out"
:
dummy_output
},
attrs
=
{
attrs
=
{
"send_varnames"
:
queue
,
"send_varnames"
:
sparse_tables
,
"merge_add"
:
True
,
"merge_add"
:
True
,
"use_send_handler"
:
False
,
"use_send_handler"
:
False
,
"endpoints"
:
pserver_endpoints
,
"endpoints"
:
pserver_endpoints
,
...
@@ -216,17 +224,10 @@ def append_send_ops_pass(program, config, merge=False):
...
@@ -216,17 +224,10 @@ def append_send_ops_pass(program, config, merge=False):
sends
=
config
.
get_trainer_send_context
()
sends
=
config
.
get_trainer_send_context
()
if
merge
:
if
merge
:
origin_varnames
=
[]
dummys
.
append
(
_append_send_op
())
merged_names
=
[]
for
merged_name
,
send
in
sends
.
items
():
for
var
in
send
.
origin_varnames
():
origin_varnames
.
append
(
var
)
merged_names
.
append
(
merged_name
)
if
len
(
origin_varnames
)
>
0
:
dummys
.
append
(
_append_send_op
(
origin_varnames
,
merged_names
))
else
:
else
:
for
merged_name
,
send
in
sends
.
items
():
for
merged_name
,
send
in
sends
.
items
():
dummys
.
append
(
_append_send_op
(
send
.
origin_varnames
(),
merged_name
))
dummys
.
append
(
_append_send_op
())
if
mode
in
[
DistributedMode
.
SYNC
,
DistributedMode
.
HALF_ASYNC
]:
if
mode
in
[
DistributedMode
.
SYNC
,
DistributedMode
.
HALF_ASYNC
]:
_append_barrier_op
(
dummys
)
_append_barrier_op
(
dummys
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录