diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 31c3e9932b68d436ccddea9a83d119700c665b7d..f925d70765e1c6700ca5ab4b8cf9369743e43947 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -36,7 +36,7 @@ def main(): learning_rate=0.01, learning_method=AdamOptimizer()) def event_handler(event): - if isinstance(event, paddle.trainer.EndIteration): + if isinstance(event, paddle.event.EndIteration): para = parameters['___fc_layer_2__.w0'] print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % ( event.pass_id, event.batch_id, event.cost, para.mean()) @@ -44,7 +44,7 @@ def main(): else: pass - trainer = paddle.trainer.SGDTrainer(update_equation=adam_optimizer) + trainer = paddle.trainer.SGD(update_equation=adam_optimizer) trainer.train(train_data_reader=train_reader, topology=model_config, diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index f50c36ce64995bbbbd6fe96b123bdcb711fd5248..72f1168e94f2a7de627551486b7dd6a5bc92940c 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -15,8 +15,9 @@ import optimizer import parameters import py_paddle.swig_paddle as api import trainer +import event -__all__ = ['optimizer', 'parameters', 'init', 'trainer'] +__all__ = ['optimizer', 'parameters', 'init', 'trainer', 'event'] def init(**kwargs): diff --git a/python/paddle/v2/event.py b/python/paddle/v2/event.py new file mode 100644 index 0000000000000000000000000000000000000000..04158f4765299de8d7045af88961ac54b94f08a3 --- /dev/null +++ b/python/paddle/v2/event.py @@ -0,0 +1,12 @@ +__all__ = ['EndIteration'] + + +class EndIteration(object): + """ + Event On One Batch Training Complete. + """ + + def __init__(self, pass_id, batch_id, cost): + self.pass_id = pass_id + self.batch_id = batch_id + self.cost = cost diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 63695be1566094f63aea9d21932f373472f3b068..a29c3a05f853952131883d8f4fdd18d1438b7671 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -6,19 +6,9 @@ from py_paddle import DataProviderConverter from paddle.proto.ModelConfig_pb2 import ModelConfig from . import optimizer as v2_optimizer from . import parameters as v2_parameters +from . import event as v2_event -__all__ = ['ITrainer', 'SGDTrainer', 'EndIteration'] - - -class EndIteration(object): - """ - Event On One Batch Training Complete. - """ - - def __init__(self, pass_id, batch_id, cost): - self.pass_id = pass_id - self.batch_id = batch_id - self.cost = cost +__all__ = ['ITrainer', 'SGD'] def default_event_handler(event): @@ -35,7 +25,7 @@ class ITrainer(object): raise NotImplementedError() -class SGDTrainer(ITrainer): +class SGD(ITrainer): def __init__(self, update_equation): """ Simple SGD Trainer. @@ -110,7 +100,7 @@ class SGDTrainer(ITrainer): cost = cost_vec.sum() / len(data_batch) updater.finishBatch(cost) event_handler( - EndIteration( + v2_event.EndIteration( pass_id=pass_id, batch_id=batch_id, cost=cost)) updater.finishPass()