提交 dea52631 编写于 作者: F fengjiayi

update error clip

上级 4ead8e1b
...@@ -190,8 +190,15 @@ def _append_backward_ops_(target, ...@@ -190,8 +190,15 @@ def _append_backward_ops_(target,
val(str): corresponding forward variable name val(str): corresponding forward variable name
callback(callable object): a callable object used to decorate new generated grad ops callback(callable object): a callable object used to decorate new generated grad ops
""" """
if callback is not None and not hasattr(callback, '__call__'): if callback is None:
def empty_callback(block):
pass
callback = empty_callback
elif not hasattr(callback, '__call__'):
raise ValueError("'callback' must be a callable object.") raise ValueError("'callback' must be a callable object.")
# grad_op_descs holds created grad_op, and will be appended to target_block # grad_op_descs holds created grad_op, and will be appended to target_block
grad_op_descs = [] grad_op_descs = []
program = block.program program = block.program
...@@ -208,8 +215,6 @@ def _append_backward_ops_(target, ...@@ -208,8 +215,6 @@ def _append_backward_ops_(target,
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list) op.desc, no_grad_dict[block.idx], grad_sub_block_list)
if callback is not None:
grad_op_desc = callback(grad_op_desc)
grad_op_descs.extend(grad_op_desc) grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
...@@ -230,6 +235,7 @@ def _append_backward_ops_(target, ...@@ -230,6 +235,7 @@ def _append_backward_ops_(target,
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc) new_op_desc.copy_from(op_desc)
callback(block=target_block, context=grad_to_var)
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
......
...@@ -6,18 +6,9 @@ __all__ = ['GradientClipByValue', 'append_gradient_clip_ops'] ...@@ -6,18 +6,9 @@ __all__ = ['GradientClipByValue', 'append_gradient_clip_ops']
class BaseErrorClipAttr(object): class BaseErrorClipAttr(object):
def create_clip_op_desc(self, grad_name): def append_clip_op(self, block, grad_name):
raise NotImplementedError() raise NotImplementedError()
def prepend_clip_op_desc(self, op_descs):
grad_names = set()
for op_desc in op_descs:
grad_names.update(
filter(lambda n: n.find(core.grad_var_suffix()) != -1,
op_desc.output_arg_names()))
for n in grad_names:
op_descs.append(self.create_clip_op_desc(grad_name=n))
class ErrorClipByValue(BaseErrorClipAttr): class ErrorClipByValue(BaseErrorClipAttr):
def __init__(self, max, min=None): def __init__(self, max, min=None):
...@@ -29,14 +20,25 @@ class ErrorClipByValue(BaseErrorClipAttr): ...@@ -29,14 +20,25 @@ class ErrorClipByValue(BaseErrorClipAttr):
self.max = max self.max = max
self.min = min self.min = min
def create_clip_op_desc(self, grad_name): def append_clip_op(self, block, grad_name):
desc = core.OpDesc() block.append_op(
desc.set_type("clip") type="clip",
desc.set_input("X", grad_name) inputs={"X": grad_name},
desc.set_output("Out", grad_name) outputs={"Out": grad_name},
desc.set_attr("min", self.min) attrs={"min": self.min,
desc.set_attr("max", self.max) "max": self.max})
return desc
def error_clip_callback(block, context):
# the context is a grad_to_var map
grad_to_var = context
op_desc = block.desc.op(block.desc.op_size() - 1)
for grad_n in filter(lambda n: grad_to_var.has_key(n),
op_desc.output_arg_names()):
fwd_var = block.var_recursive(grad_to_var[grad_n])
error_clip = getattr(fwd_var, "error_clip", None)
if error_clip is not None:
error_clip.append_clip_op(block, grad_n)
class BaseGradientClipAttr(object): class BaseGradientClipAttr(object):
......
...@@ -147,6 +147,7 @@ class Variable(object): ...@@ -147,6 +147,7 @@ class Variable(object):
dtype=None, dtype=None,
lod_level=None, lod_level=None,
persistable=None, persistable=None,
error_clip=None,
stop_gradient=False, stop_gradient=False,
**kwargs): **kwargs):
self.block = block self.block = block
...@@ -626,6 +627,17 @@ class Block(object): ...@@ -626,6 +627,17 @@ class Block(object):
raise ValueError("var %s not in this block" % name) raise ValueError("var %s not in this block" % name)
return v return v
def var_recursive(self, name):
if self.has_var(name):
return self.var(name)
else:
if self.idx == 0:
raise ValueError("var %s is not in block(%d) nor its parents." %
name, self.idx)
else:
parent_block = self.program.block(self.parent_idx)
return parent_block.var_recursive(name)
def all_parameters(self): def all_parameters(self):
return list(self.iter_parameters()) return list(self.iter_parameters())
...@@ -744,6 +756,7 @@ class Block(object): ...@@ -744,6 +756,7 @@ class Block(object):
optimize_attr=p.optimize_attr, optimize_attr=p.optimize_attr,
regularizer=p.regularizer, regularizer=p.regularizer,
clip_attr=p.clip_attr, clip_attr=p.clip_attr,
error_clip=p.error_clip,
name=v.name) name=v.name)
self.vars[new_p.name] = new_p self.vars[new_p.name] = new_p
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册