Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b54d1ba9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b54d1ba9
编写于
6月 20, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix pserver sub-blocks
上级
59729902
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
39 addition
and
20 deletion
+39
-20
paddle/fluid/operators/listen_and_serv_op.cc
paddle/fluid/operators/listen_and_serv_op.cc
+9
-1
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+30
-19
未找到文件。
paddle/fluid/operators/listen_and_serv_op.cc
浏览文件 @
b54d1ba9
...
...
@@ -101,13 +101,16 @@ void ListenAndServOp::RunSyncLoop(
framework
::
Scope
*
recv_scope
,
const
std
::
vector
<
int
>
&
prefetch_block_id_list
)
const
{
size_t
num_blocks
=
program
->
Size
();
auto
skip_sub_blks
=
Attr
<
std
::
vector
<
int
>>
(
"skip_sub_blks"
);
PADDLE_ENFORCE_GE
(
num_blocks
,
2
,
"server program should have at least 2 blocks"
);
std
::
vector
<
int
>
optimize_block_id_list
;
for
(
int
blkid
=
1
;
blkid
<
num_blocks
;
++
blkid
)
{
if
(
std
::
find
(
prefetch_block_id_list
.
begin
(),
prefetch_block_id_list
.
end
(),
blkid
)
==
prefetch_block_id_list
.
end
())
{
blkid
)
==
prefetch_block_id_list
.
end
()
&&
std
::
find
(
skip_sub_blks
.
begin
(),
skip_sub_blks
.
end
(),
blkid
)
==
skip_sub_blks
.
end
())
{
optimize_block_id_list
.
push_back
(
blkid
);
}
}
...
...
@@ -344,6 +347,11 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
({});
AddAttr
<
int
>
(
"Fanin"
,
"How many clients send to this server."
)
.
SetDefault
(
1
);
AddAttr
<
std
::
vector
<
int
>>
(
"skip_sub_blks"
,
"do not parallel execute the specify sub blocks, "
"it's used for the op which has"
"condition blocks"
)
.
SetDefault
({});
}
};
...
...
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
b54d1ba9
...
...
@@ -250,19 +250,14 @@ class DistributeTranspiler:
split_method
=
RoundRobin
,
sync_mode
=
True
):
"""
:param trainer_id: one unique id for each trainer in a job.
:type trainer_id: int
:param program: program to transpile, default is default_main_program
:type program: Program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:param trainers: total number of workers/trainers in the job
:type trainers: int
:param split_method: A function to determin how to split variables
Args:
trainer_id(int): one unique id for each trainer in a job.
program(Program): program to transpile, default is default_main_program
pservers(string): parameter server endpoints like "m1:6174,m2:6174"
trainers(int): total number of workers/trainers in the job
split_method(PSDispatcher): A function to determin how to split variables
to different servers equally.
:type split_method: function
:type sync_mode: boolean default True
:param sync_mode: if sync_mode is set True, it means that dist transpiler
sync_mode(boolean): if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program.
"""
assert
(
split_method
.
__bases__
[
0
]
==
PSDispatcher
)
...
...
@@ -403,6 +398,11 @@ class DistributeTranspiler:
NOTE: assume blocks of the same variable is not distributed
on the same pserver, only change param/grad varnames for
trainers to fetch.
Args:
endpoint(string): the endpoint for the current pserver instance.
Returns(Program): the pserver program
"""
# step1
pserver_program
=
Program
()
...
...
@@ -479,9 +479,9 @@ class DistributeTranspiler:
return
varname
return
""
def
__clone_lr_op_sub_block__
(
op
,
program
,
new_block
):
def
__clone_lr_op_sub_block__
(
op
,
program
,
new_block
,
skip_sub_blks
):
if
not
op
.
has_attr
(
'sub_block'
):
return
return
-
1
origin_block_desc
=
op
.
attr
(
'sub_block'
)
origin_block
=
self
.
origin_program
.
block
(
origin_block_desc
.
id
)
...
...
@@ -489,6 +489,7 @@ class DistributeTranspiler:
# we put the new sub block to new block to follow the block
# hierarchy of the original blocks
new_sub_block
=
program
.
create_block
(
new_block
.
idx
)
skip_sub_blks
(
new_sub_block
.
idx
)
# clone vars
for
var
in
origin_block
.
vars
:
...
...
@@ -498,20 +499,24 @@ class DistributeTranspiler:
for
op
in
origin_block
.
ops
:
self
.
_clone_lr_op
(
program
,
new_sub_block
,
op
)
# clone sub_block of op
__clone_lr_op_sub_block__
(
op
,
program
,
new_sub_block
)
__clone_lr_op_sub_block__
(
op
,
program
,
new_sub_block
,
skip_sub_blks
)
# reset the block of op
op
.
set_attr
(
'sub_block'
,
new_sub_block
)
return
new_sub_block
.
idx
# append lr decay ops to the child block if exists
lr_ops
=
self
.
_get_lr_ops
()
skip_sub_blks
=
[]
if
len
(
lr_ops
)
>
0
:
lr_decay_block
=
pserver_program
.
create_block
(
pserver_program
.
num_blocks
-
1
)
for
_
,
op
in
enumerate
(
lr_ops
):
self
.
_append_pserver_non_opt_ops
(
lr_decay_block
,
op
)
# append sub blocks to pserver_program in lr_decay_op
__clone_lr_op_sub_block__
(
op
,
pserver_program
,
lr_decay_block
)
__clone_lr_op_sub_block__
(
op
,
pserver_program
,
lr_decay_block
,
skip_sub_blks
)
# append op to the current block
grad_to_block_id
=
[]
...
...
@@ -561,7 +566,8 @@ class DistributeTranspiler:
"endpoint"
:
endpoint
,
"Fanin"
:
self
.
trainer_num
,
"sync_mode"
:
self
.
sync_mode
,
"grad_to_block_id"
:
grad_to_block_id
"grad_to_block_id"
:
grad_to_block_id
,
"skip_sub_blks"
:
skip_sub_blks
}
if
len
(
prefetch_var_name_to_block_id
)
>
0
:
attrs
[
'prefetch_var_name_to_block_id'
]
\
...
...
@@ -582,6 +588,11 @@ class DistributeTranspiler:
Get startup program for current parameter server.
Modify operator input variables if there are variables that
were split to several blocks.
Args:
endpoint(string): the endpoint for the current pserver instance.
pserver_program(Program): the program for pserver to execute.
Returns(Program): the startup program for pserver
"""
s_prog
=
Program
()
orig_s_prog
=
default_startup_program
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录