Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
c083ee70
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c083ee70
编写于
4月 18, 2018
作者:
F
fengjiayi
提交者:
GitHub
4月 18, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9950 from JiayiFeng/add_parallel_executor_tests
Add parallel executor tests
上级
61f4baa1
e84d3a7f
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
58 addition
and
14 deletion
+58
-14
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+6
-5
python/paddle/fluid/tests/unittests/test_parallel_executor.py
...on/paddle/fluid/tests/unittests/test_parallel_executor.py
+52
-9
未找到文件。
python/paddle/fluid/parallel_executor.py
浏览文件 @
c083ee70
...
...
@@ -16,6 +16,7 @@ import core
import
multiprocessing
import
framework
import
executor
import
warnings
import
sys
__all__
=
[
'ParallelExecutor'
]
...
...
@@ -62,8 +63,8 @@ class ParallelExecutor(object):
main_program=test_program,
share_vars_from=train_exe)
train_loss, = train_exe.run([loss.name], feed
_dict
=feed_dict)
test_loss, = test_exe.run([loss.name], feed
_dict
=feed_dict)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
test_loss, = test_exe.run([loss.name], feed=feed_dict)
"""
self
.
_places
=
[]
...
...
@@ -103,8 +104,8 @@ class ParallelExecutor(object):
self
.
persistable_vars
=
[
v
.
name
for
v
in
filter
(
lambda
var
:
\
var
.
persistable
and
var
.
type
!=
core
.
VarDesc
.
VarType
.
RAW
,
for
v
in
filter
(
lambda
var
:
var
.
persistable
and
var
.
type
!=
core
.
VarDesc
.
VarType
.
RAW
,
main
.
list_vars
())
]
...
...
@@ -163,7 +164,7 @@ class ParallelExecutor(object):
Returns: fetched result list.
"""
if
feed
is
None
:
if
feed
is
None
and
feed_dict
is
not
None
:
feed
=
feed_dict
print
>>
sys
.
stderr
,
"`feed_dict` is deprecated. Please use `feed=`"
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor.py
浏览文件 @
c083ee70
...
...
@@ -200,14 +200,29 @@ class TestParallelExecutorBase(unittest.TestCase):
def
check_network_convergence
(
self
,
method
,
memory_opt
=
True
,
iter
=
1
0
,
iter
=
5
0
,
batch_size
=
None
,
allow_op_delay
=
False
,
feed_dict
=
None
):
feed_dict
=
None
,
seed
=
None
,
use_parallel_executor
=
True
):
def
run_executor
(
exe
,
feed
,
fetch_list
,
program
=
None
):
if
isinstance
(
exe
,
fluid
.
ParallelExecutor
):
res
=
exe
.
run
(
fetch_list
=
fetch_list
,
feed
=
feed
)
elif
isinstance
(
exe
,
fluid
.
Executor
):
if
program
is
None
:
program
=
fluid
.
default_main_program
()
res
=
exe
.
run
(
program
=
program
,
feed
=
feed
,
fetch_list
=
fetch_list
)
else
:
raise
ValueError
(
'Unkown type exe'
)
return
res
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
startup
.
random_seed
=
1
# Fix random seed
with
fluid
.
program_guard
(
main
,
startup
):
if
seed
is
not
None
:
startup
.
random_seed
=
seed
loss
=
method
(
use_feed
=
feed_dict
is
not
None
)
adam
=
fluid
.
optimizer
.
Adam
()
adam
.
minimize
(
loss
)
...
...
@@ -217,18 +232,24 @@ class TestParallelExecutorBase(unittest.TestCase):
startup_exe
=
fluid
.
Executor
(
place
)
startup_exe
.
run
(
startup
)
exe
=
fluid
.
ParallelExecutor
(
True
,
loss_name
=
loss
.
name
,
allow_op_delay
=
allow_op_delay
)
if
use_parallel_executor
:
exe
=
fluid
.
ParallelExecutor
(
True
,
loss_name
=
loss
.
name
,
allow_op_delay
=
allow_op_delay
)
else
:
exe
=
fluid
.
Executor
(
place
=
place
)
if
batch_size
is
not
None
:
batch_size
*=
fluid
.
core
.
get_cuda_device_count
()
begin
=
time
.
time
()
first_loss
,
=
exe
.
run
([
loss
.
name
],
feed
=
feed_dict
)
first_loss
,
=
run_executor
(
exe
=
exe
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
first_loss
=
numpy
.
array
(
first_loss
)
for
i
in
xrange
(
iter
):
exe
.
run
([],
feed
=
feed_dict
)
run_executor
(
exe
=
exe
,
feed
=
feed_dict
,
fetch_list
=
[]
)
last_loss
,
=
exe
.
run
([
loss
.
name
],
feed
=
feed_dict
)
last_loss
,
=
run_executor
(
exe
=
exe
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
end
=
time
.
time
()
if
batch_size
is
not
None
:
...
...
@@ -239,6 +260,7 @@ class TestParallelExecutorBase(unittest.TestCase):
print
first_loss
,
last_loss
# self.assertGreater(first_loss[0], last_loss[0])
return
first_loss
,
last_loss
class
TestMNIST
(
TestParallelExecutorBase
):
...
...
@@ -268,6 +290,27 @@ class TestMNIST(TestParallelExecutorBase):
simple_fc_net
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
})
def
test_simple_fc_parallel_accuracy
(
self
):
img
=
numpy
.
zeros
(
shape
=
[
32
,
784
],
dtype
=
'float32'
)
label
=
numpy
.
ones
(
shape
=
[
32
,
1
],
dtype
=
'int64'
)
single_first_loss
,
single_last_loss
=
self
.
check_network_convergence
(
method
=
simple_fc_net
,
seed
=
1000
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_parallel_executor
=
False
)
parallel_first_loss
,
parallel_last_loss
=
self
.
check_network_convergence
(
method
=
simple_fc_net
,
seed
=
1000
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_parallel_executor
=
True
)
for
p_f
in
parallel_first_loss
:
self
.
assertAlmostEquals
(
p_f
,
single_first_loss
[
0
],
delta
=
1e-6
)
for
p_l
in
parallel_last_loss
:
self
.
assertAlmostEquals
(
p_l
,
single_last_loss
[
0
],
delta
=
1e-6
)
def
test_batchnorm_fc
(
self
):
self
.
check_network_convergence
(
fc_with_batchnorm
)
img
=
numpy
.
zeros
(
shape
=
[
32
,
784
],
dtype
=
'float32'
)
...
...
@@ -496,10 +539,10 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
share_vars_from
=
train_exe
)
for
i
in
xrange
(
5
):
test_loss
,
=
test_exe
.
run
([
loss
.
name
],
feed
_dict
=
feed_dict
)
test_loss
,
=
test_exe
.
run
([
loss
.
name
],
feed
=
feed_dict
)
test_loss
=
numpy
.
array
(
test_loss
)
train_loss
,
=
train_exe
.
run
([
loss
.
name
],
feed
_dict
=
feed_dict
)
train_loss
,
=
train_exe
.
run
([
loss
.
name
],
feed
=
feed_dict
)
train_loss
=
numpy
.
array
(
train_loss
)
self
.
assertTrue
(
numpy
.
allclose
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录