Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
1a438287
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1a438287
编写于
4月 18, 2018
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implement main logic
上级
79a1a7cd
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
42 addition
and
67 deletion
+42
-67
paddle/fluid/operators/async_listen_and_serv_op.cc
paddle/fluid/operators/async_listen_and_serv_op.cc
+42
-67
未找到文件。
paddle/fluid/operators/async_listen_and_serv_op.cc
浏览文件 @
1a438287
...
...
@@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/operators/async_listen_and_serv_op.h"
#include "paddle/utils/StringUtil.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -27,46 +29,18 @@ void RunServer(std::shared_ptr<detail::SyncGRPCServer> service) {
VLOG
(
4
)
<<
"RunServer thread end"
;
}
static
void
CreateTensorFromMessageType
(
framework
::
Variable
*
var
,
sendrecv
::
VarType
var_type
)
{
if
(
var_type
==
sendrecv
::
VarType
::
LOD_TENSOR
)
{
var
->
GetMutable
<
framework
::
LoDTensor
>
();
}
else
if
(
var_type
==
sendrecv
::
VarType
::
SELECTED_ROWS
)
{
var
->
GetMutable
<
framework
::
SelectedRows
>
();
}
else
{
PADDLE_THROW
(
"VariableMessage type %d is not in "
"[LoDTensor, SelectedRows]"
,
var_type
);
}
}
static
void
ParallelExecuteBlocks
(
const
std
::
vector
<
size_t
>
&
parallel_blkids
,
framework
::
Executor
*
executor
,
const
std
::
vector
<
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
&
prepared
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
scope
)
{
std
::
vector
<
std
::
future
<
void
>>
fs
;
for
(
size_t
idx
:
parallel_blkids
)
{
fs
.
push_back
(
framework
::
Async
([
&
executor
,
&
prepared
,
&
program
,
&
scope
,
idx
]()
{
int
run_block
=
idx
;
// thread local
static
void
AsyncExecuteBlock
(
framework
::
Executor
*
executor
,
framework
::
ExecutorPrepareContext
*
prepared
,
framework
::
Scope
*
scope
)
{
framework
::
Async
([
&
executor
,
&
prepared
,
&
scope
]()
{
try
{
executor
->
RunPreparedContext
(
prepared
[
run_block
].
get
(),
scope
,
false
,
false
);
executor
->
RunPreparedContext
(
prepared
,
scope
,
false
,
false
);
}
catch
(
std
::
exception
&
e
)
{
LOG
(
ERROR
)
<<
"run sub program error "
<<
e
.
what
();
}
}));
}
for
(
size_t
i
=
0
;
i
<
fs
.
size
();
++
i
)
fs
[
i
].
wait
();
});
}
static
void
AsyncExecuteBlock
(
size_t
block_id
,
framework
::
Executor
*
executor
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
ctx
,
framework
::
ProgramDesc
*
program
,
framework
::
Scope
*
scope
)
{}
AsyncListenAndServOp
::
AsyncListenAndServOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
...
...
@@ -93,6 +67,21 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_
.
reset
(
new
detail
::
SyncGRPCServer
(
endpoint
));
}
// grad name to block id
std
::
unordered_map
<
std
::
string
,
int32_t
>
grad_to_id
;
std
::
unordered_map
<
int32_t
,
std
::
string
>
id_to_grad
;
auto
grad_map_str
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"grad_map"
);
for
(
auto
&
grad_and_id
:
grad_map_str
)
{
std
::
vector
<
std
::
string
>
pieces
;
paddle
::
str
::
split
(
grad_and_id
,
' '
,
&
pieces
);
PADDLE_ENFORCE_EQ
(
pieces
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
grad_to_id
.
count
(
pieces
[
0
]),
0
);
int
block_id
=
std
::
stoi
(
pieces
[
1
]);
grad_to_id
[
pieces
[
0
]]
=
block_id
;
id_to_grad
[
block_id
]
=
pieces
[
0
];
}
auto
*
optimize_block
=
Attr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
);
auto
*
prefetch_block
=
Attr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
);
auto
*
program
=
optimize_block
->
Program
();
...
...
@@ -108,10 +97,13 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
}
}
auto
optimize_prepared
=
executor
.
Prepare
(
*
program
,
block_list
);
// Insert placeholder for block0 which holds current op itself.
optimize_prepared
.
insert
(
optimize_prepared
.
begin
(),
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>
(
nullptr
));
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
framework
::
ExecutorPrepareContext
>>
grad_to_prepared
;
for
(
size_t
i
=
0
;
i
<
block_list
.
size
();
++
i
)
{
grad_to_prepared
[
id_to_grad
[
block_list
[
i
]]]
=
optimize_prepared
[
i
];
}
rpc_service_
->
SetScope
(
&
recv_scope
);
rpc_service_
->
SetDevCtx
(
&
dev_ctx
);
...
...
@@ -122,6 +114,7 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_
->
SetPrefetchPreparedCtx
(
prefetch_prepared
.
get
());
prefetch_prepared
.
release
();
rpc_service_
->
SetProgram
(
program
);
// start the server listening after all member initialized.
server_thread_
.
reset
(
new
std
::
thread
(
RunServer
,
rpc_service_
));
VLOG
(
3
)
<<
"wait server thread to become ready..."
;
...
...
@@ -133,9 +126,6 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
port_file
.
close
();
bool
exit_flag
=
false
;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std
::
vector
<
framework
::
Variable
*>
sparse_vars
;
while
(
!
exit_flag
)
{
const
detail
::
ReceivedMessage
v
=
rpc_service_
->
Get
();
auto
recv_var_name
=
v
.
first
;
...
...
@@ -150,36 +140,17 @@ void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
LOG
(
ERROR
)
<<
"Can not find server side var: "
<<
recv_var_name
;
PADDLE_THROW
(
"Can not find server side var"
);
}
AsyncExecuteBlock
(
&
executor
,
grad_to_prepared
[
recv_var_name
].
get
(),
&
recv_scope
);
if
(
var
->
IsType
<
framework
::
SelectedRows
>
())
{
sparse_vars
.
push_back
(
var
);
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
(
);
}
AsyncExecuteBlock
();
}
if
(
exit_flag
)
{
rpc_service_
->
ShutDown
();
break
;
}
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work.
// The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
int32_t
last_parent_blkid
=
program
->
Block
(
1
).
Parent
();
VLOG
(
2
)
<<
"run all blocks spent "
<<
detail
::
GetTimestamp
()
-
ts
<<
"(ms)"
;
// Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next
// mini-batch.
// TODO(Yancey1989): move the reset action into an operator, we couldn't
// have any hide logic in the operator.
for
(
auto
&
var
:
sparse_vars
)
{
var
->
GetMutable
<
framework
::
SelectedRows
>
()
->
mutable_rows
()
->
clear
();
}
// FIXME(typhoonzero): use another condition to sync wait clients get.
sparse_vars
.
clear
();
}
// while(true)
}
...
...
@@ -199,6 +170,10 @@ from send_op and send back variables to recv_op.
"IP address to listen on."
)
.
SetDefault
(
"127.0.0.1:6164"
)
.
AddCustomChecker
([](
const
std
::
string
&
ip
)
{
return
!
ip
.
empty
();
});
AddAttr
<
std
::
vector
<
std
::
string
>>
(
"grad_map(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])"
,
"a map from grad name to it's optimize block id"
)
.
SetDefault
({});
AddAttr
<
framework
::
BlockDesc
*>
(
kOptimizeBlock
,
"BlockID to run on server side."
);
AddAttr
<
framework
::
BlockDesc
*>
(
kPrefetchBlock
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录