提交 33cb12a8 编写于 作者: F fengjiayi

update error clip

上级 dea52631
...@@ -278,7 +278,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): ...@@ -278,7 +278,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
_infer_var_data_type_(arg, block) _infer_var_data_type_(arg, block)
def append_backward(loss, parameter_list=None, no_grad_set=None): def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
""" """
Append backward part to main_program Append backward part to main_program
...@@ -322,7 +322,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None): ...@@ -322,7 +322,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None):
grad_to_var = dict() grad_to_var = dict()
_append_backward_ops_(loss, root_block, root_block, no_grad_dict, _append_backward_ops_(loss, root_block, root_block, no_grad_dict,
grad_to_var) grad_to_var, callback)
_append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map) _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
program.current_block_idx = current_block_idx program.current_block_idx = current_block_idx
......
...@@ -2,7 +2,9 @@ import functools ...@@ -2,7 +2,9 @@ import functools
import layers import layers
from . import core from . import core
__all__ = ['GradientClipByValue', 'append_gradient_clip_ops'] __all__ = [
'GradientClipByValue', 'append_gradient_clip_ops', 'error_clip_callback'
]
class BaseErrorClipAttr(object): class BaseErrorClipAttr(object):
......
...@@ -151,6 +151,7 @@ class Variable(object): ...@@ -151,6 +151,7 @@ class Variable(object):
stop_gradient=False, stop_gradient=False,
**kwargs): **kwargs):
self.block = block self.block = block
self.error_clip = error_clip
if name is None: if name is None:
name = Variable._unique_var_name_() name = Variable._unique_var_name_()
......
...@@ -6,7 +6,7 @@ from framework import unique_name, program_guard ...@@ -6,7 +6,7 @@ from framework import unique_name, program_guard
from initializer import Constant from initializer import Constant
from layer_helper import LayerHelper from layer_helper import LayerHelper
from regularizer import append_regularization_ops from regularizer import append_regularization_ops
from clip import append_gradient_clip_ops from clip import append_gradient_clip_ops, error_clip_callback
__all__ = ['SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad'] __all__ = ['SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad']
...@@ -197,7 +197,8 @@ class Optimizer(object): ...@@ -197,7 +197,8 @@ class Optimizer(object):
This method combines interface `append_backward()` and This method combines interface `append_backward()` and
`create_optimization_pass()` into one. `create_optimization_pass()` into one.
""" """
params_grads = append_backward(loss, parameter_list, no_grad_set) params_grads = append_backward(loss, parameter_list, no_grad_set,
error_clip_callback)
params_grads = append_gradient_clip_ops(params_grads) params_grads = append_gradient_clip_ops(params_grads)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册