momentum_op.cc 5.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
S
sidgoyal78 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/momentum_op.h"
S
sidgoyal78 已提交
16 17 18 19

namespace paddle {
namespace operators {

D
dzhwinter 已提交
20 21
using Tensor = framework::Tensor;

S
sidgoyal78 已提交
22 23 24 25 26
class MomentumOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
27
  void InferShape(framework::InferShapeContext* ctx) const override {
S
sidgoyal78 已提交
28 29 30 31 32 33 34 35
    PADDLE_ENFORCE(ctx->HasInput("Param"),
                   "Input(param) of Momentum should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Grad"),
                   "Input(grad) of Momentum should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Velocity"),
                   "Input(velocity) of Momentum should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
                   "Input(LearningRate) of Momentum should not be null.");
C
chengduo 已提交
36 37 38 39 40
    PADDLE_ENFORCE(
        ctx->GetInputsVarType("Param").front() ==
            framework::proto::VarType::LOD_TENSOR,
        "The input var's type should be LoDTensor, but the received is %s",
        ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
S
sidgoyal78 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59

    PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
                   "Output(ParamOut) of Momentum should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"),
                   "Output(VelocityOut) of Momentum should not be null.");

    auto param_dim = ctx->GetInputDim("Param");
    PADDLE_ENFORCE_EQ(
        param_dim, ctx->GetInputDim("Grad"),
        "Param and Grad input of MomentumOp should have the same dimension.");
    PADDLE_ENFORCE_EQ(
        param_dim, ctx->GetInputDim("Velocity"),
        "Param and Velocity of MomentumOp should have the same dimension.");
    PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1,
                      "Learning_rate should be a scalar");

    ctx->SetOutputDim("ParamOut", param_dim);
    ctx->SetOutputDim("VelocityOut", param_dim);
  }
D
dzhwinter 已提交
60
  framework::OpKernelType GetExpectedKernelType(
61 62
      const framework::ExecutionContext& ctx) const override {
    auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
D
dzhwinter 已提交
63 64
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
S
sidgoyal78 已提交
65 66
};

67 68 69 70 71 72 73 74 75 76
class MomentumOpInferVarType : public framework::VarTypeInference {
 public:
  void operator()(const framework::OpDesc& op_desc,
                  framework::BlockDesc* block) const override {
    auto input_var = op_desc.Input("Param")[0];
    for (auto& out_var : op_desc.Output("ParamOut")) {
      if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
          framework::proto::VarType::SELECTED_ROWS) {
        block->FindRecursiveOrCreateVar(out_var).SetType(
            framework::proto::VarType::SELECTED_ROWS);
D
dzhwinter 已提交
77 78
      } else if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
                 framework::proto::VarType::LOD_TENSOR) {
79 80
        block->FindRecursiveOrCreateVar(out_var).SetType(
            framework::proto::VarType::LOD_TENSOR);
D
dzhwinter 已提交
81 82 83
      } else {
        PADDLE_THROW(
            "Only support LodTensor and SelectedRows, Unexpected Input Type.");
84 85 86 87 88
      }
    }
  }
};

S
sidgoyal78 已提交
89 90
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
91
  void Make() override {
S
sidgoyal78 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105
    AddInput("Param",
             "(Tensor, default Tensor<float>) "
             "Input parameter that has to be updated");
    AddInput("Grad",
             "(Tensor, default Tensor<float>) "
             "Input gradient of the parameter");
    AddInput("Velocity",
             "(Tensor, default Tensor<float>) "
             "Input velocity (corresponding to the parameter) "
             "that has to be updated");
    AddInput("LearningRate",
             "(Tensor, default Tensor<float>) "
             "Input learning rate");

D
dangqingqing 已提交
106 107 108 109 110 111
    AddOutput("ParamOut",
              "(Tensor) This output is updated parameter. "
              "It shared memory with Input(Param).");
    AddOutput("VelocityOut",
              "(Tensor) This output is updated velocity. "
              "It shared memory with Input(Velocity).");
S
sidgoyal78 已提交
112 113

    AddAttr<float>("mu", "(float) Momentum coefficient");
114
    AddAttr<bool>("use_nesterov",
K
kexinzhao 已提交
115 116
                  "(bool, default false) "
                  "Use Nesterov Momentum")
K
kavyasrinet 已提交
117
        .SetDefault(false);
S
sidgoyal78 已提交
118
    AddComment(R"DOC(
K
kexinzhao 已提交
119 120 121 122 123 124 125 126
Momentum Optimizer.

This optimizer has a flag for Nestrov Momentum.
The update equations are as follows:

$$
velocity = mu * velocity + gradient \\
if (use\_nesterov):   \\
127
  param = param - (gradient + mu * velocity) * learning\_rate \\
K
kexinzhao 已提交
128 129 130
else:   \\
  param = param - learning\_rate * velocity. \\
$$
S
sidgoyal78 已提交
131 132 133 134 135 136 137 138

)DOC");
  }
};
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
139 140 141
REGISTER_OPERATOR(momentum, ops::MomentumOp, ops::MomentumOpMaker,
                  paddle::framework::EmptyGradOpMaker,
                  ops::MomentumOpInferVarType);
D
dzhwinter 已提交
142 143 144
REGISTER_OP_CPU_KERNEL(
    momentum, ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, float>,
    ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, double>);