Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9c8383cf
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9c8383cf
编写于
5月 10, 2018
作者:
Y
yuyang18
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Parallel Executor revised feeder
上级
28de0ea4
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
95 addition
and
0 deletion
+95
-0
python/paddle/fluid/data_feeder.py
python/paddle/fluid/data_feeder.py
+58
-0
python/paddle/fluid/tests/unittests/test_parallel_executor.py
...on/paddle/fluid/tests/unittests/test_parallel_executor.py
+37
-0
未找到文件。
python/paddle/fluid/data_feeder.py
浏览文件 @
9c8383cf
...
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
core
import
numpy
import
six.moves
as
six
import
multiprocessing
from
framework
import
Variable
,
default_main_program
...
...
@@ -116,3 +117,60 @@ class DataFeeder(object):
for
each_name
,
each_converter
in
six
.
zip
(
self
.
feed_names
,
converter
):
ret_dict
[
each_name
]
=
each_converter
.
done
()
return
ret_dict
def
feed_parallel
(
self
,
iterable
,
num_places
=
None
):
if
isinstance
(
self
.
place
,
core
.
CUDAPlace
):
places
=
[
core
.
CUDAPlace
(
i
)
for
i
in
six
.
xrange
(
self
.
_get_number_of_places_
(
num_places
))
]
else
:
places
=
[
core
.
CPUPlace
()
for
_
in
six
.
xrange
(
self
.
_get_number_of_places_
(
num_places
))
]
if
len
(
iterable
)
!=
len
(
places
):
raise
ValueError
(
"feed_parallel takes multiple mini-batches. Each "
"mini-batch will be feed on each device. The "
"number of devices and number of mini-batches "
"must be same."
)
place
=
self
.
place
for
p
,
batch
in
six
.
zip
(
places
,
iterable
):
self
.
place
=
p
yield
self
.
feed
(
batch
)
self
.
place
=
place
def
_get_number_of_places_
(
self
,
num_places
):
if
num_places
is
not
None
:
return
int
(
num_places
)
elif
isinstance
(
self
.
place
,
core
.
CUDAPlace
):
return
core
.
get_cuda_device_count
()
else
:
return
multiprocessing
.
cpu_count
()
def
decorate_reader
(
self
,
reader
,
multi_devices
,
num_places
=
None
,
drop_last
=
True
):
def
__reader_creator__
():
if
not
multi_devices
:
for
item
in
reader
():
yield
self
.
feed
(
item
)
else
:
num
=
self
.
_get_number_of_places_
(
num_places
)
item
=
[]
for
batch
in
reader
():
item
.
append
(
batch
)
if
len
(
item
)
==
num
:
yield
list
(
self
.
feed_parallel
(
item
,
num
))
item
=
[]
if
not
drop_last
and
len
(
item
)
!=
0
:
raise
ValueError
(
"The data batch which cannot fit for devices will be "
"dropped is not implementation. Other strategies are "
"not implemented"
)
return
__reader_creator__
python/paddle/fluid/tests/unittests/test_parallel_executor.py
浏览文件 @
9c8383cf
...
...
@@ -796,5 +796,42 @@ class TestFetchOp(unittest.TestCase):
self
.
parallel_exe
(
train_inputs
,
seed
=
1
)
class
TestFeedParallel
(
unittest
.
TestCase
):
def
test_main
(
self
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
startup
.
random_seed
=
1
with
fluid
.
scope_guard
(
fluid
.
core
.
Scope
()):
with
fluid
.
program_guard
(
main
,
startup
):
data
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
3
,
224
,
224
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
out
=
Lenet
(
data
,
class_dim
=
102
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
out
,
label
=
label
)
loss
=
fluid
.
layers
.
mean
(
loss
)
opt
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
0.1
,
momentum
=
0.9
,
regularization
=
fluid
.
regularizer
.
L2Decay
(
1e-4
))
opt
.
minimize
(
loss
)
place
=
fluid
.
CUDAPlace
(
0
)
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
data
,
label
])
reader
=
feeder
.
decorate_reader
(
paddle
.
batch
(
flowers
.
train
(),
batch_size
=
16
),
multi_devices
=
True
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup
)
pe
=
fluid
.
ParallelExecutor
(
use_cuda
=
True
,
loss_name
=
loss
.
name
,
main_program
=
main
)
for
batch_id
,
data
in
enumerate
(
reader
()):
loss_np
=
np
.
array
(
pe
.
run
(
feed
=
data
,
fetch_list
=
[
loss
.
name
])[
0
])
print
batch_id
,
loss_np
if
batch_id
==
2
:
break
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录