未验证 提交 869f3a9d 编写于 作者: W wopeizl 提交者: GitHub

Merge pull request #15273 from junjun315/fix-mnist-dataset-error

Fix mnist dataset error
......@@ -21,10 +21,9 @@ parse training set and test set into paddle reader creators.
from __future__ import print_function
import paddle.dataset.common
import subprocess
import gzip
import numpy
import platform
import tempfile
import struct
from six.moves import range
__all__ = ['train', 'test', 'convert']
......@@ -41,51 +40,47 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
def reader_creator(image_filename, label_filename, buffer_size):
def reader():
if platform.system() == 'Darwin':
zcat_cmd = 'gzcat'
elif platform.system() == 'Linux':
zcat_cmd = 'zcat'
else:
raise NotImplementedError()
# According to http://stackoverflow.com/a/38061619/724872, we
# cannot use standard package gzip here.
tmp_image_file = tempfile.TemporaryFile(prefix='paddle_dataset')
m = subprocess.Popen(
[zcat_cmd, image_filename], stdout=tmp_image_file).communicate()
tmp_image_file.seek(16) # skip some magic bytes
# Python3 will not take stdout as file
tmp_label_file = tempfile.TemporaryFile(prefix='paddle_dataset')
l = subprocess.Popen(
[zcat_cmd, label_filename], stdout=tmp_label_file).communicate()
tmp_label_file.seek(8) # skip some magic bytes
try: # reader could be break.
while True:
labels = numpy.fromfile(
tmp_label_file, 'ubyte', count=buffer_size).astype("int")
if labels.size != buffer_size:
break # numpy.fromfile returns empty slice after EOF.
images = numpy.fromfile(
tmp_image_file, 'ubyte', count=buffer_size * 28 *
28).reshape((buffer_size, 28 * 28)).astype('float32')
images = images / 255.0 * 2.0 - 1.0
for i in range(buffer_size):
yield images[i, :], int(labels[i])
finally:
try:
m.terminate()
except:
pass
try:
l.terminate()
except:
pass
with gzip.GzipFile(image_filename, 'rb') as image_file:
img_buf = image_file.read()
with gzip.GzipFile(label_filename, 'rb') as label_file:
lab_buf = label_file.read()
step_label = 0
offset_img = 0
# read from Big-endian
# get file info from magic byte
# image file : 16B
magic_byte_img = '>IIII'
magic_img, image_num, rows, cols = struct.unpack_from(
magic_byte_img, img_buf, offset_img)
offset_img += struct.calcsize(magic_byte_img)
offset_lab = 0
# label file : 8B
magic_byte_lab = '>II'
magic_lab, label_num = struct.unpack_from(magic_byte_lab,
lab_buf, offset_lab)
offset_lab += struct.calcsize(magic_byte_lab)
while True:
if step_label >= label_num:
break
fmt_label = '>' + str(buffer_size) + 'B'
labels = struct.unpack_from(fmt_label, lab_buf, offset_lab)
offset_lab += struct.calcsize(fmt_label)
step_label += buffer_size
fmt_images = '>' + str(buffer_size * rows * cols) + 'B'
images_temp = struct.unpack_from(fmt_images, img_buf,
offset_img)
images = numpy.reshape(images_temp, (
buffer_size, rows * cols)).astype('float32')
offset_img += struct.calcsize(fmt_images)
images = images / 255.0 * 2.0 - 1.0
for i in range(buffer_size):
yield images[i, :], int(labels[i])
return reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册