Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d6c62e85
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d6c62e85
编写于
2月 28, 2017
作者:
Y
Yi Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rewrite mnist.py and add mnist_test.py
上级
91115ab6
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
78 addition
and
23 deletion
+78
-23
python/paddle/v2/dataset/mnist.py
python/paddle/v2/dataset/mnist.py
+51
-23
python/paddle/v2/dataset/tests/mnist_test.py
python/paddle/v2/dataset/tests/mnist_test.py
+27
-0
未找到文件。
python/paddle/v2/dataset/mnist.py
浏览文件 @
d6c62e85
import
sklearn.datasets.mldata
import
s
klearn.model_selection
import
paddle.v2.dataset.common
import
s
ubprocess
import
numpy
from
common
import
DATA_HOME
__all__
=
[
'train_creator'
,
'test_creator'
]
URL_PREFIX
=
'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL
=
URL_PREFIX
+
't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5
=
'25e3cc63507ef6e98d5dc541e8672bb6'
def
__mnist_reader_creator__
(
data
,
target
):
def
reader
():
n_samples
=
data
.
shape
[
0
]
for
i
in
xrange
(
n_samples
):
yield
(
data
[
i
]
/
255.0
).
astype
(
numpy
.
float32
),
int
(
target
[
i
])
TEST_LABEL_URL
=
URL_PREFIX
+
't10k-labels-idx1-ubyte.gz'
TEST_LABEL_MD5
=
'4e9511fe019b2189026bd0421ba7b688'
TRAIN_IMAGE_URL
=
URL_PREFIX
+
'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_MD5
=
'f68b3c2dcbeaaa9fbdd348bbdeb94873'
return
reader
TRAIN_LABEL_URL
=
URL_PREFIX
+
'train-labels-idx1-ubyte.gz'
TRAIN_LABEL_MD5
=
'd53e105ee54ea40749a09fcbcd1e9432'
TEST_SIZE
=
10000
def
reader_creator
(
image_filename
,
label_filename
,
buffer_size
):
def
reader
():
# According to http://stackoverflow.com/a/38061619/724872, we
# cannot use standard package gzip here.
m
=
subprocess
.
Popen
([
"zcat"
,
image_filename
],
stdout
=
subprocess
.
PIPE
)
m
.
stdout
.
read
(
16
)
# skip some magic bytes
l
=
subprocess
.
Popen
([
"zcat"
,
label_filename
],
stdout
=
subprocess
.
PIPE
)
l
.
stdout
.
read
(
8
)
# skip some magic bytes
data
=
sklearn
.
datasets
.
mldata
.
fetch_mldata
(
"MNIST original"
,
data_home
=
DATA_HOME
)
X_train
,
X_test
,
y_train
,
y_test
=
sklearn
.
model_selection
.
train_test_split
(
data
.
data
,
data
.
target
,
test_size
=
TEST_SIZE
,
random_state
=
0
)
while
True
:
labels
=
numpy
.
fromfile
(
l
.
stdout
,
'ubyte'
,
count
=
buffer_size
).
astype
(
"int"
)
if
labels
.
size
!=
buffer_size
:
break
# numpy.fromfile returns empty slice after EOF.
def
train_creator
():
return
__mnist_reader_creator__
(
X_train
,
y_train
)
images
=
numpy
.
fromfile
(
m
.
stdout
,
'ubyte'
,
count
=
buffer_size
*
28
*
28
).
reshape
((
buffer_size
,
28
*
28
)
).
astype
(
'float32'
)
images
=
images
/
255.0
*
2.0
-
1.0
def
test_creator
(
):
return
__mnist_reader_creator__
(
X_test
,
y_test
)
for
i
in
xrange
(
buffer_size
):
yield
images
[
i
,
:],
labels
[
i
]
m
.
terminate
()
l
.
terminate
()
def
unittest
():
assert
len
(
list
(
test_creator
()()))
==
TEST_SIZE
return
reader
()
def
train
():
return
reader_creator
(
paddle
.
v2
.
dataset
.
common
.
download
(
TRAIN_IMAGE_URL
,
'mnist'
,
TRAIN_IMAGE_MD5
),
paddle
.
v2
.
dataset
.
common
.
download
(
TRAIN_LABEL_URL
,
'mnist'
,
TRAIN_LABEL_MD5
),
100
)
if
__name__
==
'__main__'
:
unittest
()
def
test
():
return
reader_creator
(
paddle
.
v2
.
dataset
.
common
.
download
(
TEST_IMAGE_URL
,
'mnist'
,
TEST_IMAGE_MD5
),
paddle
.
v2
.
dataset
.
common
.
download
(
TEST_LABEL_URL
,
'mnist'
,
TEST_LABEL_MD5
),
100
)
python/paddle/v2/dataset/tests/mnist_test.py
0 → 100644
浏览文件 @
d6c62e85
import
paddle.v2.dataset.mnist
import
unittest
class
TestMNIST
(
unittest
.
TestCase
):
def
check_reader
(
self
,
reader
):
sum
=
0
for
l
in
reader
:
self
.
assertEqual
(
l
[
0
].
size
,
784
)
self
.
assertEqual
(
l
[
1
].
size
,
1
)
self
.
assertLess
(
l
[
1
],
10
)
self
.
assertGreaterEqual
(
l
[
1
],
0
)
sum
+=
1
return
sum
def
test_train
(
self
):
self
.
assertEqual
(
self
.
check_reader
(
paddle
.
v2
.
dataset
.
mnist
.
train
()),
60000
)
def
test_test
(
self
):
self
.
assertEqual
(
self
.
check_reader
(
paddle
.
v2
.
dataset
.
mnist
.
test
()),
10000
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录