提交 134c5c4d 编写于 作者: F fengjiayi

Support callback

上级 e57a40b8
......@@ -188,7 +188,10 @@ def _append_backward_ops_(target,
grad_to_var(dict)(output argument):
key(str): grad variable name
val(str): corresponding forward variable name
callback(callable object): a callable object used to decorate new generated grad ops
"""
if callback is not None and not hasattr(callback, '__call__'):
raise ValueError("'callback' must be a callable object.")
# grad_op_descs holds created grad_op, and will be appended to target_block
grad_op_descs = []
program = block.program
......@@ -205,6 +208,8 @@ def _append_backward_ops_(target,
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list)
if callback is not None:
grad_op_desc = callback(grad_op_desc)
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册