Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
036a90f1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
036a90f1
编写于
6月 05, 2018
作者:
W
Wu Yi
提交者:
gongweibao
6月 05, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine rpc client wait sync (#11132)
上级
a3858036
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
80 addition
and
90 deletion
+80
-90
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+41
-55
paddle/fluid/operators/detail/grpc_client.h
paddle/fluid/operators/detail/grpc_client.h
+18
-5
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+1
-6
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+4
-7
paddle/fluid/operators/fetch_barrier_op.cc
paddle/fluid/operators/fetch_barrier_op.cc
+2
-2
paddle/fluid/operators/prefetch_op.cc
paddle/fluid/operators/prefetch_op.cc
+1
-1
paddle/fluid/operators/recv_op.cc
paddle/fluid/operators/recv_op.cc
+1
-1
paddle/fluid/operators/send_barrier_op.cc
paddle/fluid/operators/send_barrier_op.cc
+2
-2
paddle/fluid/operators/send_op.cc
paddle/fluid/operators/send_op.cc
+4
-4
paddle/fluid/operators/test_send_nccl_id.cc
paddle/fluid/operators/test_send_nccl_id.cc
+6
-7
未找到文件。
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
036a90f1
...
...
@@ -38,6 +38,25 @@ void RPCClient::Init() {
if
(
rpc_client_
.
get
()
==
nullptr
)
{
rpc_client_
.
reset
(
new
RPCClient
());
}
rpc_client_
->
InitEventLoop
();
}
void
RPCClient
::
InitEventLoop
()
{
// start the client process thread
// TODO(wuyi): can make this in a threadpool
client_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
RPCClient
::
Proceed
,
this
)));
}
RPCClient
::~
RPCClient
()
{
Wait
();
cq_
.
Shutdown
();
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_mutex_
);
for
(
auto
&
it
:
channels_
)
{
it
.
second
.
reset
();
}
}
client_thread_
->
join
();
}
bool
RPCClient
::
AsyncSendVariable
(
const
std
::
string
&
ep
,
...
...
@@ -204,70 +223,37 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
req_count_
++
;
}
bool
RPCClient
::
Wait
()
{
VLOG
(
3
)
<<
"RPCClient begin Wait()"
<<
" req_count_:"
<<
req_count_
;
if
(
req_count_
<=
0
)
{
return
true
;
}
const
size_t
kReqCnt
=
req_count_
;
bool
a
[
kReqCnt
];
std
::
vector
<
std
::
future
<
void
>>
waits
(
req_count_
);
std
::
mutex
mu
;
for
(
int
i
=
0
;
i
<
req_count_
;
i
++
)
{
waits
[
i
]
=
framework
::
AsyncIO
([
i
,
&
a
,
&
mu
,
this
]
{
bool
ret
=
Proceed
();
std
::
lock_guard
<
std
::
mutex
>
l
(
mu
);
a
[
i
]
=
ret
;
});
}
for
(
int
i
=
0
;
i
<
req_count_
;
i
++
)
{
waits
[
i
].
wait
();
}
int
last_req_count
=
req_count_
;
req_count_
=
0
;
for
(
int
i
=
0
;
i
<
last_req_count
;
i
++
)
{
if
(
!
a
[
i
])
{
return
false
;
}
}
return
true
;
void
RPCClient
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
sync_mutex_
);
sync_cond_
.
wait
(
lk
,
[
this
]
{
return
req_count_
==
0
;
});
}
bool
RPCClient
::
Proceed
()
{
void
*
tag
=
NULL
;
void
RPCClient
::
Proceed
()
{
void
*
tag
=
nullptr
;
bool
ok
=
false
;
// request counts.
if
(
!
cq_
.
Next
(
&
tag
,
&
ok
))
{
LOG
(
ERROR
)
<<
"Get meets CompletionQueue error"
;
return
false
;
}
GPR_ASSERT
(
ok
);
PADDLE_ENFORCE
(
tag
);
// TODO(gongwb): add more retries.
BaseProcessor
*
c
=
static_cast
<
BaseProcessor
*>
(
tag
);
if
(
!
c
->
status_
.
ok
())
{
LOG
(
ERROR
)
<<
"proc param error:"
<<
c
->
var_h_
.
String
()
<<
" grpc error:"
<<
c
->
status_
.
error_message
();
while
(
cq_
.
Next
(
&
tag
,
&
ok
))
{
BaseProcessor
*
c
=
static_cast
<
BaseProcessor
*>
(
tag
);
GPR_ASSERT
(
ok
);
PADDLE_ENFORCE
(
c
);
if
(
c
->
status_
.
ok
())
{
c
->
Process
();
}
else
{
LOG
(
ERROR
)
<<
"var: "
<<
c
->
var_h_
.
String
()
<<
" grpc error:"
<<
c
->
status_
.
error_message
();
}
delete
c
;
return
false
;
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
sync_mutex_
);
req_count_
--
;
}
sync_cond_
.
notify_all
();
}
c
->
Process
();
delete
c
;
return
true
;
}
std
::
shared_ptr
<
grpc
::
Channel
>
RPCClient
::
GetChannel
(
const
std
::
string
&
ep
)
{
// TODO(Yancey1989): make grpc client completely thread-safe
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
lock_guard
<
std
::
mutex
>
guard
(
chan_
mutex_
);
auto
it
=
channels_
.
find
(
ep
);
if
(
it
!=
channels_
.
end
())
{
return
it
->
second
;
...
...
paddle/fluid/operators/detail/grpc_client.h
浏览文件 @
036a90f1
...
...
@@ -16,15 +16,18 @@ limitations under the License. */
#include <time.h>
#include <chrono> // NOLINT
#include <chrono> // NOLINT
#include <condition_variable> // NOLINT
#include <ctime>
#include <functional>
#include <iostream>
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "grpc++/channel.h"
#include "grpc++/generic/generic_stub.h"
#include "grpc++/grpc++.h"
#include "grpc++/support/byte_buffer.h"
...
...
@@ -164,6 +167,7 @@ class FetchBarrierProcessor : public BaseProcessor {
class
RPCClient
{
public:
RPCClient
()
{}
~
RPCClient
();
static
RPCClient
*
GetInstance
();
...
...
@@ -192,19 +196,28 @@ class RPCClient {
void
AsyncSendFetchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
=
600
*
1000
);
bool
Wait
();
void
Wait
();
// InitEventLoop should only be called by Init()
void
InitEventLoop
();
private:
bool
Proceed
();
void
Proceed
();
std
::
shared_ptr
<
grpc
::
Channel
>
GetChannel
(
const
std
::
string
&
ep
);
// Init is called by GetInstance.
static
void
Init
();
private:
grpc
::
CompletionQueue
cq_
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
std
::
unique_ptr
<
std
::
thread
>
client_thread_
;
// mutex for Wait client sync
std
::
mutex
sync_mutex_
;
std
::
condition_variable
sync_cond_
;
std
::
atomic
<
int64_t
>
req_count_
{
0
};
std
::
mutex
mutex_
;
// mutex for GetChannel thread safety
std
::
mutex
chan_mutex_
;
static
std
::
unique_ptr
<
RPCClient
>
rpc_client_
;
static
std
::
once_flag
init_flag_
;
DISABLE_COPY_AND_ASSIGN
(
RPCClient
);
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
036a90f1
...
...
@@ -68,9 +68,7 @@ class RequestSend final : public RequestBase {
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestSend
()
{}
std
::
string
GetReqName
()
override
{
return
request_
->
Varname
();
}
void
Process
()
override
{
...
...
@@ -82,7 +80,6 @@ class RequestSend final : public RequestBase {
framework
::
Variable
*
outvar
=
nullptr
;
request_handler_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
);
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
...
...
@@ -125,7 +122,6 @@ class RequestGet final : public RequestBase {
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
}
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
...
...
@@ -170,10 +166,9 @@ class RequestPrefetch final : public RequestBase {
SerializeToByteBuffer
(
varname
,
outvar
,
*
request_handler_
->
dev_ctx
(),
&
reply_
);
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
status_
=
FINISH
;
}
protected:
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
036a90f1
...
...
@@ -113,10 +113,6 @@ void StartServer() {
std
::
thread
server_thread
(
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
// FIXME(gongwb): don't use hard time.
sleep
(
10
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
}
...
...
@@ -127,7 +123,7 @@ TEST(PREFETCH, CPU) {
std
::
thread
server_thread
(
StartServer
);
g_rpc_service
->
WaitServerReady
();
detail
::
RPCClient
client
;
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
()
;
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
paddle
::
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
...
...
@@ -141,8 +137,8 @@ TEST(PREFETCH, CPU) {
std
::
string
in_var_name
(
"ids"
);
std
::
string
out_var_name
(
"out"
);
client
.
AsyncPrefetchVariable
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
.
Wait
();
client
->
AsyncPrefetchVariable
(
ep
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
->
Wait
();
auto
var
=
scope
.
Var
(
out_var_name
);
auto
value
=
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
value
();
auto
ptr
=
value
.
mutable_data
<
float
>
(
place
);
...
...
@@ -152,6 +148,7 @@ TEST(PREFETCH, CPU) {
}
}
g_rpc_service
->
ShutDown
();
server_thread
.
join
();
LOG
(
INFO
)
<<
"begin reset"
;
g_rpc_service
.
reset
(
nullptr
);
...
...
paddle/fluid/operators/fetch_barrier_op.cc
浏览文件 @
036a90f1
...
...
@@ -45,13 +45,13 @@ class FetchBarrierOp : public framework::OperatorBase {
auto
rpc_client
=
detail
::
RPCClient
::
GetInstance
();
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
for
(
auto
&
ep
:
eps
)
{
VLOG
(
3
)
<<
"fetch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendFetchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
}
};
...
...
paddle/fluid/operators/prefetch_op.cc
浏览文件 @
036a90f1
...
...
@@ -53,7 +53,7 @@ class PrefetchOp : public framework::OperatorBase {
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
}
};
...
...
paddle/fluid/operators/recv_op.cc
浏览文件 @
036a90f1
...
...
@@ -51,7 +51,7 @@ class RecvOp : public framework::OperatorBase {
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
if
(
sync_mode
)
{
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
}
}
};
...
...
paddle/fluid/operators/send_barrier_op.cc
浏览文件 @
036a90f1
...
...
@@ -49,13 +49,13 @@ class SendBarrierOp : public framework::OperatorBase {
VLOG
(
3
)
<<
"SendBarrierOp sync_mode:"
<<
sync_mode
;
// need to wait before sending send_barrier message
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
if
(
sync_mode
)
{
for
(
auto
&
ep
:
eps
)
{
VLOG
(
3
)
<<
"send barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
}
}
};
...
...
paddle/fluid/operators/send_op.cc
浏览文件 @
036a90f1
...
...
@@ -59,14 +59,14 @@ class SendOp : public framework::OperatorBase {
VLOG
(
3
)
<<
"don't send no-initialied variable: "
<<
ins
[
i
];
}
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
if
(
sync_mode
)
{
for
(
auto
&
ep
:
endpoints
)
{
VLOG
(
3
)
<<
"batch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendBatchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
}
if
(
outs
.
size
()
>
0
)
{
...
...
@@ -74,13 +74,13 @@ class SendOp : public framework::OperatorBase {
VLOG
(
2
)
<<
"getting "
<<
outs
[
i
]
<<
" from "
<<
epmap
[
i
];
rpc_client
->
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
// tell pservers that current trainer have called fetch
for
(
auto
&
ep
:
endpoints
)
{
VLOG
(
2
)
<<
"send fetch barrier, ep: "
<<
ep
;
rpc_client
->
AsyncSendFetchBarrier
(
ep
);
}
PADDLE_ENFORCE
(
rpc_client
->
Wait
()
);
rpc_client
->
Wait
(
);
}
}
};
...
...
paddle/fluid/operators/test_send_nccl_id.cc
浏览文件 @
036a90f1
...
...
@@ -61,7 +61,6 @@ void StartServer() {
std
::
bind
(
&
detail
::
AsyncGRPCServer
::
StartServer
,
g_rpc_service
.
get
()));
g_rpc_service
->
SetCond
(
detail
::
kRequestSend
);
std
::
cout
<<
"before WaitFanInOfSend"
<<
std
::
endl
;
g_rpc_service
->
WaitBarrier
(
detail
::
kRequestSend
);
LOG
(
INFO
)
<<
"got nccl id and stop server..."
;
...
...
@@ -88,12 +87,12 @@ TEST(SendNcclId, GrpcServer) {
int
port
=
g_rpc_service
->
GetSelectedPort
();
std
::
string
ep
=
string
::
Sprintf
(
"127.0.0.1:%d"
,
port
);
detail
::
RPCClient
client
;
LOG
(
INFO
)
<<
"connect to server"
<<
ep
;
client
.
AsyncSendVariable
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
.
Wait
();
client
.
AsyncSendBatchBarrier
(
ep
);
client
.
Wait
();
detail
::
RPCClient
*
client
=
detail
::
RPCClient
::
GetInstance
()
;
LOG
(
INFO
)
<<
"connect to server
"
<<
ep
;
client
->
AsyncSendVariable
(
ep
,
dev_ctx
,
scope
,
NCCL_ID_VARNAME
);
client
->
Wait
();
client
->
AsyncSendBatchBarrier
(
ep
);
client
->
Wait
();
server_thread
.
join
();
g_rpc_service
.
reset
(
nullptr
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录