Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
535fefb7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
535fefb7
编写于
1月 15, 2018
作者:
G
gongweibao
提交者:
GitHub
1月 15, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix grpc bugs (#7435)
Fix grpc bugs
上级
448fee3d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
35 addition
and
25 deletion
+35
-25
cmake/external/grpc.cmake
cmake/external/grpc.cmake
+1
-1
paddle/operators/detail/grpc_client.cc
paddle/operators/detail/grpc_client.cc
+11
-5
paddle/operators/detail/grpc_client.h
paddle/operators/detail/grpc_client.h
+1
-1
paddle/operators/detail/grpc_server.cc
paddle/operators/detail/grpc_server.cc
+19
-16
paddle/operators/detail/grpc_server.h
paddle/operators/detail/grpc_server.h
+0
-1
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+2
-0
paddle/operators/send_op.cc
paddle/operators/send_op.cc
+1
-1
未找到文件。
cmake/external/grpc.cmake
浏览文件 @
535fefb7
...
...
@@ -33,7 +33,7 @@ ExternalProject_Add(
extern_grpc
DEPENDS protobuf zlib
GIT_REPOSITORY
"https://github.com/grpc/grpc.git"
GIT_TAG
"v1.
7
.x"
GIT_TAG
"v1.
8
.x"
PREFIX
${
GRPC_SOURCES_DIR
}
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
...
...
paddle/operators/detail/grpc_client.cc
浏览文件 @
535fefb7
...
...
@@ -87,7 +87,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return
true
;
}
bool
RPCClient
::
w
ait
()
{
bool
RPCClient
::
W
ait
()
{
bool
ok
=
true
;
while
(
true
)
{
...
...
@@ -96,7 +96,6 @@ bool RPCClient::wait() {
}
if
(
!
Proceed
())
{
LOG
(
ERROR
)
<<
"Get meets CompletionQueue error"
;
return
false
;
}
}
...
...
@@ -110,9 +109,9 @@ bool RPCClient::Proceed() {
// request counts.
if
(
!
cq_
.
Next
(
&
tag
,
&
ok
))
{
LOG
(
ERROR
)
<<
"Get meets CompletionQueue error"
;
return
false
;
}
req_count_
--
;
GPR_ASSERT
(
ok
);
PADDLE_ENFORCE
(
tag
);
...
...
@@ -120,12 +119,15 @@ bool RPCClient::Proceed() {
// TODO(gongwb): add more retries.
ClientBase
*
c
=
static_cast
<
ClientBase
*>
(
tag
);
if
(
!
c
->
status_
.
ok
())
{
LOG
(
ERROR
)
<<
"proc param error:"
<<
c
->
var_h_
.
String
()
<<
" grpc error:"
<<
c
->
status_
.
error_message
();
delete
c
;
return
tru
e
;
return
fals
e
;
}
c
->
Process
();
delete
c
;
req_count_
--
;
return
true
;
}
...
...
@@ -135,8 +137,12 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
return
it
->
second
;
}
grpc
::
ChannelArguments
args
;
args
.
SetMaxSendMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
args
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
auto
ch
=
std
::
shared_ptr
<
grpc
::
Channel
>
(
grpc
::
CreateC
hannel
(
ep
,
grpc
::
InsecureChannelCredentials
()
));
grpc
::
CreateC
ustomChannel
(
ep
,
grpc
::
InsecureChannelCredentials
(),
args
));
channels_
[
ep
]
=
ch
;
return
ch
;
...
...
paddle/operators/detail/grpc_client.h
浏览文件 @
535fefb7
...
...
@@ -130,7 +130,7 @@ class RPCClient {
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
600
*
1000
);
bool
w
ait
();
bool
W
ait
();
private:
bool
Proceed
();
...
...
paddle/operators/detail/grpc_server.cc
浏览文件 @
535fefb7
...
...
@@ -28,12 +28,15 @@ class RequestBase {
public:
explicit
RequestBase
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
grpc
::
ServerCompletionQueue
*
cq
)
:
service_
(
service
),
cq_
(
cq
),
status_
(
PROCESS
)
{}
:
service_
(
service
),
cq_
(
cq
),
status_
(
PROCESS
)
{
PADDLE_ENFORCE
(
cq_
);
}
virtual
~
RequestBase
()
{}
virtual
void
Process
()
{
assert
(
false
);
}
CallStatus
Status
()
{
return
status_
;
}
void
SetStatus
(
CallStatus
status
)
{
status_
=
status
;
}
virtual
std
::
string
GetReqName
()
{
assert
(
false
);
}
protected:
grpc
::
ServerContext
ctx_
;
...
...
@@ -56,12 +59,14 @@ class RequestSend final : public RequestBase {
virtual
~
RequestSend
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
.
varname
();
}
virtual
void
Process
()
{
MessageWithName
msg_with_name
=
std
::
make_pair
(
request_
.
varname
(),
std
::
move
(
request_
));
queue_
->
Push
(
std
::
move
(
msg_with_name
));
// TODO(gongwb): check var's info.
responder_
.
Finish
(
reply_
,
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
}
protected:
...
...
@@ -81,6 +86,8 @@ class RequestGet final : public RequestBase {
virtual
~
RequestGet
()
{}
virtual
std
::
string
GetReqName
()
{
return
request_
.
varname
();
}
virtual
void
Process
()
{
// proc request.
std
::
string
var_name
=
request_
.
varname
();
...
...
@@ -88,6 +95,7 @@ class RequestGet final : public RequestBase {
SerializeToMessage
(
var_name
,
var
,
platform
::
CPUDeviceContext
(),
&
reply_
);
// TODO(gongwb): check var's info.
responder_
.
Finish
(
reply_
,
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
}
protected:
...
...
@@ -100,6 +108,8 @@ class RequestGet final : public RequestBase {
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
grpc
::
ServerBuilder
builder
;
builder
.
AddListeningPort
(
address_
,
grpc
::
InsecureServerCredentials
());
builder
.
SetMaxSendMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int
>::
max
());
builder
.
RegisterService
(
&
service_
);
cq_send_
=
builder
.
AddCompletionQueue
();
...
...
@@ -159,18 +169,6 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG
(
4
)
<<
"create Requestget status:"
<<
get
->
Status
();
}
void
AsyncGRPCServer
::
SetFinishOrDelete
(
RequestBase
*&
last
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
delete
last
;
last
=
NULL
;
return
;
}
last
->
SetStatus
(
FINISH
);
return
;
}
void
AsyncGRPCServer
::
HandleRequest
(
bool
wait
,
grpc
::
ServerCompletionQueue
*
cq
,
std
::
string
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
)
{
...
...
@@ -184,13 +182,19 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
break
;
}
PADDLE_ENFORCE
(
tag
);
if
(
wait
&&
!
done_
)
{
Wait
();
}
RequestBase
*
base
=
(
RequestBase
*
)
tag
;
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if
(
!
ok
)
{
VLOG
(
4
)
<<
cq_name
<<
" recv no regular event"
;
LOG
(
WARNING
)
<<
cq_name
<<
" recv no regular event:argument name"
<<
base
->
GetReqName
();
TryToRegisterNewOne
();
delete
base
;
continue
;
...
...
@@ -201,7 +205,6 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
VLOG
(
4
)
<<
cq_name
<<
" status:"
<<
base
->
Status
();
TryToRegisterNewOne
();
base
->
Process
();
SetFinishOrDelete
(
base
);
break
;
}
case
FINISH
:
{
...
...
paddle/operators/detail/grpc_server.h
浏览文件 @
535fefb7
...
...
@@ -60,7 +60,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
std
::
function
<
void
()
>
TryToRegisterNewOne
);
void
TryToRegisterNewSendOne
();
void
TryToRegisterNewGetOne
();
void
SetFinishOrDelete
(
RequestBase
*&
last
);
void
ShutdownQueue
();
private:
...
...
paddle/operators/recv_op.cc
浏览文件 @
535fefb7
...
...
@@ -96,6 +96,8 @@ class RecvOp : public framework::OperatorBase {
rpc_service_
->
Reset
();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
VLOG
(
4
)
<<
"param_count:"
<<
param_count
<<
" trainer_count:"
<<
trainer_count
;
while
(
!
exit_flag
)
{
// TODO(gognwb): simply this loop.
// Get from multiple trainers, we don't care about order in which
...
...
paddle/operators/send_op.cc
浏览文件 @
535fefb7
...
...
@@ -48,7 +48,7 @@ class SendOp : public framework::OperatorBase {
client_
.
AsyncGetVariable
(
epmap
[
i
],
ctx
,
scope
,
outs
[
i
]);
}
client_
.
wait
(
);
PADDLE_ENFORCE
(
client_
.
Wait
()
);
}
private:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录