From d34516fb662c2fe9727989dd2885de0b9ad9cf8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E6=AF=85?= Date: Mon, 11 Sep 2017 13:00:44 +0800 Subject: [PATCH] Get output when training (#3978) * get output when training * follow comments --- python/paddle/v2/event.py | 10 ++++++++-- python/paddle/v2/trainer.py | 9 +++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python/paddle/v2/event.py b/python/paddle/v2/event.py index 7589cc9917f..e66bf67d794 100644 --- a/python/paddle/v2/event.py +++ b/python/paddle/v2/event.py @@ -53,10 +53,13 @@ class BeginPass(object): class EndPass(WithMetric): """ Event On One Pass Training Complete. + To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')" + in your event_handler call back """ - def __init__(self, pass_id, evaluator): + def __init__(self, pass_id, evaluator, gm): self.pass_id = pass_id + self.gm = gm WithMetric.__init__(self, evaluator) @@ -73,10 +76,13 @@ class BeginIteration(object): class EndIteration(WithMetric): """ Event On One Batch Training Complete. + To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')" + in your event_handler call back """ - def __init__(self, pass_id, batch_id, cost, evaluator): + def __init__(self, pass_id, batch_id, cost, evaluator, gm): self.pass_id = pass_id self.batch_id = batch_id self.cost = cost + self.gm = gm WithMetric.__init__(self, evaluator) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 0654a301049..ca95ef13bd4 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -174,13 +174,18 @@ class SGD(object): pass_id=pass_id, batch_id=batch_id, cost=cost, - evaluator=batch_evaluator)) + evaluator=batch_evaluator, + gm=self.__gradient_machine__)) self.__parameter_updater__.finishBatch(cost) batch_evaluator.finish() self.__parameter_updater__.finishPass() pass_evaluator.finish() - event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) + event_handler( + v2_event.EndPass( + pass_id, + evaluator=pass_evaluator, + gm=self.__gradient_machine__)) self.__gradient_machine__.finish() def test(self, reader, feeding=None): -- GitLab