Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
0d598cf9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0d598cf9
编写于
5月 23, 2018
作者:
X
Xin Pan
提交者:
GitHub
5月 23, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10822 from panyx0718/dist_opt
multi-thread handlerequest
上级
397a69d9
2643868c
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
169 addition
and
85 deletion
+169
-85
benchmark/cluster/vgg16/vgg16_fluid.py
benchmark/cluster/vgg16/vgg16_fluid.py
+14
-10
paddle/fluid/operators/detail/grpc_client.cc
paddle/fluid/operators/detail/grpc_client.cc
+7
-1
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+110
-59
paddle/fluid/operators/detail/grpc_server.h
paddle/fluid/operators/detail/grpc_server.h
+18
-6
paddle/fluid/operators/detail/grpc_service.h
paddle/fluid/operators/detail/grpc_service.h
+2
-0
paddle/fluid/operators/detail/send_recv.proto
paddle/fluid/operators/detail/send_recv.proto
+3
-3
paddle/fluid/operators/detail/sendrecvop_utils.cc
paddle/fluid/operators/detail/sendrecvop_utils.cc
+7
-1
paddle/fluid/operators/detail/variable_response.cc
paddle/fluid/operators/detail/variable_response.cc
+6
-4
paddle/fluid/platform/device_tracer.cc
paddle/fluid/platform/device_tracer.cc
+0
-1
paddle/fluid/platform/profiler.h
paddle/fluid/platform/profiler.h
+2
-0
未找到文件。
benchmark/cluster/vgg16/vgg16_fluid.py
浏览文件 @
0d598cf9
...
...
@@ -38,7 +38,7 @@ def str2bool(v):
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
28
,
help
=
"Batch size for training."
)
'--batch_size'
,
type
=
int
,
default
=
1
6
,
help
=
"Batch size for training."
)
parser
.
add_argument
(
'--learning_rate'
,
type
=
float
,
...
...
@@ -61,7 +61,7 @@ parser.add_argument(
parser
.
add_argument
(
'--data_set'
,
type
=
str
,
default
=
'
cifar10
'
,
default
=
'
flowers
'
,
choices
=
[
'cifar10'
,
'flowers'
],
help
=
'Optional dataset for benchmark.'
)
parser
.
add_argument
(
...
...
@@ -200,26 +200,30 @@ def main():
fetch_list
=
[
avg_cost
,
batch_acc
,
batch_size
])
return
loss
,
acc
,
b_size
if
args
.
profile
and
args
.
task_index
==
0
:
# warmup.
for
batch_id
,
data
in
enumerate
(
train_reader
()):
if
batch_id
>
5
:
break
run_step
(
batch_id
,
data
)
with
profiler
.
profiler
(
'All'
,
'total'
,
'/tmp/profile_vgg'
):
if
args
.
profile
:
with
profiler
.
profiler
(
'All'
,
'total'
,
'/tmp/profile_vgg_%d'
%
args
.
task_index
):
for
batch_id
,
data
in
enumerate
(
train_reader
()):
if
batch_id
>
5
:
break
run_step
(
batch_id
,
data
)
total_time
=
0.0
count
=
0
for
batch_id
,
data
in
enumerate
(
train_reader
()):
ts
=
time
.
time
()
loss
,
acc
,
b_size
=
run_step
(
batch_id
,
data
)
iters
+=
1
num_samples
+=
len
(
data
)
train_pass_acc
.
add
(
value
=
acc
,
weight
=
b_size
)
duration
=
time
.
time
()
-
ts
total_time
+=
duration
count
+=
len
(
data
)
print
(
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, "
"Speed = %.2f img/s"
%
(
pass_id
,
iters
,
loss
,
acc
,
len
(
data
)
/
(
time
.
time
()
-
ts
))
"Speed = %.2f (%.2f) img/s"
%
(
pass_id
,
iters
,
loss
,
acc
,
len
(
data
)
/
duration
,
count
/
total_time
)
)
# The accuracy is the accumulation of batches, but not the current batch.
pass_elapsed
=
time
.
time
()
-
start_time
...
...
paddle/fluid/operators/detail/grpc_client.cc
浏览文件 @
0d598cf9
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include <limits>
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -196,9 +197,14 @@ bool RPCClient::Wait() {
const
size_t
kReqCnt
=
req_count_
;
bool
a
[
kReqCnt
];
std
::
vector
<
std
::
future
<
void
>>
waits
(
req_count_
);
std
::
mutex
mu
;
for
(
int
i
=
0
;
i
<
req_count_
;
i
++
)
{
waits
[
i
]
=
framework
::
AsyncIO
([
i
,
&
a
,
this
]
{
a
[
i
]
=
Proceed
();
});
waits
[
i
]
=
framework
::
AsyncIO
([
i
,
&
a
,
&
mu
,
this
]
{
bool
ret
=
Proceed
();
std
::
lock_guard
<
std
::
mutex
>
l
(
mu
);
a
[
i
]
=
ret
;
});
}
for
(
int
i
=
0
;
i
<
req_count_
;
i
++
)
{
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
0d598cf9
...
...
@@ -19,10 +19,16 @@ limitations under the License. */
using
::
grpc
::
ServerAsyncResponseWriter
;
DEFINE_int32
(
rpc_server_handle_send_threads
,
20
,
"Number of threads used to handle send at rpc server."
);
DEFINE_int32
(
rpc_server_handle_get_threads
,
20
,
"Number of threads used to handle get at rpc server."
);
DEFINE_int32
(
rpc_server_handle_prefetch_threads
,
1
,
"Number of threads used to handle prefetch at rpc server."
);
namespace
paddle
{
namespace
operators
{
namespace
detail
{
enum
CallStatus
{
PROCESS
=
0
,
FINISH
};
// reference:
...
...
@@ -63,18 +69,20 @@ class RequestSend final : public RequestBase {
explicit
RequestSend
(
GrpcService
::
AsyncService
*
service
,
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
ReceivedQueue
*
queue
,
const
platform
::
DeviceContext
*
dev_ctx
)
const
platform
::
DeviceContext
*
dev_ctx
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
queue_
(
queue
),
responder_
(
&
ctx_
)
{
responder_
(
&
ctx_
),
req_id_
(
req_id
)
{
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
false
));
}
else
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
}
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kSendVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
this
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id
)));
}
virtual
~
RequestSend
()
{}
...
...
@@ -86,15 +94,17 @@ class RequestSend final : public RequestBase {
VLOG
(
3
)
<<
"RequestSend "
<<
var_name
;
queue_
->
Push
(
std
::
make_pair
(
var_name
,
request_
));
sendrecv
::
VoidMessage
reply
;
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
}
protected:
sendrecv
::
VoidMessage
reply_
;
std
::
shared_ptr
<
VariableResponse
>
request_
;
ReceivedQueue
*
queue_
;
ServerAsyncResponseWriter
<
sendrecv
::
VoidMessage
>
responder_
;
int
req_id_
;
};
class
RequestGet
final
:
public
RequestBase
{
...
...
@@ -103,14 +113,17 @@ class RequestGet final : public RequestBase {
::
grpc
::
ServerCompletionQueue
*
cq
,
bool
sync_mode
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
,
framework
::
BlockingQueue
<
MessageWithName
>*
queue
)
framework
::
BlockingQueue
<
MessageWithName
>*
queue
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
),
queue_
(
queue
)
{
queue_
(
queue
),
req_id_
(
req_id
)
{
auto
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kGetVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
}
virtual
~
RequestGet
()
{}
...
...
@@ -123,13 +136,13 @@ class RequestGet final : public RequestBase {
VLOG
(
3
)
<<
"RequestGet "
<<
var_name
;
auto
*
var
=
scope_
->
FindVar
(
var_name
);
::
grpc
::
ByteBuffer
reply
;
if
(
var_name
!=
FETCH_BARRIER_MESSAGE
)
{
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply
);
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply
_
);
}
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
if
(
var_name
==
FETCH_BARRIER_MESSAGE
)
{
sendrecv
::
VariableMessage
msg
;
...
...
@@ -140,9 +153,11 @@ class RequestGet final : public RequestBase {
protected:
sendrecv
::
VariableMessage
request_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
BlockingQueue
<
MessageWithName
>*
queue_
;
int
req_id_
;
};
class
RequestPrefetch
final
:
public
RequestBase
{
...
...
@@ -153,21 +168,24 @@ class RequestPrefetch final : public RequestBase {
const
platform
::
DeviceContext
*
dev_ctx
,
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
ExecutorPrepareContext
*
prefetch_ctx
)
framework
::
ExecutorPrepareContext
*
prefetch_ctx
,
int
req_id
)
:
RequestBase
(
service
,
cq
,
sync_mode
,
dev_ctx
),
responder_
(
&
ctx_
),
scope_
(
scope
),
executor_
(
executor
),
program_
(
program
),
prefetch_ctx_
(
prefetch_ctx
)
{
prefetch_ctx_
(
prefetch_ctx
),
req_id_
(
req_id
)
{
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
false
));
}
else
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
}
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
this
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
}
virtual
~
RequestPrefetch
()
{}
...
...
@@ -176,7 +194,6 @@ class RequestPrefetch final : public RequestBase {
virtual
void
Process
()
{
// prefetch process...
::
grpc
::
ByteBuffer
reply
;
std
::
string
var_name
=
request_
->
OutVarname
();
VLOG
(
3
)
<<
"RequestPrefetch "
<<
var_name
;
...
...
@@ -186,19 +203,22 @@ class RequestPrefetch final : public RequestBase {
InitializeVariable
(
var
,
var_desc
->
GetType
());
executor_
->
RunPreparedContext
(
prefetch_ctx_
,
scope_
);
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply
);
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply
_
);
responder_
.
Finish
(
reply
,
::
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
responder_
.
Finish
(
reply_
,
::
grpc
::
Status
::
OK
,
reinterpret_cast
<
void
*>
(
static_cast
<
intptr_t
>
(
req_id_
)));
}
protected:
std
::
shared_ptr
<
VariableResponse
>
request_
;
::
grpc
::
ByteBuffer
reply_
;
ServerAsyncResponseWriter
<::
grpc
::
ByteBuffer
>
responder_
;
framework
::
Scope
*
scope_
;
framework
::
Executor
*
executor_
;
framework
::
ProgramDesc
*
program_
;
framework
::
ExecutorPrepareContext
*
prefetch_ctx_
;
int
req_id_
;
};
void
AsyncGRPCServer
::
WaitClientGet
(
int
count
)
{
...
...
@@ -232,24 +252,39 @@ void AsyncGRPCServer::RunSyncUpdate() {
LOG
(
INFO
)
<<
"Server listening on "
<<
address_
<<
" selected port: "
<<
selected_port_
;
std
::
function
<
void
()
>
send_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewSendOne
,
this
);
std
::
function
<
void
()
>
get_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewGetOne
,
this
);
std
::
function
<
void
()
>
prefetch_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
,
this
);
// TODO(wuyi): Run these "HandleRequest" in thread pool
t_send_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_send_
.
get
(),
"cq_send"
,
send_register
)));
t_get_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_get_
.
get
(),
"cq_get"
,
get_register
)));
t_prefetch_
.
reset
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_prefetch_
.
get
(),
"cq_prefetch"
,
prefetch_register
)));
std
::
function
<
void
(
int
)
>
send_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewSendOne
,
this
,
std
::
placeholders
::
_1
);
std
::
function
<
void
(
int
)
>
get_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewGetOne
,
this
,
std
::
placeholders
::
_1
);
std
::
function
<
void
(
int
)
>
prefetch_register
=
std
::
bind
(
&
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
,
this
,
std
::
placeholders
::
_1
);
for
(
int
i
=
0
;
i
<
kSendReqsBufSize
;
++
i
)
{
TryToRegisterNewSendOne
(
i
);
}
for
(
int
i
=
0
;
i
<
kGetReqsBufSize
;
++
i
)
{
TryToRegisterNewGetOne
(
i
);
}
for
(
int
i
=
0
;
i
<
kPrefetchReqsBufSize
;
++
i
)
{
TryToRegisterNewPrefetchOne
(
i
);
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_send_threads
;
++
i
)
{
t_sends_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_send_
.
get
(),
"cq_send"
,
send_register
)));
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_get_threads
;
++
i
)
{
t_gets_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_get_
.
get
(),
"cq_get"
,
get_register
)));
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_prefetch_threads
;
++
i
)
{
t_prefetchs_
.
emplace_back
(
new
std
::
thread
(
std
::
bind
(
&
AsyncGRPCServer
::
HandleRequest
,
this
,
cq_prefetch_
.
get
(),
"cq_prefetch"
,
prefetch_register
)));
}
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_ready_
);
ready_
=
1
;
...
...
@@ -257,9 +292,15 @@ void AsyncGRPCServer::RunSyncUpdate() {
condition_ready_
.
notify_all
();
// wait server
server_
->
Wait
();
t_send_
->
join
();
t_get_
->
join
();
t_prefetch_
->
join
();
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_send_threads
;
++
i
)
{
t_sends_
[
i
]
->
join
();
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_get_threads
;
++
i
)
{
t_gets_
[
i
]
->
join
();
}
for
(
int
i
=
0
;
i
<
FLAGS_rpc_server_handle_prefetch_threads
;
++
i
)
{
t_prefetchs_
[
i
]
->
join
();
}
}
void
AsyncGRPCServer
::
ShutdownQueue
()
{
...
...
@@ -276,47 +317,48 @@ void AsyncGRPCServer::ShutDown() {
server_
->
Shutdown
();
}
void
AsyncGRPCServer
::
TryToRegisterNewSendOne
()
{
void
AsyncGRPCServer
::
TryToRegisterNewSendOne
(
int
i
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewSendOne"
;
return
;
}
RequestSend
*
send
=
new
RequestSend
(
&
service_
,
cq_send_
.
get
(),
sync_mode_
,
scope_
,
&
var_recv_queue_
,
dev_ctx_
);
scope_
,
&
var_recv_queue_
,
dev_ctx_
,
i
);
send_reqs_
[
i
]
=
static_cast
<
RequestBase
*>
(
send
);
VLOG
(
4
)
<<
"Create RequestSend status:"
<<
send
->
Status
();
}
void
AsyncGRPCServer
::
TryToRegisterNewGetOne
()
{
void
AsyncGRPCServer
::
TryToRegisterNewGetOne
(
int
req_id
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewGetOne"
;
return
;
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
sync_mode_
,
scope_
,
dev_ctx_
,
&
var_get_queue_
);
dev_ctx_
,
&
var_get_queue_
,
req_id
);
get_reqs_
[
req_id
]
=
static_cast
<
RequestBase
*>
(
get
);
VLOG
(
4
)
<<
"Create RequestGet status:"
<<
get
->
Status
();
}
void
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
()
{
void
AsyncGRPCServer
::
TryToRegisterNewPrefetchOne
(
int
req_id
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
cq_mutex_
);
if
(
is_shut_down_
)
{
VLOG
(
3
)
<<
"shutdown, do not TryToRegisterNewPrefetchOne"
;
return
;
}
RequestPrefetch
*
prefetch
=
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
sync_mode_
,
scope_
,
dev_ctx_
,
executor_
,
program_
,
prefetch_ctx_
.
get
());
RequestPrefetch
*
prefetch
=
new
RequestPrefetch
(
&
service_
,
cq_prefetch_
.
get
(),
sync_mode_
,
scope_
,
dev_ctx_
,
executor_
,
program_
,
prefetch_ctx_
.
get
(),
req_id
);
prefetch_reqs_
[
req_id
]
=
static_cast
<
RequestBase
*>
(
prefetch
);
VLOG
(
4
)
<<
"Create RequestPrefetch status:"
<<
prefetch
->
Status
();
}
// FIXME(typhoonzero): change cq_name to enum.
void
AsyncGRPCServer
::
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
)
{
TryToRegisterNewOne
();
void
AsyncGRPCServer
::
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq_name
,
std
::
function
<
void
(
int
)
>
TryToRegisterNewOne
)
{
void
*
tag
=
NULL
;
bool
ok
=
false
;
...
...
@@ -327,8 +369,7 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
break
;
}
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" get Next"
;
PADDLE_ENFORCE
(
tag
);
int
req_id
=
static_cast
<
int
>
(
reinterpret_cast
<
intptr_t
>
(
tag
));
if
(
sync_mode_
)
{
// FIXME(typhoonzero): de-couple the barriers with recv_op
...
...
@@ -337,7 +378,17 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
VLOG
(
3
)
<<
"HandleRequest for "
<<
cq_name
<<
" after WaitCond"
;
}
RequestBase
*
base
=
reinterpret_cast
<
RequestBase
*>
(
tag
);
RequestBase
*
base
=
nullptr
;
{
std
::
lock_guard
<
std
::
mutex
>
l
(
cq_mutex_
);
if
(
cq_name
==
"cq_get"
)
{
base
=
get_reqs_
[
req_id
];
}
else
if
(
cq_name
==
"cq_send"
)
{
base
=
send_reqs_
[
req_id
];
}
else
if
(
cq_name
==
"cq_prefetch"
)
{
base
=
prefetch_reqs_
[
req_id
];
}
}
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
...
...
@@ -345,19 +396,19 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
if
(
!
ok
)
{
LOG
(
WARNING
)
<<
cq_name
<<
" recv no regular event:argument name["
<<
base
->
GetReqName
()
<<
"]"
;
TryToRegisterNewOne
();
TryToRegisterNewOne
(
req_id
);
delete
base
;
continue
;
}
switch
(
base
->
Status
())
{
case
PROCESS
:
{
TryToRegisterNewOne
();
base
->
Process
();
VLOG
(
4
)
<<
cq_name
<<
" PROCESS status:"
<<
base
->
Status
();
break
;
}
case
FINISH
:
{
TryToRegisterNewOne
(
req_id
);
VLOG
(
4
)
<<
cq_name
<<
" FINISH status:"
<<
base
->
Status
();
delete
base
;
break
;
...
...
paddle/fluid/operators/detail/grpc_server.h
浏览文件 @
0d598cf9
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "grpc++/grpc++.h"
#include "paddle/fluid/framework/blocking_queue.h"
...
...
@@ -30,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/profiler.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -82,19 +84,27 @@ class AsyncGRPCServer final {
protected:
void
HandleRequest
(
::
grpc
::
ServerCompletionQueue
*
cq
,
const
std
::
string
&
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
);
void
TryToRegisterNewSendOne
();
void
TryToRegisterNewGetOne
();
void
TryToRegisterNewPrefetchOne
();
std
::
function
<
void
(
int
)
>
TryToRegisterNewOne
);
void
TryToRegisterNewSendOne
(
int
req_id
);
void
TryToRegisterNewGetOne
(
int
req_id
);
void
TryToRegisterNewPrefetchOne
(
int
req_id
);
void
ShutdownQueue
();
private:
static
const
int
kSendReqsBufSize
=
100
;
static
const
int
kGetReqsBufSize
=
100
;
static
const
int
kPrefetchReqsBufSize
=
10
;
std
::
mutex
cq_mutex_
;
volatile
bool
is_shut_down_
=
false
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_send_
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_get_
;
std
::
unique_ptr
<::
grpc
::
ServerCompletionQueue
>
cq_prefetch_
;
RequestBase
*
send_reqs_
[
kSendReqsBufSize
];
RequestBase
*
get_reqs_
[
kGetReqsBufSize
];
RequestBase
*
prefetch_reqs_
[
kPrefetchReqsBufSize
];
GrpcService
::
AsyncService
service_
;
std
::
unique_ptr
<::
grpc
::
Server
>
server_
;
...
...
@@ -113,8 +123,10 @@ class AsyncGRPCServer final {
mutable
int
barrier_cond_step_
;
std
::
condition_variable
barrier_condition_
;
std
::
unique_ptr
<
std
::
thread
>
t_send_
;
std
::
unique_ptr
<
std
::
thread
>
t_get_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_sends_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_gets_
;
std
::
vector
<
std
::
unique_ptr
<
std
::
thread
>>
t_prefetchs_
;
std
::
unique_ptr
<
std
::
thread
>
t_prefetch_
;
std
::
unique_ptr
<
framework
::
ExecutorPrepareContext
>
prefetch_ctx_
;
...
...
paddle/fluid/operators/detail/grpc_service.h
浏览文件 @
0d598cf9
...
...
@@ -25,6 +25,8 @@
#include <grpc++/support/byte_buffer.h>
#include "paddle/fluid/operators/detail/variable_response.h"
#include "paddle/fluid/platform/profiler.h"
// NOTE: This method was originally created by tensorflow
// (https://github.com/tensorflow/tensorflow/) we borrow this
// method and did some modifications so that we can parse gRPC
...
...
paddle/fluid/operators/detail/send_recv.proto
浏览文件 @
0d598cf9
...
...
@@ -70,10 +70,10 @@ message VariableMessage {
bytes
rows
=
9
;
// Look up table block execution output variable name.
string
out_varname
=
10
;
// If
true
, the ps server will start profiling, the ps
// If
1
, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from
true to false
.
bool
profile
=
11
;
// when profile switches from
1 to 2
.
int64
profile
=
11
;
}
message
VoidMessage
{}
paddle/fluid/operators/detail/sendrecvop_utils.cc
浏览文件 @
0d598cf9
...
...
@@ -123,7 +123,13 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
// 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the
// trainer.
request
.
set_profile
(
platform
::
IsProfileEnabled
());
if
(
platform
::
ShouldSendProfileState
())
{
if
(
platform
::
IsProfileEnabled
())
{
request
.
set_profile
(
platform
::
kEnableProfiler
);
}
else
{
request
.
set_profile
(
platform
::
kDisableProfiler
);
}
}
if
(
!
out_name
.
empty
())
{
request
.
set_out_varname
(
out_name
);
}
...
...
paddle/fluid/operators/detail/variable_response.cc
浏览文件 @
0d598cf9
...
...
@@ -449,8 +449,8 @@ int VariableResponse::Parse(Source* source) {
break
;
}
case
sendrecv
::
VariableMessage
::
kProfileFieldNumber
:
{
bool
profiling
;
if
(
!
input
.
Read
Raw
(
reinterpret_cast
<
void
*>
(
&
profiling
),
1
))
{
uint64_t
profiling
=
0
;
if
(
!
input
.
Read
Varint64
(
&
profiling
))
{
return
tag
;
}
meta_
.
set_profile
(
profiling
);
...
...
@@ -458,9 +458,11 @@ int VariableResponse::Parse(Source* source) {
if
(
listener_id
<=
0
)
{
break
;
}
if
(
profiling
&&
!
platform
::
IsProfileEnabled
())
{
if
(
profiling
==
platform
::
kEnableProfiler
&&
!
platform
::
IsProfileEnabled
())
{
platform
::
EnableProfiler
(
platform
::
ProfilerState
::
kCPU
);
}
else
if
(
!
profiling
&&
platform
::
IsProfileEnabled
())
{
}
else
if
(
profiling
==
platform
::
kDisableProfiler
&&
platform
::
IsProfileEnabled
())
{
// TODO(panyx0718): Should we allow to customize file dir.
platform
::
DisableProfiler
(
platform
::
EventSortingKey
::
kDefault
,
...
...
paddle/fluid/platform/device_tracer.cc
浏览文件 @
0d598cf9
...
...
@@ -245,7 +245,6 @@ class DeviceTracerImpl : public DeviceTracer {
void
Enable
()
{
std
::
lock_guard
<
std
::
mutex
>
l
(
trace_mu_
);
if
(
enabled_
)
{
fprintf
(
stderr
,
"DeviceTracer already enabled
\n
"
);
return
;
}
EnableActivity
();
...
...
paddle/fluid/platform/profiler.h
浏览文件 @
0d598cf9
...
...
@@ -116,6 +116,8 @@ void ResetProfiler();
void
DisableProfiler
(
EventSortingKey
sorted_key
,
const
std
::
string
&
profile_path
);
const
int
kEnableProfiler
=
1
;
const
int
kDisableProfiler
=
2
;
// Test if the profiler is currently enabled.
bool
IsProfileEnabled
();
// Whether the trainer should send profiling state to PS.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录