diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 650bf392bbc73415d9033f8c8134d90fd05f0cc2..767dffdad82f59d73b1505586260a61f5008f1f8 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -27,19 +27,14 @@ def main(): cost = paddle.layer.classification_cost(input=inference, label=label) parameters = paddle.parameters.create(cost) - for param_name in parameters.keys(): - array = parameters.get(param_name) - array[:] = numpy.random.uniform(low=-1.0, high=1.0, size=array.shape) - parameters.set(parameter_name=param_name, value=array) adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01) def event_handler(event): if isinstance(event, paddle.event.EndIteration): - para = parameters.get('___fc_2__.w0') - print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % ( - event.pass_id, event.batch_id, event.cost, para.mean()) - + if event.batch_id % 100 == 0: + print "Pass %d, Batch %d, Cost %f, %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics) else: pass diff --git a/python/paddle/v2/event.py b/python/paddle/v2/event.py index a16cfa91f062a60a141ea8fa962b3ecf6f5f0a22..835e28e6218df22e1cad7f7bb31c3c9941657252 100644 --- a/python/paddle/v2/event.py +++ b/python/paddle/v2/event.py @@ -3,8 +3,6 @@ All training events. There are: -* BeginTraining -* EndTraining * BeginIteration * EndIteration * BeginPass @@ -12,15 +10,62 @@ There are: TODO(yuyang18): Complete it! """ -__all__ = ['EndIteration'] +import py_paddle.swig_paddle as api +__all__ = ['EndIteration', 'BeginIteration', 'BeginPass', 'EndPass'] -class EndIteration(object): +class WithMetric(object): + def __init__(self, evaluator): + if not isinstance(evaluator, api.Evaluator): + raise TypeError("Evaluator should be api.Evaluator type") + self.__evaluator__ = evaluator + + @property + def metrics(self): + names = self.__evaluator__.getNames() + retv = dict() + for each_name in names: + val = self.__evaluator__.getValue(each_name) + retv[each_name] = val + return retv + + +class BeginPass(object): + """ + Event On One Pass Training Start. + """ + + def __init__(self, pass_id): + self.pass_id = pass_id + + +class EndPass(WithMetric): + """ + Event On One Pass Training Complete. + """ + + def __init__(self, pass_id, evaluator): + self.pass_id = pass_id + WithMetric.__init__(self, evaluator) + + +class BeginIteration(object): + """ + Event On One Batch Training Start. + """ + + def __init__(self, pass_id, batch_id): + self.pass_id = pass_id + self.batch_id = batch_id + + +class EndIteration(WithMetric): """ Event On One Batch Training Complete. """ - def __init__(self, pass_id, batch_id, cost): + def __init__(self, pass_id, batch_id, cost, evaluator): self.pass_id = pass_id self.batch_id = batch_id self.cost = cost + WithMetric.__init__(self, evaluator) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 023ab5e42d25b9f70827b1e2efba985a5442db1f..097814d2f4619797470668cbd0ea95f112a1fde6 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -97,22 +97,34 @@ class SGD(ITrainer): topology, api.CREATE_MODE_NORMAL, self.__optimizer__.enable_types()) assert isinstance(gm, api.GradientMachine) parameters.append_gradient_machine(gm) - + gm.randParameters() updater = self.__optimizer__.create_local_updater() updater.init(gm) gm.start() + batch_evaluator = gm.makeEvaluator() + assert isinstance(batch_evaluator, api.Evaluator) + pass_evaluator = gm.makeEvaluator() + assert isinstance(pass_evaluator, api.Evaluator) out_args = api.Arguments.createArguments(0) feeder = DataFeeder(data_types, reader_dict) for pass_id in xrange(num_passes): + event_handler(v2_event.BeginPass(pass_id)) + pass_evaluator.start() updater.startPass() for batch_id, data_batch in enumerate( __data_reader_to_batch__(train_data_reader, batch_size, topology)): + batch_evaluator.start() + event_handler( + v2_event.BeginIteration( + pass_id=pass_id, batch_id=batch_id)) pass_type = updater.startBatch(len(data_batch)) gm.forwardBackward(feeder(data_batch), out_args, pass_type) + gm.eval(pass_evaluator) + gm.eval(batch_evaluator) for each_param in gm.getParameters(): updater.update(each_param) # Get cost. We use numpy to calculate total cost for this batch. @@ -120,11 +132,17 @@ class SGD(ITrainer): cost_vec = cost_vec.copyToNumpyMat() cost = cost_vec.sum() / len(data_batch) updater.finishBatch(cost) + batch_evaluator.finish() event_handler( v2_event.EndIteration( - pass_id=pass_id, batch_id=batch_id, cost=cost)) + pass_id=pass_id, + batch_id=batch_id, + cost=cost, + evaluator=batch_evaluator)) updater.finishPass() + pass_evaluator.finish() + event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) gm.finish()