提交 d34516fb 编写于 作者: 武毅 提交者: GitHub

Get output when training (#3978)

* get output when training

* follow comments
上级 0f7bc938
......@@ -53,10 +53,13 @@ class BeginPass(object):
class EndPass(WithMetric):
"""
Event On One Pass Training Complete.
To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')"
in your event_handler call back
"""
def __init__(self, pass_id, evaluator):
def __init__(self, pass_id, evaluator, gm):
self.pass_id = pass_id
self.gm = gm
WithMetric.__init__(self, evaluator)
......@@ -73,10 +76,13 @@ class BeginIteration(object):
class EndIteration(WithMetric):
"""
Event On One Batch Training Complete.
To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')"
in your event_handler call back
"""
def __init__(self, pass_id, batch_id, cost, evaluator):
def __init__(self, pass_id, batch_id, cost, evaluator, gm):
self.pass_id = pass_id
self.batch_id = batch_id
self.cost = cost
self.gm = gm
WithMetric.__init__(self, evaluator)
......@@ -174,13 +174,18 @@ class SGD(object):
pass_id=pass_id,
batch_id=batch_id,
cost=cost,
evaluator=batch_evaluator))
evaluator=batch_evaluator,
gm=self.__gradient_machine__))
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
self.__parameter_updater__.finishPass()
pass_evaluator.finish()
event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator))
event_handler(
v2_event.EndPass(
pass_id,
evaluator=pass_evaluator,
gm=self.__gradient_machine__))
self.__gradient_machine__.finish()
def test(self, reader, feeding=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册