Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3472f673
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3472f673
编写于
1月 11, 2019
作者:
L
lujun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix mnist-dataset bug at windows,test=develop
上级
13b1141b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
43 addition
and
48 deletion
+43
-48
python/paddle/dataset/mnist.py
python/paddle/dataset/mnist.py
+43
-48
未找到文件。
python/paddle/dataset/mnist.py
浏览文件 @
3472f673
...
...
@@ -21,10 +21,9 @@ parse training set and test set into paddle reader creators.
from
__future__
import
print_function
import
paddle.dataset.common
import
subprocess
import
gzip
import
numpy
import
platform
import
tempfile
import
struct
from
six.moves
import
range
__all__
=
[
'train'
,
'test'
,
'convert'
]
...
...
@@ -41,51 +40,47 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
def
reader_creator
(
image_filename
,
label_filename
,
buffer_size
):
def
reader
():
if
platform
.
system
()
==
'Darwin'
:
zcat_cmd
=
'gzcat'
elif
platform
.
system
()
==
'Linux'
:
zcat_cmd
=
'zcat'
else
:
raise
NotImplementedError
()
# According to http://stackoverflow.com/a/38061619/724872, we
# cannot use standard package gzip here.
tmp_image_file
=
tempfile
.
TemporaryFile
(
prefix
=
'paddle_dataset'
)
m
=
subprocess
.
Popen
(
[
zcat_cmd
,
image_filename
],
stdout
=
tmp_image_file
).
communicate
()
tmp_image_file
.
seek
(
16
)
# skip some magic bytes
# Python3 will not take stdout as file
tmp_label_file
=
tempfile
.
TemporaryFile
(
prefix
=
'paddle_dataset'
)
l
=
subprocess
.
Popen
(
[
zcat_cmd
,
label_filename
],
stdout
=
tmp_label_file
).
communicate
()
tmp_label_file
.
seek
(
8
)
# skip some magic bytes
try
:
# reader could be break.
while
True
:
labels
=
numpy
.
fromfile
(
tmp_label_file
,
'ubyte'
,
count
=
buffer_size
).
astype
(
"int"
)
if
labels
.
size
!=
buffer_size
:
break
# numpy.fromfile returns empty slice after EOF.
images
=
numpy
.
fromfile
(
tmp_image_file
,
'ubyte'
,
count
=
buffer_size
*
28
*
28
).
reshape
((
buffer_size
,
28
*
28
)).
astype
(
'float32'
)
images
=
images
/
255.0
*
2.0
-
1.0
for
i
in
range
(
buffer_size
):
yield
images
[
i
,
:],
int
(
labels
[
i
])
finally
:
try
:
m
.
terminate
()
except
:
pass
try
:
l
.
terminate
()
except
:
pass
with
gzip
.
GzipFile
(
image_filename
,
'rb'
)
as
image_file
:
img_buf
=
image_file
.
read
()
with
gzip
.
GzipFile
(
label_filename
,
'rb'
)
as
label_file
:
lab_buf
=
label_file
.
read
()
step_label
=
0
offset_img
=
0
# read from Big-endian
# get file info from magic byte
# image file : 16B
magic_byte_img
=
'>IIII'
magic_img
,
image_num
,
rows
,
cols
=
struct
.
unpack_from
(
magic_byte_img
,
img_buf
,
offset_img
)
offset_img
+=
struct
.
calcsize
(
magic_byte_img
)
offset_lab
=
0
# label file : 8B
magic_byte_lab
=
'>II'
magic_lab
,
label_num
=
struct
.
unpack_from
(
magic_byte_lab
,
lab_buf
,
offset_lab
)
offset_lab
+=
struct
.
calcsize
(
magic_byte_lab
)
while
True
:
if
step_label
>=
label_num
:
break
fmt_label
=
'>'
+
str
(
buffer_size
)
+
'B'
labels
=
struct
.
unpack_from
(
fmt_label
,
lab_buf
,
offset_lab
)
offset_lab
+=
struct
.
calcsize
(
fmt_label
)
step_label
+=
buffer_size
fmt_images
=
'>'
+
str
(
buffer_size
*
rows
*
cols
)
+
'B'
images_temp
=
struct
.
unpack_from
(
fmt_images
,
img_buf
,
offset_img
)
images
=
numpy
.
reshape
(
images_temp
,
(
buffer_size
,
rows
*
cols
)).
astype
(
'float32'
)
offset_img
+=
struct
.
calcsize
(
fmt_images
)
images
=
images
/
255.0
*
2.0
-
1.0
for
i
in
range
(
buffer_size
):
yield
images
[
i
,
:],
int
(
labels
[
i
])
return
reader
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录