momentum_op.cc 4.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
S
sidgoyal78 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/optimizers/momentum_op.h"
16
#include "paddle/fluid/framework/op_version_registry.h"
S
sidgoyal78 已提交
17 18 19 20

namespace paddle {
namespace operators {

D
dzhwinter 已提交
21 22
using Tensor = framework::Tensor;

23 24
class MomentumOpInferVarType : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
25
  void operator()(framework::InferVarTypeContext* ctx) const override {
26 27 28 29 30 31 32 33 34
    auto in_var_type = ctx->GetInputType("Param");
    PADDLE_ENFORCE_EQ(
        in_var_type == framework::proto::VarType::SELECTED_ROWS ||
            in_var_type == framework::proto::VarType::LOD_TENSOR,
        true,
        platform::errors::InvalidArgument(
            "Only support LodTensor and SelectedRows, Unexpected Input Type."));

    ctx->SetOutputType("ParamOut", in_var_type, framework::ALL_ELEMENTS);
35 36 37
  }
};

38 39 40 41 42 43 44 45 46 47 48 49 50 51
void MomentumOpMaker::Make() {
  AddInput("Param",
           "(Tensor, default Tensor<float>) "
           "Input parameter that has to be updated");
  AddInput("Grad",
           "(Tensor, default Tensor<float>) "
           "Input gradient of the parameter");
  AddInput("Velocity",
           "(Tensor, default Tensor<float>) "
           "Input velocity (corresponding to the parameter) "
           "that has to be updated");
  AddInput("LearningRate",
           "(Tensor, default Tensor<float>) "
           "Input learning rate");
52
  AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
53 54 55 56 57 58
  AddOutput("ParamOut",
            "(Tensor) This output is updated parameter. "
            "It shared memory with Input(Param).");
  AddOutput("VelocityOut",
            "(Tensor) This output is updated velocity. "
            "It shared memory with Input(Velocity).");
59 60 61 62
  AddOutput("MasterParamOut",
            "The updated FP32 master weight for AMP. "
            "It shared memory with Input(MasterParam).")
      .AsDispensable();
S
sidgoyal78 已提交
63

64 65 66 67 68
  AddAttr<float>("mu", "(float) Momentum coefficient");
  AddAttr<bool>("use_nesterov",
                "(bool, default false) "
                "Use Nesterov Momentum")
      .SetDefault(false);
69 70 71 72 73
  AddAttr<std::string>(
      "regularization_method",
      "(string) regularization_method, right now only support l2decay or none")
      .SetDefault("");
  AddAttr<float>("regularization_coeff", "(float) regularization_coeff")
74 75 76 77 78 79 80 81 82 83 84
      .SetDefault(0.0f);
  AddAttr<bool>("multi_precision",
                "(bool, default false) "
                "Whether to use multi-precision during weight updating.")
      .SetDefault(false);
  AddAttr<float>(
      "rescale_grad",
      "(float, default 1.0) Multiply the gradient with `rescale_grad`"
      "before updating. Often choose to be `1.0/batch_size`.")
      .SetDefault(1.0f);

85
  AddComment(R"DOC(
K
kexinzhao 已提交
86 87 88 89 90 91 92 93
Momentum Optimizer.

This optimizer has a flag for Nestrov Momentum.
The update equations are as follows:

$$
velocity = mu * velocity + gradient \\
if (use\_nesterov):   \\
94
  param = param - (gradient + mu * velocity) * learning\_rate \\
K
kexinzhao 已提交
95 96 97
else:   \\
  param = param - learning\_rate * velocity. \\
$$
S
sidgoyal78 已提交
98 99

)DOC");
100 101
}

S
sidgoyal78 已提交
102 103 104 105
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
H
hong 已提交
106 107 108 109 110
REGISTER_OPERATOR(
    momentum, ops::MomentumOp, ops::MomentumOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
    ops::MomentumOpInferVarType);
D
dzhwinter 已提交
111 112 113
REGISTER_OP_CPU_KERNEL(
    momentum, ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, float>,
    ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, double>);
114 115 116 117 118 119 120 121 122 123 124 125

REGISTER_OP_VERSION(momentum)
    .AddCheckpoint(
        R"ROC(
      Upgrade momentum add 2 attributes [regularization_method, regularization_coeff].
    )ROC",
        paddle::framework::compatible::OpVersionDesc()
            .NewAttr("regularization_method",
                     "(string) regularization_method, right now only support "
                     "l2decay or none",
                     std::string(""))
            .NewAttr("regularization_coeff", "(float) regularization_coeff",
126 127 128 129 130 131 132 133 134
                     0.0f)
            .NewAttr(
                "multi_precision",
                "(bool) Whether to use multi-precision during weight updating.",
                false)
            .NewAttr("rescale_grad",
                     "(float) Multiply the gradient with `rescale_grad`"
                     "before updating. Often choose to be `1.0/batch_size`.",
                     1.0f));