Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0a881a1e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a881a1e
编写于
4月 20, 2018
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
init RunAsyncUpdate
上级
36083018
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
116 addition
and
1 deletion
+116
-1
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+111
-1
paddle/fluid/operators/listen_and_serv_op.h
paddle/fluid/operators/listen_and_serv_op.h
+5
-0
未找到文件。
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
0a881a1e
...
...
@@ -27,6 +27,36 @@ void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
split
(
const
std
::
string
&
str
,
char
sep
,
std
::
vector
<
std
::
string
>
*
pieces
)
{
pieces
->
clear
();
if
(
str
.
empty
())
{
return
;
}
size_t
pos
=
0
;
size_t
next
=
str
.
find
(
sep
,
pos
);
while
(
next
!=
std
::
string
::
npos
)
{
pieces
->
push_back
(
str
.
substr
(
pos
,
next
-
pos
));
pos
=
next
+
1
;
next
=
str
.
find
(
sep
,
pos
);
}
if
(
!
str
.
substr
(
pos
).
empty
())
{
pieces
->
push_back
(
str
.
substr
(
pos
));
}
}
static
void
AsyncExecuteBlock
(
framework
::
Executor
*
executor
,
framework
::
ExecutorPrepareContext
*
prepared
,
framework
::
Scope
*
scope
)
{
framework
::
Async
([
&
executor
,
&
prepared
,
&
scope
]()
{
try
{
executor
->
RunPreparedContext
(
prepared
,
scope
,
false
,
false
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
});
}
static
void
ParallelExecuteBlocks
(
const
std
::
vector
<
size_t
>
&
parallel_blkids
,
framework
::
Executor
*
executor
,
const
std
::
vector
<
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
...
...
@@ -168,12 +198,82 @@ void ListenAndServOp::RunSyncUpdate(
}
// while(true)
}
void
ListenAndServOp
::
RunAsyncUpdate
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
{
// grad name to block id
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_id
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
auto
grad_map_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_map"
);
for
(
auto
&
grad_and_id
:
grad_map_str
)
{
std
::
vector
<
std
::
string
>
pieces
;
split
(
grad_and_id
,
' '
,
&
pieces
);
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
grad_to_id
.
count
(
pieces
[
0
]),
0
);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
grad_to_id
[
pieces
[
0
]]
=
block_id
;
id_to_grad
[
block_id
]
=
pieces
[
0
];
}
size_t
num_blocks
=
program
->
Size
();
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
"server program should have at least 2 blocks"
);
std
::
vector
<
int
>
block_list
;
for
(
size_t
blkid
=
1
;
blkid
<
num_blocks
;
++
blkid
)
{
if
(
blkid
!=
static_cast
<
size_t
>
(
prefetch_block
->
ID
()))
{
block_list
.
push_back
(
blkid
);
}
}
PADDLE_ENFORCE_EQ
(
grad_map_str
.
size
(),
block_list
.
size
(),
"grad num should be equal to optimize block num"
);
auto
optimize_prepared
=
executor
->
Prepare
(
*
program
,
block_list
);
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
grad_to_prepared
;
for
(
size_t
i
=
0
;
i
<
block_list
.
size
();
++
i
)
{
grad_to_prepared
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
}
bool
exit_flag
=
false
;
while
(
!
exit_flag
)
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
auto
recv_var_name
=
v
.
first
;
if
(
recv_var_name
==
LISTEN_TERMINATE_MESSAGE
)
{
LOG
(
INFO
)
<<
"received terminate message and exit"
;
exit_flag
=
true
;
break
;
}
else
{
VLOG
(
3
)
<<
"received grad: "
<<
recv_var_name
;
auto
var
=
v
.
second
->
GetVar
();
if
(
var
==
nullptr
)
{
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
AsyncExecuteBlock
(
executor
,
grad_to_prepared
[
recv_var_name
].
get
(),
recv_scope
);
// TODO(qiao): explain why
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
}
if
(
exit_flag
)
{
rpc_service_
->
ShutDown
();
break
;
}
}
// while(true)
}
void
ListenAndServOp
::
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
{
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
dev_place
);
framework
::
Scope
&
recv_scope
=
scope
.
NewScope
();
bool
sync_mode
=
Attr
<
bool
>
(
"sync_mode"
);
PADDLE_ENFORCE
(
!
rpc_service_
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
rpc_service_
.
reset
(
new
detail
::
AsyncGRPCServer
(
endpoint
));
...
...
@@ -201,7 +301,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
sleep
(
5
);
// Write to a file of server selected port for python use.
SavePort
(
rpc_service_
);
if
(
sync_mode
)
{
RunSyncUpdate
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
}
else
{
RunAsyncUpdate
(
&
executor
,
program
,
&
recv_scope
,
prefetch_block
);
}
}
class
ListenAndServOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
@@ -220,6 +324,12 @@ from send_op and send back variables to recv_op.
"IP address to listen on."
)
.
SetDefault
(
"127.0.0.1:6164"
)
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"grad_map(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])"
,
"a map from grad name to it's optimize block id"
)
.
SetDefault
({});
AddAttr
<
bool
>
(
"sync_mode"
,
"if works at sync_mode or not"
)
.
SetDefault
(
false
);
AddAttr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
,
"BlockID to run on server side."
);
AddAttr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
,
...
...
paddle/fluid/operators/listen_and_serv_op.h
浏览文件 @
0a881a1e
...
...
@@ -46,6 +46,11 @@ class ListenAndServOp : public framework::OperatorBase {
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
;
void
RunAsyncUpdate
(
framework
::
Executor
*
executor
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
recv_scope
,
framework
::
BlockDesc
*
prefetch_block
)
const
;
void
Stop
()
override
;
void
RunImpl
(
const
framework
::
Scope
&
scope
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录