From 059a162be5a3b7a6cb80f41e1dd73ccc39df0181 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 16 Feb 2017 11:27:25 +0800 Subject: [PATCH] Follow comments --- demo/mnist/api_train_v2.py | 2 +- python/paddle/v2/parameters.py | 3 +++ python/paddle/v2/trainer.py | 13 +++---------- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 8618a8f2111..31c3e9932b6 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -36,7 +36,7 @@ def main(): learning_rate=0.01, learning_method=AdamOptimizer()) def event_handler(event): - if isinstance(event, paddle.trainer.CompleteTrainOneBatch): + if isinstance(event, paddle.trainer.EndIteration): para = parameters['___fc_layer_2__.w0'] print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % ( event.pass_id, event.batch_id, event.cost, para.mean()) diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index a30bf5d3631..c2e74b8fb12 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -49,6 +49,9 @@ class Parameters(object): def has_key(self, key): return key in self.__param_conf__.keys() + def __iter__(self): + return iter(self.__param_conf__) + def __getitem__(self, key): shape = self.get_shape(key) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index baed7d0025b..63695be1566 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -7,17 +7,10 @@ from paddle.proto.ModelConfig_pb2 import ModelConfig from . import optimizer as v2_optimizer from . import parameters as v2_parameters -__all__ = ['ITrainer', 'SGDTrainer', 'CompleteTrainOneBatch', 'BaseEvent'] +__all__ = ['ITrainer', 'SGDTrainer', 'EndIteration'] -class BaseEvent(object): - """ - Just a marker class - """ - pass - - -class CompleteTrainOneBatch(BaseEvent): +class EndIteration(object): """ Event On One Batch Training Complete. """ @@ -117,7 +110,7 @@ class SGDTrainer(ITrainer): cost = cost_vec.sum() / len(data_batch) updater.finishBatch(cost) event_handler( - CompleteTrainOneBatch( + EndIteration( pass_id=pass_id, batch_id=batch_id, cost=cost)) updater.finishPass() -- GitLab