From c8067ed586d47c30e22df37026ff87410f694258 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 16 Feb 2017 11:36:23 +0800 Subject: [PATCH] Rearrange event. --- demo/mnist/api_train_v2.py | 4 ++-- python/paddle/v2/__init__.py | 3 ++- python/paddle/v2/event.py | 12 ++++++++++++ python/paddle/v2/trainer.py | 18 ++++-------------- 4 files changed, 20 insertions(+), 17 deletions(-) create mode 100644 python/paddle/v2/event.py diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 31c3e9932b6..f925d70765e 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -36,7 +36,7 @@ def main(): learning_rate=0.01, learning_method=AdamOptimizer()) def event_handler(event): - if isinstance(event, paddle.trainer.EndIteration): + if isinstance(event, paddle.event.EndIteration): para = parameters['___fc_layer_2__.w0'] print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % ( event.pass_id, event.batch_id, event.cost, para.mean()) @@ -44,7 +44,7 @@ def main(): else: pass - trainer = paddle.trainer.SGDTrainer(update_equation=adam_optimizer) + trainer = paddle.trainer.SGD(update_equation=adam_optimizer) trainer.train(train_data_reader=train_reader, topology=model_config, diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index f50c36ce649..72f1168e94f 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -15,8 +15,9 @@ import optimizer import parameters import py_paddle.swig_paddle as api import trainer +import event -__all__ = ['optimizer', 'parameters', 'init', 'trainer'] +__all__ = ['optimizer', 'parameters', 'init', 'trainer', 'event'] def init(**kwargs): diff --git a/python/paddle/v2/event.py b/python/paddle/v2/event.py new file mode 100644 index 00000000000..04158f47652 --- /dev/null +++ b/python/paddle/v2/event.py @@ -0,0 +1,12 @@ +__all__ = ['EndIteration'] + + +class EndIteration(object): + """ + Event On One Batch Training Complete. + """ + + def __init__(self, pass_id, batch_id, cost): + self.pass_id = pass_id + self.batch_id = batch_id + self.cost = cost diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 63695be1566..a29c3a05f85 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -6,19 +6,9 @@ from py_paddle import DataProviderConverter from paddle.proto.ModelConfig_pb2 import ModelConfig from . import optimizer as v2_optimizer from . import parameters as v2_parameters +from . import event as v2_event -__all__ = ['ITrainer', 'SGDTrainer', 'EndIteration'] - - -class EndIteration(object): - """ - Event On One Batch Training Complete. - """ - - def __init__(self, pass_id, batch_id, cost): - self.pass_id = pass_id - self.batch_id = batch_id - self.cost = cost +__all__ = ['ITrainer', 'SGD'] def default_event_handler(event): @@ -35,7 +25,7 @@ class ITrainer(object): raise NotImplementedError() -class SGDTrainer(ITrainer): +class SGD(ITrainer): def __init__(self, update_equation): """ Simple SGD Trainer. @@ -110,7 +100,7 @@ class SGDTrainer(ITrainer): cost = cost_vec.sum() / len(data_batch) updater.finishBatch(cost) event_handler( - EndIteration( + v2_event.EndIteration( pass_id=pass_id, batch_id=batch_id, cost=cost)) updater.finishPass() -- GitLab