diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index 59043ce6c42085fe1988ceba8358b97cb01f74c7..e508af7a0c571c7ccb1f07bcd5267d45481d0e6b 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -1,6 +1,9 @@ import py_paddle.swig_paddle as api +from py_paddle import DataProviderConverter +import paddle.trainer.PyDataProvider2 as dp import paddle.trainer.config_parser import numpy as np +from mnist_util import read_from_mnist def init_parameter(network): @@ -13,6 +16,22 @@ def init_parameter(network): array[i] = np.random.uniform(-1.0, 1.0) +def generator_to_batch(generator, batch_size): + ret_val = list() + for each_item in generator: + ret_val.append(each_item) + if len(ret_val) == batch_size: + yield ret_val + ret_val = list() + if len(ret_val) != 0: + yield ret_val + + +def input_order_converter(generator): + for each_item in generator: + yield each_item['pixel'], each_item['label'] + + def main(): api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores config = paddle.trainer.config_parser.parse_config( @@ -30,10 +49,20 @@ def main(): updater = api.ParameterUpdater.createLocalUpdater(opt_config) assert isinstance(updater, api.ParameterUpdater) updater.init(m) + + converter = DataProviderConverter( + input_types=[dp.dense_vector(784), dp.integer_value(10)]) + + train_file = './data/raw_data/train' + m.start() for _ in xrange(100): updater.startPass() + train_data_generator = input_order_converter( + read_from_mnist(train_file)) + for data_batch in generator_to_batch(train_data_generator, 128): + inArgs = converter(data_batch) updater.finishPass() diff --git a/demo/mnist/mnist_provider.py b/demo/mnist/mnist_provider.py index 4635833d36b9f21c992d96910f3ac9094ccefd2c..888cfef1e7e3e1b4f556756c003eeb23e741cabe 100644 --- a/demo/mnist/mnist_provider.py +++ b/demo/mnist/mnist_provider.py @@ -1,5 +1,5 @@ from paddle.trainer.PyDataProvider2 import * -import numpy +from mnist_util import read_from_mnist # Define a py data provider @@ -8,27 +8,5 @@ import numpy 'label': integer_value(10)}, cache=CacheType.CACHE_PASS_IN_MEM) def process(settings, filename): # settings is not used currently. - imgf = filename + "-images-idx3-ubyte" - labelf = filename + "-labels-idx1-ubyte" - f = open(imgf, "rb") - l = open(labelf, "rb") - - f.read(16) - l.read(8) - - # Define number of samples for train/test - if "train" in filename: - n = 60000 - else: - n = 10000 - - images = numpy.fromfile( - f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32') - images = images / 255.0 * 2.0 - 1.0 - labels = numpy.fromfile(l, 'ubyte', count=n).astype("int") - - for i in xrange(n): - yield {"pixel": images[i, :], 'label': labels[i]} - - f.close() - l.close() + for each in read_from_mnist(filename): + yield each diff --git a/demo/mnist/mnist_util.py b/demo/mnist/mnist_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3fd88ae7edc821296ca0accbf6dedc083e411744 --- /dev/null +++ b/demo/mnist/mnist_util.py @@ -0,0 +1,30 @@ +import numpy + +__all__ = ['read_from_mnist'] + + +def read_from_mnist(filename): + imgf = filename + "-images-idx3-ubyte" + labelf = filename + "-labels-idx1-ubyte" + f = open(imgf, "rb") + l = open(labelf, "rb") + + f.read(16) + l.read(8) + + # Define number of samples for train/test + if "train" in filename: + n = 60000 + else: + n = 10000 + + images = numpy.fromfile( + f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32') + images = images / 255.0 * 2.0 - 1.0 + labels = numpy.fromfile(l, 'ubyte', count=n).astype("int") + + for i in xrange(n): + yield {"pixel": images[i, :], 'label': labels[i]} + + f.close() + l.close()