Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
260bf5ac
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
260bf5ac
编写于
4月 20, 2018
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sync_mode
上级
63fbdcf9
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
37 addition
and
20 deletion
+37
-20
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+31
-17
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+3
-1
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+1
-1
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+2
-1
未找到文件。
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
260bf5ac
...
@@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH };
...
@@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH };
class
RequestBase
{
class
RequestBase
{
public:
public:
explicit
RequestBase
(
GrpcService
::
AsyncService
*
service
,
explicit
RequestBase
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
const
platform
::
DeviceContext
*
dev_ctx
)
const
platform
::
DeviceContext
*
dev_ctx
)
:
service_
(
service
),
cq_
(
cq
),
status_
(
PROCESS
),
dev_ctx_
(
dev_ctx
)
{
:
service_
(
service
),
cq_
(
cq
),
sync_mode_
(
sync_mode
),
status_
(
PROCESS
),
dev_ctx_
(
dev_ctx
)
{
PADDLE_ENFORCE
(
cq_
);
PADDLE_ENFORCE
(
cq_
);
}
}
virtual
~
RequestBase
()
{}
virtual
~
RequestBase
()
{}
...
@@ -49,6 +53,7 @@ class RequestBase {
...
@@ -49,6 +53,7 @@ class RequestBase {
::
grpc
::
ServerContext
ctx_
;
::
grpc
::
ServerContext
ctx_
;
GrpcService
::
AsyncService
*
service_
;
GrpcService
::
AsyncService
*
service_
;
::
grpc
::
ServerCompletionQueue
*
cq_
;
::
grpc
::
ServerCompletionQueue
*
cq_
;
const
bool
sync_mode_
;
CallStatus
status_
;
CallStatus
status_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
};
};
...
@@ -56,11 +61,17 @@ class RequestBase {
...
@@ -56,11 +61,17 @@ class RequestBase {
class
RequestSend
final
:
public
RequestBase
{
class
RequestSend
final
:
public
RequestBase
{
public:
public:
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
ReceivedQueue
*
queue
,
framework
::
Scope
*
scope
,
ReceivedQueue
*
queue
,
const
platform
::
DeviceContext
*
dev_ctx
)
const
platform
::
DeviceContext
*
dev_ctx
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
queue_
(
queue
),
responder_
(
&
ctx_
)
{
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
request_
.
reset
(
new
VariableResponse
(
false
,
scope
,
dev_ctx_
));
queue_
(
queue
),
responder_
(
&
ctx_
)
{
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
false
,
scope
,
dev_ctx_
));
}
else
{
request_
.
reset
(
new
VariableResponse
(
true
,
scope
,
dev_ctx_
));
}
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
this
);
cq_
,
cq_
,
this
);
...
@@ -87,11 +98,11 @@ class RequestSend final : public RequestBase {
...
@@ -87,11 +98,11 @@ class RequestSend final : public RequestBase {
class
RequestGet
final
:
public
RequestBase
{
class
RequestGet
final
:
public
RequestBase
{
public:
public:
explicit
RequestGet
(
GrpcService
::
AsyncService
*
service
,
explicit
RequestGet
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
const
platform
::
DeviceContext
*
dev_ctx
,
SimpleBlockQueue
<
MessageWithName
>*
queue
)
SimpleBlockQueue
<
MessageWithName
>*
queue
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
responder_
(
&
ctx_
),
scope_
(
scope
),
scope_
(
scope
),
queue_
(
queue
)
{
queue_
(
queue
)
{
...
@@ -134,19 +145,23 @@ class RequestGet final : public RequestBase {
...
@@ -134,19 +145,23 @@ class RequestGet final : public RequestBase {
class
RequestPrefetch
final
:
public
RequestBase
{
class
RequestPrefetch
final
:
public
RequestBase
{
public:
public:
explicit
RequestPrefetch
(
GrpcService
::
AsyncService
*
service
,
explicit
RequestPrefetch
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
const
platform
::
DeviceContext
*
dev_ctx
,
framework
::
Executor
*
executor
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
)
framework
::
ExecutorPrepareContext
*
prefetch_ctx
)
:
RequestBase
(
service
,
cq
,
dev_ctx
),
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
responder_
(
&
ctx_
),
scope_
(
scope
),
scope_
(
scope
),
executor_
(
executor
),
executor_
(
executor
),
program_
(
program
),
program_
(
program
),
prefetch_ctx_
(
prefetch_ctx
)
{
prefetch_ctx_
(
prefetch_ctx
)
{
request_
.
reset
(
new
VariableResponse
(
false
,
scope
,
dev_ctx_
));
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
false
,
scope
,
dev_ctx_
));
}
else
{
request_
.
reset
(
new
VariableResponse
(
true
,
scope
,
dev_ctx_
));
}
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
this
);
cq_
,
cq_
,
this
);
...
@@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase {
...
@@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase {
framework
::
Executor
*
executor_
;
framework
::
Executor
*
executor_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
int
blkid_
;
};
};
void
AsyncGRPCServer
::
WaitClientGet
(
int
count
)
{
void
AsyncGRPCServer
::
WaitClientGet
(
int
count
)
{
...
@@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
...
@@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewSendOne"
;
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewSendOne"
;
return
;
return
;
}
}
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
s
cop
e_
,
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
s
ync_mod
e_
,
&
var_recv_queue_
,
dev_ctx_
);
scope_
,
&
var_recv_queue_
,
dev_ctx_
);
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
send
->
Status
();
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
send
->
Status
();
}
}
...
@@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
...
@@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewGetOne"
;
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewGetOne"
;
return
;
return
;
}
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
s
cope_
,
dev_ctx
_
,
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
s
ync_mode_
,
scope
_
,
&
var_get_queue_
);
dev_ctx_
,
&
var_get_queue_
);
VLOG
(
4
)
<<
"Create RequestGet status:"
<<
get
->
Status
();
VLOG
(
4
)
<<
"Create RequestGet status:"
<<
get
->
Status
();
}
}
...
@@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
...
@@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
return
;
return
;
}
}
RequestPrefetch
*
prefetch
=
RequestPrefetch
*
prefetch
=
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
s
cope_
,
dev_ctx
_
,
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
s
ync_mode_
,
scope
_
,
executor_
,
program_
,
prefetch_ctx_
);
dev_ctx_
,
executor_
,
program_
,
prefetch_ctx_
);
VLOG
(
4
)
<<
"Create RequestPrefetch status:"
<<
prefetch
->
Status
();
VLOG
(
4
)
<<
"Create RequestPrefetch status:"
<<
prefetch
->
Status
();
}
}
...
...
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
260bf5ac
...
@@ -44,7 +44,8 @@ class RequestBase;
...
@@ -44,7 +44,8 @@ class RequestBase;
class
AsyncGRPCServer
final
{
class
AsyncGRPCServer
final
{
public:
public:
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
)
:
address_
(
address
)
{}
explicit
AsyncGRPCServer
(
const
std
::
string
&
address
,
bool
sync_mode
)
:
address_
(
address
),
sync_mode_
(
sync_mode
)
{}
void
RunSyncUpdate
();
void
RunSyncUpdate
();
...
@@ -95,6 +96,7 @@ class AsyncGRPCServer final {
...
@@ -95,6 +96,7 @@ class AsyncGRPCServer final {
std
::
unique_ptr
<::
grpc
::
Server
>
server_
;
std
::
unique_ptr
<::
grpc
::
Server
>
server_
;
std
::
string
address_
;
std
::
string
address_
;
const
bool
sync_mode_
;
framework
::
Scope
*
scope_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
260bf5ac
...
@@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
...
@@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
}
void
StartServer
(
const
std
::
string
&
endpoint
)
{
void
StartServer
(
const
std
::
string
&
endpoint
)
{
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
true
));
framework
::
ProgramDesc
program
;
framework
::
ProgramDesc
program
;
framework
::
Scope
scope
;
framework
::
Scope
scope
;
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
...
...
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
260bf5ac
...
@@ -278,7 +278,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
...
@@ -278,7 +278,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE
(
!
rpc_service_
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
,
sync_mode
));
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
prefetch_block
=
Attr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
);
auto
*
prefetch_block
=
Attr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录