未验证 提交 869f3a9d 编写于 作者: W wopeizl 提交者: GitHub

Merge pull request #15273 from junjun315/fix-mnist-dataset-error

Fix mnist dataset error
...@@ -21,10 +21,9 @@ parse training set and test set into paddle reader creators. ...@@ -21,10 +21,9 @@ parse training set and test set into paddle reader creators.
from __future__ import print_function from __future__ import print_function
import paddle.dataset.common import paddle.dataset.common
import subprocess import gzip
import numpy import numpy
import platform import struct
import tempfile
from six.moves import range from six.moves import range
__all__ = ['train', 'test', 'convert'] __all__ = ['train', 'test', 'convert']
...@@ -41,51 +40,47 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432' ...@@ -41,51 +40,47 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
def reader_creator(image_filename, label_filename, buffer_size): def reader_creator(image_filename, label_filename, buffer_size):
def reader(): def reader():
if platform.system() == 'Darwin': with gzip.GzipFile(image_filename, 'rb') as image_file:
zcat_cmd = 'gzcat' img_buf = image_file.read()
elif platform.system() == 'Linux': with gzip.GzipFile(label_filename, 'rb') as label_file:
zcat_cmd = 'zcat' lab_buf = label_file.read()
else:
raise NotImplementedError() step_label = 0
# According to http://stackoverflow.com/a/38061619/724872, we offset_img = 0
# cannot use standard package gzip here. # read from Big-endian
tmp_image_file = tempfile.TemporaryFile(prefix='paddle_dataset') # get file info from magic byte
m = subprocess.Popen( # image file : 16B
[zcat_cmd, image_filename], stdout=tmp_image_file).communicate() magic_byte_img = '>IIII'
tmp_image_file.seek(16) # skip some magic bytes magic_img, image_num, rows, cols = struct.unpack_from(
magic_byte_img, img_buf, offset_img)
# Python3 will not take stdout as file offset_img += struct.calcsize(magic_byte_img)
tmp_label_file = tempfile.TemporaryFile(prefix='paddle_dataset')
l = subprocess.Popen( offset_lab = 0
[zcat_cmd, label_filename], stdout=tmp_label_file).communicate() # label file : 8B
tmp_label_file.seek(8) # skip some magic bytes magic_byte_lab = '>II'
magic_lab, label_num = struct.unpack_from(magic_byte_lab,
try: # reader could be break. lab_buf, offset_lab)
while True: offset_lab += struct.calcsize(magic_byte_lab)
labels = numpy.fromfile(
tmp_label_file, 'ubyte', count=buffer_size).astype("int") while True:
if step_label >= label_num:
if labels.size != buffer_size: break
break # numpy.fromfile returns empty slice after EOF. fmt_label = '>' + str(buffer_size) + 'B'
labels = struct.unpack_from(fmt_label, lab_buf, offset_lab)
images = numpy.fromfile( offset_lab += struct.calcsize(fmt_label)
tmp_image_file, 'ubyte', count=buffer_size * 28 * step_label += buffer_size
28).reshape((buffer_size, 28 * 28)).astype('float32')
fmt_images = '>' + str(buffer_size * rows * cols) + 'B'
images = images / 255.0 * 2.0 - 1.0 images_temp = struct.unpack_from(fmt_images, img_buf,
offset_img)
for i in range(buffer_size): images = numpy.reshape(images_temp, (
yield images[i, :], int(labels[i]) buffer_size, rows * cols)).astype('float32')
finally: offset_img += struct.calcsize(fmt_images)
try:
m.terminate() images = images / 255.0 * 2.0 - 1.0
except: for i in range(buffer_size):
pass yield images[i, :], int(labels[i])
try:
l.terminate()
except:
pass
return reader return reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册