Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fe8f28c9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fe8f28c9
编写于
1月 25, 2019
作者:
G
gongweibao
提交者:
GitHub
1月 25, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add GetVariableNoBarrier on brpc. (#15488)
上级
981fc2bd
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
111 addition
and
24 deletion
+111
-24
paddle/fluid/operators/distributed/CMakeLists.txt
paddle/fluid/operators/distributed/CMakeLists.txt
+5
-3
paddle/fluid/operators/distributed/brpc/brpc_client.cc
paddle/fluid/operators/distributed/brpc/brpc_client.cc
+33
-13
paddle/fluid/operators/distributed/brpc/brpc_client.h
paddle/fluid/operators/distributed/brpc/brpc_client.h
+9
-0
paddle/fluid/operators/distributed/brpc/brpc_server.cc
paddle/fluid/operators/distributed/brpc/brpc_server.cc
+59
-6
paddle/scripts/paddle_build.sh
paddle/scripts/paddle_build.sh
+2
-1
python/paddle/fluid/transpiler/details/checkport.py
python/paddle/fluid/transpiler/details/checkport.py
+3
-1
未找到文件。
paddle/fluid/operators/distributed/CMakeLists.txt
浏览文件 @
fe8f28c9
...
...
@@ -20,7 +20,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc
${
GRPC_SRCS
}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory
)
DEPS lod_tensor selected_rows_functor memory
scope
${
GRPC_DEPS
}
)
set_source_files_properties
(
grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set
(
RPC_DEPS sendrecvop_rpc
${
GRPC_DEPS
}
)
...
...
@@ -32,15 +32,17 @@ else()
set
(
BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc
)
set_source_files_properties
(
${
BRPC_SRCS
}
parameter_prefetch.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set
(
BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib
)
brpc_library
(
sendrecvop_rpc SRCS sendrecvop_utils.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc
variable_response.cc
collective_client.cc collective_server.cc
${
BRPC_SRCS
}
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory
)
DEPS lod_tensor selected_rows memory
scope
${
BRPC_DEPS
}
)
set
(
RPC_DEPS sendrecvop_rpc
brpc ssl crypto protobuf leveldb snappystream snappy zlib
)
set
(
RPC_DEPS sendrecvop_rpc
${
BRPC_DEPS
}
)
cc_test
(
brpc_serde_test SRCS brpc/brpc_serde_test.cc
DEPS
${
RPC_DEPS
}
gflags glog executor proto_desc lookup_sparse_table_op SERIAL
)
endif
()
...
...
paddle/fluid/operators/distributed/brpc/brpc_client.cc
浏览文件 @
fe8f28c9
...
...
@@ -62,7 +62,7 @@ VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep,
const
std
::
string
var_name_val
=
var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch_ptr
=
GetChannel
(
ep_val
);
const
std
::
string
method
=
"SendRPC"
;
const
std
::
string
method
=
kSendRPC
;
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
var_name_val
,
p_ctx
,
p_scope
));
framework
::
AsyncIO
([
=
]
{
...
...
@@ -156,15 +156,18 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
method_name
,
int64_t
time_out
)
{
const
platform
::
DeviceContext
*
p_ctx
=
&
ctx
;
const
std
::
string
ep_val
=
ep
;
const
std
::
string
var_name_val
=
var_name
;
const
std
::
string
out_varname_val
=
out_var_name
;
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch_ptr
=
GetChannel
(
ep_val
);
const
std
::
string
method
=
"GetRPC"
;
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
var_name_val
,
p_ctx
,
p_scope
));
const
std
::
string
method
=
kGetRPC
;
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
out_varname_val
,
p_ctx
,
p_scope
));
framework
::
AsyncIO
([
=
]
{
auto
ch_ctx
=
ch_ptr
->
Pop
();
...
...
@@ -175,6 +178,7 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
req
.
set_out_varname
(
out_varname_val
);
req
.
set_trainer_id
(
trainer_id_
);
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
...
...
@@ -182,8 +186,10 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
platform
::
RecordRPCEvent
record_event
(
method
,
p_ctx
);
if
(
method_name
==
"GetMonomerVariable"
)
{
if
(
method_name
==
kGetMonomerRPC
)
{
ch_ctx
->
stub
->
GetMonomerVariable
(
cntl
,
&
req
,
response
,
done
);
}
else
if
(
method_name
==
kGetNoBarrierRPC
)
{
ch_ctx
->
stub
->
GetVariableNoBarrier
(
cntl
,
&
req
,
response
,
done
);
}
else
{
ch_ctx
->
stub
->
GetVariable
(
cntl
,
&
req
,
response
,
done
);
}
...
...
@@ -198,25 +204,39 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
return
var_h
;
}
VarHandlePtr
BRPCClient
::
AsyncGetVarNoBarrier
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
)
{
std
::
string
var_name_no_barrier
=
string
::
Sprintf
(
"%s%s"
,
var_name
,
WITHOUT_BARRIER_MESSAGE
);
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name_no_barrier
,
out_var_name
,
kGetNoBarrierRPC
,
time_out
);
}
VarHandlePtr
BRPCClient
::
AsyncGetMonomerVariable
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
"GetMonomerVariable"
,
time_out
);
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
var_name
,
kGetMonomerRPC
,
time_out
);
}
VarHandlePtr
BRPCClient
::
AsyncGetMonomerBarrier
(
const
std
::
string
&
ep
,
const
std
::
string
&
var_name
,
int64_t
time_out
)
{
return
AsyncSendMessage
(
ep
,
"GetMonomerBarrier"
,
var_name
,
time_out
);
return
AsyncSendMessage
(
ep
,
kSendMonomerFetchBarrierRPC
,
var_name
,
time_out
);
}
VarHandlePtr
BRPCClient
::
AsyncGetVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
)
{
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
"GetVariable"
,
time_out
);
return
_AsyncGetVar
(
ep
,
ctx
,
scope
,
var_name
,
out_var_name
,
kGetRPC
,
time_out
);
}
VarHandlePtr
BRPCClient
::
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
...
...
@@ -234,7 +254,7 @@ VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
const
framework
::
Scope
*
p_scope
=
&
scope
;
const
auto
ch_ptr
=
GetChannel
(
ep_val
);
const
std
::
string
method
=
"PrefetchRPC"
;
const
std
::
string
method
=
kPrefetchRPC
;
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
out_var_name_val
,
p_ctx
,
p_scope
));
...
...
@@ -270,7 +290,7 @@ VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
VarHandlePtr
BRPCClient
::
AsyncSendBatchBarrier
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
return
AsyncSendMessage
(
ep
,
"BatchBarrierRPC"
,
BATCH_BARRIER_MESSAGE
,
return
AsyncSendMessage
(
ep
,
kBatchBarrierRPC
,
BATCH_BARRIER_MESSAGE
,
time_out
);
}
...
...
@@ -286,7 +306,7 @@ VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
FETCH_BARRIER_MESSAGE
);
const
std
::
string
method
=
"FetchBarrierRPC"
;
const
std
::
string
method
=
kFetchBarrierRPC
;
// var handle
VarHandlePtr
var_h
(
new
VarHandle
(
ep
,
method
,
FETCH_BARRIER_MESSAGE
,
nullptr
,
nullptr
));
...
...
@@ -367,7 +387,7 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
VarHandlePtr
BRPCClient
::
AsyncSendComplete
(
const
std
::
string
&
ep
,
int64_t
time_out
)
{
return
AsyncSendMessage
(
ep
,
"SendCompleteRPC"
,
COMPLETE_MESSAGE
,
time_out
);
return
AsyncSendMessage
(
ep
,
kSendCompleteRPC
,
COMPLETE_MESSAGE
,
time_out
);
}
void
BRPCClient
::
SendComplete
()
{
...
...
@@ -394,9 +414,9 @@ VarHandlePtr BRPCClient::AsyncSendVarMessage(
google
::
protobuf
::
Closure
*
done
=
brpc
::
NewCallback
(
&
HandleSendResponse
,
cntl
,
response
,
var_h
,
ch_ptr
,
ch_ctx
,
this
);
if
(
method_name
==
"CheckPointNotifyRPC"
)
{
if
(
method_name
==
kCheckPointNotifyRPC
)
{
ch_ctx
->
stub
->
CheckpointNotify
(
cntl
,
&
req
,
response
,
done
);
}
else
if
(
method_name
==
"GetMonomerBarrier"
)
{
}
else
if
(
method_name
==
kSendMonomerFetchBarrierRPC
)
{
ch_ctx
->
stub
->
GetMonomerBarrier
(
cntl
,
&
req
,
response
,
done
);
}
else
{
ch_ctx
->
stub
->
SendVariable
(
cntl
,
&
req
,
response
,
done
);
...
...
paddle/fluid/operators/distributed/brpc/brpc_client.h
浏览文件 @
fe8f28c9
...
...
@@ -65,6 +65,7 @@ class BRPCClient : public RPCClient {
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetMonomerBarrier
(
...
...
@@ -76,6 +77,13 @@ class BRPCClient : public RPCClient {
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
)
override
;
VarHandlePtr
AsyncGetVarNoBarrier
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_varname
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
VarHandlePtr
AsyncPrefetchVar
(
const
std
::
string
&
ep
,
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
...
...
@@ -103,6 +111,7 @@ class BRPCClient : public RPCClient {
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Scope
&
scope
,
const
std
::
string
&
var_name
,
const
std
::
string
&
out_var_name
,
const
std
::
string
&
method_name
,
int64_t
time_out
=
FLAGS_rpc_deadline
);
...
...
paddle/fluid/operators/distributed/brpc/brpc_server.cc
浏览文件 @
fe8f28c9
...
...
@@ -45,6 +45,13 @@ class BRPCServiceImpl : public SendRecvService {
rpc_server_
->
GetThreadNum
(
distributed
::
kRequestGet
)));
}
it
=
rpc_call_map
.
find
(
distributed
::
kRequestGetNoBarrier
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_getnobarrier_h_
=
it
->
second
;
getnobarrier_threads_
.
reset
(
new
paddle
::
framework
::
ThreadPool
(
rpc_server_
->
GetThreadNum
(
distributed
::
kRequestGetNoBarrier
)));
}
it
=
rpc_call_map
.
find
(
distributed
::
kRequestPrefetch
);
if
(
it
!=
rpc_call_map
.
end
())
{
request_prefetch_h_
=
it
->
second
;
...
...
@@ -112,6 +119,14 @@ class BRPCServiceImpl : public SendRecvService {
[
=
]
{
_GetVariable
(
cntl_butil
,
request
,
response
,
done
);
});
}
void
GetVariableNoBarrier
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
override
{
getnobarrier_threads_
->
Run
(
[
=
]
{
_GetVariableNoBarrier
(
cntl_butil
,
request
,
response
,
done
);
});
}
void
_GetVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
...
...
@@ -122,23 +137,59 @@ class BRPCServiceImpl : public SendRecvService {
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
std
::
string
varname
=
request
->
varname
();
std
::
string
out_varname
=
request
->
out_varname
();
VLOG
(
3
)
<<
"RequestGet varname:"
<<
varname
<<
", out_varname:"
<<
out_varname
<<
", trainer_id:"
<<
request
->
trainer_id
()
<<
", from:"
<<
cntl
->
remote_side
();
auto
scope
=
request_get_h_
->
scope
();
auto
invar
=
scope
->
FindVar
(
varname
);
paddle
::
framework
::
Variable
*
invar
=
nullptr
;
int
trainer_id
=
request
->
trainer_id
();
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
request_get_h_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_varname
);
if
(
outvar
)
{
distributed
::
SerializeToIOBuf
(
out_varname
,
outvar
,
*
request_get_h_
->
dev_ctx
(),
response
,
&
cntl
->
response_attachment
(),
""
,
false
);
}
}
void
_GetVariableNoBarrier
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
google
::
protobuf
::
Closure
*
done
)
{
PADDLE_ENFORCE
(
request_getnobarrier_h_
!=
nullptr
,
"RequestGetNoBarrier handler should be registed first!"
);
brpc
::
ClosureGuard
done_guard
(
done
);
brpc
::
Controller
*
cntl
=
static_cast
<
brpc
::
Controller
*>
(
cntl_butil
);
std
::
string
varname
=
request
->
varname
();
std
::
string
out_varname
=
request
->
out_varname
();
int
trainer_id
=
request
->
trainer_id
();
VLOG
(
3
)
<<
"RequestGetNoBarrier varname:"
<<
varname
<<
", out_varname:"
<<
out_varname
<<
", trainer_id:"
<<
trainer_id
<<
", from:"
<<
cntl
->
remote_side
();
auto
scope
=
request_getnobarrier_h_
->
scope
();
paddle
::
framework
::
Variable
*
invar
=
nullptr
;
paddle
::
framework
::
Variable
*
outvar
=
nullptr
;
request_get_h_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
);
request_getnobarrier_h_
->
Handle
(
varname
,
scope
,
invar
,
&
outvar
,
trainer_id
,
out_varname
);
if
(
outvar
)
{
distributed
::
SerializeToIOBuf
(
varname
,
outvar
,
*
request_get_h_
->
dev_ctx
(),
response
,
&
cntl
->
response_attachment
(),
""
,
false
);
distributed
::
SerializeToIOBuf
(
out_varname
,
outvar
,
*
request_getnobarrier_h_
->
dev_ctx
(),
response
,
&
cntl
->
response_attachment
(),
""
,
false
);
}
}
void
PrefetchVariable
(
google
::
protobuf
::
RpcController
*
cntl_butil
,
const
VariableMessage
*
request
,
VariableMessage
*
response
,
...
...
@@ -282,6 +333,7 @@ class BRPCServiceImpl : public SendRecvService {
private:
distributed
::
RequestHandler
*
request_send_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_get_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_getnobarrier_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_prefetch_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_checkpoint_h_
{
nullptr
};
distributed
::
RequestHandler
*
request_get_monomer_handler_h_
{
nullptr
};
...
...
@@ -289,9 +341,10 @@ class BRPCServiceImpl : public SendRecvService {
distributed
::
RPCServer
*
rpc_server_
{
nullptr
};
// FIXME(gongwb): brpc should support process one rpc
e
use one threadpool.
// FIXME(gongwb): brpc should support process one rpc use one threadpool.
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
send_threads_
;
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
get_threads_
;
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
getnobarrier_threads_
;
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
prefetch_threads_
;
std
::
unique_ptr
<
paddle
::
framework
::
ThreadPool
>
checkpoint_notify_threads_
;
};
...
...
paddle/scripts/paddle_build.sh
浏览文件 @
fe8f28c9
...
...
@@ -328,7 +328,8 @@ function run_brpc_test() {
========================================
EOF
set
+x
declare
-a
other_tests
=(
"test_listen_and_serv_op"
"system_allocator_test"
)
declare
-a
other_tests
=(
"test_listen_and_serv_op"
"system_allocator_test"
\
"rpc_server_test"
"varhandle_test"
"collective_server_test"
"brpc_serde_test"
)
all_tests
=
`
ctest
-N
`
for
t
in
"
${
other_tests
[@]
}
"
...
...
python/paddle/fluid/transpiler/details/checkport.py
浏览文件 @
fe8f28c9
...
...
@@ -16,6 +16,7 @@ import sys
import
time
import
socket
from
contextlib
import
closing
from
six
import
string_types
def
wait_server_ready
(
endpoints
):
...
...
@@ -32,6 +33,7 @@ def wait_server_ready(endpoints):
wait_server_ready(["127.0.0.1:8080", "127.0.0.1:8081"])
"""
assert
not
isinstance
(
endpoints
,
string_types
)
while
True
:
all_ok
=
True
not_ready_endpoints
=
[]
...
...
@@ -45,7 +47,7 @@ def wait_server_ready(endpoints):
all_ok
=
False
not_ready_endpoints
.
append
(
ep
)
if
not
all_ok
:
sys
.
stderr
.
write
(
"
p
server not ready, wait 3 sec to retry...
\n
"
)
sys
.
stderr
.
write
(
"server not ready, wait 3 sec to retry...
\n
"
)
sys
.
stderr
.
write
(
"not ready endpoints:"
+
str
(
not_ready_endpoints
)
+
"
\n
"
)
sys
.
stderr
.
flush
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录