sgd.py 7.2 KB
Newer Older
J
Jiawei Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import warnings
16

17
import paddle
18
from paddle import _C_ops
19

20
from ..fluid import core, framework, unique_name
21
from ..fluid.dygraph import no_grad
22
from ..fluid.framework import in_dygraph_mode
23 24
from ..fluid.layer_helper import LayerHelper
from .optimizer import Optimizer
J
Jiawei Wang 已提交
25

26 27
__all__ = []

J
Jiawei Wang 已提交
28 29

class SGD(Optimizer):
30
    r"""
J
Jiawei Wang 已提交
31 32 33 34 35 36 37 38 39
    Optimizer of the stochastic gradient descent algorithm.

    .. math::

        param\_out = param - learning\_rate * grad

    Parameters:
        learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``.
            It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001.
40
        parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \
J
Jiawei Wang 已提交
41
            This parameter is required in dygraph mode. \
42
            The default value is None in static graph mode, at this time all parameters will be updated.
J
Jiawei Wang 已提交
43
        weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \
44 45 46 47 48 49
            It canbe a float value as coeff of L2 regularization or \
            :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`.
            If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \
            the regularization setting here in optimizer will be ignored for this parameter. \
            Otherwise, the regularization setting here in optimizer will take effect. \
            Default None, meaning there is no regularization.
J
Jiawei Wang 已提交
50 51 52 53 54 55
        grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
            some derived class of ``GradientClipBase`` . There are three cliping strategies
            ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
            :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
        name (str, optional): The default value is None. Normally there is no need for user
                to set this property. For more information, please refer to
56 57
                :ref:`api_guide_Name` .

J
Jiawei Wang 已提交
58 59 60 61
    Examples:
        .. code-block:: python

            import paddle
62 63

            inp = paddle.uniform(min=-0.1, max=0.1, shape=[10, 10], dtype='float32')
J
Jiawei Wang 已提交
64 65 66 67 68
            linear = paddle.nn.Linear(10, 10)
            inp = paddle.to_tensor(inp)
            out = linear(inp)
            loss = paddle.mean(out)
            sgd = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01)
69
            out.backward()
J
Jiawei Wang 已提交
70 71 72 73 74
            sgd.step()
            sgd.clear_grad()

    """

75 76 77 78 79 80 81 82 83
    def __init__(
        self,
        learning_rate=0.001,
        parameters=None,
        weight_decay=None,
        grad_clip=None,
        multi_precision=False,
        name=None,
    ):
J
Jiawei Wang 已提交
84 85
        if learning_rate is None:
            raise ValueError("learning_rate is not set")
86
        super().__init__(
87 88 89 90 91 92
            learning_rate=learning_rate,
            parameters=parameters,
            weight_decay=weight_decay,
            grad_clip=grad_clip,
            name=name,
        )
J
Jiawei Wang 已提交
93
        self.type = "sgd"
94 95 96 97 98 99 100 101 102 103 104
        self._multi_precision = multi_precision
        self._master_weights = {}

    def _create_master_weight(self, param):
        if param.name in self._master_weights:
            var = self._master_weights[param.name]
        else:
            assert isinstance(self.helper, LayerHelper)

            var_name = param.name + "_fp32_master"
            var_name = unique_name.generate(var_name)
105
            var = paddle.static.create_global_var(
106 107 108 109 110 111
                name=var_name,
                shape=param.shape,
                value=0,
                dtype='float32',
                persistable=True,
            )
112
            block = self.helper.startup_program.global_block()
113 114 115 116 117 118 119 120 121
            block.append_op(
                type="cast",
                inputs={"X": [param]},
                outputs={"Out": [var]},
                attrs={
                    "in_dtype": param.dtype,
                    "out_dtype": core.VarDesc.VarType.FP32,
                },
            )
122 123 124 125 126 127 128 129 130 131 132 133 134
            self._master_weights[param.name] = var
        return var

    def _create_accumulators(self, block, parameters):
        assert isinstance(block, framework.Block)
        if isinstance(parameters, dict):
            parameters = self._update_param_group(parameters)

        # Create accumulator tensors for first and second moments
        for p in parameters:
            if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
                master_p = self._create_master_weight(p)
                continue
135 136 137 138
            if (
                p.dtype == core.VarDesc.VarType.FP16
                and not self._multi_precision
            ):
139 140 141 142
                warnings.warn(
                    "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
                    "Consider using multi_precision=True option of the Adam optimizer."
                )
J
Jiawei Wang 已提交
143

144
    @no_grad
J
Jiawei Wang 已提交
145
    def _append_optimize_op(self, block, param_and_grad):
146 147
        if isinstance(param_and_grad, dict):
            param_and_grad = self._update_param_group(param_and_grad)
148

149 150 151 152 153 154 155 156 157
        find_master = (
            self._multi_precision
            and param_and_grad[0].dtype == core.VarDesc.VarType.FP16
        )
        master_weight = (
            self._master_weights[param_and_grad[0].name]
            if find_master
            else None
        )
158

J
Jiawei Wang 已提交
159
        lr = self._create_param_lr(param_and_grad)
Z
zyfncg 已提交
160
        if in_dygraph_mode():
161 162 163 164 165 166 167
            _C_ops.sgd_(
                param_and_grad[0],
                lr,
                param_and_grad[1],
                master_weight,
                find_master,
            )
Z
zyfncg 已提交
168
            return None
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
        else:
            assert isinstance(block, framework.Block)
            # create the optimize op
            inputs = {
                "Param": param_and_grad[0],
                "Grad": param_and_grad[1],
                "LearningRate": lr,
            }

            outputs = {"ParamOut": param_and_grad[0]}

            attrs = {"multi_precision": find_master}

            if find_master:
                inputs["MasterParam"] = master_weight
                outputs["MasterParamOut"] = master_weight

            sgd_op = block.append_op(
                type=self.type,
                inputs=inputs,
                outputs=outputs,
                attrs=attrs,
                stop_gradient=True,
192
            )
J
Jiawei Wang 已提交
193

194
            return sgd_op
195 196 197 198

    def _update_param_group(self, parameters):
        parameters = parameters.get('params')
        return parameters