sgd_op.cc 4.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/optimizers/sgd_op.h"
16
#include <string>
Q
Qiao Longfei 已提交
17 18 19
namespace paddle {
namespace operators {

D
dongzhihong 已提交
20
class SGDOp : public framework::OperatorWithKernel {
Y
Yu Yang 已提交
21 22 23
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

C
chengduo 已提交
24
  void InferShape(framework::InferShapeContext *ctx) const override {
25 26 27 28 29 30 31 32
    PADDLE_ENFORCE(ctx->HasInput("Param"),
                   "Input(Param) of SGDOp should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Grad"),
                   "Input(Grad) of SGDOp should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
                   "Input(LearningRate) of SGDOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
                   "Output(ParamOut) of SGDOp should not be null.");
Q
Qiao Longfei 已提交
33

34
    auto lr_dims = ctx->GetInputDim("LearningRate");
35 36 37 38 39
    PADDLE_ENFORCE_NE(framework::product(lr_dims), 0,
                      "Maybe the Input variable LearningRate has not "
                      "been initialized. You may need to confirm "
                      "if you put exe.run(startup_program) "
                      "after optimizer.minimize function.");
40 41
    PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
                      "Learning rate should have 1 element");
42
    auto param_dim = ctx->GetInputDim("Param");
43 44 45 46 47 48
    if (ctx->GetInputsVarType("Grad")[0] ==
        framework::proto::VarType::LOD_TENSOR) {
      PADDLE_ENFORCE_EQ(
          param_dim, ctx->GetInputDim("Grad"),
          platform::errors::InvalidArgument(
              "SGD Operator's input Param and Grad dimensions do not match. "
49 50 51
              "The Param %s shape is [%s], but the Grad %s shape is [%s].",
              ctx->Inputs("Param")[0], param_dim, ctx->Inputs("Grad")[0],
              ctx->GetInputDim("Grad")));
52
    }
53
    ctx->SetOutputDim("ParamOut", param_dim);
Q
Qiao Longfei 已提交
54
  }
55 56 57

 protected:
  framework::OpKernelType GetExpectedKernelType(
C
chengduo 已提交
58
      const framework::ExecutionContext &ctx) const override {
59
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
Q
qiaolongfei 已提交
60
    return framework::OpKernelType(data_type, ctx.device_context());
61
  }
62 63 64 65 66 67 68 69 70 71 72

  framework::OpKernelType GetKernelTypeForVar(
      const std::string &var_name, const framework::Tensor &tensor,
      const framework::OpKernelType &expected_kernel_type) const {
    if (var_name == "LearningRate") {
      return framework::OpKernelType(tensor.type(), tensor.place(),
                                     tensor.layout());
    }
    return framework::OpKernelType(expected_kernel_type.data_type_,
                                   tensor.place(), tensor.layout());
  }
Q
Qiao Longfei 已提交
73 74
};

Y
Yancey1989 已提交
75 76
class SGDOpInferVarType : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
77 78 79
  void operator()(framework::InferVarTypeContext *ctx) const override {
    auto &input_var_n = ctx->Input("Param")[0];
    auto in_var_type = ctx->GetType(input_var_n);
C
chengduo 已提交
80 81 82 83 84 85
    PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
                       in_var_type == framework::proto::VarType::LOD_TENSOR,
                   "The input Var's type should be LoDtensor or SelectedRows,"
                   " but the received var(%s)'s type is %s",
                   input_var_n, in_var_type);

M
minqiyang 已提交
86 87 88
    for (auto &out_var_n : ctx->Output("ParamOut")) {
      if (ctx->GetType(out_var_n) != in_var_type) {
        ctx->SetType(out_var_n, in_var_type);
Y
Yancey1989 已提交
89 90 91 92 93
      }
    }
  }
};

D
dongzhihong 已提交
94
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
95
 public:
Y
Yu Yang 已提交
96
  void Make() override {
97
    AddInput("Param", "(Tensor or SelectedRows) Input parameter");
98
    AddInput("LearningRate", "(Tensor) Learning rate of SGD");
99 100 101
    AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
    AddOutput("ParamOut",
              "(Tensor or SelectedRows, same with Param) "
102
              "Output parameter, should share the same memory with Param");
Q
Qiao Longfei 已提交
103 104
    AddComment(R"DOC(

105
SGD operator
Q
Qiao Longfei 已提交
106

107 108
This operator implements one step of the stochastic gradient descent algorithm.

109
$$param\_out = param - learning\_rate * grad$$
Q
Qiao Longfei 已提交
110 111 112 113

)DOC");
  }
};
Q
qijun 已提交
114

Q
Qiao Longfei 已提交
115 116 117
}  // namespace operators
}  // namespace paddle

D
dongzhihong 已提交
118
namespace ops = paddle::operators;
H
hong 已提交
119 120 121 122 123
REGISTER_OPERATOR(
    sgd, ops::SGDOp, ops::SGDOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
    ops::SGDOpInferVarType);
124 125 126
REGISTER_OP_CPU_KERNEL(
    sgd, ops::SGDOpKernel<paddle::platform::CPUDeviceContext, float>,
    ops::SGDOpKernel<paddle::platform::CPUDeviceContext, double>);