optimization.py 6.9 KB
Newer Older
Y
Yibing Liu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization and learning rate scheduling."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import paddle.fluid as fluid
22
from utils.fp16 import create_master_params_grads, master_param_to_train_param, apply_dynamic_loss_scaling
Y
Yibing Liu 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61


def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps):
    """ Applies linear warmup of learning rate from 0 and decay to 0."""
    with fluid.default_main_program()._lr_schedule_guard():
        lr = fluid.layers.tensor.create_global_var(
            shape=[1],
            value=0.0,
            dtype='float32',
            persistable=True,
            name="scheduled_learning_rate")

        global_step = fluid.layers.learning_rate_scheduler._decay_step_counter()

        with fluid.layers.control_flow.Switch() as switch:
            with switch.case(global_step < warmup_steps):
                warmup_lr = learning_rate * (global_step / warmup_steps)
                fluid.layers.tensor.assign(warmup_lr, lr)
            with switch.default():
                decayed_lr = fluid.layers.learning_rate_scheduler.polynomial_decay(
                    learning_rate=learning_rate,
                    decay_steps=num_train_steps,
                    end_learning_rate=0.0,
                    power=1.0,
                    cycle=False)
                fluid.layers.tensor.assign(decayed_lr, lr)

        return lr


def optimization(loss,
                 warmup_steps,
                 num_train_steps,
                 learning_rate,
                 train_program,
                 startup_prog,
                 weight_decay,
                 scheduler='linear_warmup_decay',
                 use_fp16=False,
62 63 64 65 66 67 68 69 70 71
                 use_dynamic_loss_scaling=False,
                 init_loss_scaling=1.0,
                 incr_every_n_steps=1000,
                 decr_every_n_nan_or_inf=2,
                 incr_ratio=2.0,
                 decr_ratio=0.8):

    scheduled_lr, loss_scaling = None, None
    if scheduler == 'noam_decay':
        if warmup_steps > 0:
Y
Yibing Liu 已提交
72 73
            scheduled_lr = fluid.layers.learning_rate_scheduler\
             .noam_decay(1/(warmup_steps *(learning_rate ** 2)),
74
           warmup_steps)
Y
Yibing Liu 已提交
75
        else:
Y
Yibing Liu 已提交
76 77 78 79
            print(
                "WARNING: noam decay of learning rate should have postive warmup "
                "steps but given {}, using constant learning rate instead!"
                .format(warmup_steps))
80 81 82 83 84 85 86
            scheduled_lr = fluid.layers.create_global_var(
                name=fluid.unique_name.generate("learning_rate"),
                shape=[1],
                value=learning_rate,
                dtype='float32',
                persistable=True)
    elif scheduler == 'linear_warmup_decay':
Y
Yibing Liu 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100
        if warmup_steps > 0:
            scheduled_lr = linear_warmup_decay(learning_rate, warmup_steps,
                                               num_train_steps)
        else:
            print(
                "WARNING: linear warmup decay of learning rate should have "
                "postive warmup steps but given {}, use constant learning rate "
                "instead!".format(warmup_steps))
            scheduled_lr = fluid.layers.create_global_var(
                name=fluid.unique_name.generate("learning_rate"),
                shape=[1],
                value=learning_rate,
                dtype='float32',
                persistable=True)
Y
Yibing Liu 已提交
101
    else:
102 103 104
        raise ValueError("Unkown learning rate scheduler, should be "
                         "'noam_decay' or 'linear_warmup_decay'")

105 106 107 108
    clip1 = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)

    optimizer = fluid.optimizer.Adam(
        learning_rate=scheduled_lr, grad_clip=clip1)
Y
Yibing Liu 已提交
109

110 111
    def exclude_from_weight_decay(param):
        name = param.name.rstrip(".master")
Y
Yibing Liu 已提交
112 113 114 115 116 117 118 119 120 121 122
        if name.find("layer_norm") > -1:
            return True
        bias_suffix = ["_bias", "_b", ".b_0"]
        for suffix in bias_suffix:
            if name.endswith(suffix):
                return True
        return False

    param_list = dict()

    if use_fp16:
123 124 125 126 127 128 129 130
        loss_scaling = fluid.layers.create_global_var(
            name=fluid.unique_name.generate("loss_scaling"),
            shape=[1],
            value=init_loss_scaling,
            dtype='float32',
            persistable=True)
        loss *= loss_scaling

Y
Yibing Liu 已提交
131 132 133 134
        param_grads = optimizer.backward(loss)
        master_param_grads = create_master_params_grads(
            param_grads, train_program, startup_prog, loss_scaling)

135 136 137 138 139 140 141 142 143
        if weight_decay > 0:
            for param, _ in master_param_grads:
                param_list[param.name] = param * 1.0
                param_list[param.name].stop_gradient = True

        if use_dynamic_loss_scaling:
            apply_dynamic_loss_scaling(
                loss_scaling, master_param_grads, incr_every_n_steps,
                decr_every_n_nan_or_inf, incr_ratio, decr_ratio)
Y
Yibing Liu 已提交
144 145 146 147 148

        optimizer.apply_gradients(master_param_grads)

        if weight_decay > 0:
            for param, grad in master_param_grads:
149
                if exclude_from_weight_decay(param):
Y
Yibing Liu 已提交
150 151 152 153 154 155 156 157 158 159 160
                    continue
                with param.block.program._optimized_guard(
                    [param, grad]), fluid.framework.name_scope("weight_decay"):
                    updated_param = param - param_list[
                        param.name] * weight_decay * scheduled_lr
                    fluid.layers.assign(output=param, input=updated_param)

        master_param_to_train_param(master_param_grads, param_grads,
                                    train_program)

    else:
161
        if weight_decay > 0:
162
            for param in train_program.all_parameters():
163 164
                param_list[param.name] = param * 1.0
                param_list[param.name].stop_gradient = True
Y
Yibing Liu 已提交
165 166 167 168 169

        _, param_grads = optimizer.minimize(loss)

        if weight_decay > 0:
            for param, grad in param_grads:
170
                if exclude_from_weight_decay(param):
Y
Yibing Liu 已提交
171 172 173 174 175 176 177
                    continue
                with param.block.program._optimized_guard(
                    [param, grad]), fluid.framework.name_scope("weight_decay"):
                    updated_param = param - param_list[
                        param.name] * weight_decay * scheduled_lr
                    fluid.layers.assign(output=param, input=updated_param)

178
    return scheduled_lr, loss_scaling