diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 1c2c831bec18ef4215338949f96c3357dd8335c9..612b0d218fc0be68b668ffb8f1bfb0ba92c4d741 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -38,7 +38,6 @@ def main(): cost=cost, parameters=parameters, event_handler=event_handler, - batch_size=32, # batch size should be refactor in Data reader reader_dict={images.name: 0, label.name: 1}) diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index 653c91aacde6618389069d571c3d8a23b88c5f50..045bcfcc805ec01c4fb47901285e9a4e62aa9b2b 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -1,6 +1,7 @@ import paddle.v2.dataset.common import subprocess import numpy +import platform __all__ = ['train', 'test'] @@ -18,12 +19,19 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432' def reader_creator(image_filename, label_filename, buffer_size): def reader(): + if platform.system() == 'Darwin': + zcat_cmd = 'gzcat' + elif platform.system() == 'Linux': + zcat_cmd = 'zcat' + else: + raise NotImplementedError() + # According to http://stackoverflow.com/a/38061619/724872, we # cannot use standard package gzip here. - m = subprocess.Popen(["zcat", image_filename], stdout=subprocess.PIPE) + m = subprocess.Popen([zcat_cmd, image_filename], stdout=subprocess.PIPE) m.stdout.read(16) # skip some magic bytes - l = subprocess.Popen(["zcat", label_filename], stdout=subprocess.PIPE) + l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE) l.stdout.read(8) # skip some magic bytes while True: