From 173a81b56b60567eb4e30f66331ebab4e004ead7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 23 Feb 2017 16:01:20 +0800 Subject: [PATCH] Complete Event, Add Metric to Event. --- demo/mnist/api_train_v2.py | 14 ++++------ python/paddle/v2/event.py | 55 +++++++++++++++++++++++++++++++++---- python/paddle/v2/trainer.py | 22 +++++++++++++-- 3 files changed, 75 insertions(+), 16 deletions(-) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 6fc01ce58be..6b0cfa0b05c 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -27,19 +27,14 @@ def main(): cost = paddle.layer.classification_cost(input=inference, label=label) parameters = paddle.parameters.create(cost) - for param_name in parameters.keys(): - array = parameters.get(param_name) - array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape) - parameters.set(parameter_name=param_name, value=array) adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01) def event_handler(event): if isinstance(event, paddle.event.EndIteration): - para = parameters.get('___fc_2__.w0') - print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % ( - event.pass_id, event.batch_id, event.cost, para.mean()) - + if event.batch_id % 100 == 0: + print "Pass %d, Batch %d, Cost %f, %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics) else: pass @@ -49,7 +44,8 @@ def main(): topology=cost, parameters=parameters, event_handler=event_handler, - batch_size=32, # batch size should be refactor in Data reader + num_passes=100, + batch_size=200, # batch size should be refactor in Data reader data_types={ # data_types will be removed, It should be in # network topology 'pixel': images.type, diff --git a/python/paddle/v2/event.py b/python/paddle/v2/event.py index a16cfa91f06..835e28e6218 100644 --- a/python/paddle/v2/event.py +++ b/python/paddle/v2/event.py @@ -3,8 +3,6 @@ All training events. There are: -* BeginTraining -* EndTraining * BeginIteration * EndIteration * BeginPass @@ -12,15 +10,62 @@ There are: TODO(yuyang18): Complete it! """ -__all__ = ['EndIteration'] +import py_paddle.swig_paddle as api +__all__ = ['EndIteration', 'BeginIteration', 'BeginPass', 'EndPass'] -class EndIteration(object): +class WithMetric(object): + def __init__(self, evaluator): + if not isinstance(evaluator, api.Evaluator): + raise TypeError("Evaluator should be api.Evaluator type") + self.__evaluator__ = evaluator + + @property + def metrics(self): + names = self.__evaluator__.getNames() + retv = dict() + for each_name in names: + val = self.__evaluator__.getValue(each_name) + retv[each_name] = val + return retv + + +class BeginPass(object): + """ + Event On One Pass Training Start. + """ + + def __init__(self, pass_id): + self.pass_id = pass_id + + +class EndPass(WithMetric): + """ + Event On One Pass Training Complete. + """ + + def __init__(self, pass_id, evaluator): + self.pass_id = pass_id + WithMetric.__init__(self, evaluator) + + +class BeginIteration(object): + """ + Event On One Batch Training Start. + """ + + def __init__(self, pass_id, batch_id): + self.pass_id = pass_id + self.batch_id = batch_id + + +class EndIteration(WithMetric): """ Event On One Batch Training Complete. """ - def __init__(self, pass_id, batch_id, cost): + def __init__(self, pass_id, batch_id, cost, evaluator): self.pass_id = pass_id self.batch_id = batch_id self.cost = cost + WithMetric.__init__(self, evaluator) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 4365bd41e70..0acfcee2cee 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -96,11 +96,15 @@ class SGD(ITrainer): topology, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types()) assert isinstance(gm, api.GradientMachine) parameters.append_gradient_machine(gm) - + gm.randParameters() updater = self.__optimizer__.create_local_updater() updater.init(gm) gm.start() + batch_evaluator = gm.makeEvaluator() + assert isinstance(batch_evaluator, api.Evaluator) + pass_evaluator = gm.makeEvaluator() + assert isinstance(pass_evaluator, api.Evaluator) out_args = api.Arguments.createArguments(0) data_types_lists = [] @@ -112,12 +116,20 @@ class SGD(ITrainer): converter = DataProviderConverter(input_types=data_types_lists) for pass_id in xrange(num_passes): + event_handler(v2_event.BeginPass(pass_id)) + pass_evaluator.start() updater.startPass() for batch_id, data_batch in enumerate( __data_reader_to_batch__(train_data_reader, batch_size, topology)): + batch_evaluator.start() + event_handler( + v2_event.BeginIteration( + pass_id=pass_id, batch_id=batch_id)) pass_type = updater.startBatch(len(data_batch)) gm.forwardBackward(converter(data_batch), out_args, pass_type) + gm.eval(pass_evaluator) + gm.eval(batch_evaluator) for each_param in gm.getParameters(): updater.update(each_param) # Get cost. We use numpy to calculate total cost for this batch. @@ -125,11 +137,17 @@ class SGD(ITrainer): cost_vec = cost_vec.copyToNumpyMat() cost = cost_vec.sum() / len(data_batch) updater.finishBatch(cost) + batch_evaluator.finish() event_handler( v2_event.EndIteration( - pass_id=pass_id, batch_id=batch_id, cost=cost)) + pass_id=pass_id, + batch_id=batch_id, + cost=cost, + evaluator=batch_evaluator)) updater.finishPass() + pass_evaluator.finish() + event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) gm.finish() -- GitLab