提交 d34516fb 编写于 作者: 武毅 提交者: GitHub

Get output when training (#3978)

* get output when training

* follow comments
上级 0f7bc938
...@@ -53,10 +53,13 @@ class BeginPass(object): ...@@ -53,10 +53,13 @@ class BeginPass(object):
class EndPass(WithMetric): class EndPass(WithMetric):
""" """
Event On One Pass Training Complete. Event On One Pass Training Complete.
To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')"
in your event_handler call back
""" """
def __init__(self, pass_id, evaluator): def __init__(self, pass_id, evaluator, gm):
self.pass_id = pass_id self.pass_id = pass_id
self.gm = gm
WithMetric.__init__(self, evaluator) WithMetric.__init__(self, evaluator)
...@@ -73,10 +76,13 @@ class BeginIteration(object): ...@@ -73,10 +76,13 @@ class BeginIteration(object):
class EndIteration(WithMetric): class EndIteration(WithMetric):
""" """
Event On One Batch Training Complete. Event On One Batch Training Complete.
To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')"
in your event_handler call back
""" """
def __init__(self, pass_id, batch_id, cost, evaluator): def __init__(self, pass_id, batch_id, cost, evaluator, gm):
self.pass_id = pass_id self.pass_id = pass_id
self.batch_id = batch_id self.batch_id = batch_id
self.cost = cost self.cost = cost
self.gm = gm
WithMetric.__init__(self, evaluator) WithMetric.__init__(self, evaluator)
...@@ -174,13 +174,18 @@ class SGD(object): ...@@ -174,13 +174,18 @@ class SGD(object):
pass_id=pass_id, pass_id=pass_id,
batch_id=batch_id, batch_id=batch_id,
cost=cost, cost=cost,
evaluator=batch_evaluator)) evaluator=batch_evaluator,
gm=self.__gradient_machine__))
self.__parameter_updater__.finishBatch(cost) self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish() batch_evaluator.finish()
self.__parameter_updater__.finishPass() self.__parameter_updater__.finishPass()
pass_evaluator.finish() pass_evaluator.finish()
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) event_handler(
v2_event.EndPass(
pass_id,
evaluator=pass_evaluator,
gm=self.__gradient_machine__))
self.__gradient_machine__.finish() self.__gradient_machine__.finish()
def test(self, reader, feeding=None): def test(self, reader, feeding=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册