Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
2799b0ec
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2799b0ec
编写于
5月 24, 2017
作者:
W
wanghaoshuang@baidu.com
提交者:
wanghaoshuang
6月 05, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add flowers dataset for image classification model
上级
b15b2637
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
409 addition
and
8 deletion
+409
-8
python/paddle/v2/dataset/flowers.py
python/paddle/v2/dataset/flowers.py
+255
-0
python/paddle/v2/dataset/tests/flowers_test.py
python/paddle/v2/dataset/tests/flowers_test.py
+51
-0
python/paddle/v2/image.py
python/paddle/v2/image.py
+29
-7
python/paddle/v2/reader/decorator.py
python/paddle/v2/reader/decorator.py
+74
-1
未找到文件。
python/paddle/v2/dataset/flowers.py
0 → 100644
浏览文件 @
2799b0ec
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CIFAR dataset.
This module will download dataset from
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
and parse train/test set intopaddle reader creators.
This set contains images of flowers belonging to 102 different categories.
The images were acquired by searching the web and taking pictures. There are a
minimum of 40 images for each category.
The database was used in:
Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
number of classes.Proceedings of the Indian Conference on Computer Vision,
Graphics and Image Processing (2008)
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
"""
import
cPickle
import
itertools
from
common
import
download
import
tarfile
import
scipy.io
as
scio
from
image
import
*
import
os
from
multiprocessing
import
Process
from
multiprocessing
import
Pool
from
multiprocessing
import
cpu_count
import
numpy
as
np
import
paddle.v2
as
paddle
__all__
=
[
'train'
,
'test'
,
'valid'
]
DATA_URL
=
'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
LABEL_URL
=
'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
SETID_URL
=
'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5
=
'52808999861908f626f3c1f4e79d11fa'
LABEL_MD5
=
'e0620be6f572b9609742df49c70aed4d'
SETID_MD5
=
'a5357ecc9cb78c4bef273ce3793fc85c'
def
extract_file
(
tarFile
):
'''
Extract tar file to tmp dir.
Example usage:
.. code-block:: python
tmp = extract_file("/home/work/test.tar.gz")
:param tarFile: target tar file
:type tarFile: string
:return: extracted dir. For example:
'/home/work/test/' while input is '/home/work/test.tar.gz'
:rtype: string
'''
base_dir
=
os
.
path
.
dirname
(
tarFile
)
base_name
=
os
.
path
.
basename
(
tarFile
)
if
'.'
in
base_name
:
base_name
=
base_name
.
split
(
'.'
,
1
)[
0
]
out_path
=
'/'
.
join
([
base_dir
,
base_name
])
if
not
os
.
path
.
exists
(
out_path
):
df
=
tarfile
.
open
(
tarFile
,
mode
=
'r'
)
df
.
extractall
(
path
=
out_path
)
df
.
close
()
return
out_path
def
default_mapper
(
sample
):
'''
map image bytes data to type needed by model input layer
'''
img
,
label
=
sample
img
=
paddle
.
image
.
load_image_bytes
(
img
)
img
=
paddle
.
image
.
simple_transform
(
img
,
256
,
224
,
True
)
return
img
.
flatten
().
astype
(
'float32'
),
label
def
reader_creator
(
data_file
,
label_file
,
setid_file
,
flag
,
mapper
=
default_mapper
):
'''
1. extract 102flowers.tgz to 102flowers/
2. merge images into batch files in 102flowers_batch/
3. get a reader to read sample from batch file
:param data_file: downloaded data file
:type data_file: string
:param label_file: downloaded label file
:type label_file: string
:param setid_file: downloaded setid file containing information
about how to split dataset
:type setid_file: string
:param flag: data set name (tstid|trnid|valid)
:type flag: string
:param mapper: a function to map image bytes data to type
needed by model input layer
:type mapper: callable
:return: data reader
:rtype: callable
'''
base_dir
=
os
.
path
.
dirname
(
data_file
)
tmp_dir
=
extract_file
(
data_file
)
file_list
=
create_batch
(
tmp_dir
,
label_file
,
setid_file
,
flag
)
def
reader
():
for
file
in
open
(
file_list
):
file
=
file
.
strip
()
batch
=
None
with
open
(
file
,
'r'
)
as
f
:
batch
=
cPickle
.
load
(
f
)
data
=
batch
[
'data'
]
labels
=
batch
[
'label'
]
for
sample
,
label
in
itertools
.
izip
(
data
,
batch
[
'label'
]):
yield
sample
,
int
(
label
)
return
paddle
.
reader
.
xmap
(
mapper
,
reader
,
cpu_count
(),
1024
*
8
)
def
create_batch
(
data_dir
,
label_file
,
setid_file
,
flag
,
numPerBatch
=
1024
,
nThread
=
16
):
batch_dir
=
data_dir
+
"_batch"
labels
=
scio
.
loadmat
(
label_file
)[
'labels'
][
0
]
indexes
=
scio
.
loadmat
(
setid_file
)[
flag
][
0
]
count
=
len
(
indexes
)
out_path
=
"%s/%s"
%
(
batch_dir
,
flag
)
meta_file
=
"%s/%s.txt"
%
(
batch_dir
,
flag
)
if
os
.
path
.
exists
(
out_path
):
return
meta_file
else
:
os
.
makedirs
(
out_path
)
def
batch
(
file_out
,
start
,
end
):
data
=
[]
labellist
=
[]
for
index
in
indexes
[
start
:
end
]:
img_name
=
"%s/jpg/image_%05d.jpg"
%
(
data_dir
,
index
)
with
open
(
img_name
,
'r'
)
as
f
:
data
.
append
(
f
.
read
())
labellist
.
append
(
labels
[
index
-
1
])
output
=
{}
output
[
'label'
]
=
labellist
output
[
'data'
]
=
data
cPickle
.
dump
(
output
,
open
(
file_out
,
'w'
),
protocol
=
cPickle
.
HIGHEST_PROTOCOL
)
cur_id
=
0
file_id
=
0
while
cur_id
<
count
:
thread
=
[]
for
i
in
xrange
(
nThread
):
end_id
=
min
(
cur_id
+
numPerBatch
,
count
)
batch_file_name
=
"%s/batch_%05d"
%
(
out_path
,
file_id
)
w
=
Process
(
target
=
batch
,
args
=
(
batch_file_name
,
cur_id
,
end_id
))
w
.
daemon
=
True
thread
.
append
(
w
)
cur_id
=
end_id
file_id
+=
1
if
cur_id
==
count
:
break
for
t
in
thread
:
t
.
start
()
for
t
in
thread
:
t
.
join
()
with
open
(
meta_file
,
'a'
)
as
meta
:
for
file
in
os
.
listdir
(
out_path
):
meta
.
write
(
os
.
path
.
abspath
(
"%s/%s"
%
(
out_path
,
file
))
+
"
\n
"
)
return
meta_file
def
train
(
mapper
=
default_mapper
):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:return: train data reader
:rtype: callable
'''
return
reader_creator
(
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
),
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
),
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
),
'trnid'
)
def
test
(
mapper
=
default_mapper
):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:return: test data reader
:rtype: callable
'''
return
reader_creator
(
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
),
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
),
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
),
'tstid'
)
def
valid
():
'''
Create flowers validation set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
'''
return
reader_creator
(
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
),
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
),
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
),
'valid'
)
def
fetch
():
download
(
DATA_URL
,
'flowers'
,
DATA_MD5
)
download
(
LABEL_URL
,
'flowers'
,
LABEL_MD5
)
download
(
SETID_URL
,
'flowers'
,
SETID_MD5
)
if
__name__
==
'__main__'
:
for
i
in
test
()():
pass
python/paddle/v2/dataset/tests/flowers_test.py
0 → 100644
浏览文件 @
2799b0ec
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.v2.dataset.flowers
import
unittest
class
TestFlowers
(
unittest
.
TestCase
):
def
check_reader
(
self
,
reader
):
sum
=
0
label
=
0
size
=
224
*
224
*
3
for
l
in
reader
():
self
.
assertEqual
(
l
[
0
].
size
,
size
)
if
l
[
1
]
>
label
:
label
=
l
[
1
]
sum
+=
1
return
sum
,
label
def
test_train
(
self
):
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
flowers
.
train
())
self
.
assertEqual
(
instances
,
1020
)
self
.
assertEqual
(
max_label_value
,
102
)
def
test_test
(
self
):
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
flowers
.
test
())
self
.
assertEqual
(
instances
,
6149
)
self
.
assertEqual
(
max_label_value
,
102
)
def
test_valid
(
self
):
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
flowers
.
valid
())
self
.
assertEqual
(
instances
,
1020
)
self
.
assertEqual
(
max_label_value
,
102
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/v2/image.py
浏览文件 @
2799b0ec
import
numpy
as
np
try
:
import
cv2
except
:
print
(
"import cv2 error, please install opencv-python: pip install opencv-python"
)
except
ImportError
:
cv2
=
None
from
cv2
import
resize
__all__
=
[
"load_image
"
,
"resize_short"
,
"to_chw"
,
"center_crop"
,
"random
_crop"
,
"left_right_flip"
,
"simple_transform"
,
"load_and_transform"
"load_image
_bytes"
,
"load_image"
,
"resize_short"
,
"to_chw"
,
"center
_crop"
,
"
random_crop"
,
"
left_right_flip"
,
"simple_transform"
,
"load_and_transform"
]
"""
This file contains some common interfaces for image preprocess.
...
...
@@ -28,6 +28,28 @@ the image layout as follows.
"""
def
load_image_bytes
(
bytes
,
is_color
=
True
):
"""
Load an color or gray image from bytes array.
Example usage:
.. code-block:: python
with open('cat.jpg') as f:
im = load_image(f.read())
:param bytes: the input image bytes array.
:type file: str
:param is_color: If set is_color True, it will load and
return a color image. Otherwise, it will
load and return a gray image.
"""
flag
=
1
if
is_color
else
0
file_bytes
=
np
.
asarray
(
bytearray
(
bytes
),
dtype
=
np
.
uint8
)
img
=
cv2
.
imdecode
(
file_bytes
,
flag
)
return
img
def
load_image
(
file
,
is_color
=
True
):
"""
Load an color or gray image from the file path.
...
...
@@ -76,7 +98,7 @@ def resize_short(im, size):
h_new
=
size
*
h
/
w
else
:
w_new
=
size
*
w
/
h
im
=
cv2
.
resize
(
im
,
(
h_new
,
w_new
),
interpolation
=
cv2
.
INTER_CUBIC
)
im
=
resize
(
im
,
(
h_new
,
w_new
),
interpolation
=
cv2
.
INTER_CUBIC
)
return
im
...
...
python/paddle/v2/reader/decorator.py
浏览文件 @
2799b0ec
...
...
@@ -14,13 +14,15 @@
__all__
=
[
'map_readers'
,
'buffered'
,
'compose'
,
'chain'
,
'shuffle'
,
'ComposeNotAligned'
,
'firstn'
'ComposeNotAligned'
,
'firstn'
,
'xmap'
]
import
itertools
import
random
from
Queue
import
Queue
from
threading
import
Thread
from
multiprocessing
import
Queue
as
MQueue
from
multiprocessing
import
Process
def
map_readers
(
func
,
*
readers
):
...
...
@@ -224,3 +226,74 @@ def firstn(reader, n):
yield
item
return
firstn_reader
class
XmapEndSignal
():
pass
def
xmap
(
mapper
,
reader
,
process_num
,
buffer_size
):
"""
Use multiprocess to map samples from reader by a mapper defined by user.
And this function contains a buffered decorator.
:param mapper: a function to map sample.
:type mapper: callable
:param reader: the data reader to read from
:type reader: callable
:param process_num: process number to handle original sample
:type process_num: int
:param buffer_size: max buffer size
:type buffer_size: int
:return: the decarated reader
:rtype: callable
"""
end
=
XmapEndSignal
()
in_queue
=
MQueue
(
buffer_size
)
out_queue
=
MQueue
(
buffer_size
)
# define a worker to read samples from reader to in_queue
def
read_worker
(
reader
,
in_queue
):
for
i
in
reader
():
in_queue
.
put
(
i
)
in_queue
.
put
(
end
)
# start a read worker in a thread
t
=
Thread
(
target
=
read_worker
,
args
=
(
reader
,
in_queue
))
t
.
daemon
=
True
t
.
start
()
# define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue
def
handle_worker
(
in_queue
,
out_queue
,
mapper
):
sample
=
in_queue
.
get
()
while
not
isinstance
(
sample
,
XmapEndSignal
):
r
=
mapper
(
sample
)
out_queue
.
put
(
r
)
sample
=
in_queue
.
get
()
in_queue
.
put
(
end
)
out_queue
.
put
(
end
)
# start several handle_workers
workers
=
[]
for
i
in
xrange
(
process_num
):
worker
=
Process
(
target
=
handle_worker
,
args
=
(
in_queue
,
out_queue
,
mapper
))
worker
.
daemon
=
True
workers
.
append
(
worker
)
for
w
in
workers
:
w
.
start
()
def
xreader
():
sample
=
out_queue
.
get
()
while
not
isinstance
(
sample
,
XmapEndSignal
):
yield
sample
sample
=
out_queue
.
get
()
finish
=
1
while
finish
<
process_num
:
sample
=
out_queue
.
get
()
if
isinstance
(
sample
,
XmapEndSignal
):
finish
+=
1
else
:
yield
sample
return
xreader
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录