Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0d5de244
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0d5de244
编写于
12月 27, 2017
作者:
武
武毅
提交者:
GitHub
12月 27, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7064 from typhoonzero/fix_dist_transpiler
fix dist train trainspiler bugs
上级
35c1683e
3e703e91
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
23 addition
and
13 deletion
+23
-13
python/paddle/v2/fluid/distribute_transpiler.py
python/paddle/v2/fluid/distribute_transpiler.py
+5
-3
python/paddle/v2/fluid/framework.py
python/paddle/v2/fluid/framework.py
+1
-1
python/paddle/v2/fluid/tests/book_distribute/notest_recognize_digits_conv_dist.py
...ests/book_distribute/notest_recognize_digits_conv_dist.py
+17
-9
未找到文件。
python/paddle/v2/fluid/distribute_transpiler.py
浏览文件 @
0d5de244
...
...
@@ -95,7 +95,9 @@ class DistributeTranspiler:
"""
if
program
is
None
:
program
=
default_main_program
()
self
.
program
=
program
self
.
trainers
=
trainers
self
.
optimize_ops
=
optimize_ops
self
.
_optimize_distributed
(
optimize_ops
,
program
,
...
...
@@ -156,9 +158,10 @@ class DistributeTranspiler:
attrs
=
{
"endpoints"
:
pserver_endpoints
,
"epmap"
:
epmap
})
def
get_trainer_program
(
optimize_ops
,
program
):
def
get_trainer_program
(
self
):
# remove optimize ops and add a send op to main_program
program
.
global_block
().
delete_ops
(
optimize_ops
)
self
.
program
.
global_block
().
delete_ops
(
self
.
optimize_ops
)
return
self
.
program
def
_create_var_for_trainers
(
self
,
block
,
var
,
trainers
):
var_list
=
[]
...
...
@@ -210,7 +213,6 @@ class DistributeTranspiler:
if
opt_op
.
inputs
.
has_key
(
"Grad"
):
if
opt_op
.
inputs
[
"Grad"
].
name
in
grad_var_names
:
print
"appending "
,
opt_op
.
type
,
opt_op
.
inputs
optimize_sub_program
.
global_block
().
append_op
(
type
=
opt_op
.
type
,
inputs
=
opt_op
.
inputs
,
...
...
python/paddle/v2/fluid/framework.py
浏览文件 @
0d5de244
...
...
@@ -663,7 +663,7 @@ class Block(object):
end
=
list
(
self
.
ops
).
index
(
ops
[
-
1
])
except
Exception
,
e
:
raise
e
self
.
desc
.
remove_op
(
start
,
end
)
self
.
desc
.
remove_op
(
start
,
end
+
1
)
def
prepend_op
(
self
,
*
args
,
**
kwargs
):
op_desc
=
self
.
desc
.
prepend_op
()
...
...
python/paddle/v2/fluid/tests/book/notest_recognize_digits_conv_dist.py
→
python/paddle/v2/fluid/tests/book
_distribute
/notest_recognize_digits_conv_dist.py
浏览文件 @
0d5de244
...
...
@@ -38,35 +38,43 @@ train_reader = paddle.batch(
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
t
=
fluid
.
DistributeTranspiler
()
# all parameter server endpoints list for spliting parameters
pserver_endpoints
=
os
.
getenv
(
"PSERVERS"
)
# server endpoint for current node
current_endpoint
=
os
.
getenv
(
"SERVER_ENDPOINT"
)
# run as trainer or parameter server
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
# get the training role: trainer/pserver
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
1
)
t
.
transpile
(
optimize_ops
,
params_grads
,
pservers
=
pserver_endpoints
,
trainers
=
2
)
if
training_role
==
"PSERVER"
:
pserver_prog
=
t
.
get_pserver_program
(
pserver_endpoints
,
optimize_ops
)
if
not
current_endpoint
:
print
(
"need env SERVER_ENDPOINT"
)
exit
(
1
)
pserver_prog
=
t
.
get_pserver_program
(
current_endpoint
,
optimize_ops
)
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
pserver_prog
)
elif
training_role
==
"TRAINER"
:
trainer_prog
=
t
.
get_trainer_program
()
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
images
,
label
],
place
=
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
pass_id
in
range
(
PASS_NUM
):
accuracy
.
reset
(
exe
)
batch_id
=
0
for
data
in
train_reader
():
loss
,
acc
=
exe
.
run
(
fluid
.
default_main_program
()
,
loss
,
acc
=
exe
.
run
(
trainer_prog
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
avg_cost
]
+
accuracy
.
metrics
)
pass_acc
=
accuracy
.
eval
(
exe
)
# print loss, acc
if
loss
<
10.0
and
pass_acc
>
0.9
:
# if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good.
exit
(
0
)
if
batch_id
%
100
==
0
:
print
(
"batch_id %d, loss: %f, acc: %f"
%
(
batch_id
,
loss
,
pass_acc
))
batch_id
+=
1
pass_acc
=
accuracy
.
eval
(
exe
)
print
(
"pass_id="
+
str
(
pass_id
)
+
" pass_acc="
+
str
(
pass_acc
))
else
:
print
(
"environment var TRAINER_ROLE should be TRAINER os PSERVER"
)
exit
(
1
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录