Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5f89ce7f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5f89ce7f
编写于
1月 27, 2019
作者:
乔
乔龙飞 Qiao Longfei
提交者:
GitHub
1月 27, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15536 from jacquesqiao/fix-prefetch-one-parameter
Fix prefetch one parameter
上级
d303270a
806658d7
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
34 addition
and
16 deletion
+34
-16
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+5
-0
paddle/fluid/operators/distributed/rpc_server.cc
paddle/fluid/operators/distributed/rpc_server.cc
+22
-15
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
+7
-1
未找到文件。
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
5f89ce7f
...
@@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
// Async
// Async
if
(
!
sync_mode_
)
{
if
(
!
sync_mode_
)
{
VLOG
(
3
)
<<
"async process var: "
<<
varname
;
VLOG
(
3
)
<<
"async process var: "
<<
varname
;
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
PADDLE_THROW
(
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE"
);
}
try
{
try
{
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
scope
);
scope
);
...
...
paddle/fluid/operators/distributed/rpc_server.cc
浏览文件 @
5f89ce7f
...
@@ -39,27 +39,33 @@ void RPCServer::SavePort() const {
...
@@ -39,27 +39,33 @@ void RPCServer::SavePort() const {
port_file
.
open
(
file_path
);
port_file
.
open
(
file_path
);
port_file
<<
selected_port_
;
port_file
<<
selected_port_
;
port_file
.
close
();
port_file
.
close
();
VLOG
(
4
)
<<
"selected port written to "
<<
file_path
;
VLOG
(
3
)
<<
"selected port written to "
<<
file_path
;
}
}
void
RPCServer
::
WaitBarrier
(
const
std
::
string
&
rpc_name
)
{
void
RPCServer
::
WaitBarrier
(
const
std
::
string
&
rpc_name
)
{
VLOG
(
3
)
<<
"WaitBarrier in: "
<<
rpc_name
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
barrier_cond_
.
wait
(
lock
,
[
this
,
&
rpc_name
]
{
barrier_cond_
.
wait
(
lock
,
[
this
,
&
rpc_name
]
{
return
((
barrier_counter_
[
rpc_name
]
==
client_num_
&&
client_num_
!=
0
)
||
return
((
barrier_counter_
[
rpc_name
]
==
client_num_
&&
client_num_
!=
0
)
||
exit_flag_
.
load
());
exit_flag_
.
load
());
});
});
VLOG
(
3
)
<<
"
batch_barrier_: "
<<
rpc_name
<<
" "
VLOG
(
3
)
<<
"
WaitBarrier out: "
<<
rpc_name
<<
barrier_counter_
[
rpc_name
];
<<
" counter: "
<<
barrier_counter_
[
rpc_name
];
}
}
void
RPCServer
::
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
)
{
void
RPCServer
::
IncreaseBatchBarrier
(
const
std
::
string
rpc_name
)
{
VLOG
(
4
)
<<
"RPCServer begin IncreaseBatchBarrier "
<<
rpc_name
;
VLOG
(
3
)
<<
"RPCServer begin IncreaseBatchBarrier "
<<
rpc_name
;
// barrier msg should make sure that it's in the right cond(send|recv)
WaitCond
(
rpc_name
);
int
b
=
0
;
int
b
=
0
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
b
=
++
barrier_counter_
[
rpc_name
];
b
=
++
barrier_counter_
[
rpc_name
];
VLOG
(
3
)
<<
rpc_name
<<
" barrier_counter: "
<<
b
;
if
(
b
>=
client_num_
)
{
if
(
b
>=
client_num_
)
{
lock
.
unlock
();
lock
.
unlock
();
VLOG
(
3
)
<<
"BatchBarrier counter reach "
<<
client_num_
<<
" for "
<<
rpc_name
;
barrier_cond_
.
notify_all
();
barrier_cond_
.
notify_all
();
lock
.
lock
();
lock
.
lock
();
}
}
...
@@ -71,7 +77,7 @@ void RPCServer::Complete() {
...
@@ -71,7 +77,7 @@ void RPCServer::Complete() {
client_num_
--
;
client_num_
--
;
need_reset_all_vars_
=
true
;
need_reset_all_vars_
=
true
;
VLOG
(
4
)
<<
"decrease client_num to: "
<<
client_num_
;
VLOG
(
3
)
<<
"decrease client_num to: "
<<
client_num_
;
if
(
cur_cond_
.
load
()
==
rpc_cond_map_
[
kRequestGet
])
{
if
(
cur_cond_
.
load
()
==
rpc_cond_map_
[
kRequestGet
])
{
barrier_counter_
[
kRequestGet
]
--
;
barrier_counter_
[
kRequestGet
]
--
;
}
}
...
@@ -105,8 +111,8 @@ void RPCServer::RegisterRPC(const std::string& rpc_name,
...
@@ -105,8 +111,8 @@ void RPCServer::RegisterRPC(const std::string& rpc_name,
static
int
cond
=
-
1
;
static
int
cond
=
-
1
;
rpc_cond_map_
[
rpc_name
]
=
++
cond
;
rpc_cond_map_
[
rpc_name
]
=
++
cond
;
VLOG
(
4
)
<<
"RegisterRPC rpc_name:"
<<
rpc_name
<<
", handler:
"
<<
handler
VLOG
(
3
)
<<
"RegisterRPC rpc_name: "
<<
rpc_name
<<
", handler:
"
<<
handler
<<
", cond:"
<<
rpc_cond_map_
[
rpc_name
];
<<
", cond:
"
<<
rpc_cond_map_
[
rpc_name
];
}
}
void
RPCServer
::
SetCond
(
const
std
::
string
&
rpc_name
)
{
void
RPCServer
::
SetCond
(
const
std
::
string
&
rpc_name
)
{
...
@@ -120,7 +126,7 @@ void RPCServer::SetCond(const std::string& rpc_name) {
...
@@ -120,7 +126,7 @@ void RPCServer::SetCond(const std::string& rpc_name) {
}
}
void
RPCServer
::
WaitCond
(
const
std
::
string
&
rpc_name
)
{
void
RPCServer
::
WaitCond
(
const
std
::
string
&
rpc_name
)
{
VLOG
(
4
)
<<
"RPCServer WaitCond
"
<<
rpc_name
;
VLOG
(
3
)
<<
"RPCServer WaitCond in
"
<<
rpc_name
;
int
cond
=
0
;
int
cond
=
0
;
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
...
@@ -130,6 +136,7 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
...
@@ -130,6 +136,7 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
rpc_cond_
.
wait
(
rpc_cond_
.
wait
(
lock
,
[
=
]
{
return
(
cur_cond_
.
load
()
==
cond
||
exit_flag_
.
load
());
});
lock
,
[
=
]
{
return
(
cur_cond_
.
load
()
==
cond
||
exit_flag_
.
load
());
});
VLOG
(
3
)
<<
"RPCServer WaitCond out "
<<
rpc_name
;
}
}
void
RPCServer
::
RegisterVar
(
const
std
::
string
&
var_name
,
void
RPCServer
::
RegisterVar
(
const
std
::
string
&
var_name
,
...
@@ -151,7 +158,7 @@ void RPCServer::RegisterVar(const std::string& var_name,
...
@@ -151,7 +158,7 @@ void RPCServer::RegisterVar(const std::string& var_name,
}
}
rpc_cond_
.
notify_all
();
rpc_cond_
.
notify_all
();
VLOG
(
4
)
<<
"RegisterVar context:"
<<
h
.
String
();
VLOG
(
3
)
<<
"RegisterVar context:"
<<
h
.
String
();
}
}
void
RPCServer
::
IncreaseVarBarrier
(
const
std
::
string
&
var_name
)
{
void
RPCServer
::
IncreaseVarBarrier
(
const
std
::
string
&
var_name
)
{
...
@@ -167,11 +174,11 @@ void RPCServer::IncreaseVarBarrier(const std::string& var_name) {
...
@@ -167,11 +174,11 @@ void RPCServer::IncreaseVarBarrier(const std::string& var_name) {
barrier_cond_
.
notify_all
();
barrier_cond_
.
notify_all
();
}
}
VLOG
(
4
)
<<
"IncreaseVarBarrier context:"
<<
h
.
String
();
VLOG
(
3
)
<<
"IncreaseVarBarrier context:"
<<
h
.
String
();
}
}
void
RPCServer
::
WaitVarBarrier
(
const
std
::
string
&
var_name
)
{
void
RPCServer
::
WaitVarBarrier
(
const
std
::
string
&
var_name
)
{
VLOG
(
4
)
<<
"Wait
Barrier var_name:"
<<
var_name
;
VLOG
(
3
)
<<
"WaitVar
Barrier var_name:"
<<
var_name
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
barrier_cond_
.
wait
(
lock
,
[
&
]()
{
barrier_cond_
.
wait
(
lock
,
[
&
]()
{
...
@@ -179,11 +186,11 @@ void RPCServer::WaitVarBarrier(const std::string& var_name) {
...
@@ -179,11 +186,11 @@ void RPCServer::WaitVarBarrier(const std::string& var_name) {
exit_flag_
.
load
());
exit_flag_
.
load
());
});
});
VLOG
(
4
)
<<
"Wait
Barrier context: "
<<
var_map_
[
var_name
].
String
();
VLOG
(
3
)
<<
"WaitVar
Barrier context: "
<<
var_map_
[
var_name
].
String
();
}
}
void
RPCServer
::
SetVarCond
(
const
std
::
string
&
var_name
)
{
void
RPCServer
::
SetVarCond
(
const
std
::
string
&
var_name
)
{
VLOG
(
4
)
<<
"SetVarCond var_name:"
<<
var_name
;
VLOG
(
3
)
<<
"SetVarCond var_name:"
<<
var_name
;
{
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
var_map_
.
find
(
var_name
)
!=
var_map_
.
end
())
{
if
(
var_map_
.
find
(
var_name
)
!=
var_map_
.
end
())
{
...
@@ -193,14 +200,14 @@ void RPCServer::SetVarCond(const std::string& var_name) {
...
@@ -193,14 +200,14 @@ void RPCServer::SetVarCond(const std::string& var_name) {
}
}
void
RPCServer
::
WaitVarCond
(
const
std
::
string
&
var_name
)
{
void
RPCServer
::
WaitVarCond
(
const
std
::
string
&
var_name
)
{
VLOG
(
4
)
<<
"WaitVarCond var_name:"
<<
var_name
;
VLOG
(
3
)
<<
"WaitVarCond var_name:"
<<
var_name
;
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
rpc_cond_
.
wait
(
lock
,
[
=
]
{
rpc_cond_
.
wait
(
lock
,
[
=
]
{
return
(
var_map_
.
find
(
var_name
)
!=
var_map_
.
end
()
||
exit_flag_
.
load
());
return
(
var_map_
.
find
(
var_name
)
!=
var_map_
.
end
()
||
exit_flag_
.
load
());
});
});
VLOG
(
4
)
<<
"WaitVarCond var_name:"
<<
var_name
<<
" end"
;
VLOG
(
3
)
<<
"WaitVarCond var_name:"
<<
var_name
<<
" end"
;
}
}
MonomerHandle
RPCServer
::
GetMonomer
(
const
std
::
string
&
var_name
)
{
MonomerHandle
RPCServer
::
GetMonomer
(
const
std
::
string
&
var_name
)
{
...
...
paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc
浏览文件 @
5f89ce7f
...
@@ -137,7 +137,9 @@ void ListenAndServOp::RunSyncLoop(
...
@@ -137,7 +137,9 @@ void ListenAndServOp::RunSyncLoop(
while
(
true
)
{
while
(
true
)
{
// Get from multiple trainers, we don't care about the order in which
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
// the gradients arrives, just add suffix 0~n and merge the gradient.
VLOG
(
3
)
<<
"wait all clients to send gradient"
;
rpc_service_
->
SetCond
(
distributed
::
kRequestSend
);
rpc_service_
->
SetCond
(
distributed
::
kRequestSend
);
VLOG
(
3
)
<<
"wait all clients to send send_barrier"
;
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestSend
);
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestSend
);
if
(
rpc_service_
->
IsExit
())
{
if
(
rpc_service_
->
IsExit
())
{
...
@@ -168,12 +170,16 @@ void ListenAndServOp::RunSyncLoop(
...
@@ -168,12 +170,16 @@ void ListenAndServOp::RunSyncLoop(
}
}
ParallelExecuteBlocks
(
parallel_blkids
,
executor
,
optimize_prepared
,
program
,
ParallelExecuteBlocks
(
parallel_blkids
,
executor
,
optimize_prepared
,
program
,
recv_scope
);
recv_scope
);
VLOG
(
2
)
<<
"run all blocks spent "
<<
GetTimestamp
()
-
ts
<<
"(ms)"
;
VLOG
(
3
)
<<
"run all blocks spent "
<<
GetTimestamp
()
-
ts
<<
"(ms)"
;
VLOG
(
3
)
<<
"ResetReceivedVars"
;
ResetReceivedVars
(
recv_scope
,
dev_ctx
,
rpc_service_
->
NeedResetAllVars
());
ResetReceivedVars
(
recv_scope
,
dev_ctx
,
rpc_service_
->
NeedResetAllVars
());
VLOG
(
3
)
<<
"wait all clients to get parameters back"
;
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
rpc_service_
->
SetCond
(
distributed
::
kRequestGet
);
VLOG
(
3
)
<<
"wait all clients to send fetch_barrier"
;
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestGet
);
rpc_service_
->
WaitBarrier
(
distributed
::
kRequestGet
);
VLOG
(
3
)
<<
"ResetBarrierCounter"
;
rpc_service_
->
ResetBarrierCounter
();
rpc_service_
->
ResetBarrierCounter
();
}
// while(true)
}
// while(true)
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录