diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index 59986c9f0ca8e4b793463db0e8c5da0489654ee9..9b3792ee9e3e4c6f319b3e2b13c4aa3a05cce8be 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -16,12 +16,13 @@ import regularizer from param_attr import ParamAttr from data_feeder import DataFeeder from core import LoDTensor, CPUPlace, GPUPlace +import clip Tensor = LoDTensor __all__ = framework.__all__ + executor.__all__ + [ 'io', 'initializer', 'layers', 'nets', 'optimizer', 'backward', 'regularizer', 'LoDTensor', 'CPUPlace', 'GPUPlace', 'Tensor', 'ParamAttr' - 'DataFeeder' + 'DataFeeder', 'clip' ] diff --git a/python/paddle/v2/fluid/clip.py b/python/paddle/v2/fluid/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..d7ec2fbe13fe6d9158345099b8668afc5c7d4571 --- /dev/null +++ b/python/paddle/v2/fluid/clip.py @@ -0,0 +1,61 @@ +import functools +import layers + +__all__ = ['GradientClipByValue', 'append_gradient_clip_ops'] + + +class BaseGradientClipAttr(object): + def process_context(self, context, p_g): + raise NotImplementedError() + + def create_operators(self, param, grad): + raise NotImplementedError() + + +class NullGradientClipAttr(BaseGradientClipAttr): + def process_context(self, context, p_g): + pass + + def create_operators(self, param, grad): + return param, grad + + +class GradientClipByValue(BaseGradientClipAttr): + def __init__(self, max, min=None): + max = float(max) + if min is None: + min = -max + else: + min = float(min) + self.max = max + self.min = min + + def process_context(self, context, p_g): + pass + + def create_operators(self, param, grad): + new_grad = layers.clip(x=grad, min=self.min, max=self.max) + return param, new_grad + + +def append_gradient_clip_ops(param_grad): + context = dict() + create_op_callbacks = [] + for p, g in param_grad: + clip_attr = getattr(p, 'clip_attr', NullGradientClipAttr()) + if clip_attr is None: + clip_attr = NullGradientClipAttr() + if not isinstance(clip_attr, BaseGradientClipAttr): + raise TypeError( + "clip attribute should be an instance of BaseGradientClippingAttr" + ) + + clip_attr.process_context(context=context, p_g=param_grad) + create_op_callbacks.append( + functools.partial( + clip_attr.create_operators, param=p, grad=g)) + + return [each_callback() for each_callback in create_op_callbacks] + + +ClipByValue = GradientClipByValue diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index bf0cd275b62ae2c4d7312592b8a730291c59a071..973672e6e469c7619ea5b4a166cce120655e4c6e 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -704,6 +704,7 @@ class Block(object): trainable=p.trainable, optimize_attr=p.optimize_attr, regularizer=p.regularizer, + clip_attr=p.clip_attr, name=v.name) self.vars[new_p.name] = new_p @@ -866,6 +867,8 @@ class Parameter(Variable): self.regularizer = kwargs.get('regularizer', None) + self.clip_attr = kwargs.get('clip_attr', None) + # program is a global instance. _main_program_ = Program() diff --git a/python/paddle/v2/fluid/optimizer.py b/python/paddle/v2/fluid/optimizer.py index 9f03eeea83e6d212da5fbe3d090d82028fa378ac..84fcbcdc2f2868a84bad5e145a934e33485b1fef 100644 --- a/python/paddle/v2/fluid/optimizer.py +++ b/python/paddle/v2/fluid/optimizer.py @@ -6,6 +6,7 @@ from framework import unique_name, program_guard from initializer import Constant from layer_helper import LayerHelper from regularizer import append_regularization_ops +from clip import append_gradient_clip_ops __all__ = ['SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad'] @@ -197,9 +198,13 @@ class Optimizer(object): `create_optimization_pass()` into one. """ params_grads = append_backward_ops(loss, parameter_list, no_grad_set) + + params_grads = append_gradient_clip_ops(params_grads) + # Add regularization if any params_grads = append_regularization_ops(params_grads, self.regularization) + optimize_ops = self.create_optimization_pass(params_grads, loss, startup_program) return optimize_ops diff --git a/python/paddle/v2/fluid/param_attr.py b/python/paddle/v2/fluid/param_attr.py index 7952a5ea51c00f72664443fb26faa455e89da7be..f6f320c788e7e08d44df8aff5ad3792b237e103a 100644 --- a/python/paddle/v2/fluid/param_attr.py +++ b/python/paddle/v2/fluid/param_attr.py @@ -1,6 +1,8 @@ from initializer import Initializer, Xavier, Constant from regularizer import WeightDecayRegularizer +__all__ = ['ParamAttr'] + class ParamAttr(object): def __init__(self, @@ -8,12 +10,14 @@ class ParamAttr(object): initializer=None, learning_rate=1.0, regularizer=None, - trainable=True): + trainable=True, + clip=None): self.name = name self.initializer = initializer self.learning_rate = learning_rate self.regularizer = regularizer self.trainable = trainable + self.clip = clip def set_default_initializer(self, initializer): if initializer is None: @@ -56,7 +60,8 @@ class ParamAttr(object): 'name': self.name, 'learning_rate': self.learning_rate, 'regularizer': self.regularizer, - 'trainable': self.trainable + 'trainable': self.trainable, + 'clip_attr': self.clip } if with_initializer: kwargs['initializer'] = self.initializer diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index d77f19660ebcd470837e8b4e63509683de4a7a82..fc073f6be8563a363c0f98b9235ae267fa68562d 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -11,7 +11,9 @@ regularizer = fluid.regularizer.L2Decay(0.0005 * BATCH_SIZE) hidden1 = fluid.layers.fc(input=image, size=128, act='relu', - param_attr=regularizer) + param_attr=fluid.ParamAttr( + regularizer=regularizer, + clip=fluid.clip.ClipByValue(10))) hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu',