提交 c8067ed5 编写于 作者: Y Yu Yang

Rearrange event.

上级 059a162b
...@@ -36,7 +36,7 @@ def main(): ...@@ -36,7 +36,7 @@ def main():
learning_rate=0.01, learning_method=AdamOptimizer()) learning_rate=0.01, learning_method=AdamOptimizer())
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.trainer.EndIteration): if isinstance(event, paddle.event.EndIteration):
para = parameters['___fc_layer_2__.w0'] para = parameters['___fc_layer_2__.w0']
print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % ( print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % (
event.pass_id, event.batch_id, event.cost, para.mean()) event.pass_id, event.batch_id, event.cost, para.mean())
...@@ -44,7 +44,7 @@ def main(): ...@@ -44,7 +44,7 @@ def main():
else: else:
pass pass
trainer = paddle.trainer.SGDTrainer(update_equation=adam_optimizer) trainer = paddle.trainer.SGD(update_equation=adam_optimizer)
trainer.train(train_data_reader=train_reader, trainer.train(train_data_reader=train_reader,
topology=model_config, topology=model_config,
......
...@@ -15,8 +15,9 @@ import optimizer ...@@ -15,8 +15,9 @@ import optimizer
import parameters import parameters
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
import trainer import trainer
import event
__all__ = ['optimizer', 'parameters', 'init', 'trainer'] __all__ = ['optimizer', 'parameters', 'init', 'trainer', 'event']
def init(**kwargs): def init(**kwargs):
......
__all__ = ['EndIteration']
class EndIteration(object):
"""
Event On One Batch Training Complete.
"""
def __init__(self, pass_id, batch_id, cost):
self.pass_id = pass_id
self.batch_id = batch_id
self.cost = cost
...@@ -6,19 +6,9 @@ from py_paddle import DataProviderConverter ...@@ -6,19 +6,9 @@ from py_paddle import DataProviderConverter
from paddle.proto.ModelConfig_pb2 import ModelConfig from paddle.proto.ModelConfig_pb2 import ModelConfig
from . import optimizer as v2_optimizer from . import optimizer as v2_optimizer
from . import parameters as v2_parameters from . import parameters as v2_parameters
from . import event as v2_event
__all__ = ['ITrainer', 'SGDTrainer', 'EndIteration'] __all__ = ['ITrainer', 'SGD']
class EndIteration(object):
"""
Event On One Batch Training Complete.
"""
def __init__(self, pass_id, batch_id, cost):
self.pass_id = pass_id
self.batch_id = batch_id
self.cost = cost
def default_event_handler(event): def default_event_handler(event):
...@@ -35,7 +25,7 @@ class ITrainer(object): ...@@ -35,7 +25,7 @@ class ITrainer(object):
raise NotImplementedError() raise NotImplementedError()
class SGDTrainer(ITrainer): class SGD(ITrainer):
def __init__(self, update_equation): def __init__(self, update_equation):
""" """
Simple SGD Trainer. Simple SGD Trainer.
...@@ -110,7 +100,7 @@ class SGDTrainer(ITrainer): ...@@ -110,7 +100,7 @@ class SGDTrainer(ITrainer):
cost = cost_vec.sum() / len(data_batch) cost = cost_vec.sum() / len(data_batch)
updater.finishBatch(cost) updater.finishBatch(cost)
event_handler( event_handler(
EndIteration( v2_event.EndIteration(
pass_id=pass_id, batch_id=batch_id, cost=cost)) pass_id=pass_id, batch_id=batch_id, cost=cost))
updater.finishPass() updater.finishPass()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册