Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0bcc4d48
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
0bcc4d48
编写于
2月 27, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplize cifar
上级
434ada47
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
53 addition
and
117 deletion
+53
-117
python/paddle/v2/dataset/cifar.py
python/paddle/v2/dataset/cifar.py
+53
-117
未找到文件。
python/paddle/v2/dataset/cifar.py
浏览文件 @
0bcc4d48
...
...
@@ -15,159 +15,95 @@ import cPickle
import
itertools
import
numpy
__all__
=
[
'CIFAR10'
,
'CIFAR100'
,
'train_creator'
,
'test_creator'
]
__all__
=
[
'cifar_100_train_creator'
,
'cifar_100_test_creator'
,
'train_creator'
,
'test_creator'
]
def
__download_file__
(
filename
,
url
,
md5
):
def
__file_ok__
():
if
not
os
.
path
.
exists
(
filename
):
return
False
md5_hash
=
hashlib
.
md5
()
with
open
(
filename
,
'rb'
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
4096
),
b
""
):
md5_hash
.
update
(
chunk
)
return
md5_hash
.
hexdigest
()
==
md5
while
not
__file_ok__
():
response
=
urllib2
.
urlopen
(
url
)
with
open
(
filename
,
mode
=
'wb'
)
as
of
:
shutil
.
copyfileobj
(
fsrc
=
response
,
fdst
=
of
)
CIFAR10_URL
=
'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
CIFAR10_MD5
=
'c58f30108f718f92721af3b95e74349a'
CIFAR100_URL
=
'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
CIFAR100_MD5
=
'eb9058c3a382ffc7106e4002c42a8d85'
def
__read_one_batch__
(
batch
):
def
__read_batch__
(
filename
,
sub_name
):
def
reader
():
def
__read_one_batch_impl__
(
batch
):
data
=
batch
[
'data'
]
labels
=
batch
.
get
(
'labels'
,
batch
.
get
(
'fine_labels'
,
None
))
assert
labels
is
not
None
for
sample
,
label
in
itertools
.
izip
(
data
,
labels
):
yield
(
sample
/
255.0
).
astype
(
numpy
.
float32
),
int
(
label
)
with
tarfile
.
open
(
filename
,
mode
=
'r'
)
as
f
:
names
=
(
each_item
.
name
for
each_item
in
f
if
sub_name
in
each_item
.
name
)
CIFAR10_URL
=
'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
CIFAR10_MD5
=
'c58f30108f718f92721af3b95e74349a'
CIFAR100_URL
=
'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
CIFAR100_MD5
=
'eb9058c3a382ffc7106e4002c42a8d85'
for
name
in
names
:
batch
=
cPickle
.
load
(
f
.
extractfile
(
name
))
for
item
in
__read_one_batch_impl__
(
batch
):
yield
item
class
CIFAR
(
object
):
"""
CIFAR dataset reader. The base class for CIFAR-10 and CIFAR-100
return
reader
:param url: Download url.
:param md5: File md5sum
:param meta_filename: Meta file name in package.
:param train_filename: Train file name in package.
:param test_filename: Test file name in package.
"""
def
__init__
(
self
,
url
,
md5
,
meta_filename
,
train_filename
,
test_filename
):
def
download
(
url
,
md5
):
filename
=
os
.
path
.
split
(
url
)[
-
1
]
assert
DATA_HOME
is
not
None
filepath
=
os
.
path
.
join
(
DATA_HOME
,
md5
)
if
not
os
.
path
.
exists
(
filepath
):
os
.
makedirs
(
filepath
)
__full_file__
=
os
.
path
.
join
(
filepath
,
filename
)
self
.
__full_file__
=
os
.
path
.
join
(
filepath
,
filename
)
self
.
__meta_filename__
=
meta_filename
self
.
__train_filename__
=
train_filename
self
.
__test_filename__
=
test_filename
__download_file__
(
filename
=
self
.
__full_file__
,
url
=
url
,
md5
=
md5
)
def
labels
(
self
):
"""
labels get all dataset label in order.
:return: a list of label.
:rtype: list[string]
"""
with
tarfile
.
open
(
self
.
__full_file__
,
mode
=
'r'
)
as
f
:
name
=
[
each_item
.
name
for
each_item
in
f
if
self
.
__meta_filename__
in
each_item
.
name
][
0
]
meta_f
=
f
.
extractfile
(
name
)
meta
=
cPickle
.
load
(
meta_f
)
for
key
in
meta
:
if
'label'
in
key
:
return
meta
[
key
]
else
:
raise
RuntimeError
(
"Unexpected branch."
)
def
train
(
self
):
"""
Train Reader
"""
return
self
.
__read_batch__
(
self
.
__train_filename__
)
def
test
(
self
):
"""
Test Reader
"""
return
self
.
__read_batch__
(
self
.
__test_filename__
)
def
__read_batch__
(
self
,
sub_name
):
with
tarfile
.
open
(
self
.
__full_file__
,
mode
=
'r'
)
as
f
:
names
=
(
each_item
.
name
for
each_item
in
f
if
sub_name
in
each_item
.
name
)
for
name
in
names
:
batch
=
cPickle
.
load
(
f
.
extractfile
(
name
))
for
item
in
__read_one_batch__
(
batch
):
yield
item
def
__file_ok__
():
if
not
os
.
path
.
exists
(
__full_file__
):
return
False
md5_hash
=
hashlib
.
md5
()
with
open
(
__full_file__
,
'rb'
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
4096
),
b
""
):
md5_hash
.
update
(
chunk
)
return
md5_hash
.
hexdigest
()
==
md5
class
CIFAR10
(
CIFAR
):
"""
CIFAR-10 dataset, images are classified in 10 classes.
"""
while
not
__file_ok__
():
response
=
urllib2
.
urlopen
(
url
)
with
open
(
__full_file__
,
mode
=
'wb'
)
as
of
:
shutil
.
copyfileobj
(
fsrc
=
response
,
fdst
=
of
)
return
__full_file__
def
__init__
(
self
):
super
(
CIFAR10
,
self
).
__init__
(
CIFAR10_URL
,
CIFAR10_MD5
,
meta_filename
=
'batches.meta'
,
train_filename
=
'data_batch'
,
test_filename
=
'test_batch'
)
def
cifar_100_train_creator
():
fn
=
download
(
url
=
CIFAR100_URL
,
md5
=
CIFAR100_MD5
)
return
__read_batch__
(
fn
,
'train'
)
class
CIFAR100
(
CIFAR
):
"""
CIFAR-100 dataset, images are classified in 100 classes.
"""
def
__init__
(
self
):
super
(
CIFAR100
,
self
).
__init__
(
CIFAR100_URL
,
CIFAR100_MD5
,
meta_filename
=
'meta'
,
train_filename
=
'train'
,
test_filename
=
'test'
)
def
cifar_100_test_creator
():
fn
=
download
(
url
=
CIFAR100_URL
,
md5
=
CIFAR100_MD5
)
return
__read_batch__
(
fn
,
'test'
)
def
train_creator
():
"""
Default train reader creator. Use CIFAR-10 dataset.
"""
cifar
=
CIFAR10
(
)
return
cifar
.
train
fn
=
download
(
url
=
CIFAR10_URL
,
md5
=
CIFAR10_MD5
)
return
__read_batch__
(
fn
,
'data_batch'
)
def
test_creator
():
"""
Default test reader creator. Use CIFAR-10 dataset.
"""
cifar
=
CIFAR10
(
)
return
cifar
.
test
fn
=
download
(
url
=
CIFAR10_URL
,
md5
=
CIFAR10_MD5
)
return
__read_batch__
(
fn
,
'test_batch'
)
def
unittest
(
label_count
=
100
):
cifar
=
globals
()[
"CIFAR%d"
%
label_count
]()
assert
len
(
cifar
.
labels
())
==
label_count
for
_
in
cifar
.
test
():
def
unittest
():
for
_
in
train_creator
()():
pass
for
_
in
cifar
.
train
():
for
_
in
test_creator
()
():
pass
if
__name__
==
'__main__'
:
unittest
(
10
)
unittest
(
100
)
unittest
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录