Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5f4d9130
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5f4d9130
编写于
1月 18, 2018
作者:
T
typhoonzero
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
merge codes
上级
ae19d2ea
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
10 addition
and
16 deletion
+10
-16
paddle/operators/detail/grpc_server.cc
paddle/operators/detail/grpc_server.cc
+3
-2
paddle/operators/detail/grpc_server.h
paddle/operators/detail/grpc_server.h
+2
-4
paddle/operators/recv_op.cc
paddle/operators/recv_op.cc
+5
-10
未找到文件。
paddle/operators/detail/grpc_server.cc
浏览文件 @
5f4d9130
...
...
@@ -162,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() {
}
// This URL explains why shutdown is complicate:
// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c
void
AsyncGRPCServer
::
ShutDown
()
{
server_
->
Shutdown
();
ShutdownQueue
();
...
...
@@ -188,6 +187,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
VLOG
(
4
)
<<
"create Requestget status:"
<<
get
->
Status
();
}
// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
void
AsyncGRPCServer
::
HandleRequest
(
bool
wait
,
grpc
::
ServerCompletionQueue
*
cq
,
std
::
string
cq_name
,
std
::
function
<
void
()
>
TryToRegisterNewOne
)
{
...
...
@@ -202,7 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
}
PADDLE_ENFORCE
(
tag
);
if
(
cq_name
==
"cq_get"
)
WaitCond
(
2
);
// FIXME(typhoonzero): de-couple the barriers with recv_op
if
(
cq_name
==
"cq_get"
)
WaitCond
(
1
);
if
(
cq_name
==
"cq_send"
)
WaitCond
(
0
);
RequestBase
*
base
=
(
RequestBase
*
)
tag
;
...
...
paddle/operators/detail/grpc_server.h
浏览文件 @
5f4d9130
...
...
@@ -42,10 +42,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void
RunSyncUpdate
();
// functions to sync server barrier status.
void
WaitStart
();
void
WaitDone
();
void
Start
();
void
Done
();
void
WaitCond
(
int
cond
);
void
SetCond
(
int
cond
);
void
WaitClientGet
(
int
count
);
void
SetScope
(
framework
::
Scope
*
scope
)
{
scope_
=
scope
;
}
...
...
paddle/operators/recv_op.cc
浏览文件 @
5f4d9130
...
...
@@ -105,15 +105,14 @@ class RecvOp : public framework::OperatorBase {
framework
::
ProgramDesc
program
(
program_desc
);
framework
::
Executor
executor
(
dev_place
);
// rpc_service_->Reset();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool
exit_flag
=
false
;
int64_t
barrier_size
=
param_count
*
fan_in
;
while
(
!
exit_flag
)
{
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_
->
SetCond
(
kCondStart
);
VLOG
(
3
)
<<
"================ start get from service ==========="
;
for
(
size_t
i
=
0
;
i
<
param_count
*
fan_in
;
++
i
)
{
rpc_service_
->
SetCond
(
0
);
for
(
size_t
i
=
0
;
i
<
barrier_size
;
++
i
)
{
const
detail
::
MessageWithName
&
v
=
rpc_service_
->
Get
();
auto
grad_var_name
=
v
.
first
;
if
(
grad_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
...
...
@@ -130,8 +129,6 @@ class RecvOp : public framework::OperatorBase {
}
VLOG
(
3
)
<<
"recved grad: "
<<
grad_var_name
<<
" updating param: "
<<
param_var_name
;
// Assume grad_var_name must appear in global scope.
std
::
string
grad_var_name_trainer
;
if
(
fan_in
>
1
)
{
grad_var_name
=
this
->
GetGradVarNameForTrainer
(
grad_var_name
);
}
...
...
@@ -145,16 +142,14 @@ class RecvOp : public framework::OperatorBase {
if
(
exit_flag
)
{
break
;
}
// rpc_service_->Reset();
try
{
executor
.
Run
(
program
,
&
recv_scope
,
0
,
/*global_block*/
false
/*create_local_scope*/
,
false
/*create_vars*/
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
VLOG
(
3
)
<<
"================ run sub program end ==========="
;
rpc_service_
->
SetCond
(
kCondDone
);
rpc_service_
->
WaitClientGet
(
param_count
*
fan_in
);
rpc_service_
->
SetCond
(
1
);
rpc_service_
->
WaitClientGet
(
barrier_size
);
grads_counter_
.
clear
();
}
// while(true)
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录