Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ae19d2ea
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ae19d2ea
编写于
1月 18, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix comm issues
上级
f233b936
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
48 addition
and
29 deletion
+48
-29
paddle/operators/detail/grpc_server.cc
paddle/operators/detail/grpc_server.cc
+28
-19
paddle/operators/detail/grpc_server.h
paddle/operators/detail/grpc_server.h
+9
-6
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+11
-4
未找到文件。
paddle/operators/detail/grpc_server.cc
浏览文件 @
ae19d2ea
...
...
@@ -36,7 +36,10 @@ class RequestBase {
CallStatus
Status
()
{
return
status_
;
}
void
SetStatus
(
CallStatus
status
)
{
status_
=
status
;
}
virtual
std
::
string
GetReqName
()
{
assert
(
false
);
}
virtual
std
::
string
GetReqName
()
{
assert
(
false
);
return
""
;
}
protected:
grpc
::
ServerContext
ctx_
;
...
...
@@ -80,11 +83,13 @@ class RequestGet final : public RequestBase {
public:
explicit
RequestGet
(
sendrecv
::
SendRecvService
::
AsyncService
*
service
,
grpc
::
ServerCompletionQueue
*
cq
,
framework
::
Scope
*
scope
,
const
platform
::
DeviceContext
*
dev_ctx
)
const
platform
::
DeviceContext
*
dev_ctx
,
SimpleBlockQueue
<
char
>*
queue
)
:
RequestBase
(
service
,
cq
),
responder_
(
&
ctx_
),
scope_
(
scope
),
dev_ctx_
(
dev_ctx
)
{
dev_ctx_
(
dev_ctx
),
queue_
(
queue
)
{
service_
->
RequestGetVariable
(
&
ctx_
,
&
request_
,
&
responder_
,
cq_
,
cq_
,
this
);
}
...
...
@@ -100,6 +105,7 @@ class RequestGet final : public RequestBase {
// TODO(gongwb): check var's info.
responder_
.
Finish
(
reply_
,
grpc
::
Status
::
OK
,
this
);
status_
=
FINISH
;
queue_
->
Push
(
'c'
);
}
protected:
...
...
@@ -108,8 +114,15 @@ class RequestGet final : public RequestBase {
ServerAsyncResponseWriter
<
sendrecv
::
VariableMessage
>
responder_
;
framework
::
Scope
*
scope_
;
const
platform
::
DeviceContext
*
dev_ctx_
;
SimpleBlockQueue
<
char
>*
queue_
;
};
void
AsyncGRPCServer
::
WaitClientGet
(
int
count
)
{
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
var_get_queue_
.
Pop
();
}
}
void
AsyncGRPCServer
::
RunSyncUpdate
()
{
grpc
::
ServerBuilder
builder
;
builder
.
AddListeningPort
(
address_
,
grpc
::
InsecureServerCredentials
());
...
...
@@ -170,7 +183,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
if
(
is_shut_down_
)
{
return
;
}
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
scope_
,
dev_ctx_
);
RequestGet
*
get
=
new
RequestGet
(
&
service_
,
cq_get_
.
get
(),
scope_
,
dev_ctx_
,
&
var_get_queue_
);
VLOG
(
4
)
<<
"create Requestget status:"
<<
get
->
Status
();
}
...
...
@@ -188,9 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}
PADDLE_ENFORCE
(
tag
);
if
(
wait
&&
!
done_
)
{
Wait
();
}
if
(
cq_name
==
"cq_get"
)
WaitCond
(
2
);
if
(
cq_name
==
"cq_send"
)
WaitCond
(
0
);
RequestBase
*
base
=
(
RequestBase
*
)
tag
;
// reference:
...
...
@@ -222,22 +235,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}
}
void
AsyncGRPCServer
::
Wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
mutex_
);
condition_
.
wait
(
lock
,
[
=
]
{
return
this
->
done_
==
true
;
});
}
void
AsyncGRPCServer
::
Reset
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
false
;
void
AsyncGRPCServer
::
WaitCond
(
int
cond
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
this
->
barrier_mutex_
);
barrier_condition_
.
wait
(
lock
,
[
=
]
{
return
this
->
barrier_cond_step_
==
cond
;
});
}
void
AsyncGRPCServer
::
Done
(
)
{
void
AsyncGRPCServer
::
SetCond
(
int
cond
)
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
mutex_
);
done_
=
true
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
this
->
barrier_
mutex_
);
barrier_cond_step_
=
cond
;
}
condition_
.
notify_all
();
barrier_
condition_
.
notify_all
();
}
}
// namespace detail
...
...
paddle/operators/detail/grpc_server.h
浏览文件 @
ae19d2ea
...
...
@@ -41,9 +41,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void
RunSyncUpdate
();
void
Reset
();
// functions to sync server barrier status.
void
WaitStart
();
void
WaitDone
();
void
Start
();
void
Done
();
void
WaitClientGet
(
int
count
);
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
...
...
@@ -56,7 +59,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void
ShutDown
();
protected:
void
Wait
();
void
HandleRequest
(
bool
wait
,
grpc
::
ServerCompletionQueue
*
cq
,
std
::
string
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
);
...
...
@@ -78,11 +80,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
const
platform
::
DeviceContext
*
dev_ctx_
;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue
<
MessageWithName
>
var_recv_queue_
;
SimpleBlockQueue
<
char
>
var_get_queue_
;
// condition of the sub program
std
::
mutex
mutex_
;
volatile
mutable
bool
done
_
;
std
::
condition_variable
condition_
;
std
::
mutex
barrier_
mutex_
;
mutable
int
barrier_cond_step
_
;
std
::
condition_variable
barrier_
condition_
;
std
::
unique_ptr
<
std
::
thread
>
t_send_
;
std
::
unique_ptr
<
std
::
thread
>
t_get_
;
...
...
paddle/operators/recv_op.cc
浏览文件 @
ae19d2ea
...
...
@@ -34,6 +34,10 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
constexpr
int
kCondStart
=
0
;
constexpr
int
kCondRunning
=
1
;
constexpr
int
kCondDone
=
2
;
void
RunServer
(
std
::
shared_ptr
<
detail
::
AsyncGRPCServer
>
service
)
{
service
->
RunSyncUpdate
();
VLOG
(
4
)
<<
"RunServer thread end"
;
...
...
@@ -101,12 +105,14 @@ class RecvOp : public framework::OperatorBase {
framework
::
ProgramDesc
program
(
program_desc
);
framework
::
Executor
executor
(
dev_place
);
rpc_service_
->
Reset
();
//
rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
while
(
!
exit_flag
)
{
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_
->
SetCond
(
kCondStart
);
VLOG
(
3
)
<<
"================ start get from service ==========="
;
for
(
size_t
i
=
0
;
i
<
param_count
*
fan_in
;
++
i
)
{
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
auto
grad_var_name
=
v
.
first
;
...
...
@@ -139,15 +145,16 @@ class RecvOp : public framework::OperatorBase {
if
(
exit_flag
)
{
break
;
}
rpc_service_
->
Reset
();
//
rpc_service_->Reset();
try
{
executor
.
Run
(
program
,
&
recv_scope
,
0
,
/*global_block*/
false
/*create_local_scope*/
,
false
/*create_vars*/
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
rpc_service_
->
Done
();
VLOG
(
3
)
<<
"================ run sub program end ==========="
;
rpc_service_
->
SetCond
(
kCondDone
);
rpc_service_
->
WaitClientGet
(
param_count
*
fan_in
);
grads_counter_
.
clear
();
}
// while(true)
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录