sgd_op.cc 3.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/sgd_op.h"
Q
Qiao Longfei 已提交
16 17 18 19

namespace paddle {
namespace operators {

D
dongzhihong 已提交
20
class SGDOp : public framework::OperatorWithKernel {
Y
Yu Yang 已提交
21 22 23
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

Q
qijun 已提交
24
  void InferShape(framework::InferShapeContext* ctx) const override {
25 26 27 28 29 30 31 32
    PADDLE_ENFORCE(ctx->HasInput("Param"),
                   "Input(Param) of SGDOp should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("Grad"),
                   "Input(Grad) of SGDOp should not be null.");
    PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
                   "Input(LearningRate) of SGDOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
                   "Output(ParamOut) of SGDOp should not be null.");
Q
Qiao Longfei 已提交
33

34
    auto lr_dims = ctx->GetInputDim("LearningRate");
35 36
    PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
                      "Learning rate should have 1 element");
37
    auto param_dim = ctx->GetInputDim("Param");
Q
qijun 已提交
38 39
    // TODO(qijun): check dimensions of Param and Grad at complie
    // and run time.
40
    ctx->SetOutputDim("ParamOut", param_dim);
Q
Qiao Longfei 已提交
41
  }
42 43 44 45

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
46 47 48 49 50 51 52 53 54 55 56 57 58
    auto* table_var = ctx.InputVar("Param");
    if (table_var->IsType<framework::LoDTensor>()) {
      return framework::OpKernelType(
          framework::ToDataType(table_var->Get<framework::LoDTensor>().type()),
          ctx.device_context());
    } else if (table_var->IsType<framework::SelectedRows>()) {
      return framework::OpKernelType(
          framework::ToDataType(
              table_var->Get<framework::SelectedRows>().value().type()),
          ctx.device_context());
    } else {
      PADDLE_THROW("Param should be LoDTensor or SelectedRows");
    }
59
  }
Q
Qiao Longfei 已提交
60 61
};

D
dongzhihong 已提交
62
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
63
 public:
64
  SGDOpMaker(OpProto* proto, OpAttrChecker* op_checker)
65
      : OpProtoAndCheckerMaker(proto, op_checker) {
66
    AddInput("Param", "(Tensor or SelectedRows) Input parameter");
67
    AddInput("LearningRate", "(Tensor) Learning rate of SGD");
68 69 70 71
    AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
    AddOutput("ParamOut",
              "(Tensor or SelectedRows, same with Param) "
              "Output parameter, should share the same memory with Param");
Q
Qiao Longfei 已提交
72 73
    AddComment(R"DOC(

74
SGD operator
Q
Qiao Longfei 已提交
75

76 77
This operator implements one step of the stochastic gradient descent algorithm.

78
$$param\_out = param - learning\_rate * grad$$
Q
Qiao Longfei 已提交
79 80 81 82

)DOC");
  }
};
Q
qijun 已提交
83

Q
Qiao Longfei 已提交
84 85 86
}  // namespace operators
}  // namespace paddle

D
dongzhihong 已提交
87
namespace ops = paddle::operators;
F
fengjiayi 已提交
88
REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker);
C
chengduoZH 已提交
89
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<float>, ops::SGDOpKernel<double>);