diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py index 5241f4843c0df0314eba6168f8f36335c3f19d0a..386df9823de9119287abf87569eab0b283ecc802 100644 --- a/python/paddle/v2/fluid/clip.py +++ b/python/paddle/v2/fluid/clip.py @@ -14,6 +14,7 @@ import functools import layers +import framework from . import core __all__ = [ @@ -66,7 +67,7 @@ def error_clip_callback(block, context): class BaseGradientClipAttr(object): - def process_context(self, context, p_g): + def process_context(self, context, param, grad): raise NotImplementedError() def create_operators(self, param, grad): @@ -74,7 +75,7 @@ class BaseGradientClipAttr(object): class NullGradientClipAttr(BaseGradientClipAttr): - def process_context(self, context, p_g): + def process_context(self, context, param, grad): pass def create_operators(self, param, grad): @@ -91,7 +92,7 @@ class GradientClipByValue(BaseGradientClipAttr): self.max = max self.min = min - def process_context(self, context, p_g): + def process_context(self, context, param, grad): pass def create_operators(self, param, grad): @@ -99,19 +100,93 @@ class GradientClipByValue(BaseGradientClipAttr): return param, new_grad +class GradientClipByNorm(BaseGradientClipAttr): + def __init__(self, clip_norm): + self.clip_norm = clip_norm + + def process_context(self, context, param, grad): + pass + + def create_operators(self, param, grad): + new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm) + return param, new_grad + + +class GradientClipByGlobalNorm(BaseGradientClipAttr): + def __init__(self, clip_norm, group_name="default_group"): + if not isinstance(group_name, basestring): + raise TypeError("'group_name' must be a basestring.") + + self.clip_norm = clip_norm + self.group_name = group_name + + def process_context(self, context, param, grad): + if self.group_name not in context: + context[self.group_name] = [] + context[self.group_name + "_clip_value"] = self.clip_norm + context[self.group_name + "_clip"] = layers.fill_constant( + shape=[1], dtype="float32", value=self.clip_norm) + else: + if not self.clip_norm == context[self.group_name + "_clip_value"]: + raise ValueError( + "All parameters' 'clip_norm' of a same group should be the same" + ) + + local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0)) + context[self.group_name].append(local_norm_var) + + self.context = context + + def create_operators(self, param, grad): + group_scale_name = self.group_name + "_scale" + if group_scale_name not in self.context: + group_norm_var = layers.sums(input=self.context[self.group_name]) + layers.sqrt(x=group_norm_var, out=group_norm_var) + clip_var = self.context[self.group_name + "_clip"] + group_scale_var = layers.elementwise_div( + x=clip_var, + y=layers.elementwise_max( + x=clip_var, y=group_norm_var)) + assert group_scale_var.shape == (1L, ) + self.context[group_scale_name] = group_scale_var + + new_grad = layers.elementwise_mul( + x=grad, y=self.context[group_scale_name]) + return param, new_grad + + +def gradient_clip_by_global_norm(clip_norm, + param_list=None, + group_name="default_group", + program=None): + if program is None: + program = framework.default_main_program() + if param_list is None: + param_list = program.block(0).all_parameters() + if all(isinstance(elem, basestring) for elem in param_list): + param_list = [program.block(0).var(elem) for elem in param_list] + if not all(isinstance(elem, framework.Parameter) for elem in param_list): + raise TypeError( + "'param_list' should be a list of Parameter or basestring(parameter's name)." + ) + + for param in param_list: + param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm, + group_name) + + def append_gradient_clip_ops(param_grad): context = dict() create_op_callbacks = [] for p, g in param_grad: - clip_attr = getattr(p, 'clip_attr', NullGradientClipAttr()) + clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr()) if clip_attr is None: clip_attr = NullGradientClipAttr() if not isinstance(clip_attr, BaseGradientClipAttr): raise TypeError( - "clip attribute should be an instance of BaseGradientClippingAttr" - ) + "clip attribute should be an instance of BaseGradientClipAttr") - clip_attr.process_context(context=context, p_g=param_grad) + clip_attr.process_context(context=context, param=p, grad=g) create_op_callbacks.append( functools.partial( clip_attr.create_operators, param=p, grad=g)) diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index f87666545893884b2664b52be2d1a11b5aa4a244..4d8343e7de9526d527ebe93f334b59108d5ace8e 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -780,7 +780,7 @@ class Block(object): trainable=p.trainable, optimize_attr=p.optimize_attr, regularizer=p.regularizer, - clip_attr=p.clip_attr, + gradient_clip_attr=p.gradient_clip_attr, error_clip=p.error_clip, name=v.name) self.vars[new_p.name] = new_p @@ -948,7 +948,7 @@ class Parameter(Variable): self.regularizer = kwargs.get('regularizer', None) - self.clip_attr = kwargs.get('clip_attr', None) + self.gradient_clip_attr = kwargs.get('gradient_clip_attr', None) # program is a global instance. diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index 7716052a5cae0d8883dd9996f34d39e97f4399ea..d29607616266973b3e59dbae660724d0c3874def 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -46,20 +46,10 @@ __activations__ = [ ] __all__ = [ - 'mean', - 'mul', - 'reshape', - 'scale', - 'transpose', - 'sigmoid_cross_entropy_with_logits', - 'elementwise_add', - 'elementwise_div', - 'elementwise_sub', - 'elementwise_mul', - 'elementwise_max', - 'elementwise_min', - 'clip', - 'sequence_softmax', + 'mean', 'mul', 'reshape', 'scale', 'transpose', + 'sigmoid_cross_entropy_with_logits', 'elementwise_add', 'elementwise_div', + 'elementwise_sub', 'elementwise_mul', 'elementwise_max', 'elementwise_min', + 'clip', 'clip_by_norm', 'sequence_softmax' ] + __activations__ for _OP in set(__all__): diff --git a/python/paddle/v2/fluid/param_attr.py b/python/paddle/v2/fluid/param_attr.py index 26e9111f6f31019ab14780cc4bea01e617561fb7..dcca8b6c547d10864ff4cd0af1c217d89e3b522f 100644 --- a/python/paddle/v2/fluid/param_attr.py +++ b/python/paddle/v2/fluid/param_attr.py @@ -25,13 +25,13 @@ class ParamAttr(object): learning_rate=1.0, regularizer=None, trainable=True, - clip=None): + gradient_clip=None): self.name = name self.initializer = initializer self.learning_rate = learning_rate self.regularizer = regularizer self.trainable = trainable - self.clip = clip + self.gradient_clip = gradient_clip def set_default_initializer(self, initializer): if initializer is None: @@ -77,7 +77,7 @@ class ParamAttr(object): }, 'regularizer': self.regularizer, 'trainable': self.trainable, - 'clip_attr': self.clip + 'gradient_clip_attr': self.gradient_clip } if with_initializer: kwargs['initializer'] = self.initializer diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index be22e97054a16d69cf9e2d1e88629497e519c778..8776a65bf804e93dfeb295ecca34fac0840b0a90 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -27,7 +27,7 @@ hidden1 = fluid.layers.fc(input=image, act='relu', param_attr=fluid.ParamAttr( regularizer=regularizer, - clip=fluid.clip.ClipByValue(10))) + gradient_clip=fluid.clip.ClipByValue(10))) hidden2 = fluid.layers.fc(input=hidden1, size=64, diff --git a/python/paddle/v2/fluid/tests/test_clip.py b/python/paddle/v2/fluid/tests/test_error_clip.py similarity index 100% rename from python/paddle/v2/fluid/tests/test_clip.py rename to python/paddle/v2/fluid/tests/test_error_clip.py diff --git a/python/paddle/v2/fluid/tests/test_gradient_clip.py b/python/paddle/v2/fluid/tests/test_gradient_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..4e6e6a1ef6961d8f087dfc1ac5a4c4a8ad90032e --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_gradient_clip.py @@ -0,0 +1,81 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.v2 as paddle +import paddle.v2.fluid as fluid + +BATCH_SIZE = 128 +CLIP = 1 + +prog = fluid.framework.Program() +with fluid.program_guard(main_program=prog): + image = fluid.layers.data(name='x', shape=[784], dtype='float32') + + hidden1 = fluid.layers.fc(input=image, size=128, act='relu') + hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu') + predict = fluid.layers.fc(input=hidden2, size=10, act='softmax') + + label = fluid.layers.data(name='y', shape=[1], dtype='int64') + + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + +prog_clip = prog.clone() + +avg_cost_clip = prog_clip.block(0).var(avg_cost.name) + +p_g = fluid.backward.append_backward(loss=avg_cost) +p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip) + +with fluid.program_guard(main_program=prog_clip): + fluid.clip.gradient_clip_by_global_norm(clip_norm=CLIP) + p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip) + +grad_list = [elem[1] for elem in p_g] +grad_clip_list = [elem[1] for elem in p_g_clip] + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=8192), + batch_size=BATCH_SIZE) + +place = fluid.CPUPlace() +exe = fluid.Executor(place) +feeder = fluid.DataFeeder(feed_list=[image, label], place=place) +exe.run(fluid.default_startup_program()) + +count = 0 +for data in train_reader(): + count += 1 + if count > 5: + break + out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list) + out_clip = exe.run(prog_clip, + feed=feeder.feed(data), + fetch_list=grad_clip_list) + global_norm = 0 + for v in out[1:]: + global_norm += np.sum(np.power(v, 2)) + global_norm = np.sqrt(global_norm) + + global_norm_clip = 0 + for v in out_clip[1:]: + global_norm_clip += np.sum(np.power(v, 2)) + global_norm_clip = np.sqrt(global_norm_clip) + + if not np.isclose( + a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3): + exit(1) +exit(0)