clip.py 8.1 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
F
fengjiayi 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
F
fengjiayi 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
F
update  
fengjiayi 已提交
14

F
fengjiayi 已提交
15 16
import copy

Y
Yu Yang 已提交
17 18
import functools
import layers
F
fengjiayi 已提交
19
import framework
F
fengjiayi 已提交
20
from . import core
Y
Yu Yang 已提交
21

F
fengjiayi 已提交
22
__all__ = [
23
    'ErrorClipByValue',
F
fengjiayi 已提交
24 25 26
    'GradientClipByValue',
    'GradientClipByNorm',
    'GradientClipByGlobalNorm',
27 28
    'append_gradient_clip_ops',
    'error_clip_callback',
F
fengjiayi 已提交
29
]
Y
Yu Yang 已提交
30 31


F
fengjiayi 已提交
32
class BaseErrorClipAttr(object):
F
fengjiayi 已提交
33 34 35
    def __str__(self):
        raise NotImplementedError()

F
fengjiayi 已提交
36
    def append_clip_op(self, block, grad_name):
F
fengjiayi 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49
        raise NotImplementedError()


class ErrorClipByValue(BaseErrorClipAttr):
    def __init__(self, max, min=None):
        max = float(max)
        if min is None:
            min = -max
        else:
            min = float(min)
        self.max = max
        self.min = min

F
fengjiayi 已提交
50 51 52
    def __str__(self):
        return "ByValue, min=%f, max=%f" % (self.min, self.max)

F
fengjiayi 已提交
53
    def append_clip_op(self, block, grad_name):
54 55 56 57 58 59
        clip_op_desc = block.desc.append_op()
        clip_op_desc.set_type("clip")
        clip_op_desc.set_input("X", [grad_name])
        clip_op_desc.set_output("Out", [grad_name])
        clip_op_desc.set_attr("min", self.min)
        clip_op_desc.set_attr("max", self.max)
F
fengjiayi 已提交
60 61 62 63 64 65 66 67 68 69


def error_clip_callback(block, context):
    # the context is a grad_to_var map
    grad_to_var = context
    op_desc = block.desc.op(block.desc.op_size() - 1)
    for grad_n in filter(lambda n: grad_to_var.has_key(n),
                         op_desc.output_arg_names()):
        fwd_var = block.var_recursive(grad_to_var[grad_n])
        error_clip = getattr(fwd_var, "error_clip", None)
F
fengjiayi 已提交
70 71 72 73 74
        if not (error_clip is None or isinstance(error_clip,
                                                 BaseErrorClipAttr)):
            raise TypeError(
                "Variable's error_clip should be an instance of BaseErrorClipAttr or None."
            )
F
fengjiayi 已提交
75 76
        if error_clip is not None:
            error_clip.append_clip_op(block, grad_n)
F
fengjiayi 已提交
77 78


Y
Yu Yang 已提交
79
class BaseGradientClipAttr(object):
F
fengjiayi 已提交
80 81 82
    def __str__(self):
        raise NotImplementedError()

F
fengjiayi 已提交
83
    def process_context(self, context, param, grad):
Y
Yu Yang 已提交
84 85 86 87 88 89 90
        raise NotImplementedError()

    def create_operators(self, param, grad):
        raise NotImplementedError()


class NullGradientClipAttr(BaseGradientClipAttr):
F
fengjiayi 已提交
91 92 93
    def __str__(self):
        return "Null"

F
fengjiayi 已提交
94
    def process_context(self, context, param, grad):
Y
Yu Yang 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
        pass

    def create_operators(self, param, grad):
        return param, grad


class GradientClipByValue(BaseGradientClipAttr):
    def __init__(self, max, min=None):
        max = float(max)
        if min is None:
            min = -max
        else:
            min = float(min)
        self.max = max
        self.min = min

F
fengjiayi 已提交
111 112 113
    def __str__(self):
        return "ByValue, min=%f, max=%f" % (self.min, self.max)

F
fengjiayi 已提交
114
    def process_context(self, context, param, grad):
Y
Yu Yang 已提交
115 116 117 118 119 120 121
        pass

    def create_operators(self, param, grad):
        new_grad = layers.clip(x=grad, min=self.min, max=self.max)
        return param, new_grad


F
fengjiayi 已提交
122 123 124 125
class GradientClipByNorm(BaseGradientClipAttr):
    def __init__(self, clip_norm):
        self.clip_norm = clip_norm

F
fengjiayi 已提交
126 127 128
    def __str__(self):
        return "ByNorm, clip_norm=%f" % self.clip_norm

F
fengjiayi 已提交
129
    def process_context(self, context, param, grad):
F
fengjiayi 已提交
130 131 132 133 134 135 136
        pass

    def create_operators(self, param, grad):
        new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
        return param, new_grad


F
fengjiayi 已提交
137
class GradientClipByGlobalNorm(BaseGradientClipAttr):
F
update  
fengjiayi 已提交
138 139 140 141 142 143
    def __init__(self, clip_norm, group_name="default_group"):
        if not isinstance(group_name, basestring):
            raise TypeError("'group_name' must be a basestring.")

        self.clip_norm = clip_norm
        self.group_name = group_name
144

F
fengjiayi 已提交
145 146 147 148
    def __str__(self):
        return "ByGlobalNorm, group_name=%s, clip_norm=%f" % (self.group_name,
                                                              self.clip_norm)

F
fengjiayi 已提交
149
    def process_context(self, context, param, grad):
F
update  
fengjiayi 已提交
150 151 152 153 154 155 156 157 158 159
        if self.group_name not in context:
            context[self.group_name] = []
            context[self.group_name + "_clip_value"] = self.clip_norm
            context[self.group_name + "_clip"] = layers.fill_constant(
                shape=[1], dtype="float32", value=self.clip_norm)
        else:
            if not self.clip_norm == context[self.group_name + "_clip_value"]:
                raise ValueError(
                    "All parameters' 'clip_norm' of a same group should be the same"
                )
F
fengjiayi 已提交
160

F
update  
fengjiayi 已提交
161 162
        local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0))
        context[self.group_name].append(local_norm_var)
F
fengjiayi 已提交
163

F
update  
fengjiayi 已提交
164
        self.context = context
165

F
update  
fengjiayi 已提交
166 167 168 169 170 171 172 173
    def create_operators(self, param, grad):
        group_scale_name = self.group_name + "_scale"
        if group_scale_name not in self.context:
            group_norm_var = layers.sums(input=self.context[self.group_name])
            layers.sqrt(x=group_norm_var, out=group_norm_var)
            clip_var = self.context[self.group_name + "_clip"]
            group_scale_var = layers.elementwise_div(
                x=clip_var,
F
fengjiayi 已提交
174
                y=layers.elementwise_max(
F
update  
fengjiayi 已提交
175 176 177
                    x=clip_var, y=group_norm_var))
            assert group_scale_var.shape == (1L, )
            self.context[group_scale_name] = group_scale_var
F
fengjiayi 已提交
178

F
update  
fengjiayi 已提交
179 180
        new_grad = layers.elementwise_mul(
            x=grad, y=self.context[group_scale_name])
181
        return param, new_grad
F
fengjiayi 已提交
182 183


F
fengjiayi 已提交
184
def set_gradient_clip(clip, param_list=None, program=None):
F
fengjiayi 已提交
185 186 187 188 189 190 191 192 193 194 195
    """
        To specify parameters that require gradient clip.
        Args:
            clip(BaseGradientClipAttr): An instance of some derived class of BaseGradientClipAttr, 
                    which describes the type and detailed attributes of required gradient clip.
            param_list(list, None by default): Parameters that require gradient clip. 
                    It can be a list of parameter or a list of parameter's name. 
                    When it's None, all parameters in the program will be included. 
            program(Program, None by default): The program where parameters are. 
                    Will be the default main program when assigned with None.
    """
F
fengjiayi 已提交
196 197 198 199
    if not isinstance(clip, BaseGradientClipAttr):
        raise TypeError(
            "'clip' should be an instance of BaseGradientClipAttr's derived class"
        )
F
fengjiayi 已提交
200 201 202 203 204 205 206 207 208 209 210 211
    if program is None:
        program = framework.default_main_program()
    if param_list is None:
        param_list = program.block(0).all_parameters()
    if all(isinstance(elem, basestring) for elem in param_list):
        param_list = [program.block(0).var(elem) for elem in param_list]
    if not all(isinstance(elem, framework.Parameter) for elem in param_list):
        raise TypeError(
            "'param_list' should be a list of Parameter or basestring(parameter's name)."
        )

    for param in param_list:
F
fengjiayi 已提交
212
        param.gradient_clip_attr = copy.deepcopy(clip)
F
fengjiayi 已提交
213 214


Y
Yu Yang 已提交
215 216 217 218
def append_gradient_clip_ops(param_grad):
    context = dict()
    create_op_callbacks = []
    for p, g in param_grad:
F
fengjiayi 已提交
219
        clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
Y
Yu Yang 已提交
220 221 222 223
        if clip_attr is None:
            clip_attr = NullGradientClipAttr()
        if not isinstance(clip_attr, BaseGradientClipAttr):
            raise TypeError(
F
fengjiayi 已提交
224
                "clip attribute should be an instance of BaseGradientClipAttr")
Y
Yu Yang 已提交
225

F
fengjiayi 已提交
226
        clip_attr.process_context(context=context, param=p, grad=g)
Y
Yu Yang 已提交
227 228 229 230 231 232 233 234
        create_op_callbacks.append(
            functools.partial(
                clip_attr.create_operators, param=p, grad=g))

    return [each_callback() for each_callback in create_op_callbacks]


ClipByValue = GradientClipByValue
F
fengjiayi 已提交
235 236
ClipByNorm = GradientClipByNorm
ClipByGlobalNorm = GradientClipByGlobalNorm