Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c083ee70
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c083ee70
编写于
4月 18, 2018
作者:
F
fengjiayi
提交者:
GitHub
4月 18, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9950 from JiayiFeng/add_parallel_executor_tests
Add parallel executor tests
上级
61f4baa1
e84d3a7f
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
58 addition
and
14 deletion
+58
-14
python/paddle/fluid/parallel_executor.py
python/paddle/fluid/parallel_executor.py
+6
-5
python/paddle/fluid/tests/unittests/test_parallel_executor.py
...on/paddle/fluid/tests/unittests/test_parallel_executor.py
+52
-9
未找到文件。
python/paddle/fluid/parallel_executor.py
浏览文件 @
c083ee70
...
@@ -16,6 +16,7 @@ import core
...
@@ -16,6 +16,7 @@ import core
import
multiprocessing
import
multiprocessing
import
framework
import
framework
import
executor
import
executor
import
warnings
import
sys
import
sys
__all__
=
[
'ParallelExecutor'
]
__all__
=
[
'ParallelExecutor'
]
...
@@ -62,8 +63,8 @@ class ParallelExecutor(object):
...
@@ -62,8 +63,8 @@ class ParallelExecutor(object):
main_program=test_program,
main_program=test_program,
share_vars_from=train_exe)
share_vars_from=train_exe)
train_loss, = train_exe.run([loss.name], feed
_dict
=feed_dict)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
test_loss, = test_exe.run([loss.name], feed
_dict
=feed_dict)
test_loss, = test_exe.run([loss.name], feed=feed_dict)
"""
"""
self
.
_places
=
[]
self
.
_places
=
[]
...
@@ -103,8 +104,8 @@ class ParallelExecutor(object):
...
@@ -103,8 +104,8 @@ class ParallelExecutor(object):
self
.
persistable_vars
=
[
self
.
persistable_vars
=
[
v
.
name
v
.
name
for
v
in
filter
(
lambda
var
:
\
for
v
in
filter
(
var
.
persistable
and
var
.
type
!=
core
.
VarDesc
.
VarType
.
RAW
,
lambda
var
:
var
.
persistable
and
var
.
type
!=
core
.
VarDesc
.
VarType
.
RAW
,
main
.
list_vars
())
main
.
list_vars
())
]
]
...
@@ -163,7 +164,7 @@ class ParallelExecutor(object):
...
@@ -163,7 +164,7 @@ class ParallelExecutor(object):
Returns: fetched result list.
Returns: fetched result list.
"""
"""
if
feed
is
None
:
if
feed
is
None
and
feed_dict
is
not
None
:
feed
=
feed_dict
feed
=
feed_dict
print
>>
sys
.
stderr
,
"`feed_dict` is deprecated. Please use `feed=`"
print
>>
sys
.
stderr
,
"`feed_dict` is deprecated. Please use `feed=`"
...
...
python/paddle/fluid/tests/unittests/test_parallel_executor.py
浏览文件 @
c083ee70
...
@@ -200,14 +200,29 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -200,14 +200,29 @@ class TestParallelExecutorBase(unittest.TestCase):
def
check_network_convergence
(
self
,
def
check_network_convergence
(
self
,
method
,
method
,
memory_opt
=
True
,
memory_opt
=
True
,
iter
=
1
0
,
iter
=
5
0
,
batch_size
=
None
,
batch_size
=
None
,
allow_op_delay
=
False
,
allow_op_delay
=
False
,
feed_dict
=
None
):
feed_dict
=
None
,
seed
=
None
,
use_parallel_executor
=
True
):
def
run_executor
(
exe
,
feed
,
fetch_list
,
program
=
None
):
if
isinstance
(
exe
,
fluid
.
ParallelExecutor
):
res
=
exe
.
run
(
fetch_list
=
fetch_list
,
feed
=
feed
)
elif
isinstance
(
exe
,
fluid
.
Executor
):
if
program
is
None
:
program
=
fluid
.
default_main_program
()
res
=
exe
.
run
(
program
=
program
,
feed
=
feed
,
fetch_list
=
fetch_list
)
else
:
raise
ValueError
(
'Unkown type exe'
)
return
res
main
=
fluid
.
Program
()
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
startup
.
random_seed
=
1
# Fix random seed
startup
.
random_seed
=
1
# Fix random seed
with
fluid
.
program_guard
(
main
,
startup
):
with
fluid
.
program_guard
(
main
,
startup
):
if
seed
is
not
None
:
startup
.
random_seed
=
seed
loss
=
method
(
use_feed
=
feed_dict
is
not
None
)
loss
=
method
(
use_feed
=
feed_dict
is
not
None
)
adam
=
fluid
.
optimizer
.
Adam
()
adam
=
fluid
.
optimizer
.
Adam
()
adam
.
minimize
(
loss
)
adam
.
minimize
(
loss
)
...
@@ -217,18 +232,24 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -217,18 +232,24 @@ class TestParallelExecutorBase(unittest.TestCase):
startup_exe
=
fluid
.
Executor
(
place
)
startup_exe
=
fluid
.
Executor
(
place
)
startup_exe
.
run
(
startup
)
startup_exe
.
run
(
startup
)
if
use_parallel_executor
:
exe
=
fluid
.
ParallelExecutor
(
exe
=
fluid
.
ParallelExecutor
(
True
,
loss_name
=
loss
.
name
,
allow_op_delay
=
allow_op_delay
)
True
,
loss_name
=
loss
.
name
,
allow_op_delay
=
allow_op_delay
)
else
:
exe
=
fluid
.
Executor
(
place
=
place
)
if
batch_size
is
not
None
:
if
batch_size
is
not
None
:
batch_size
*=
fluid
.
core
.
get_cuda_device_count
()
batch_size
*=
fluid
.
core
.
get_cuda_device_count
()
begin
=
time
.
time
()
begin
=
time
.
time
()
first_loss
,
=
exe
.
run
([
loss
.
name
],
feed
=
feed_dict
)
first_loss
,
=
run_executor
(
exe
=
exe
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
first_loss
=
numpy
.
array
(
first_loss
)
first_loss
=
numpy
.
array
(
first_loss
)
for
i
in
xrange
(
iter
):
for
i
in
xrange
(
iter
):
exe
.
run
([],
feed
=
feed_dict
)
run_executor
(
exe
=
exe
,
feed
=
feed_dict
,
fetch_list
=
[]
)
last_loss
,
=
exe
.
run
([
loss
.
name
],
feed
=
feed_dict
)
last_loss
,
=
run_executor
(
exe
=
exe
,
feed
=
feed_dict
,
fetch_list
=
[
loss
.
name
])
end
=
time
.
time
()
end
=
time
.
time
()
if
batch_size
is
not
None
:
if
batch_size
is
not
None
:
...
@@ -239,6 +260,7 @@ class TestParallelExecutorBase(unittest.TestCase):
...
@@ -239,6 +260,7 @@ class TestParallelExecutorBase(unittest.TestCase):
print
first_loss
,
last_loss
print
first_loss
,
last_loss
# self.assertGreater(first_loss[0], last_loss[0])
# self.assertGreater(first_loss[0], last_loss[0])
return
first_loss
,
last_loss
class
TestMNIST
(
TestParallelExecutorBase
):
class
TestMNIST
(
TestParallelExecutorBase
):
...
@@ -268,6 +290,27 @@ class TestMNIST(TestParallelExecutorBase):
...
@@ -268,6 +290,27 @@ class TestMNIST(TestParallelExecutorBase):
simple_fc_net
,
feed_dict
=
{
"image"
:
img
,
simple_fc_net
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
})
"label"
:
label
})
def
test_simple_fc_parallel_accuracy
(
self
):
img
=
numpy
.
zeros
(
shape
=
[
32
,
784
],
dtype
=
'float32'
)
label
=
numpy
.
ones
(
shape
=
[
32
,
1
],
dtype
=
'int64'
)
single_first_loss
,
single_last_loss
=
self
.
check_network_convergence
(
method
=
simple_fc_net
,
seed
=
1000
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_parallel_executor
=
False
)
parallel_first_loss
,
parallel_last_loss
=
self
.
check_network_convergence
(
method
=
simple_fc_net
,
seed
=
1000
,
feed_dict
=
{
"image"
:
img
,
"label"
:
label
},
use_parallel_executor
=
True
)
for
p_f
in
parallel_first_loss
:
self
.
assertAlmostEquals
(
p_f
,
single_first_loss
[
0
],
delta
=
1e-6
)
for
p_l
in
parallel_last_loss
:
self
.
assertAlmostEquals
(
p_l
,
single_last_loss
[
0
],
delta
=
1e-6
)
def
test_batchnorm_fc
(
self
):
def
test_batchnorm_fc
(
self
):
self
.
check_network_convergence
(
fc_with_batchnorm
)
self
.
check_network_convergence
(
fc_with_batchnorm
)
img
=
numpy
.
zeros
(
shape
=
[
32
,
784
],
dtype
=
'float32'
)
img
=
numpy
.
zeros
(
shape
=
[
32
,
784
],
dtype
=
'float32'
)
...
@@ -496,10 +539,10 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
...
@@ -496,10 +539,10 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
share_vars_from
=
train_exe
)
share_vars_from
=
train_exe
)
for
i
in
xrange
(
5
):
for
i
in
xrange
(
5
):
test_loss
,
=
test_exe
.
run
([
loss
.
name
],
feed
_dict
=
feed_dict
)
test_loss
,
=
test_exe
.
run
([
loss
.
name
],
feed
=
feed_dict
)
test_loss
=
numpy
.
array
(
test_loss
)
test_loss
=
numpy
.
array
(
test_loss
)
train_loss
,
=
train_exe
.
run
([
loss
.
name
],
feed
_dict
=
feed_dict
)
train_loss
,
=
train_exe
.
run
([
loss
.
name
],
feed
=
feed_dict
)
train_loss
=
numpy
.
array
(
train_loss
)
train_loss
=
numpy
.
array
(
train_loss
)
self
.
assertTrue
(
self
.
assertTrue
(
numpy
.
allclose
(
numpy
.
allclose
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录