提交 dea52631 编写于 作者: F fengjiayi

update error clip

上级 4ead8e1b
......@@ -190,8 +190,15 @@ def _append_backward_ops_(target,
val(str): corresponding forward variable name
callback(callable object): a callable object used to decorate new generated grad ops
"""
if callback is not None and not hasattr(callback, '__call__'):
if callback is None:
def empty_callback(block):
pass
callback = empty_callback
elif not hasattr(callback, '__call__'):
raise ValueError("'callback' must be a callable object.")
# grad_op_descs holds created grad_op, and will be appended to target_block
grad_op_descs = []
program = block.program
......@@ -208,8 +215,6 @@ def _append_backward_ops_(target,
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list)
if callback is not None:
grad_op_desc = callback(grad_op_desc)
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
......@@ -230,6 +235,7 @@ def _append_backward_ops_(target,
for op_desc in grad_op_descs:
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc)
callback(block=target_block, context=grad_to_var)
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
......
......@@ -6,18 +6,9 @@ __all__ = ['GradientClipByValue', 'append_gradient_clip_ops']
class BaseErrorClipAttr(object):
def create_clip_op_desc(self, grad_name):
def append_clip_op(self, block, grad_name):
raise NotImplementedError()
def prepend_clip_op_desc(self, op_descs):
grad_names = set()
for op_desc in op_descs:
grad_names.update(
filter(lambda n: n.find(core.grad_var_suffix()) != -1,
op_desc.output_arg_names()))
for n in grad_names:
op_descs.append(self.create_clip_op_desc(grad_name=n))
class ErrorClipByValue(BaseErrorClipAttr):
def __init__(self, max, min=None):
......@@ -29,14 +20,25 @@ class ErrorClipByValue(BaseErrorClipAttr):
self.max = max
self.min = min
def create_clip_op_desc(self, grad_name):
desc = core.OpDesc()
desc.set_type("clip")
desc.set_input("X", grad_name)
desc.set_output("Out", grad_name)
desc.set_attr("min", self.min)
desc.set_attr("max", self.max)
return desc
def append_clip_op(self, block, grad_name):
block.append_op(
type="clip",
inputs={"X": grad_name},
outputs={"Out": grad_name},
attrs={"min": self.min,
"max": self.max})
def error_clip_callback(block, context):
# the context is a grad_to_var map
grad_to_var = context
op_desc = block.desc.op(block.desc.op_size() - 1)
for grad_n in filter(lambda n: grad_to_var.has_key(n),
op_desc.output_arg_names()):
fwd_var = block.var_recursive(grad_to_var[grad_n])
error_clip = getattr(fwd_var, "error_clip", None)
if error_clip is not None:
error_clip.append_clip_op(block, grad_n)
class BaseGradientClipAttr(object):
......
......@@ -147,6 +147,7 @@ class Variable(object):
dtype=None,
lod_level=None,
persistable=None,
error_clip=None,
stop_gradient=False,
**kwargs):
self.block = block
......@@ -626,6 +627,17 @@ class Block(object):
raise ValueError("var %s not in this block" % name)
return v
def var_recursive(self, name):
if self.has_var(name):
return self.var(name)
else:
if self.idx == 0:
raise ValueError("var %s is not in block(%d) nor its parents." %
name, self.idx)
else:
parent_block = self.program.block(self.parent_idx)
return parent_block.var_recursive(name)
def all_parameters(self):
return list(self.iter_parameters())
......@@ -744,6 +756,7 @@ class Block(object):
optimize_attr=p.optimize_attr,
regularizer=p.regularizer,
clip_attr=p.clip_attr,
error_clip=p.error_clip,
name=v.name)
self.vars[new_p.name] = new_p
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册