From 409a5774c475b67160ea5cdf22b489652da6bff3 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Dec 2016 22:45:42 +0800 Subject: [PATCH] Complete a very simple mnist demo. --- demo/mnist/api_train.py | 108 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 99 insertions(+), 9 deletions(-) diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index 52cc13c5a3e..c1439bd526d 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -1,8 +1,17 @@ +""" +A very basic example for how to use current Raw SWIG API to train mnist network. + +Current implementation uses Raw SWIG, which means the API call is directly \ +passed to C++ side of Paddle. + +The user api could be simpler and carefully designed. +""" import py_paddle.swig_paddle as api from py_paddle import DataProviderConverter import paddle.trainer.PyDataProvider2 as dp import paddle.trainer.config_parser import numpy as np +import random from mnist_util import read_from_mnist @@ -27,6 +36,18 @@ def generator_to_batch(generator, batch_size): yield ret_val +class BatchPool(object): + def __init__(self, generator, batch_size): + self.data = list(generator) + self.batch_size = batch_size + + def __call__(self): + random.shuffle(self.data) + for offset in xrange(0, len(self.data), self.batch_size): + limit = min(offset + self.batch_size, len(self.data)) + yield self.data[offset:limit] + + def input_order_converter(generator): for each_item in generator: yield each_item['pixel'], each_item['label'] @@ -37,46 +58,115 @@ def main(): config = paddle.trainer.config_parser.parse_config( 'simple_mnist_network.py', '') + # get enable_types for each optimizer. + # enable_types = [value, gradient, momentum, etc] + # For each optimizer(SGD, Adam), GradientMachine should enable different + # buffers. opt_config = api.OptimizationConfig.createFromProto(config.opt_config) _temp_optimizer_ = api.ParameterOptimizer.create(opt_config) enable_types = _temp_optimizer_.getParameterTypes() + # Create Simple Gradient Machine. m = api.GradientMachine.createFromConfigProto( config.model_config, api.CREATE_MODE_NORMAL, enable_types) + + # This type check is not useful. Only enable type hint in IDE. + # Such as PyCharm assert isinstance(m, api.GradientMachine) + + # Initialize Parameter by numpy. init_parameter(network=m) + + # Create Local Updater. Local means not run in cluster. + # For a cluster training, here we can change to createRemoteUpdater + # in future. updater = api.ParameterUpdater.createLocalUpdater(opt_config) assert isinstance(updater, api.ParameterUpdater) + + # Initialize ParameterUpdater. updater.init(m) + # DataProvider Converter is a utility convert Python Object to Paddle C++ + # Input. The input format is as same as Paddle's DataProvider. converter = DataProviderConverter( input_types=[dp.dense_vector(784), dp.integer_value(10)]) train_file = './data/raw_data/train' + test_file = './data/raw_data/t10k' + # start gradient machine. + # the gradient machine must be started before invoke forward/backward. + # not just for training, but also for inference. m.start() - for _ in xrange(100): + # evaluator can print error rate, etc. It is a C++ class. + batch_evaluator = m.makeEvaluator() + test_evaluator = m.makeEvaluator() + + # Get Train Data. + # TrainData will stored in a data pool. Currently implementation is not care + # about memory, speed. Just a very naive implementation. + train_data_generator = input_order_converter(read_from_mnist(train_file)) + train_data = BatchPool(train_data_generator, 128) + + # outArgs is Neural Network forward result. Here is not useful, just passed + # to gradient_machine.forward + outArgs = api.Arguments.createArguments(0) + + for pass_id in xrange(2): # we train 2 passes. updater.startPass() - outArgs = api.Arguments.createArguments(0) - train_data_generator = input_order_converter( - read_from_mnist(train_file)) - for batch_id, data_batch in enumerate( - generator_to_batch(train_data_generator, 2048)): - trainRole = updater.startBatch(len(data_batch)) + for batch_id, data_batch in enumerate(train_data()): + # data_batch is input images. + # here, for online learning, we could get data_batch from network. + + # Start update one batch. + pass_type = updater.startBatch(len(data_batch)) + + # Start BatchEvaluator. + # batch_evaluator can be used between start/finish. + batch_evaluator.start() + + # A callback when backward. + # It is used for updating weight values vy calculated Gradient. def updater_callback(param): updater.update(param) + # forwardBackward is a shortcut for forward and backward. + # It is sometimes faster than invoke forward/backward separately, + # because in GradientMachine, it may be async. m.forwardBackward( - converter(data_batch), outArgs, trainRole, updater_callback) + converter(data_batch), outArgs, pass_type, updater_callback) + # Get cost. We use numpy to calculate total cost for this batch. cost_vec = outArgs.getSlotValue(0) cost_vec = cost_vec.copyToNumpyMat() cost = cost_vec.sum() / len(data_batch) - print 'Batch id', batch_id, 'with cost=', cost + + # Make evaluator works. + m.eval(batch_evaluator) + + # Print logs. + print 'Pass id', pass_id, 'Batch id', batch_id, 'with cost=', \ + cost, batch_evaluator + + batch_evaluator.finish() + # Finish batch. + # * will clear gradient. + # * ensure all values should be updated. updater.finishBatch(cost) + # testing stage. use test data set to test current network. + test_evaluator.start() + test_data_generator = input_order_converter(read_from_mnist(test_file)) + for data_batch in generator_to_batch(test_data_generator, 128): + # in testing stage, only forward is needed. + m.forward(converter(data_batch), outArgs, api.PASS_TEST) + m.eval(test_evaluator) + + # print error rate for test data set + print 'Pass', pass_id, ' test evaluator: ', test_evaluator + test_evaluator.finish() updater.finishPass() m.finish() -- GitLab