From 134c5c4db7b0d8e6188821cef41c028a790f26e9 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 4 Jan 2018 15:14:17 +0800 Subject: [PATCH] Support callback --- python/paddle/v2/fluid/backward.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index ac60bf54360..b788a23eb60 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -188,7 +188,10 @@ def _append_backward_ops_(target, grad_to_var(dict)(output argument): key(str): grad variable name val(str): corresponding forward variable name + callback(callable object): a callable object used to decorate new generated grad ops """ + if callback is not None and not hasattr(callback, '__call__'): + raise ValueError("'callback' must be a callable object.") # grad_op_descs holds created grad_op, and will be appended to target_block grad_op_descs = [] program = block.program @@ -205,6 +208,8 @@ def _append_backward_ops_(target, # Getting op's corresponding grad_op grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, no_grad_dict[block.idx], grad_sub_block_list) + if callback is not None: + grad_op_desc = callback(grad_op_desc) grad_op_descs.extend(grad_op_desc) grad_to_var.update(op_grad_to_var) -- GitLab