提交 059a162b 编写于 作者: Y Yu Yang

Follow comments

上级 094d29aa
...@@ -36,7 +36,7 @@ def main(): ...@@ -36,7 +36,7 @@ def main():
learning_rate=0.01, learning_method=AdamOptimizer()) learning_rate=0.01, learning_method=AdamOptimizer())
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.trainer.CompleteTrainOneBatch): if isinstance(event, paddle.trainer.EndIteration):
para = parameters['___fc_layer_2__.w0'] para = parameters['___fc_layer_2__.w0']
print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % ( print "Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f" % (
event.pass_id, event.batch_id, event.cost, para.mean()) event.pass_id, event.batch_id, event.cost, para.mean())
......
...@@ -49,6 +49,9 @@ class Parameters(object): ...@@ -49,6 +49,9 @@ class Parameters(object):
def has_key(self, key): def has_key(self, key):
return key in self.__param_conf__.keys() return key in self.__param_conf__.keys()
def __iter__(self):
return iter(self.__param_conf__)
def __getitem__(self, key): def __getitem__(self, key):
shape = self.get_shape(key) shape = self.get_shape(key)
......
...@@ -7,17 +7,10 @@ from paddle.proto.ModelConfig_pb2 import ModelConfig ...@@ -7,17 +7,10 @@ from paddle.proto.ModelConfig_pb2 import ModelConfig
from . import optimizer as v2_optimizer from . import optimizer as v2_optimizer
from . import parameters as v2_parameters from . import parameters as v2_parameters
__all__ = ['ITrainer', 'SGDTrainer', 'CompleteTrainOneBatch', 'BaseEvent'] __all__ = ['ITrainer', 'SGDTrainer', 'EndIteration']
class BaseEvent(object): class EndIteration(object):
"""
Just a marker class
"""
pass
class CompleteTrainOneBatch(BaseEvent):
""" """
Event On One Batch Training Complete. Event On One Batch Training Complete.
""" """
...@@ -117,7 +110,7 @@ class SGDTrainer(ITrainer): ...@@ -117,7 +110,7 @@ class SGDTrainer(ITrainer):
cost = cost_vec.sum() / len(data_batch) cost = cost_vec.sum() / len(data_batch)
updater.finishBatch(cost) updater.finishBatch(cost)
event_handler( event_handler(
CompleteTrainOneBatch( EndIteration(
pass_id=pass_id, batch_id=batch_id, cost=cost)) pass_id=pass_id, batch_id=batch_id, cost=cost))
updater.finishPass() updater.finishPass()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册