lookahead.py 10.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
16
from paddle.fluid import framework, unique_name
17
from paddle.fluid.dygraph import base as imperative_base
18
from paddle.fluid.framework import Variable
19
from paddle.fluid.layer_helper import LayerHelper
20
from paddle.optimizer import Optimizer
21

22
__all__ = []
23 24 25 26 27 28 29 30


class LookAhead(Optimizer):
    r"""
    This implements the Lookahead optimizer of the
    paper : https://arxiv.org/abs/1907.08610.

    Lookahead keeps two sets of params: the fast_params and
31 32
    the slow_params. inner_optimizer update fast_params every
    training step. Lookahead updates the slow_params and fast_params
33 34 35
    every k training steps as follows:

    .. math::
36

37
        slow\_param_t &= slow\_param_{t-1} + \\alpha * (fast\_param_{t-1} - slow\_param_{t-1})
38

39 40 41
        fast\_param_t &=  slow\_param_t

    Args:
42
        inner_optimizer (Optimizer): The optimizer that update fast params step by step.
43 44 45 46 47 48 49 50 51
        alpha (float, optinal): The learning rate of Lookahead. The default value is 0.5.
        k (int, optinal): The slow params is updated every k steps. The default value is 5.
        name (str, optional): Normally there is no need for user to set this property.
            For more information, please refer to :ref:`api_guide_Name`.
            The default value is None.

    Examples:

        .. code-block:: python
52

53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
            import numpy as np
            import paddle
            import paddle.nn as nn

            BATCH_SIZE = 16
            BATCH_NUM = 4
            EPOCH_NUM = 4

            IMAGE_SIZE = 784
            CLASS_NUM = 10
            # define a random dataset
            class RandomDataset(paddle.io.Dataset):
                def __init__(self, num_samples):
                    self.num_samples = num_samples

                def __getitem__(self, idx):
                    image = np.random.random([IMAGE_SIZE]).astype('float32')
                    label = np.random.randint(0, CLASS_NUM - 1,
                                            (1, )).astype('int64')
                    return image, label

                def __len__(self):
                    return self.num_samples

            class LinearNet(nn.Layer):
                def __init__(self):
79
                    super().__init__()
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
                    self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
                    self.bias = self._linear.bias

                @paddle.jit.to_static
                def forward(self, x):
                    return self._linear(x)

            def train(layer, loader, loss_fn, opt):
                for epoch_id in range(EPOCH_NUM):
                    for batch_id, (image, label) in enumerate(loader()):
                        out = layer(image)
                        loss = loss_fn(out, label)
                        loss.backward()
                        opt.step()
                        opt.clear_grad()
                        print("Train Epoch {} batch {}: loss = {}".format(
                            epoch_id, batch_id, np.mean(loss.numpy())))

            layer = LinearNet()
            loss_fn = nn.CrossEntropyLoss()
            optimizer = paddle.optimizer.SGD(learning_rate=0.1, parameters=layer.parameters())
101
            lookahead = paddle.incubate.LookAhead(optimizer, alpha=0.2, k=5)
102 103 104 105 106 107 108 109 110

            # create data loader
            dataset = RandomDataset(BATCH_NUM * BATCH_SIZE)
            loader = paddle.io.DataLoader(
                dataset,
                batch_size=BATCH_SIZE,
                shuffle=True,
                drop_last=True,
                num_workers=2)
111

112 113 114 115 116 117
            train(layer, loader, loss_fn, lookahead)

    """
    _slow_str = "slow"

    def __init__(self, inner_optimizer, alpha=0.5, k=5, name=None):
118
        assert inner_optimizer is not None, "inner optimizer can not be None"
119 120 121
        assert (
            0.0 <= alpha <= 1.0
        ), "alpha should be larger or equal to 0.0, and less or equal than 1.0"
122
        assert isinstance(k, int) and k > 0, "k should be a positive integer"
123 124 125

        self.inner_optimizer = inner_optimizer
        if self.inner_optimizer._parameter_list is None:
126 127 128
            parameters = (
                framework.default_main_program().global_block().all_parameters()
            )
129 130 131
        else:
            parameters = self.inner_optimizer._parameter_list

132
        super().__init__(
133 134 135 136 137 138
            learning_rate=alpha,
            parameters=parameters,
            weight_decay=None,
            grad_clip=None,
            name=name,
        )
139 140 141 142 143 144 145 146

        self.alpha = alpha
        self.k = k
        self.type = "lookahead"
        self.helper = LayerHelper(self.__class__.__name__)
        self._global_step_var = None
        self._k_var = None

147 148 149 150
    def _set_auxiliary_var(self, key, val):
        super()._set_auxiliary_var(key, val)
        self.inner_optimizer._set_auxiliary_var(key, val)

151 152 153 154 155
    @framework.dygraph_only
    @imperative_base.no_grad
    def step(self):
        """
        Execute the optimizer and update parameters once.
156

157 158 159 160 161 162 163 164
        Returns:
            None

        Examples:

            .. code-block:: python

                import paddle
165
                inp = paddle.rand([1,10], dtype="float32")
166 167 168 169
                linear = paddle.nn.Linear(10, 1)
                out = linear(inp)
                loss = paddle.mean(out)
                sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters())
170
                lookahead = paddle.incubate.LookAhead(sgd, alpha=0.2, k=5)
171 172 173 174 175 176 177
                loss.backward()
                lookahead.step()
                lookahead.clear_grad()

        """
        self.inner_optimizer.step()

178
        self._increment_global_var()
179 180 181 182 183 184 185 186
        params_grads = []
        for param in self._parameter_list:
            if not param.trainable:
                continue
            if param._grad_ivar() is not None:
                grad_var = param._grad_ivar()
                params_grads.append((param, grad_var))

187 188 189
        self._apply_optimize(
            loss=None, startup_program=None, params_grads=params_grads
        )
190 191 192 193 194 195 196

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)

        for p in parameters:
            self._add_accumulator(self._slow_str, p)

197
    def _increment_global_var(self):
198
        if self._global_step_var is None:
199
            self._global_step_var = paddle.static.create_global_var(
200 201 202 203
                name=unique_name.generate("lookahead_step"),
                shape=[1],
                value=0,
                dtype='int32',
204 205
                persistable=True,
            )
206

207 208 209 210 211 212
        self.helper.append_op(
            type='increment',
            inputs={'X': [self._global_step_var]},
            outputs={'Out': [self._global_step_var]},
            attrs={'step': 1.0},
        )
213

214
    def _append_optimize_op(self, block, param_and_grad):
215
        one_var = paddle.ones(shape=[1], dtype='int32', name='lookahead_ones')
216 217 218
        zero_var = paddle.zeros(
            shape=[1], dtype='int32', name='lookahead_zeros'
        )
219
        k_var = paddle.static.create_global_var(
220 221 222 223
            name=unique_name.generate("lookahead_k"),
            shape=[1],
            value=self.k,
            dtype='int32',
224 225
            persistable=True,
        )
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247

        mod = paddle.remainder(self._global_step_var, k_var)

        cond_1 = paddle.equal(self._global_step_var, one_var)
        cond_1 = paddle.cast(cond_1, dtype='float32')

        cond_2 = paddle.equal(mod, zero_var)
        cond_2 = paddle.cast(cond_2, dtype='float32')

        slow_var = self._get_accumulator(self._slow_str, param_and_grad[0])

        tmp_var = cond_1 * param_and_grad[0] + (1 - cond_1) * slow_var
        paddle.assign(tmp_var, slow_var)

        tmp_var = self.alpha * param_and_grad[0] + (1.0 - self.alpha) * slow_var
        tmp_var_1 = cond_2 * tmp_var + (1 - cond_2) * param_and_grad[0]
        paddle.assign(tmp_var_1, param_and_grad[0])

        tmp_var_1 = cond_2 * tmp_var + (1 - cond_2) * slow_var
        paddle.assign(tmp_var_1, slow_var)

    @imperative_base.no_grad
248 249 250
    def minimize(
        self, loss, startup_program=None, parameters=None, no_grad_set=None
    ):
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
        """
        Add operations to minimize ``loss`` by updating ``parameters``.

        Args:
            loss (Tensor): A ``Tensor`` containing the value to minimize.
            startup_program (Program, optional): :ref:`api_fluid_Program` for
                initializing parameters in ``parameters``. The default value
                is None, at this time :ref:`api_fluid_default_startup_program` will be used.
            parameters (list, optional): List of ``Tensor`` or ``Tensor.name`` to update
                to minimize ``loss``. The default value is None, at this time all parameters
                will be updated.
            no_grad_set (set, optional): Set of ``Tensor``  or ``Tensor.name`` that don't need
                to be updated. The default value is None.

        Returns:
            tuple: tuple (optimize_ops, params_grads), A list of operators appended
            by minimize and a list of (param, grad) tensor pairs, param is
            ``Parameter``, grad is the gradient value corresponding to the parameter.
269 270
            In static graph mode, the returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
            indicate program pruning. If so, the program will be pruned by ``feed`` and
271 272 273 274 275 276 277
            ``fetch_list`` before run, see details in ``Executor``.

        Examples:

            .. code-block:: python

                import paddle
278 279

                inp = paddle.rand([1, 10], dtype="float32")
280 281 282 283
                linear = paddle.nn.Linear(10, 1)
                out = linear(inp)
                loss = paddle.mean(out)
                sgd = paddle.optimizer.SGD(learning_rate=0.1,parameters=linear.parameters())
284
                lookahead = paddle.incubate.LookAhead(sgd, alpha=0.2, k=5)
285 286 287 288 289 290 291 292 293 294 295 296
                loss.backward()
                lookahead.minimize(loss)
                lookahead.clear_grad()

        """
        assert isinstance(loss, Variable), "The loss should be an Tensor."

        # Apply inner optimizer to the main_program
        optimize_ops, params_grads = self.inner_optimizer.minimize(
            loss,
            startup_program=startup_program,
            parameters=parameters,
297 298
            no_grad_set=no_grad_set,
        )
299

300 301
        self._increment_global_var()

302 303 304
        _ = self._apply_optimize(
            loss, startup_program=startup_program, params_grads=params_grads
        )
305 306

        return optimize_ops, params_grads