提交 33cb12a8 编写于 作者: F fengjiayi

update error clip

上级 dea52631
......@@ -278,7 +278,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
_infer_var_data_type_(arg, block)
def append_backward(loss, parameter_list=None, no_grad_set=None):
def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
"""
Append backward part to main_program
......@@ -322,7 +322,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
grad_to_var = dict()
_append_backward_ops_(loss, root_block, root_block, no_grad_dict,
grad_to_var)
grad_to_var, callback)
_append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
program.current_block_idx = current_block_idx
......
......@@ -2,7 +2,9 @@ import functools
import layers
from . import core
__all__ = ['GradientClipByValue', 'append_gradient_clip_ops']
__all__ = [
'GradientClipByValue', 'append_gradient_clip_ops', 'error_clip_callback'
]
class BaseErrorClipAttr(object):
......
......@@ -151,6 +151,7 @@ class Variable(object):
stop_gradient=False,
**kwargs):
self.block = block
self.error_clip = error_clip
if name is None:
name = Variable._unique_var_name_()
......
......@@ -6,7 +6,7 @@ from framework import unique_name, program_guard
from initializer import Constant
from layer_helper import LayerHelper
from regularizer import append_regularization_ops
from clip import append_gradient_clip_ops
from clip import append_gradient_clip_ops, error_clip_callback
__all__ = ['SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad']
......@@ -197,7 +197,8 @@ class Optimizer(object):
This method combines interface `append_backward()` and
`create_optimization_pass()` into one.
"""
params_grads = append_backward(loss, parameter_list, no_grad_set)
params_grads = append_backward(loss, parameter_list, no_grad_set,
error_clip_callback)
params_grads = append_gradient_clip_ops(params_grads)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册