提交 3f874143 编写于 作者: 武毅 提交者: GitHub

fix grad debug event (#4536)

上级 c3b46d16
...@@ -10,7 +10,8 @@ There are: ...@@ -10,7 +10,8 @@ There are:
* EndPass * EndPass
""" """
__all__ = [ __all__ = [
'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult' 'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult',
'EndForwardBackward'
] ]
...@@ -73,6 +74,17 @@ class BeginIteration(object): ...@@ -73,6 +74,17 @@ class BeginIteration(object):
self.batch_id = batch_id self.batch_id = batch_id
class EndForwardBackward(object):
"""
Event On One Batch ForwardBackward Complete.
"""
def __init__(self, pass_id, batch_id, gm):
self.pass_id = pass_id
self.batch_id = batch_id
self.gm = gm
class EndIteration(WithMetric): class EndIteration(WithMetric):
""" """
Event On One Batch Training Complete. Event On One Batch Training Complete.
......
...@@ -164,11 +164,18 @@ class SGD(object): ...@@ -164,11 +164,18 @@ class SGD(object):
pass_type) pass_type)
self.__gradient_machine__.eval(pass_evaluator) self.__gradient_machine__.eval(pass_evaluator)
self.__gradient_machine__.eval(batch_evaluator) self.__gradient_machine__.eval(batch_evaluator)
event_handler(
v2_event.EndForwardBackward(
pass_id=pass_id,
batch_id=batch_id,
gm=self.__gradient_machine__))
for each_param in self.__gradient_machine__.getNonStaticParameters( for each_param in self.__gradient_machine__.getNonStaticParameters(
): ):
self.__parameter_updater__.update(each_param) self.__parameter_updater__.update(each_param)
cost_sum = out_args.sum() cost_sum = out_args.sum()
cost = cost_sum / len(data_batch) cost = cost_sum / len(data_batch)
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
event_handler( event_handler(
v2_event.EndIteration( v2_event.EndIteration(
pass_id=pass_id, pass_id=pass_id,
...@@ -176,8 +183,6 @@ class SGD(object): ...@@ -176,8 +183,6 @@ class SGD(object):
cost=cost, cost=cost,
evaluator=batch_evaluator, evaluator=batch_evaluator,
gm=self.__gradient_machine__)) gm=self.__gradient_machine__))
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
self.__parameter_updater__.finishPass() self.__parameter_updater__.finishPass()
pass_evaluator.finish() pass_evaluator.finish()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册