Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
4eb54c24
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4eb54c24
编写于
2月 28, 2017
作者:
Y
Yi Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Debug unit tests
上级
6bc82c8e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
93 addition
and
67 deletion
+93
-67
python/paddle/v2/dataset/cifar.py
python/paddle/v2/dataset/cifar.py
+32
-53
python/paddle/v2/dataset/common.py
python/paddle/v2/dataset/common.py
+0
-1
python/paddle/v2/dataset/mnist.py
python/paddle/v2/dataset/mnist.py
+6
-4
python/paddle/v2/dataset/tests/cifar_test.py
python/paddle/v2/dataset/tests/cifar_test.py
+42
-0
python/paddle/v2/dataset/tests/mnist_test.py
python/paddle/v2/dataset/tests/mnist_test.py
+13
-9
未找到文件。
python/paddle/v2/dataset/cifar.py
浏览文件 @
4eb54c24
"""
CIFAR Dataset.
URL: https://www.cs.toronto.edu/~kriz/cifar.html
the default train_creator, test_creator used for CIFAR-10 dataset.
CIFAR dataset: https://www.cs.toronto.edu/~kriz/cifar.html
"""
import
cPickle
import
itertools
import
tarfile
import
numpy
import
paddle.v2.dataset.common
import
tarfile
from
common
import
download
__all__
=
[
'cifar_100_train_creator'
,
'cifar_100_test_creator'
,
'train_creator'
,
'test_creator'
]
__all__
=
[
'train100'
,
'test100'
,
'train10'
,
'test10'
]
CIFAR10_URL
=
'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
URL_PREFIX
=
'https://www.cs.toronto.edu/~kriz/'
CIFAR10_URL
=
URL_PREFIX
+
'cifar-10-python.tar.gz'
CIFAR10_MD5
=
'c58f30108f718f92721af3b95e74349a'
CIFAR100_URL
=
'https://www.cs.toronto.edu/~kriz/
cifar-100-python.tar.gz'
CIFAR100_URL
=
URL_PREFIX
+
'
cifar-100-python.tar.gz'
CIFAR100_MD5
=
'eb9058c3a382ffc7106e4002c42a8d85'
def
__read_batch__
(
filename
,
sub_name
):
def
reader
():
def
__read_one_batch_impl__
(
batch
):
data
=
batch
[
'data'
]
labels
=
batch
.
get
(
'labels'
,
batch
.
get
(
'fine_labels'
,
None
))
assert
labels
is
not
None
for
sample
,
label
in
itertools
.
izip
(
data
,
labels
):
yield
(
sample
/
255.0
).
astype
(
numpy
.
float32
),
int
(
label
)
def
reader_creator
(
filename
,
sub_name
):
def
read_batch
(
batch
):
data
=
batch
[
'data'
]
labels
=
batch
.
get
(
'labels'
,
batch
.
get
(
'fine_labels'
,
None
))
assert
labels
is
not
None
for
sample
,
label
in
itertools
.
izip
(
data
,
labels
):
yield
(
sample
/
255.0
).
astype
(
numpy
.
float32
),
int
(
label
)
def
reader
():
with
tarfile
.
open
(
filename
,
mode
=
'r'
)
as
f
:
names
=
(
each_item
.
name
for
each_item
in
f
if
sub_name
in
each_item
.
name
)
for
name
in
names
:
batch
=
cPickle
.
load
(
f
.
extractfile
(
name
))
for
item
in
__read_one_batch_impl__
(
batch
):
for
item
in
read_batch
(
batch
):
yield
item
return
reader
def
cifar_100_train_creator
():
fn
=
download
(
url
=
CIFAR100_URL
,
md5
=
CIFAR100_MD5
)
return
__read_batch__
(
fn
,
'train'
)
def
cifar_100_test_creator
():
fn
=
download
(
url
=
CIFAR100_URL
,
md5
=
CIFAR100_MD5
)
return
__read_batch__
(
fn
,
'test'
)
def
train_creator
():
"""
Default train reader creator. Use CIFAR-10 dataset.
"""
fn
=
download
(
url
=
CIFAR10_URL
,
md5
=
CIFAR10_MD5
)
return
__read_batch__
(
fn
,
'data_batch'
)
def
train100
():
return
reader_creator
(
paddle
.
v2
.
dataset
.
common
.
download
(
CIFAR100_URL
,
'cifar'
,
CIFAR100_MD5
),
'train'
)
def
test_creator
():
"""
Default test reader creator. Use CIFAR-10 dataset.
"""
fn
=
download
(
url
=
CIFAR10_URL
,
md5
=
CIFAR10_MD5
)
return
__read_batch__
(
fn
,
'test_batch'
)
def
test100
():
return
reader_creator
(
paddle
.
v2
.
dataset
.
common
.
download
(
CIFAR100_URL
,
'cifar'
,
CIFAR100_MD5
),
'test'
)
def
unittest
():
for
_
in
train_creator
()():
pass
for
_
in
test_creator
()():
pass
def
train10
():
return
reader_creator
(
paddle
.
v2
.
dataset
.
common
.
download
(
CIFAR10_URL
,
'cifar'
,
CIFAR10_MD5
),
'data_batch'
)
if
__name__
==
'__main__'
:
unittest
()
def
test10
():
return
reader_creator
(
paddle
.
v2
.
dataset
.
common
.
download
(
CIFAR10_URL
,
'cifar'
,
CIFAR10_MD5
),
'test_batch'
)
python/paddle/v2/dataset/common.py
浏览文件 @
4eb54c24
...
...
@@ -27,7 +27,6 @@ def download(url, module_name, md5sum):
filename
=
os
.
path
.
join
(
dirname
,
url
.
split
(
'/'
)[
-
1
])
if
not
(
os
.
path
.
exists
(
filename
)
and
md5file
(
filename
)
==
md5sum
):
# If file doesn't exist or MD5 doesn't match, then download.
r
=
requests
.
get
(
url
,
stream
=
True
)
with
open
(
filename
,
'w'
)
as
f
:
shutil
.
copyfileobj
(
r
.
raw
,
f
)
...
...
python/paddle/v2/dataset/mnist.py
浏览文件 @
4eb54c24
"""
MNIST dataset.
"""
import
numpy
import
paddle.v2.dataset.common
import
subprocess
import
numpy
__all__
=
[
'train'
,
'test'
]
URL_PREFIX
=
'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL
=
URL_PREFIX
+
't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5
=
'25e3cc63507ef6e98d5dc541e8672bb6'
TEST_LABEL_URL
=
URL_PREFIX
+
't10k-labels-idx1-ubyte.gz'
...
...
@@ -40,12 +42,12 @@ def reader_creator(image_filename, label_filename, buffer_size):
images
=
images
/
255.0
*
2.0
-
1.0
for
i
in
xrange
(
buffer_size
):
yield
images
[
i
,
:],
labels
[
i
]
yield
images
[
i
,
:],
int
(
labels
[
i
])
m
.
terminate
()
l
.
terminate
()
return
reader
()
return
reader
def
train
():
...
...
python/paddle/v2/dataset/tests/cifar_test.py
0 → 100644
浏览文件 @
4eb54c24
import
paddle.v2.dataset.cifar
import
unittest
class
TestCIFAR
(
unittest
.
TestCase
):
def
check_reader
(
self
,
reader
):
sum
=
0
label
=
0
for
l
in
reader
():
self
.
assertEqual
(
l
[
0
].
size
,
3072
)
if
l
[
1
]
>
label
:
label
=
l
[
1
]
sum
+=
1
return
sum
,
label
def
test_test10
(
self
):
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
cifar
.
test10
())
self
.
assertEqual
(
instances
,
10000
)
self
.
assertEqual
(
max_label_value
,
9
)
def
test_train10
(
self
):
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
cifar
.
train10
())
self
.
assertEqual
(
instances
,
50000
)
self
.
assertEqual
(
max_label_value
,
9
)
def
test_test100
(
self
):
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
cifar
.
test100
())
self
.
assertEqual
(
instances
,
10000
)
self
.
assertEqual
(
max_label_value
,
99
)
def
test_train100
(
self
):
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
cifar
.
train100
())
self
.
assertEqual
(
instances
,
50000
)
self
.
assertEqual
(
max_label_value
,
99
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/v2/dataset/tests/mnist_test.py
浏览文件 @
4eb54c24
...
...
@@ -5,21 +5,25 @@ import unittest
class
TestMNIST
(
unittest
.
TestCase
):
def
check_reader
(
self
,
reader
):
sum
=
0
for
l
in
reader
:
label
=
0
for
l
in
reader
():
self
.
assertEqual
(
l
[
0
].
size
,
784
)
self
.
assertEqual
(
l
[
1
].
size
,
1
)
self
.
assertLess
(
l
[
1
],
10
)
self
.
assertGreaterEqual
(
l
[
1
],
0
)
if
l
[
1
]
>
label
:
label
=
l
[
1
]
sum
+=
1
return
sum
return
sum
,
label
def
test_train
(
self
):
self
.
assertEqual
(
self
.
check_reader
(
paddle
.
v2
.
dataset
.
mnist
.
train
()),
60000
)
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
mnist
.
train
())
self
.
assertEqual
(
instances
,
60000
)
self
.
assertEqual
(
max_label_value
,
9
)
def
test_test
(
self
):
self
.
assertEqual
(
self
.
check_reader
(
paddle
.
v2
.
dataset
.
mnist
.
test
()),
10000
)
instances
,
max_label_value
=
self
.
check_reader
(
paddle
.
v2
.
dataset
.
mnist
.
test
())
self
.
assertEqual
(
instances
,
10000
)
self
.
assertEqual
(
max_label_value
,
9
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录