Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
09fcf5f2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
09fcf5f2
编写于
4月 02, 2018
作者:
Y
Yancey
提交者:
GitHub
4月 02, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9555 from jacquesqiao/improve-prefetch-on-server
Improve prefetch on server
上级
b9d8bbe4
04a5c037
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
55 addition
and
32 deletion
+55
-32
paddle/fluid/operators/detail/CMakeLists.txt
paddle/fluid/operators/detail/CMakeLists.txt
+1
-1
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+10
-9
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+15
-3
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+3
-1
paddle/fluid/operators/detail/grpc_server_test.cc
paddle/fluid/operators/detail/grpc_server_test.cc
+16
-5
paddle/fluid/operators/detail/grpc_service.h
paddle/fluid/operators/detail/grpc_service.h
+3
-3
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+7
-10
未找到文件。
paddle/fluid/operators/detail/CMakeLists.txt
浏览文件 @
09fcf5f2
...
...
@@ -2,7 +2,7 @@ if(WITH_DISTRIBUTE)
grpc_library
(
sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
serde_test.cc grpc_server_test PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
serde_test.cc grpc_server_test
.cc
PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
cc_test
(
serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc
)
cc_test
(
grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf
)
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
09fcf5f2
...
...
@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "grpc_client.h"
#include <sys/time.h>
#include "paddle/fluid/operators/detail/grpc_client.h"
#include <limits>
#include "paddle/fluid/framework/threadpool.h"
namespace
paddle
{
...
...
@@ -52,7 +54,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/SendVariable"
,
req
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
static_cast
<
void
*>
(
s
)
);
});
req_count_
++
;
...
...
@@ -70,8 +72,7 @@ void ProcGetResponse(const VarHandle& var_h,
template
<
typename
T
>
void
RequestToByteBuffer
(
const
T
&
proto
,
::
grpc
::
ByteBuffer
*
result
)
{
::
grpc
::
Slice
slice
(
proto
.
ByteSizeLong
());
proto
.
SerializeWithCachedSizesToArray
(
const_cast
<
uint8_t
*>
(
reinterpret_cast
<
const
uint8_t
*>
(
slice
.
begin
())));
proto
.
SerializeWithCachedSizesToArray
(
const_cast
<
uint8_t
*>
(
slice
.
begin
()));
::
grpc
::
ByteBuffer
tmp
(
&
slice
,
1
);
result
->
Swap
(
&
tmp
);
}
...
...
@@ -109,7 +110,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/GetVariable"
,
buf
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
static_cast
<
void
*>
(
s
)
);
});
req_count_
++
;
...
...
@@ -153,7 +154,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/PrefetchVariable"
,
req
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
static_cast
<
void
*>
(
s
)
);
});
req_count_
++
;
...
...
@@ -169,7 +170,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
BATCH_BARRIER_MESSAGE
);
auto
rpc
=
s
->
stub_
->
AsyncSendVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
static_cast
<
void
*>
(
s
)
);
req_count_
++
;
}
...
...
@@ -181,7 +182,7 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
FETCH_BARRIER_MESSAGE
);
auto
rpc
=
s
->
stub_
->
AsyncGetVariable
(
s
->
context_
.
get
(),
req
,
&
cq_
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
(
void
*
)
s
);
rpc
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
static_cast
<
void
*>
(
s
)
);
req_count_
++
;
}
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
09fcf5f2
...
...
@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
#include <limits>
#include <string>
using
::
grpc
::
ServerAsyncResponseWriter
;
namespace
paddle
{
...
...
@@ -156,6 +159,8 @@ class RequestPrefetch final : public RequestBase {
::
grpc
::
ByteBuffer
reply
;
// TODO(Yancey1989): execute the Block which containers prefetch ops
VLOG
(
3
)
<<
"RequestPrefetch Process in"
;
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
}
...
...
@@ -221,6 +226,7 @@ void AsyncGRPCServer::ShutdownQueue() {
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
cq_send_
->
Shutdown
();
cq_get_
->
Shutdown
();
cq_prefetch_
->
Shutdown
();
}
// This URL explains why shutdown is complicate:
...
...
@@ -233,6 +239,7 @@ void AsyncGRPCServer::ShutDown() {
void
AsyncGRPCServer
::
TryToRegisterNewSendOne
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewSendOne"
;
return
;
}
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
scope_
,
...
...
@@ -243,6 +250,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
void
AsyncGRPCServer
::
TryToRegisterNewGetOne
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewGetOne"
;
return
;
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
scope_
,
dev_ctx_
,
...
...
@@ -253,6 +261,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
void
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewPrefetchOne"
;
return
;
}
RequestPrefetch
*
prefetch
=
...
...
@@ -270,25 +279,28 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
void
*
tag
=
NULL
;
bool
ok
=
false
;
while
(
true
)
{
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" while in"
;
if
(
!
cq
->
Next
(
&
tag
,
&
ok
))
{
LOG
(
INFO
)
<<
cq_name
<<
" CompletionQueue shutdown!"
;
break
;
}
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" while after Next"
;
PADDLE_ENFORCE
(
tag
);
// FIXME(typhoonzero): de-couple the barriers with recv_op
if
(
!
is_shut_down_
&&
cq_name
==
"cq_get"
)
WaitCond
(
1
);
if
(
!
is_shut_down_
&&
cq_name
==
"cq_send"
)
WaitCond
(
0
);
RequestBase
*
base
=
(
RequestBase
*
)
tag
;
RequestBase
*
base
=
reinterpret_cast
<
RequestBase
*>
(
tag
)
;
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if
(
!
ok
)
{
LOG
(
WARNING
)
<<
cq_name
<<
" recv no regular event:argument name"
<<
base
->
GetReqName
();
LOG
(
WARNING
)
<<
cq_name
<<
" recv no regular event:argument name
[
"
<<
base
->
GetReqName
()
<<
"]"
;
TryToRegisterNewOne
();
delete
base
;
continue
;
...
...
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
09fcf5f2
...
...
@@ -15,7 +15,8 @@ limitations under the License. */
#pragma once
#include <grpc++/grpc++.h>
#include <thread>
#include <string>
#include <utility>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
...
@@ -93,6 +94,7 @@ class AsyncGRPCServer final {
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue
<
MessageWithName
>
var_get_queue_
;
// client send variable to this queue.
ReceivedQueue
var_recv_queue_
;
// condition of the sub program
...
...
paddle/fluid/operators/detail/grpc_server_test.cc
浏览文件 @
09fcf5f2
...
...
@@ -28,6 +28,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
void
StartServer
(
const
std
::
string
&
endpoint
)
{
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
rpc_service_
->
RunSyncUpdate
();
}
TEST
(
PREFETCH
,
CPU
)
{
...
...
@@ -39,13 +40,23 @@ TEST(PREFETCH, CPU) {
platform
::
CPUPlace
place
;
platform
::
CPUDeviceContext
ctx
(
place
);
// create var on local scope
std
::
string
var_name
(
"tmp_0"
);
auto
var
=
scope
.
Var
(
var_name
);
auto
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
({
10
,
10
});
std
::
string
in_var_name
(
"in"
);
std
::
string
out_var_name
(
"out"
);
auto
*
in_var
=
scope
.
Var
(
in_var_name
);
auto
*
in_tensor
=
in_var
->
GetMutable
<
framework
::
LoDTensor
>
();
in_tensor
->
Resize
({
10
,
10
});
VLOG
(
3
)
<<
"before mutable_data"
;
in_tensor
->
mutable_data
<
int
>
(
place
);
scope
.
Var
(
out_var_name
);
VLOG
(
3
)
<<
"before fetch"
;
detail
::
RPCClient
client
;
client
.
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
var_name
,
""
);
client
.
AsyncPrefetchVariable
(
"127.0.0.1:8889"
,
ctx
,
scope
,
in_var_name
,
out_var_name
);
client
.
Wait
();
rpc_service_
->
ShutDown
();
server_thread
.
join
();
rpc_service_
.
reset
(
nullptr
);
}
paddle/fluid/operators/detail/grpc_service.h
浏览文件 @
09fcf5f2
...
...
@@ -80,7 +80,7 @@ enum class GrpcMethod {
};
static
const
int
kGrpcNumMethods
=
static_cast
<
int
>
(
GrpcMethod
::
k
Get
Variable
)
+
1
;
static_cast
<
int
>
(
GrpcMethod
::
k
Prefetch
Variable
)
+
1
;
inline
const
char
*
GrpcMethodName
(
GrpcMethod
id
)
{
switch
(
id
)
{
...
...
@@ -89,7 +89,7 @@ inline const char* GrpcMethodName(GrpcMethod id) {
case
GrpcMethod
::
kGetVariable
:
return
"/sendrecv.SendRecvService/GetVariable"
;
case
GrpcMethod
::
kPrefetchVariable
:
return
"/sendrecv.SendR
E
cvService/PrefetchVariable"
;
return
"/sendrecv.SendR
e
cvService/PrefetchVariable"
;
}
// Shouldn't be reached.
...
...
@@ -117,5 +117,5 @@ class GrpcService final {
};
}
// namespace detail
}
// namespace operator
}
// namespace operator
s
}
// namespace paddle
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
09fcf5f2
...
...
@@ -13,22 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <stdint.h>
#include <sys/stat.h>
#include <ostream>
#include <thread>
#include <unistd.h>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
#include "paddle/fluid/string/printf.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -111,6 +102,11 @@ class ListenAndServOp : public framework::OperatorBase {
framework
::
Executor
executor
(
dev_place
);
// TODO(qiao) set proper fields for table lookup and update
rpc_service_
->
SetExecutor
(
&
executor
);
rpc_service_
->
SetPrefetchBlkdId
(
0
);
rpc_service_
->
SetProgram
(
program
);
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
// Record received sparse variables, so that
...
...
@@ -173,7 +169,8 @@ class ListenAndServOp : public framework::OperatorBase {
}
ParallelExecuteBlocks
(
parallel_blkids
,
&
executor
,
program
,
&
recv_scope
);
VLOG
(
2
)
<<
"run all blocks spent (ms) "
<<
detail
::
GetTimestamp
()
-
ts
;
VLOG
(
3
)
<<
"run all blocks spent "
<<
detail
::
GetTimestamp
()
-
ts
<<
"(ms)"
;
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录