From 19c554f9e4ef5c96e47f65efd44e2524417e38d7 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 19 Jan 2018 19:19:35 +0800 Subject: [PATCH] update --- python/paddle/v2/fluid/clip.py | 82 +++++++++---------- .../v2/fluid/tests/test_gradient_clip.py | 44 +++++----- 2 files changed, 59 insertions(+), 67 deletions(-) diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index d97cd9ecc9..fb0907c9f4 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -112,58 +112,52 @@ class GradientClipByNorm(BaseGradientClipAttr): class GradientClipByGlobalNorm(BaseGradientClipAttr): - global_norm_var = None - local_norm_var = None - clip_norm_var = None - scale_var = None - - @classmethod - def init(cls, clip_norm): - if not (isinstance(clip_norm, int) or isinstance(clip_norm, float)): - raise TypeError("The 'clip_norm' must be a value of int or float") - - cls.global_norm_var = layers.fill_constant( - shape=[1], dtype="float32", value=0.0) - cls.local_norm_var = layers.create_tensor(dtype="float32") - cls.clip_norm_var = layers.fill_constant( - shape=[1], dtype="float32", value=clip_norm) - - @classmethod - def check_init(cls): - if not (isinstance(cls.global_norm_var, framework.Variable) and - isinstance(cls.local_norm_var, framework.Variable) and - isinstance(cls.clip_norm_var, framework.Variable)): - raise ValueError( - "Class 'GradientClipByGlobalNorm' has not been properly initialized. \ - Please call GradientClipByGlobalNorm.init() first.") + def __init__(self, clip_norm, group_name="default_group"): + if not isinstance(group_name, basestring): + raise TypeError("'group_name' must be a basestring.") + + self.clip_norm = clip_norm + self.group_name = group_name def process_context(self, context, param, grad): - cls = self.__class__ - cls.check_init() + if self.group_name not in context: + context[self.group_name] = [] + context[self.group_name + "_clip_value"] = self.clip_norm + context[self.group_name + "_clip"] = layers.fill_constant( + shape=[1], dtype="float32", value=self.clip_norm) + else: + if not self.clip_norm == context[self.group_name + "_clip_value"]: + raise ValueError( + "All parameters' 'clip_norm' of a same group should be the same" + ) - cls.local_norm_var = layers.reduce_sum( - input=layers.pow(x=grad, factor=2.0)) - layers.sums( - input=[cls.local_norm_var, cls.global_norm_var], - out=[cls.global_norm_var]) + local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0)) + context[self.group_name].append(local_norm_var) - def create_operators(self, param, grad): - cls = self.__class__ - cls.check_init() + self.context = context - if cls.scale_var is None: - layers.sqrt(x=cls.global_norm_var, out=cls.global_norm_var) - cls.scale_var = layers.elementwise_div( - x=cls.clip_norm_var, + def create_operators(self, param, grad): + group_scale_name = self.group_name + "_scale" + if group_scale_name not in self.context: + group_norm_var = layers.sums(input=self.context[self.group_name]) + layers.sqrt(x=group_norm_var, out=group_norm_var) + clip_var = self.context[self.group_name + "_clip"] + group_scale_var = layers.elementwise_div( + x=clip_var, y=layers.elementwise_max( - x=cls.clip_norm_var, y=cls.global_norm_var)) - assert cls.scale_var.shape == (1L, ) + x=clip_var, y=group_norm_var)) + assert group_scale_var.shape == (1L, ) + self.context[group_scale_name] = group_scale_var - new_grad = layers.elementwise_mul(x=grad, y=cls.scale_var) + new_grad = layers.elementwise_mul( + x=grad, y=self.context[group_scale_name]) return param, new_grad -def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None): +def gradient_clip_by_global_norm(clip_norm, + param_list=None, + group_name="default_group", + program=None): if program is None: program = framework.default_main_program() if param_list is None: @@ -175,9 +169,9 @@ def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None): "'param_list' should be a list of Parameter or basestring(parameter's name)." ) - GradientClipByGlobalNorm.init(clip_norm) for param in param_list: - param.gradient_clip_attr = GradientClipByGlobalNorm() + param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm, + group_name) def append_gradient_clip_ops(param_grad): diff --git a/python/paddle/v2/fluid/tests/test_gradient_clip.py b/python/paddle/v2/fluid/tests/test_gradient_clip.py index 4fb7f0b2cb..75c5fd9892 100644 --- a/python/paddle/v2/fluid/tests/test_gradient_clip.py +++ b/python/paddle/v2/fluid/tests/test_gradient_clip.py @@ -15,21 +15,10 @@ import numpy as np import paddle.v2 as paddle import paddle.v2.fluid as fluid - -def _get_global_param_norm_(params_grads): - res = fluid.layers.fill_constant(shape=[1], dtype="float32", value=0.0) - for _, grad in params_grads: - norm_var = fluid.layers.reduce_sum( - input=fluid.layers.pow(x=grad, factor=2.0)) - fluid.layers.sums(input=[norm_var, res], out=[res]) - fluid.layers.sqrt(x=res, out=res) - return res - - BATCH_SIZE = 128 -CLIP = 0.5 -prog = fluid.framework.Program() +CLIP = 1 +prog = fluid.framework.Program() with fluid.program_guard(main_program=prog): image = fluid.layers.data(name='x', shape=[784], dtype='float32') @@ -49,13 +38,12 @@ avg_cost_clip = prog_clip.block(0).var(avg_cost.name) p_g = fluid.backward.append_backward(loss=avg_cost) p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip) -with fluid.program_guard(main_program=prog): - gloabl_norm = _get_global_param_norm_(p_g) - with fluid.program_guard(main_program=prog_clip): fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP) p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip) - gloabl_norm_clip = _get_global_param_norm_(p_g_clip) + +grad_list = [elem[1] for elem in p_g] +grad_clip_list = [elem[1] for elem in p_g_clip] train_reader = paddle.batch( paddle.reader.shuffle( @@ -72,11 +60,21 @@ for data in train_reader(): count += 1 if count > 5: break - out, = exe.run(prog, feed=feeder.feed(data), fetch_list=[gloabl_norm]) - out_clip, = exe.run(prog_clip, - feed=feeder.feed(data), - fetch_list=[gloabl_norm_clip]) - - if not np.allclose(out_clip, np.minimum(out, np.array([CLIP]))): + out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list) + out_clip = exe.run(prog_clip, + feed=feeder.feed(data), + fetch_list=grad_clip_list) + global_norm = 0 + for v in out[1:]: + global_norm += np.sum(np.power(v, 2)) + global_norm = np.sqrt(global_norm) + + global_norm_clip = 0 + for v in out_clip[1:]: + global_norm_clip += np.sum(np.power(v, 2)) + global_norm_clip = np.sqrt(global_norm_clip) + + if not np.isclose( + a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3): exit(1) exit(0) -- GitLab