diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 612b0d218fc0be68b668ffb8f1bfb0ba92c4d741..a59b30ccdb2eddea6680d6ad5c790c857b9c5141 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -20,26 +20,29 @@ def main(): adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01) + trainer = paddle.trainer.SGD(cost=cost, + parameters=parameters, + update_equation=adam_optimizer) + def event_handler(event): if isinstance(event, paddle.event.EndIteration): - if event.batch_id % 100 == 0: - print "Pass %d, Batch %d, Cost %f, %s" % ( - event.pass_id, event.batch_id, event.cost, event.metrics) + if event.batch_id % 1000 == 0: + result = trainer.test(reader=paddle.reader.batched( + paddle.dataset.mnist.test(), batch_size=256)) + + print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics, + result.metrics) + else: pass - trainer = paddle.trainer.SGD(update_equation=adam_optimizer) - trainer.train( reader=paddle.reader.batched( paddle.reader.shuffle( paddle.dataset.mnist.train(), buf_size=8192), batch_size=32), - cost=cost, - parameters=parameters, - event_handler=event_handler, - reader_dict={images.name: 0, - label.name: 1}) + event_handler=event_handler) if __name__ == '__main__': diff --git a/python/paddle/v2/dataset/mnist.py b/python/paddle/v2/dataset/mnist.py index f1315b35cd55c5387295f1f883b997cd6dd71bd1..1512a3c3189de4e54f8502cfadf450b0710a246e 100644 --- a/python/paddle/v2/dataset/mnist.py +++ b/python/paddle/v2/dataset/mnist.py @@ -9,9 +9,9 @@ __all__ = ['train', 'test'] URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/' TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz' -TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6' +TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3' TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz' -TEST_LABEL_MD5 = '4e9511fe019b2189026bd0421ba7b688' +TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c' TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz' TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873' TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz' diff --git a/python/paddle/v2/event.py b/python/paddle/v2/event.py index 835e28e6218df22e1cad7f7bb31c3c9941657252..a78bcf076cc65e0dfdfc5760e099900418162f35 100644 --- a/python/paddle/v2/event.py +++ b/python/paddle/v2/event.py @@ -11,7 +11,10 @@ There are: TODO(yuyang18): Complete it! """ import py_paddle.swig_paddle as api -__all__ = ['EndIteration', 'BeginIteration', 'BeginPass', 'EndPass'] + +__all__ = [ + 'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult' +] class WithMetric(object): @@ -30,6 +33,11 @@ class WithMetric(object): return retv +class TestResult(WithMetric): + def __init__(self, evaluator): + super(TestResult, self).__init__(evaluator) + + class BeginPass(object): """ Event On One Pass Training Start. diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index a51b1073b4fc4fd3ac44c355e050b0d720944645..16fc92e63d98cfa714fd9a0a94f7f10385374f80 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -21,6 +21,14 @@ import layer as v2_layer __all__ = ['Topology'] +def __bfs_travel__(callback, *layers): + for each_layer in layers: + __break__ = callback(each_layer) + if __break__: + return + __bfs_travel__(callback, *each_layer.__parent_layers__.values()) + + class Topology(object): """ Topology is used to store the information about all layers @@ -46,21 +54,17 @@ class Topology(object): :param name: :return: """ - result_layer = [] + result_layer = [None] - def find_layer_by_name(layer, layer_name): - if len(result_layer) == 1: - return - elif layer.name == layer_name: - result_layer.append(layer) - else: - for parent_layer in layer.__parent_layers__.values(): - find_layer_by_name(parent_layer, layer_name) + def __impl__(l): + if l.name == name: + result_layer[0] = l + return True # break + return False - for layer in self.layers: - find_layer_by_name(layer, name) - - assert len(result_layer) == 1 + __bfs_travel__(__impl__, *self.layers) + if result_layer[0] is None: + raise ValueError("No such layer %s" % name) return result_layer[0] def data_layers(self): @@ -68,17 +72,13 @@ class Topology(object): get all data layer :return: """ - data_layers = set() - - def find_data_layer(layer): - if isinstance(layer, v2_layer.DataLayerV2): - data_layers.add(layer) - for parent_layer in layer.__parent_layers__.values(): - find_data_layer(parent_layer) + data_layers = dict() - for layer in self.layers: - find_data_layer(layer) + def __impl__(l): + if isinstance(l, v2_layer.DataLayerV2): + data_layers[l.name] = l + __bfs_travel__(__impl__, *self.layers) return data_layers def data_type(self): @@ -86,8 +86,9 @@ class Topology(object): get data_type from proto, such as: [('image', dense_vector(768)), ('label', integer_value(10))] """ - return [(data_layer.name, data_layer.type) - for data_layer in self.data_layers()] + data_layers = self.data_layers() + return [(nm, data_layers[nm].type) + for nm in self.proto().input_layer_names] def __check_layer_type__(layer): diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index bf8b181e42064f01a78b92313805b5fed3a3ceac..5003f55f3e0d15149d28d1478e0487d6873d6e0a 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -42,25 +42,35 @@ class ITrainer(object): class SGD(ITrainer): - def __init__(self, update_equation): + def __init__(self, cost, parameters, update_equation): """ Simple SGD Trainer. :param update_equation: The optimizer object. :type update_equation: v2_optimizer.Optimizer """ + + if not isinstance(parameters, v2_parameters.Parameters): + raise TypeError('parameters should be parameters') + if not isinstance(update_equation, v2_optimizer.Optimizer): - raise ValueError("update equation parameter must be " - "paddle.v2.optimizer.Optimizer") + raise TypeError("update equation parameter must be " + "paddle.v2.optimizer.Optimizer") + topology = Topology(cost) self.__optimizer__ = update_equation + self.__topology__ = topology + self.__parameters__ = parameters + self.__topology_in_proto__ = topology.proto() + self.__data_types__ = topology.data_type() + gm = api.GradientMachine.createFromConfigProto( + self.__topology_in_proto__, api.CREATE_MODE_NORMAL, + self.__optimizer__.enable_types()) + assert isinstance(gm, api.GradientMachine) + parameters.append_gradient_machine(gm) + self.__gradient_machine__ = gm + self.__gradient_machine__.randParameters() - def train(self, - reader, - cost, - parameters, - num_passes=1, - event_handler=None, - reader_dict=None): + def train(self, reader, num_passes=1, event_handler=None, reader_dict=None): """ Training method. Will train num_passes of input data. @@ -76,27 +86,22 @@ class SGD(ITrainer): if event_handler is None: event_handler = default_event_handler - topology = Topology(cost) + if reader_dict is None: + reader_dict = self.default_reader_dict() __check_train_args__(**locals()) - gm = api.GradientMachine.createFromConfigProto( - topology.proto(), api.CREATE_MODE_NORMAL, - self.__optimizer__.enable_types()) - assert isinstance(gm, api.GradientMachine) - parameters.append_gradient_machine(gm) - gm.randParameters() updater = self.__optimizer__.create_local_updater() - updater.init(gm) + updater.init(self.__gradient_machine__) - gm.start() - batch_evaluator = gm.makeEvaluator() + self.__gradient_machine__.start() + batch_evaluator = self.__gradient_machine__.makeEvaluator() assert isinstance(batch_evaluator, api.Evaluator) - pass_evaluator = gm.makeEvaluator() + pass_evaluator = self.__gradient_machine__.makeEvaluator() assert isinstance(pass_evaluator, api.Evaluator) out_args = api.Arguments.createArguments(0) - feeder = DataFeeder(topology.data_type(), reader_dict) + feeder = DataFeeder(self.__data_types__, reader_dict) for pass_id in xrange(num_passes): event_handler(v2_event.BeginPass(pass_id)) @@ -104,16 +109,18 @@ class SGD(ITrainer): updater.startPass() for batch_id, data_batch in enumerate(reader()): pass_type = updater.startBatch(len(data_batch)) - gm.forwardBackward(feeder(data_batch), out_args, pass_type) + self.__gradient_machine__.forwardBackward( + feeder(data_batch), out_args, pass_type) batch_evaluator.start() event_handler( v2_event.BeginIteration( pass_id=pass_id, batch_id=batch_id)) pass_type = updater.startBatch(len(data_batch)) - gm.forwardBackward(feeder(data_batch), out_args, pass_type) - gm.eval(pass_evaluator) - gm.eval(batch_evaluator) - for each_param in gm.getParameters(): + self.__gradient_machine__.forwardBackward( + feeder(data_batch), out_args, pass_type) + self.__gradient_machine__.eval(pass_evaluator) + self.__gradient_machine__.eval(batch_evaluator) + for each_param in self.__gradient_machine__.getParameters(): updater.update(each_param) # Get cost. We use numpy to calculate total cost for this batch. cost_vec = out_args.getSlotValue(0) @@ -131,22 +138,37 @@ class SGD(ITrainer): updater.finishPass() pass_evaluator.finish() event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) - gm.finish() + self.__gradient_machine__.finish() + + def default_reader_dict(self): + reader_dict = dict() + for i, tp in enumerate(self.__data_types__): + reader_dict[tp[0]] = i + return reader_dict + + def test(self, reader, reader_dict=None): + if reader_dict is None: + reader_dict = self.default_reader_dict() + + feeder = DataFeeder(self.__data_types__, reader_dict) + evaluator = self.__gradient_machine__.makeEvaluator() + out_args = api.Arguments.createArguments(0) + evaluator.start() + for data_batch in reader(): + self.__gradient_machine__.forward( + feeder(data_batch), out_args, api.PASS_TEST) + self.__gradient_machine__.eval(evaluator) + + evaluator.finish() + return v2_event.TestResult(evaluator=evaluator) -def __check_train_args__(reader, topology, parameters, event_handler, **kwargs): +def __check_train_args__(reader, event_handler, **kwargs): """ Check train function's argument types """ if not callable(reader) or not isinstance(reader(), collections.Iterator): raise TypeError('train_data_reader should be a function, ' 'which can return a iterator') - - if not isinstance(topology, Topology): - raise TypeError('topology should be a model config') - - if not isinstance(parameters, v2_parameters.Parameters): - raise TypeError('parameters should be a parameter pool') - if not callable(event_handler): raise TypeError('event handler should be a function')