From 1dac173b518faeb8f31c321a61fa287b8de4246e Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Wed, 17 Jan 2018 20:15:03 +0800 Subject: [PATCH] add API for clip_by_global_norm --- python/paddle/v2/fluid/clip.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index f7917fc142..d1e6987e01 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -13,7 +13,7 @@ # limitations under the License. import functools import layers -from framework import Variable +import framework from . import core __all__ = [ @@ -128,8 +128,8 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): @classmethod def check_init(cls): - if not (isinstance(cls.global_norm_var, Variable) and - isinstance(cls.clip_norm_var, Variable)): + if not (isinstance(cls.global_norm_var, framework.Variable) and + isinstance(cls.clip_norm_var, framework.Variable)): raise ValueError( "Class 'GradientClipByGlobalNorm' has not been properly initialized. \ Please call GradientClipByGlobalNorm.init() first.") @@ -158,6 +158,23 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): return param, new_grad +def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None): + if program is None: + program = framework.default_main_program() + if param_list is None: + param_list = program.block(0).all_parameters() + if all(isinstance(elem, basestring) for elem in param_list): + param_list = [program.block(0).var(elem) for elem in param_list] + if not all(isinstance(elem, framework.Parameter) for elem in param_list): + raise TypeError( + "'param_list' should be a list of Parameter or basestring(parameter's name)." + ) + + GradientClipByGlobalNorm.init(clip_norm) + for param in param_list: + param.gradient_clip_attr = GradientClipByGlobalNorm() + + def append_gradient_clip_ops(param_grad): context = dict() create_op_callbacks = [] -- GitLab