Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
03d4665f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
03d4665f
编写于
11月 30, 2020
作者:
1
123malin
提交者:
GitHub
11月 30, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
prefetch optimize (#29095)
* test=develop, optimize async prefetch
上级
7c61ba3a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
45 addition
and
32 deletion
+45
-32
paddle/fluid/operators/distributed/communicator.cc
paddle/fluid/operators/distributed/communicator.cc
+12
-0
paddle/fluid/operators/distributed/grpc/grpc_client.cc
paddle/fluid/operators/distributed/grpc/grpc_client.cc
+31
-30
paddle/fluid/operators/distributed/grpc/grpc_client.h
paddle/fluid/operators/distributed/grpc/grpc_client.h
+1
-1
paddle/fluid/operators/distributed/rpc_server.h
paddle/fluid/operators/distributed/rpc_server.h
+1
-1
未找到文件。
paddle/fluid/operators/distributed/communicator.cc
浏览文件 @
03d4665f
...
...
@@ -162,6 +162,18 @@ void AsyncCommunicator::SendByCommunicator() {
auto
after_send
=
GetCurrentUS
();
VLOG
(
3
)
<<
"send "
<<
var_name
<<
" use time "
<<
after_send
-
after_merge
;
if
(
var_name
.
rfind
(
"@GRAD"
)
!=
var_name
.
size
()
-
5
)
return
;
auto
recv_param
=
var_name
.
substr
(
0
,
var_name
.
size
()
-
5
);
if
(
recv_varname_to_ctx_
.
find
(
recv_param
)
==
recv_varname_to_ctx_
.
end
())
return
;
auto
recv_functor
=
distributed
::
ParameterRecv
<
float
>
();
recv_functor
(
recv_varname_to_ctx_
.
at
(
recv_param
),
*
recv_scope_
);
auto
after_recv
=
GetCurrentUS
();
VLOG
(
3
)
<<
"recv "
<<
recv_param
<<
" use time "
<<
after_recv
-
after_send
;
};
task_futures
.
emplace_back
(
send_threadpool_
->
enqueue
(
std
::
move
(
send_task
)));
}
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.cc
浏览文件 @
03d4665f
...
...
@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32
(
rpc_client_threads
,
2
,
""
);
DECLARE_bool
(
rpc_disable_reuse_port
);
namespace
paddle
{
...
...
@@ -32,10 +33,11 @@ namespace distributed {
void
GRPCClient
::
InitImpl
()
{
// start the client process thread
// TODO(wuyi): can make this in a threadpool
PADDLE_ENFORCE_EQ
(
client_thread_
==
nullptr
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"please not re init proceed thread"
));
client_thread_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
GRPCClient
::
Proceed
,
this
)));
client_threads_
.
resize
(
FLAGS_rpc_client_threads
);
for
(
int
i
=
0
;
i
<
FLAGS_rpc_client_threads
;
i
++
)
{
client_threads_
[
i
].
reset
(
new
std
::
thread
(
std
::
bind
(
&
GRPCClient
::
Proceed
,
this
)));
}
}
void
GRPCClient
::
SendComplete
()
{
...
...
@@ -62,7 +64,8 @@ GRPCClient::~GRPCClient() {
}
channels_
.
clear
();
}
client_thread_
->
join
();
for
(
size_t
i
=
0
;
i
<
client_threads_
.
size
();
i
++
)
client_threads_
[
i
]
->
join
();
}
VarHandlePtr
GRPCClient
::
AsyncSendVar
(
const
std
::
string
&
ep
,
...
...
@@ -84,7 +87,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
var_name_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
framework
::
Async
IO
([
var_name_val
,
p_scope
,
p_ctx
,
s
,
method
,
h
,
this
]
{
framework
::
Async
([
var_name_val
,
p_scope
,
p_ctx
,
s
,
method
,
h
,
this
]
{
auto
*
var
=
p_scope
->
FindVar
(
var_name_val
);
::
grpc
::
ByteBuffer
req
;
...
...
@@ -206,8 +209,8 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
out_varname_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
framework
::
Async
IO
([
var_name_val
,
out_varname_val
,
table_name_val
,
s
,
method
,
p_ctx
,
h
,
rpc_path
,
this
]
{
framework
::
Async
([
var_name_val
,
out_varname_val
,
table_name_val
,
s
,
method
,
p_ctx
,
h
,
rpc_path
,
this
]
{
// prepare input
sendrecv
::
VariableMessage
req
;
req
.
set_varname
(
var_name_val
);
...
...
@@ -273,31 +276,29 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
out_var_name_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
kPrefetchTimeout
);
framework
::
AsyncIO
([
in_var_name_val
,
out_var_name_val
,
ep_val
,
p_scope
,
p_ctx
,
s
,
method
,
h
,
table_name_val
,
this
]
{
auto
*
var
=
p_scope
->
FindVar
(
in_var_name_val
);
auto
*
var
=
p_scope
->
FindVar
(
in_var_name_val
);
::
grpc
::
ByteBuffer
req
;
SerializeToByteBuffer
(
in_var_name_val
,
var
,
*
p_ctx
,
&
req
,
out_var_name_val
,
0
,
table_name_val
);
::
grpc
::
ByteBuffer
req
;
SerializeToByteBuffer
(
in_var_name_val
,
var
,
*
p_ctx
,
&
req
,
out_var_name_val
,
0
,
table_name_val
);
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
VLOG
(
3
)
<<
s
->
GetVarHandlePtr
()
->
String
()
<<
" begin"
;
// stub context
s
->
response_call_back_
=
ProcGetResponse
;
// stub context
s
->
response_call_back_
=
ProcGetResponse
;
platform
::
RecordRPCEvent
record_event
(
method
);
platform
::
RecordRPCEvent
record_event
(
method
);
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/PrefetchVariable"
,
req
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
static_cast
<
void
*>
(
s
));
auto
call
=
s
->
stub_g_
.
PrepareUnaryCall
(
s
->
context_
.
get
(),
"/sendrecv.SendRecvService/PrefetchVariable"
,
req
,
&
cq_
);
call
->
StartCall
();
call
->
Finish
(
&
s
->
reply_
,
&
s
->
status_
,
static_cast
<
void
*>
(
s
));
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
if
(
UNLIKELY
(
platform
::
IsProfileEnabled
()))
{
h
->
Wait
();
}
});
req_count_
++
;
if
(
FLAGS_rpc_retry_times
>
0
&&
retry_times_
<
FLAGS_rpc_retry_times
)
{
...
...
@@ -467,7 +468,7 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
VarHandlePtr
h
(
new
VarHandle
(
ep
,
method
,
var_name_val
,
p_ctx
,
p_scope
));
s
->
Prepare
(
h
,
time_out
);
framework
::
Async
IO
([
var_name_val
,
p_scope
,
p_ctx
,
s
,
method
,
h
,
this
]
{
framework
::
Async
([
var_name_val
,
p_scope
,
p_ctx
,
s
,
method
,
h
,
this
]
{
auto
*
var
=
p_scope
->
FindVar
(
var_name_val
);
::
grpc
::
ByteBuffer
req
;
...
...
@@ -523,8 +524,8 @@ VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep,
s
->
Prepare
(
h
,
time_out
);
s
->
RecvPrepare
(
h_recv
);
framework
::
Async
IO
([
send_var_name_val
,
recv_var_name_val
,
table_name_val
,
p_scope
,
p_ctx
,
s
,
method
,
h
,
this
]
{
framework
::
Async
([
send_var_name_val
,
recv_var_name_val
,
table_name_val
,
p_scope
,
p_ctx
,
s
,
method
,
h
,
this
]
{
auto
*
send_var
=
p_scope
->
FindVar
(
send_var_name_val
);
send_var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
set_lod
({});
::
grpc
::
ByteBuffer
buf
;
...
...
paddle/fluid/operators/distributed/grpc/grpc_client.h
浏览文件 @
03d4665f
...
...
@@ -297,7 +297,7 @@ class GRPCClient : public RPCClient {
private:
grpc
::
CompletionQueue
cq_
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
grpc
::
Channel
>>
channels_
;
std
::
unique_ptr
<
std
::
thread
>
client_thread_
{
nullptr
}
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
client_threads_
;
// mutex for Wait client sync
std
::
mutex
sync_mutex_
;
...
...
paddle/fluid/operators/distributed/rpc_server.h
浏览文件 @
03d4665f
...
...
@@ -85,7 +85,7 @@ class RPCServer {
// class, and auto generate a condition id for this call
// to be used for the barrier.
void
RegisterRPC
(
const
std
::
string
&
rpc_name
,
RequestHandler
*
handler
,
int
thread_num
=
5
);
int
thread_num
=
1
);
int
GetThreadNum
(
const
std
::
string
&
rpc_name
)
{
return
rpc_thread_num_
[
rpc_name
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录