From c852003d6b2e11c8d0f75d5c0adf9d332dc429f1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 28 Feb 2017 14:00:22 +0800 Subject: [PATCH] Fix errors in mnist dataset --- demo/mnist/api_train_v2.py | 1 - python/paddle/v2/dataset/mnist.py | 12 ++++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 1c2c831be..612b0d218 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -38,7 +38,6 @@ def main(): cost=cost, parameters=parameters, event_handler=event_handler, - batch_size=32, # batch size should be refactor in Data reader reader_dict={images.name: 0, label.name: 1}) diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index 653c91aac..045bcfcc8 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -1,6 +1,7 @@ import paddle.v2.dataset.common import subprocess import numpy +import platform __all__ = ['train', 'test'] @@ -18,12 +19,19 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432' def reader_creator(image_filename, label_filename, buffer_size): def reader(): + if platform.system() == 'Darwin': + zcat_cmd = 'gzcat' + elif platform.system() == 'Linux': + zcat_cmd = 'zcat' + else: + raise NotImplementedError() + # According to http://stackoverflow.com/a/38061619/724872, we # cannot use standard package gzip here. - m = subprocess.Popen(["zcat", image_filename], stdout=subprocess.PIPE) + m = subprocess.Popen([zcat_cmd, image_filename], stdout=subprocess.PIPE) m.stdout.read(16) # skip some magic bytes - l = subprocess.Popen(["zcat", label_filename], stdout=subprocess.PIPE) + l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE) l.stdout.read(8) # skip some magic bytes while True: -- GitLab