Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
80642bee
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
80642bee
编写于
6月 28, 2017
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix_xmap and refine flowers dataset
上级
633082ad
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
72 addition
and
67 deletion
+72
-67
python/paddle/v2/dataset/__init__.py
python/paddle/v2/dataset/__init__.py
+2
-1
python/paddle/v2/dataset/flowers.py
python/paddle/v2/dataset/flowers.py
+35
-32
python/paddle/v2/dataset/tests/flowers_test.py
python/paddle/v2/dataset/tests/flowers_test.py
+2
-2
python/paddle/v2/reader/decorator.py
python/paddle/v2/reader/decorator.py
+23
-24
python/paddle/v2/reader/tests/decorator_test.py
python/paddle/v2/reader/tests/decorator_test.py
+10
-8
未找到文件。
python/paddle/v2/dataset/__init__.py
浏览文件 @
80642bee
...
@@ -25,8 +25,9 @@ import uci_housing
...
@@ -25,8 +25,9 @@ import uci_housing
import
sentiment
import
sentiment
import
wmt14
import
wmt14
import
mq2007
import
mq2007
import
flowers
__all__
=
[
__all__
=
[
'mnist'
,
'imikolov'
,
'imdb'
,
'cifar'
,
'movielens'
,
'conll05'
,
'sentiment'
'mnist'
,
'imikolov'
,
'imdb'
,
'cifar'
,
'movielens'
,
'conll05'
,
'sentiment'
'uci_housing'
,
'wmt14'
,
'mq2007'
'uci_housing'
,
'wmt14'
,
'mq2007'
,
'flowers'
]
]
python/paddle/v2/dataset/flowers.py
浏览文件 @
80642bee
...
@@ -34,9 +34,9 @@ from common import download
...
@@ -34,9 +34,9 @@ from common import download
import
tarfile
import
tarfile
import
scipy.io
as
scio
import
scipy.io
as
scio
from
paddle.v2.image
import
*
from
paddle.v2.image
import
*
from
paddle.v2.reader
import
*
import
os
import
os
import
numpy
as
np
import
numpy
as
np
import
paddle.v2
as
paddle
from
multiprocessing
import
cpu_count
from
multiprocessing
import
cpu_count
__all__
=
[
'train'
,
'test'
,
'valid'
]
__all__
=
[
'train'
,
'test'
,
'valid'
]
...
@@ -53,8 +53,8 @@ def default_mapper(sample):
...
@@ -53,8 +53,8 @@ def default_mapper(sample):
map image bytes data to type needed by model input layer
map image bytes data to type needed by model input layer
'''
'''
img
,
label
=
sample
img
,
label
=
sample
img
=
paddle
.
image
.
load_image_bytes
(
img
)
img
=
load_image_bytes
(
img
)
img
=
paddle
.
image
.
simple_transform
(
img
,
256
,
224
,
True
)
img
=
simple_transform
(
img
,
256
,
224
,
True
)
return
img
.
flatten
().
astype
(
'float32'
),
label
return
img
.
flatten
().
astype
(
'float32'
),
label
...
@@ -63,7 +63,8 @@ def reader_creator(data_file,
...
@@ -63,7 +63,8 @@ def reader_creator(data_file,
setid_file
,
setid_file
,
dataset_name
,
dataset_name
,
mapper
=
default_mapper
,
mapper
=
default_mapper
,
buffered_size
=
1024
):
buffered_size
=
1024
,
useXmap
=
True
):
'''
'''
1. read images from tar file and
1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/
merge images into batch files in 102flowers.tgz_batch/
...
@@ -105,11 +106,13 @@ def reader_creator(data_file,
...
@@ -105,11 +106,13 @@ def reader_creator(data_file,
for
sample
,
label
in
itertools
.
izip
(
data
,
batch
[
'label'
]):
for
sample
,
label
in
itertools
.
izip
(
data
,
batch
[
'label'
]):
yield
sample
,
int
(
label
)
yield
sample
,
int
(
label
)
return
paddle
.
reader
.
xmap_readers
(
mapper
,
reader
,
if
useXmap
:
cpu_count
(),
buffered_size
)
return
xmap_readers
(
mapper
,
reader
,
cpu_count
(),
buffered_size
)
else
:
return
map_readers
(
mapper
,
reader
)
def
train
(
mapper
=
default_mapper
,
buffered_size
=
1024
):
def
train
(
mapper
=
default_mapper
,
buffered_size
=
1024
,
useXmap
=
True
):
'''
'''
Create flowers training set reader.
Create flowers training set reader.
It returns a reader, each sample in the reader is
It returns a reader, each sample in the reader is
...
@@ -128,11 +131,11 @@ def train(mapper=default_mapper, buffered_size=1024):
...
@@ -128,11 +131,11 @@ def train(mapper=default_mapper, buffered_size=1024):
return
reader_creator
(
return
reader_creator
(
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
),
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
),
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
),
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
),
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
),
't
rn
id'
,
mapper
,
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
),
't
st
id'
,
mapper
,
buffered_size
)
buffered_size
,
useXmap
)
def
test
(
mapper
=
default_mapper
,
buffered_size
=
1024
):
def
test
(
mapper
=
default_mapper
,
buffered_size
=
1024
,
useXmap
=
True
):
'''
'''
Create flowers test set reader.
Create flowers test set reader.
It returns a reader, each sample in the reader is
It returns a reader, each sample in the reader is
...
@@ -151,11 +154,11 @@ def test(mapper=default_mapper, buffered_size=1024):
...
@@ -151,11 +154,11 @@ def test(mapper=default_mapper, buffered_size=1024):
return
reader_creator
(
return
reader_creator
(
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
),
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
),
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
),
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
),
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
),
't
st
id'
,
mapper
,
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
),
't
rn
id'
,
mapper
,
buffered_size
)
buffered_size
,
useXmap
)
def
valid
(
mapper
=
default_mapper
,
buffered_size
=
1024
):
def
valid
(
mapper
=
default_mapper
,
buffered_size
=
1024
,
useXmap
=
True
):
'''
'''
Create flowers validation set reader.
Create flowers validation set reader.
It returns a reader, each sample in the reader is
It returns a reader, each sample in the reader is
...
@@ -175,7 +178,7 @@ def valid(mapper=default_mapper, buffered_size=1024):
...
@@ -175,7 +178,7 @@ def valid(mapper=default_mapper, buffered_size=1024):
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
),
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
),
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
),
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
),
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
),
'valid'
,
mapper
,
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
),
'valid'
,
mapper
,
buffered_size
)
buffered_size
,
useXmap
)
def
fetch
():
def
fetch
():
...
...
python/paddle/v2/dataset/tests/flowers_test.py
浏览文件 @
80642bee
...
@@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase):
...
@@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase):
def
test_train
(
self
):
def
test_train
(
self
):
instances
,
max_label_value
=
self
.
check_reader
(
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
flowers
.
train
())
paddle
.
v2
.
dataset
.
flowers
.
train
())
self
.
assertEqual
(
instances
,
1020
)
self
.
assertEqual
(
instances
,
6149
)
self
.
assertEqual
(
max_label_value
,
102
)
self
.
assertEqual
(
max_label_value
,
102
)
def
test_test
(
self
):
def
test_test
(
self
):
instances
,
max_label_value
=
self
.
check_reader
(
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
flowers
.
test
())
paddle
.
v2
.
dataset
.
flowers
.
test
())
self
.
assertEqual
(
instances
,
6149
)
self
.
assertEqual
(
instances
,
1020
)
self
.
assertEqual
(
max_label_value
,
102
)
self
.
assertEqual
(
max_label_value
,
102
)
def
test_valid
(
self
):
def
test_valid
(
self
):
...
...
python/paddle/v2/reader/decorator.py
浏览文件 @
80642bee
...
@@ -248,9 +248,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
...
@@ -248,9 +248,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
:rtype: callable
:rtype: callable
"""
"""
end
=
XmapEndSignal
()
end
=
XmapEndSignal
()
in_queue
=
Queue
(
buffer_size
)
out_queue
=
Queue
(
buffer_size
)
out_order
=
[
0
]
# define a worker to read samples from reader to in_queue
# define a worker to read samples from reader to in_queue
def
read_worker
(
reader
,
in_queue
):
def
read_worker
(
reader
,
in_queue
):
...
@@ -266,12 +263,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
...
@@ -266,12 +263,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_order
+=
1
in_order
+=
1
in_queue
.
put
(
end
)
in_queue
.
put
(
end
)
# start a read worker in a thread
target
=
order_read_worker
if
order
else
read_worker
t
=
Thread
(
target
=
target
,
args
=
(
reader
,
in_queue
))
t
.
daemon
=
True
t
.
start
()
# define a worker to handle samples from in_queue by mapper
# define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue
# and put mapped samples into out_queue
def
handle_worker
(
in_queue
,
out_queue
,
mapper
):
def
handle_worker
(
in_queue
,
out_queue
,
mapper
):
...
@@ -298,6 +289,15 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
...
@@ -298,6 +289,15 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_queue
.
put
(
end
)
in_queue
.
put
(
end
)
out_queue
.
put
(
end
)
out_queue
.
put
(
end
)
def
xreader
():
in_queue
=
Queue
(
buffer_size
)
out_queue
=
Queue
(
buffer_size
)
out_order
=
[
0
]
# start a read worker in a thread
target
=
order_read_worker
if
order
else
read_worker
t
=
Thread
(
target
=
target
,
args
=
(
reader
,
in_queue
))
t
.
daemon
=
True
t
.
start
()
# start several handle_workers
# start several handle_workers
target
=
order_handle_worker
if
order
else
handle_worker
target
=
order_handle_worker
if
order
else
handle_worker
args
=
(
in_queue
,
out_queue
,
mapper
,
out_order
)
if
order
else
(
args
=
(
in_queue
,
out_queue
,
mapper
,
out_order
)
if
order
else
(
...
@@ -310,7 +310,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
...
@@ -310,7 +310,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
for
w
in
workers
:
for
w
in
workers
:
w
.
start
()
w
.
start
()
def
xreader
():
sample
=
out_queue
.
get
()
sample
=
out_queue
.
get
()
while
not
isinstance
(
sample
,
XmapEndSignal
):
while
not
isinstance
(
sample
,
XmapEndSignal
):
yield
sample
yield
sample
...
...
python/paddle/v2/reader/tests/decorator_test.py
浏览文件 @
80642bee
...
@@ -132,10 +132,12 @@ class TestXmap(unittest.TestCase):
...
@@ -132,10 +132,12 @@ class TestXmap(unittest.TestCase):
for
order
in
orders
:
for
order
in
orders
:
for
tNum
in
thread_nums
:
for
tNum
in
thread_nums
:
for
size
in
buffered_size
:
for
size
in
buffered_size
:
result
=
[]
reader
=
paddle
.
v2
.
reader
.
xmap_readers
(
mapper
,
for
i
in
paddle
.
v2
.
reader
.
xmap_readers
(
mapper
,
reader_creator_10
(
0
),
reader_creator_10
(
0
),
tNum
,
size
,
order
)():
tNum
,
size
,
order
)
for
n
in
xrange
(
3
):
result
=
[]
for
i
in
reader
():
result
.
append
(
i
)
result
.
append
(
i
)
if
not
order
:
if
not
order
:
result
.
sort
()
result
.
sort
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录