clip.py 6.7 KB
Newer Older
D
dzhwinter 已提交
1 2
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
F
fengjiayi 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
F
fengjiayi 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
Yu Yang 已提交
14 15
import functools
import layers
F
fengjiayi 已提交
16
import framework
F
fengjiayi 已提交
17
from . import core
Y
Yu Yang 已提交
18

F
fengjiayi 已提交
19
__all__ = [
20
    'GradientClipByValue',
21
    'ErrorClipByValue',
22 23
    'append_gradient_clip_ops',
    'error_clip_callback',
F
fengjiayi 已提交
24
]
Y
Yu Yang 已提交
25 26


F
fengjiayi 已提交
27
class BaseErrorClipAttr(object):
F
fengjiayi 已提交
28
    def append_clip_op(self, block, grad_name):
F
fengjiayi 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41
        raise NotImplementedError()


class ErrorClipByValue(BaseErrorClipAttr):
    def __init__(self, max, min=None):
        max = float(max)
        if min is None:
            min = -max
        else:
            min = float(min)
        self.max = max
        self.min = min

F
fengjiayi 已提交
42
    def append_clip_op(self, block, grad_name):
43 44 45 46 47 48
        clip_op_desc = block.desc.append_op()
        clip_op_desc.set_type("clip")
        clip_op_desc.set_input("X", [grad_name])
        clip_op_desc.set_output("Out", [grad_name])
        clip_op_desc.set_attr("min", self.min)
        clip_op_desc.set_attr("max", self.max)
F
fengjiayi 已提交
49 50 51 52 53 54 55 56 57 58


def error_clip_callback(block, context):
    # the context is a grad_to_var map
    grad_to_var = context
    op_desc = block.desc.op(block.desc.op_size() - 1)
    for grad_n in filter(lambda n: grad_to_var.has_key(n),
                         op_desc.output_arg_names()):
        fwd_var = block.var_recursive(grad_to_var[grad_n])
        error_clip = getattr(fwd_var, "error_clip", None)
F
fengjiayi 已提交
59 60 61 62 63
        if not (error_clip is None or isinstance(error_clip,
                                                 BaseErrorClipAttr)):
            raise TypeError(
                "Variable's error_clip should be an instance of BaseErrorClipAttr or None."
            )
F
fengjiayi 已提交
64 65
        if error_clip is not None:
            error_clip.append_clip_op(block, grad_n)
F
fengjiayi 已提交
66 67


Y
Yu Yang 已提交
68
class BaseGradientClipAttr(object):
F
fengjiayi 已提交
69
    def process_context(self, context, param, grad):
Y
Yu Yang 已提交
70 71 72 73 74 75 76
        raise NotImplementedError()

    def create_operators(self, param, grad):
        raise NotImplementedError()


class NullGradientClipAttr(BaseGradientClipAttr):
F
fengjiayi 已提交
77
    def process_context(self, context, param, grad):
Y
Yu Yang 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
        pass

    def create_operators(self, param, grad):
        return param, grad


class GradientClipByValue(BaseGradientClipAttr):
    def __init__(self, max, min=None):
        max = float(max)
        if min is None:
            min = -max
        else:
            min = float(min)
        self.max = max
        self.min = min

F
fengjiayi 已提交
94
    def process_context(self, context, param, grad):
Y
Yu Yang 已提交
95 96 97 98 99 100 101
        pass

    def create_operators(self, param, grad):
        new_grad = layers.clip(x=grad, min=self.min, max=self.max)
        return param, new_grad


F
fengjiayi 已提交
102 103 104 105
class GradientClipByNorm(BaseGradientClipAttr):
    def __init__(self, clip_norm):
        self.clip_norm = clip_norm

F
fengjiayi 已提交
106
    def process_context(self, context, param, grad):
F
fengjiayi 已提交
107 108 109 110 111 112 113
        pass

    def create_operators(self, param, grad):
        new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
        return param, new_grad


F
fengjiayi 已提交
114 115
class GradientClipByGlobalNorm(BaseGradientClipAttr):
    global_norm_var = None
F
fengjiayi 已提交
116
    local_norm_var = None
F
fengjiayi 已提交
117
    clip_norm_var = None
118
    scale_var = None
F
fengjiayi 已提交
119 120 121

    @classmethod
    def init(cls, clip_norm):
122 123 124
        if not (isinstance(clip_norm, int) or isinstance(clip_norm, float)):
            raise TypeError("The 'clip_norm' must be a value of int or float")

F
fengjiayi 已提交
125 126
        cls.global_norm_var = layers.fill_constant(
            shape=[1], dtype="float32", value=0.0)
F
fengjiayi 已提交
127
        cls.local_norm_var = layers.create_tensor(dtype="float32")
F
fengjiayi 已提交
128 129 130
        cls.clip_norm_var = layers.fill_constant(
            shape=[1], dtype="float32", value=clip_norm)

131 132
    @classmethod
    def check_init(cls):
F
fengjiayi 已提交
133
        if not (isinstance(cls.global_norm_var, framework.Variable) and
F
fengjiayi 已提交
134
                isinstance(cls.local_norm_var, framework.Variable) and
F
fengjiayi 已提交
135
                isinstance(cls.clip_norm_var, framework.Variable)):
F
fengjiayi 已提交
136
            raise ValueError(
137 138 139
                "Class 'GradientClipByGlobalNorm' has not been properly initialized. \
                 Please call GradientClipByGlobalNorm.init() first.")

F
fengjiayi 已提交
140 141
    def process_context(self, context, param, grad):
        cls = self.__class__
142
        cls.check_init()
F
fengjiayi 已提交
143

F
fengjiayi 已提交
144 145
        cls.local_norm_var = layers.reduce_sum(
            input=layers.pow(x=grad, factor=2.0))
F
fengjiayi 已提交
146
        layers.sums(
F
fengjiayi 已提交
147
            input=[cls.local_norm_var, cls.global_norm_var],
148
            out=[cls.global_norm_var])
F
fengjiayi 已提交
149

F
fengjiayi 已提交
150 151
    def create_operators(self, param, grad):
        cls = self.__class__
152 153 154
        cls.check_init()

        if cls.scale_var is None:
F
fengjiayi 已提交
155
            layers.sqrt(x=cls.global_norm_var, out=cls.global_norm_var)
156 157
            cls.scale_var = layers.elementwise_div(
                x=cls.clip_norm_var,
F
fengjiayi 已提交
158
                y=layers.elementwise_max(
159
                    x=cls.clip_norm_var, y=cls.global_norm_var))
F
fengjiayi 已提交
160 161
            assert cls.scale_var.shape == (1L, )

162 163
        new_grad = layers.elementwise_mul(x=grad, y=cls.scale_var)
        return param, new_grad
F
fengjiayi 已提交
164 165


F
fengjiayi 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None):
    if program is None:
        program = framework.default_main_program()
    if param_list is None:
        param_list = program.block(0).all_parameters()
    if all(isinstance(elem, basestring) for elem in param_list):
        param_list = [program.block(0).var(elem) for elem in param_list]
    if not all(isinstance(elem, framework.Parameter) for elem in param_list):
        raise TypeError(
            "'param_list' should be a list of Parameter or basestring(parameter's name)."
        )

    GradientClipByGlobalNorm.init(clip_norm)
    for param in param_list:
        param.gradient_clip_attr = GradientClipByGlobalNorm()


Y
Yu Yang 已提交
183 184 185 186
def append_gradient_clip_ops(param_grad):
    context = dict()
    create_op_callbacks = []
    for p, g in param_grad:
F
fengjiayi 已提交
187
        clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
Y
Yu Yang 已提交
188 189 190 191
        if clip_attr is None:
            clip_attr = NullGradientClipAttr()
        if not isinstance(clip_attr, BaseGradientClipAttr):
            raise TypeError(
F
fengjiayi 已提交
192
                "clip attribute should be an instance of BaseGradientClipAttr")
Y
Yu Yang 已提交
193

F
fengjiayi 已提交
194
        clip_attr.process_context(context=context, param=p, grad=g)
Y
Yu Yang 已提交
195 196 197 198 199 200 201 202
        create_op_callbacks.append(
            functools.partial(
                clip_attr.create_operators, param=p, grad=g))

    return [each_callback() for each_callback in create_op_callbacks]


ClipByValue = GradientClipByValue