From dea52631792ca765b0007a9389bf1203a8b31aab Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 5 Jan 2018 17:57:42 +0800 Subject: [PATCH] update error clip --- python/paddle/v2/fluid/backward.py | 12 ++++++--- python/paddle/v2/fluid/clip.py | 38 +++++++++++++++-------------- python/paddle/v2/fluid/framework.py | 13 ++++++++++ 3 files changed, 42 insertions(+), 21 deletions(-) diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index b788a23eb60..b3f6887e4c1 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -190,8 +190,15 @@ def _append_backward_ops_(target, val(str): corresponding forward variable name callback(callable object): a callable object used to decorate new generated grad ops """ - if callback is not None and not hasattr(callback, '__call__'): + if callback is None: + + def empty_callback(block): + pass + + callback = empty_callback + elif not hasattr(callback, '__call__'): raise ValueError("'callback' must be a callable object.") + # grad_op_descs holds created grad_op, and will be appended to target_block grad_op_descs = [] program = block.program @@ -208,8 +215,6 @@ def _append_backward_ops_(target, # Getting op's corresponding grad_op grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, no_grad_dict[block.idx], grad_sub_block_list) - if callback is not None: - grad_op_desc = callback(grad_op_desc) grad_op_descs.extend(grad_op_desc) grad_to_var.update(op_grad_to_var) @@ -230,6 +235,7 @@ def _append_backward_ops_(target, for op_desc in grad_op_descs: new_op_desc = target_block.desc.append_op() new_op_desc.copy_from(op_desc) + callback(block=target_block, context=grad_to_var) def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index 89972b8346f..8fb27d69fab 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -6,18 +6,9 @@ __all__ = ['GradientClipByValue', 'append_gradient_clip_ops'] class BaseErrorClipAttr(object): - def create_clip_op_desc(self, grad_name): + def append_clip_op(self, block, grad_name): raise NotImplementedError() - def prepend_clip_op_desc(self, op_descs): - grad_names = set() - for op_desc in op_descs: - grad_names.update( - filter(lambda n: n.find(core.grad_var_suffix()) != -1, - op_desc.output_arg_names())) - for n in grad_names: - op_descs.append(self.create_clip_op_desc(grad_name=n)) - class ErrorClipByValue(BaseErrorClipAttr): def __init__(self, max, min=None): @@ -29,14 +20,25 @@ class ErrorClipByValue(BaseErrorClipAttr): self.max = max self.min = min - def create_clip_op_desc(self, grad_name): - desc = core.OpDesc() - desc.set_type("clip") - desc.set_input("X", grad_name) - desc.set_output("Out", grad_name) - desc.set_attr("min", self.min) - desc.set_attr("max", self.max) - return desc + def append_clip_op(self, block, grad_name): + block.append_op( + type="clip", + inputs={"X": grad_name}, + outputs={"Out": grad_name}, + attrs={"min": self.min, + "max": self.max}) + + +def error_clip_callback(block, context): + # the context is a grad_to_var map + grad_to_var = context + op_desc = block.desc.op(block.desc.op_size() - 1) + for grad_n in filter(lambda n: grad_to_var.has_key(n), + op_desc.output_arg_names()): + fwd_var = block.var_recursive(grad_to_var[grad_n]) + error_clip = getattr(fwd_var, "error_clip", None) + if error_clip is not None: + error_clip.append_clip_op(block, grad_n) class BaseGradientClipAttr(object): diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index b66a8bce5f4..4b01a2d0465 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -147,6 +147,7 @@ class Variable(object): dtype=None, lod_level=None, persistable=None, + error_clip=None, stop_gradient=False, **kwargs): self.block = block @@ -626,6 +627,17 @@ class Block(object): raise ValueError("var %s not in this block" % name) return v + def var_recursive(self, name): + if self.has_var(name): + return self.var(name) + else: + if self.idx == 0: + raise ValueError("var %s is not in block(%d) nor its parents." % + name, self.idx) + else: + parent_block = self.program.block(self.parent_idx) + return parent_block.var_recursive(name) + def all_parameters(self): return list(self.iter_parameters()) @@ -744,6 +756,7 @@ class Block(object): optimize_attr=p.optimize_attr, regularizer=p.regularizer, clip_attr=p.clip_attr, + error_clip=p.error_clip, name=v.name) self.vars[new_p.name] = new_p -- GitLab