diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index c787a6122ad6c6331b29002881d331ec530ccb93..45a70bc84afa29c5f12d2a8dddf17b8034e1c541 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -21,28 +21,29 @@ def main(): adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01) + trainer = paddle.trainer.SGD(topology=cost, + parameters=parameters, + update_equation=adam_optimizer) + def event_handler(event): if isinstance(event, paddle.event.EndIteration): - if event.batch_id % 100 == 0: - print "Pass %d, Batch %d, Cost %f, %s" % ( - event.pass_id, event.batch_id, event.cost, event.metrics) + if event.batch_id % 1000 == 0: + result = trainer.test(reader=paddle.reader.batched( + paddle.dataset.mnist.test_creator(), batch_size=256)) + + print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics, + result.metrics) + else: pass - trainer = paddle.trainer.SGD(update_equation=adam_optimizer) trainer.train( reader=paddle.reader.batched( - paddle.reader.shuffle(paddle.dataset.mnist.train_creator(), - buf_size=8192), batch_size=32), - topology=cost, - parameters=parameters, - event_handler=event_handler, - data_types=[ # data_types will be removed, It should be in - # network topology - ('pixel', images.type), - ('label', label.type)], - reader_dict={'pixel': 0, 'label': 1} - ) + paddle.reader.shuffle( + paddle.dataset.mnist.train_creator(), buf_size=8192), + batch_size=32), + event_handler=event_handler) if __name__ == '__main__': diff --git a/python/paddle/v2/event.py b/python/paddle/v2/event.py index 835e28e6218df22e1cad7f7bb31c3c9941657252..a78bcf076cc65e0dfdfc5760e099900418162f35 100644 --- a/python/paddle/v2/event.py +++ b/python/paddle/v2/event.py @@ -11,7 +11,10 @@ There are: TODO(yuyang18): Complete it! """ import py_paddle.swig_paddle as api -__all__ = ['EndIteration', 'BeginIteration', 'BeginPass', 'EndPass'] + +__all__ = [ + 'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult' +] class WithMetric(object): @@ -30,6 +33,11 @@ class WithMetric(object): return retv +class TestResult(WithMetric): + def __init__(self, evaluator): + super(TestResult, self).__init__(evaluator) + + class BeginPass(object): """ Event On One Pass Training Start. diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index f3a323db4fcfae9358dc23134971dcbda8e09344..77232d5ac5e411e7cb6c44236c7d1a7e9341af05 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -23,6 +23,13 @@ def default_event_handler(event): pass +def __bfs_travel_topology__(callback, *topologies): + for each_layer in topologies: + callback(each_layer) + __bfs_travel_topology__(callback, + *each_layer.__parent_layers__.values()) + + class ITrainer(object): """ The interface of Trainer. The only exposed method is `train`. @@ -43,26 +50,49 @@ class ITrainer(object): class SGD(ITrainer): - def __init__(self, update_equation): + def __init__(self, topology, parameters, update_equation): """ Simple SGD Trainer. :param update_equation: The optimizer object. :type update_equation: v2_optimizer.Optimizer """ + if not isinstance(parameters, v2_parameters.Parameters): + raise TypeError('parameters should be parameters') + if not isinstance(update_equation, v2_optimizer.Optimizer): - raise ValueError("update equation parameter must be " - "paddle.v2.optimizer.Optimizer") + raise TypeError("update equation parameter must be " + "paddle.v2.optimizer.Optimizer") self.__optimizer__ = update_equation + self.__topology__ = topology + self.__parameters__ = parameters + self.__topology_in_proto__ = v2_layer.parse_network(topology) + data_types = dict() + + def __travel__(l): + if hasattr(l, 'type'): + data_types[l.name] = l.type + + if not isinstance(topology, collections.Sequence): + topology = [topology] + __bfs_travel_topology__(__travel__, *topology) + self.__data_types__ = [ + (iname, data_types[iname]) + for iname in self.__topology_in_proto__.input_layer_names + ] + + if not isinstance(self.__topology_in_proto__, ModelConfig): + raise TypeError('topology should be a model config') - def train(self, - reader, - topology, - parameters, - num_passes=1, - event_handler=None, - data_types=None, - reader_dict=None): + gm = api.GradientMachine.createFromConfigProto( + self.__topology_in_proto__, api.CREATE_MODE_NORMAL, + self.__optimizer__.enable_types()) + assert isinstance(gm, api.GradientMachine) + parameters.append_gradient_machine(gm) + self.__gradient_machine__ = gm + self.__gradient_machine__.randParameters() + + def train(self, reader, num_passes=1, event_handler=None, reader_dict=None): """ Training method. Will train num_passes of input data. @@ -79,26 +109,21 @@ class SGD(ITrainer): if event_handler is None: event_handler = default_event_handler - topology = v2_layer.parse_network(topology) + if reader_dict is None: + reader_dict = self.default_reader_dict() __check_train_args__(**locals()) - - gm = api.GradientMachine.createFromConfigProto( - topology, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types()) - assert isinstance(gm, api.GradientMachine) - parameters.append_gradient_machine(gm) - gm.randParameters() updater = self.__optimizer__.create_local_updater() - updater.init(gm) + updater.init(self.__gradient_machine__) - gm.start() - batch_evaluator = gm.makeEvaluator() + self.__gradient_machine__.start() + batch_evaluator = self.__gradient_machine__.makeEvaluator() assert isinstance(batch_evaluator, api.Evaluator) - pass_evaluator = gm.makeEvaluator() + pass_evaluator = self.__gradient_machine__.makeEvaluator() assert isinstance(pass_evaluator, api.Evaluator) out_args = api.Arguments.createArguments(0) - feeder = DataFeeder(data_types, reader_dict) + feeder = DataFeeder(self.__data_types__, reader_dict) for pass_id in xrange(num_passes): event_handler(v2_event.BeginPass(pass_id)) @@ -106,16 +131,18 @@ class SGD(ITrainer): updater.startPass() for batch_id, data_batch in enumerate(reader()): pass_type = updater.startBatch(len(data_batch)) - gm.forwardBackward(feeder(data_batch), out_args, pass_type) + self.__gradient_machine__.forwardBackward( + feeder(data_batch), out_args, pass_type) batch_evaluator.start() event_handler( v2_event.BeginIteration( pass_id=pass_id, batch_id=batch_id)) pass_type = updater.startBatch(len(data_batch)) - gm.forwardBackward(feeder(data_batch), out_args, pass_type) - gm.eval(pass_evaluator) - gm.eval(batch_evaluator) - for each_param in gm.getParameters(): + self.__gradient_machine__.forwardBackward( + feeder(data_batch), out_args, pass_type) + self.__gradient_machine__.eval(pass_evaluator) + self.__gradient_machine__.eval(batch_evaluator) + for each_param in self.__gradient_machine__.getParameters(): updater.update(each_param) # Get cost. We use numpy to calculate total cost for this batch. cost_vec = out_args.getSlotValue(0) @@ -133,10 +160,32 @@ class SGD(ITrainer): updater.finishPass() pass_evaluator.finish() event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) - gm.finish() + self.__gradient_machine__.finish() + + def default_reader_dict(self): + reader_dict = dict() + for i, tp in enumerate(self.__data_types__): + reader_dict[tp[0]] = i + return reader_dict + def test(self, reader, reader_dict=None): + if reader_dict is None: + reader_dict = self.default_reader_dict() + + feeder = DataFeeder(self.__data_types__, reader_dict) + evaluator = self.__gradient_machine__.makeEvaluator() + out_args = api.Arguments.createArguments(0) + evaluator.start() + for data_batch in reader(): + self.__gradient_machine__.forward( + feeder(data_batch), out_args, api.PASS_TEST) + self.__gradient_machine__.eval(evaluator) -def __check_train_args__(reader, topology, parameters, event_handler, **kwargs): + evaluator.finish() + return v2_event.TestResult(evaluator=evaluator) + + +def __check_train_args__(reader, event_handler, **kwargs): """ Check train function's argument types """ @@ -144,11 +193,5 @@ def __check_train_args__(reader, topology, parameters, event_handler, **kwargs): raise TypeError('train_data_reader should be a function, ' 'which can return a iterator') - if not isinstance(topology, ModelConfig): - raise TypeError('topology should be a model config') - - if not isinstance(parameters, v2_parameters.Parameters): - raise TypeError('parameters should be a parameter pool') - if not callable(event_handler): raise TypeError('event handler should be a function')