提交 3f874143 编写于 作者: 武毅 提交者: GitHub

fix grad debug event (#4536)

上级 c3b46d16
......@@ -10,7 +10,8 @@ There are:
* EndPass
"""
__all__ = [
'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult'
'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult',
'EndForwardBackward'
]
......@@ -73,6 +74,17 @@ class BeginIteration(object):
self.batch_id = batch_id
class EndForwardBackward(object):
"""
Event On One Batch ForwardBackward Complete.
"""
def __init__(self, pass_id, batch_id, gm):
self.pass_id = pass_id
self.batch_id = batch_id
self.gm = gm
class EndIteration(WithMetric):
"""
Event On One Batch Training Complete.
......
......@@ -164,11 +164,18 @@ class SGD(object):
pass_type)
self.__gradient_machine__.eval(pass_evaluator)
self.__gradient_machine__.eval(batch_evaluator)
event_handler(
v2_event.EndForwardBackward(
pass_id=pass_id,
batch_id=batch_id,
gm=self.__gradient_machine__))
for each_param in self.__gradient_machine__.getNonStaticParameters(
):
self.__parameter_updater__.update(each_param)
cost_sum = out_args.sum()
cost = cost_sum / len(data_batch)
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
event_handler(
v2_event.EndIteration(
pass_id=pass_id,
......@@ -176,8 +183,6 @@ class SGD(object):
cost=cost,
evaluator=batch_evaluator,
gm=self.__gradient_machine__))
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
self.__parameter_updater__.finishPass()
pass_evaluator.finish()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册