Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
dcbfbb15
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
dcbfbb15
编写于
2月 28, 2017
作者:
Y
Yi Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
yapf format
上级
d6c62e85
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
24 addition
and
26 deletion
+24
-26
python/paddle/v2/dataset/mnist.py
python/paddle/v2/dataset/mnist.py
+16
-18
python/paddle/v2/dataset/tests/common_test.py
python/paddle/v2/dataset/tests/common_test.py
+5
-4
python/paddle/v2/dataset/tests/mnist_test.py
python/paddle/v2/dataset/tests/mnist_test.py
+3
-4
未找到文件。
python/paddle/v2/dataset/mnist.py
浏览文件 @
dcbfbb15
...
...
@@ -22,23 +22,21 @@ def reader_creator(image_filename, label_filename, buffer_size):
# According to http://stackoverflow.com/a/38061619/724872, we
# cannot use standard package gzip here.
m
=
subprocess
.
Popen
([
"zcat"
,
image_filename
],
stdout
=
subprocess
.
PIPE
)
m
.
stdout
.
read
(
16
)
# skip some magic bytes
m
.
stdout
.
read
(
16
)
# skip some magic bytes
l
=
subprocess
.
Popen
([
"zcat"
,
label_filename
],
stdout
=
subprocess
.
PIPE
)
l
.
stdout
.
read
(
8
)
# skip some magic bytes
l
.
stdout
.
read
(
8
)
# skip some magic bytes
while
True
:
labels
=
numpy
.
fromfile
(
l
.
stdout
,
'ubyte'
,
count
=
buffer_size
).
astype
(
"int"
)
l
.
stdout
,
'ubyte'
,
count
=
buffer_size
).
astype
(
"int"
)
if
labels
.
size
!=
buffer_size
:
break
# numpy.fromfile returns empty slice after EOF.
break
# numpy.fromfile returns empty slice after EOF.
images
=
numpy
.
fromfile
(
m
.
stdout
,
'ubyte'
,
count
=
buffer_size
*
28
*
28
).
reshape
((
buffer_size
,
28
*
28
)
).
astype
(
'float32'
)
m
.
stdout
,
'ubyte'
,
count
=
buffer_size
*
28
*
28
).
reshape
(
(
buffer_size
,
28
*
28
)).
astype
(
'float32'
)
images
=
images
/
255.0
*
2.0
-
1.0
...
...
@@ -50,18 +48,18 @@ def reader_creator(image_filename, label_filename, buffer_size):
return
reader
()
def
train
():
return
reader_creator
(
paddle
.
v2
.
dataset
.
common
.
download
(
TRAIN_IMAGE_URL
,
'mnist'
,
TRAIN_IMAGE_MD5
),
paddle
.
v2
.
dataset
.
common
.
download
(
TRAIN_LABEL_URL
,
'mnist'
,
TRAIN_LABEL_MD5
),
100
)
paddle
.
v2
.
dataset
.
common
.
download
(
TRAIN_IMAGE_URL
,
'mnist'
,
TRAIN_IMAGE_MD5
),
paddle
.
v2
.
dataset
.
common
.
download
(
TRAIN_LABEL_URL
,
'mnist'
,
TRAIN_LABEL_MD5
),
100
)
def
test
():
return
reader_creator
(
paddle
.
v2
.
dataset
.
common
.
download
(
TEST_IMAGE_URL
,
'mnist'
,
TEST_IMAGE_MD5
),
paddle
.
v2
.
dataset
.
common
.
download
(
TEST_LABEL_URL
,
'mnist'
,
TEST_LABEL_MD5
),
100
)
paddle
.
v2
.
dataset
.
common
.
download
(
TEST_IMAGE_URL
,
'mnist'
,
TEST_IMAGE_MD5
),
paddle
.
v2
.
dataset
.
common
.
download
(
TEST_LABEL_URL
,
'mnist'
,
TEST_LABEL_MD5
),
100
)
python/paddle/v2/dataset/tests/common_test.py
浏览文件 @
dcbfbb15
...
...
@@ -2,14 +2,14 @@ import paddle.v2.dataset.common
import
unittest
import
tempfile
class
TestCommon
(
unittest
.
TestCase
):
def
test_md5file
(
self
):
_
,
temp_path
=
tempfile
.
mkstemp
()
_
,
temp_path
=
tempfile
.
mkstemp
()
with
open
(
temp_path
,
'w'
)
as
f
:
f
.
write
(
"Hello
\n
"
)
self
.
assertEqual
(
'09f7e02f1290be211da707a266f153b3'
,
paddle
.
v2
.
dataset
.
common
.
md5file
(
temp_path
))
self
.
assertEqual
(
'09f7e02f1290be211da707a266f153b3'
,
paddle
.
v2
.
dataset
.
common
.
md5file
(
temp_path
))
def
test_download
(
self
):
yi_avatar
=
'https://avatars0.githubusercontent.com/u/1548775?v=3&s=460'
...
...
@@ -18,5 +18,6 @@ class TestCommon(unittest.TestCase):
paddle
.
v2
.
dataset
.
common
.
download
(
yi_avatar
,
'test'
,
'f75287202d6622414c706c36c16f8e0d'
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/v2/dataset/tests/mnist_test.py
浏览文件 @
dcbfbb15
import
paddle.v2.dataset.mnist
import
unittest
class
TestMNIST
(
unittest
.
TestCase
):
def
check_reader
(
self
,
reader
):
sum
=
0
...
...
@@ -14,13 +15,11 @@ class TestMNIST(unittest.TestCase):
def
test_train
(
self
):
self
.
assertEqual
(
self
.
check_reader
(
paddle
.
v2
.
dataset
.
mnist
.
train
()),
60000
)
self
.
check_reader
(
paddle
.
v2
.
dataset
.
mnist
.
train
()),
60000
)
def
test_test
(
self
):
self
.
assertEqual
(
self
.
check_reader
(
paddle
.
v2
.
dataset
.
mnist
.
test
()),
10000
)
self
.
check_reader
(
paddle
.
v2
.
dataset
.
mnist
.
test
()),
10000
)
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录