Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
11fe3c79
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
11fe3c79
编写于
5月 22, 2018
作者:
X
Xin Pan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean up
上级
b4dd4c04
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
27 addition
and
41 deletion
+27
-41
benchmark/cluster/vgg16/vgg16_fluid.py
benchmark/cluster/vgg16/vgg16_fluid.py
+1
-1
cmake/external/grpc.cmake
cmake/external/grpc.cmake
+1
-1
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+25
-39
未找到文件。
benchmark/cluster/vgg16/vgg16_fluid.py
浏览文件 @
11fe3c79
...
...
@@ -204,7 +204,7 @@ def main():
with
profiler
.
profiler
(
'All'
,
'total'
,
'/tmp/profile_vgg_%d'
%
args
.
task_index
):
for
batch_id
,
data
in
enumerate
(
train_reader
()):
if
batch_id
>
4
:
break
if
batch_id
>
5
:
break
run_step
(
batch_id
,
data
)
total_time
=
0.0
...
...
cmake/external/grpc.cmake
浏览文件 @
11fe3c79
...
...
@@ -33,7 +33,7 @@ ExternalProject_Add(
extern_grpc
DEPENDS protobuf zlib
GIT_REPOSITORY
"https://github.com/grpc/grpc.git"
GIT_TAG
"v1.
8
.x"
GIT_TAG
"v1.
10
.x"
PREFIX
${
GRPC_SOURCES_DIR
}
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
11fe3c79
...
...
@@ -66,11 +66,11 @@ class RequestSend final : public RequestBase {
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
ReceivedQueue
*
queue
,
const
platform
::
DeviceContext
*
dev_ctx
,
int
i
)
const
platform
::
DeviceContext
*
dev_ctx
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
queue_
(
queue
),
responder_
(
&
ctx_
),
i_
(
i
)
{
req_id_
(
req_id
)
{
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
false
));
}
else
{
...
...
@@ -79,7 +79,7 @@ class RequestSend final : public RequestBase {
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
i
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestSend
()
{}
...
...
@@ -93,7 +93,7 @@ class RequestSend final : public RequestBase {
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
i
_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
_
)));
}
protected:
...
...
@@ -101,7 +101,7 @@ class RequestSend final : public RequestBase {
std
::
shared_ptr
<
VariableResponse
>
request_
;
ReceivedQueue
*
queue_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
int
i
_
;
int
req_id
_
;
};
class
RequestGet
final
:
public
RequestBase
{
...
...
@@ -110,16 +110,17 @@ class RequestGet final : public RequestBase {
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
framework
::
BlockingQueue
<
MessageWithName
>*
queue
,
int
i
)
framework
::
BlockingQueue
<
MessageWithName
>*
queue
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
),
queue_
(
queue
),
i_
(
i
)
{
req_id_
(
req_id
)
{
auto
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kGetVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
i
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
}
virtual
~
RequestGet
()
{}
...
...
@@ -138,7 +139,7 @@ class RequestGet final : public RequestBase {
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
i
_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
_
)));
if
(
var_name
==
FETCH_BARRIER_MESSAGE
)
{
sendrecv
::
VariableMessage
msg
;
...
...
@@ -153,7 +154,7 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
BlockingQueue
<
MessageWithName
>*
queue_
;
int
i
_
;
int
req_id
_
;
};
class
RequestPrefetch
final
:
public
RequestBase
{
...
...
@@ -165,14 +166,14 @@ class RequestPrefetch final : public RequestBase {
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
,
int
i
)
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
),
executor_
(
executor
),
program_
(
program
),
prefetch_ctx_
(
prefetch_ctx
),
i_
(
i
)
{
req_id_
(
req_id
)
{
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
false
));
}
else
{
...
...
@@ -202,7 +203,7 @@ class RequestPrefetch final : public RequestBase {
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply
);
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
i
_
)));
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
_
)));
status_
=
FINISH
;
}
...
...
@@ -213,7 +214,7 @@ class RequestPrefetch final : public RequestBase {
framework
::
Executor
*
executor_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
int
i
_
;
int
req_id
_
;
};
void
AsyncGRPCServer
::
WaitClientGet
(
int
count
)
{
...
...
@@ -291,21 +292,6 @@ void AsyncGRPCServer::RunSyncUpdate() {
for
(
int
i
=
0
;
i
<
kNumHandleGetThreads
;
++
i
)
{
t_gets_
[
i
]
->
join
();
}
{
std
::
lock_guard
<
std
::
mutex
>
l
(
cq_mutex_
);
for
(
int
i
=
0
;
i
<
kSendReqsBufSize
;
++
i
)
{
if
(
send_reqs_
[
i
])
{
delete
send_reqs_
[
i
];
send_reqs_
[
i
]
=
nullptr
;
}
}
for
(
int
i
=
0
;
i
<
kGetReqsBufSize
;
++
i
)
{
if
(
get_reqs_
[
i
])
{
delete
get_reqs_
[
i
];
get_reqs_
[
i
]
=
nullptr
;
}
}
}
t_prefetch_
->
join
();
}
...
...
@@ -335,19 +321,19 @@ void AsyncGRPCServer::TryToRegisterNewSendOne(int i) {
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
send
->
Status
();
}
void
AsyncGRPCServer
::
TryToRegisterNewGetOne
(
int
i
)
{
void
AsyncGRPCServer
::
TryToRegisterNewGetOne
(
int
req_id
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewGetOne"
;
return
;
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
sync_mode_
,
scope_
,
dev_ctx_
,
&
var_get_queue_
,
i
);
get_reqs_
[
i
]
=
static_cast
<
RequestBase
*>
(
get
);
dev_ctx_
,
&
var_get_queue_
,
req_id
);
get_reqs_
[
req_id
]
=
static_cast
<
RequestBase
*>
(
get
);
VLOG
(
4
)
<<
"Create RequestGet status:"
<<
get
->
Status
();
}
void
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
(
int
i
)
{
void
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
(
int
req_id
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewPrefetchOne"
;
...
...
@@ -355,7 +341,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int i) {
}
RequestPrefetch
*
prefetch
=
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
sync_mode_
,
scope_
,
dev_ctx_
,
executor_
,
program_
,
prefetch_ctx_
.
get
(),
i
);
program_
,
prefetch_ctx_
.
get
(),
req_id
);
VLOG
(
4
)
<<
"Create RequestPrefetch status:"
<<
prefetch
->
Status
();
}
...
...
@@ -374,7 +360,7 @@ void AsyncGRPCServer::HandleRequest(
break
;
}
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" get Next"
;
int
i
=
static_cast
<
int
>
(
reinterpret_cast
<
intptr_t
>
(
tag
));
int
req_id
=
static_cast
<
int
>
(
reinterpret_cast
<
intptr_t
>
(
tag
));
if
(
sync_mode_
)
{
// FIXME(typhoonzero): de-couple the barriers with recv_op
...
...
@@ -387,9 +373,9 @@ void AsyncGRPCServer::HandleRequest(
{
std
::
lock_guard
<
std
::
mutex
>
l
(
cq_mutex_
);
if
(
cq_name
==
"cq_get"
)
{
base
=
get_reqs_
[
i
];
base
=
get_reqs_
[
req_id
];
}
else
if
(
cq_name
==
"cq_send"
)
{
base
=
send_reqs_
[
i
];
base
=
send_reqs_
[
req_id
];
}
else
{
CHECK
(
false
);
}
...
...
@@ -401,7 +387,7 @@ void AsyncGRPCServer::HandleRequest(
if
(
!
ok
)
{
LOG
(
WARNING
)
<<
cq_name
<<
" recv no regular event:argument name["
<<
base
->
GetReqName
()
<<
"]"
;
TryToRegisterNewOne
(
i
);
TryToRegisterNewOne
(
req_id
);
delete
base
;
continue
;
}
...
...
@@ -413,7 +399,7 @@ void AsyncGRPCServer::HandleRequest(
break
;
}
case
FINISH
:
{
TryToRegisterNewOne
(
i
);
TryToRegisterNewOne
(
req_id
);
VLOG
(
4
)
<<
cq_name
<<
" FINISH status:"
<<
base
->
Status
();
delete
base
;
break
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录