sgd_op.cc 4.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/operators/optimizers/sgd_op.h"
16 17

#include <string>
18 19 20
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
21

H
hong 已提交
22 23 24 25
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"

Q
Qiao Longfei 已提交
26 27 28
namespace paddle {
namespace operators {

D
dongzhihong 已提交
29
class SGDOp : public framework::OperatorWithKernel {
Y
Yu Yang 已提交
30 31 32
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

33 34
 protected:
  framework::OpKernelType GetExpectedKernelType(
C
chengduo 已提交
35
      const framework::ExecutionContext &ctx) const override {
36
    auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
37

38 39 40
    // NOTE(jiahongyu): Below codes originally enclosed by PADDLE_WITH_MKLDNN
    const auto *param_var = ctx.InputVar("Param");
    const auto *grad_var = ctx.InputVar("Grad");
41

42 43 44 45 46 47 48
    // supported cases
    bool dense_param_sparse_grad = param_var->IsType<phi::DenseTensor>() &&
                                   grad_var->IsType<phi::SelectedRows>();
    bool dense_param_and_grad = param_var->IsType<phi::DenseTensor>() &&
                                grad_var->IsType<phi::DenseTensor>();
    if (!(dense_param_sparse_grad || dense_param_and_grad)) {
      this->SetDnnFallback(true);
49
    }
50 51
    // NOTE(jiahongyu): Above codes originally enclosed by PADDLE_WITH_MKLDNN

Q
qiaolongfei 已提交
52
    return framework::OpKernelType(data_type, ctx.device_context());
53
  }
54 55

  framework::OpKernelType GetKernelTypeForVar(
56
      const std::string &var_name,
57
      const phi::DenseTensor &tensor,
58
      const framework::OpKernelType &expected_kernel_type) const override {
59
    if (var_name == "LearningRate") {
60
      return framework::OpKernelType(
61 62
          framework::TransToProtoVarType(tensor.dtype()),
          tensor.place(),
63
          tensor.layout());
64
    }
65 66
    return framework::OpKernelType(
        expected_kernel_type.data_type_, tensor.place(), tensor.layout());
67
  }
Q
Qiao Longfei 已提交
68 69
};

Y
Yancey1989 已提交
70 71
class SGDOpInferVarType : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
72
  void operator()(framework::InferVarTypeContext *ctx) const override {
73 74 75
    auto in_var_type = ctx->GetInputType("Param");
    PADDLE_ENFORCE_EQ(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
                          in_var_type == framework::proto::VarType::LOD_TENSOR,
76 77 78 79 80
                      true,
                      platform::errors::InvalidArgument(
                          "The input Var's type should be LoDtensor or "
                          "SelectedRows, but the received type is %s",
                          in_var_type));
C
chengduo 已提交
81

82
    ctx->SetOutputType("ParamOut", in_var_type, framework::ALL_ELEMENTS);
Y
Yancey1989 已提交
83 84 85
  }
};

D
dongzhihong 已提交
86
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
87
 public:
Y
Yu Yang 已提交
88
  void Make() override {
89
    AddInput("Param", "(Tensor or SelectedRows) Input parameter");
90
    AddInput("LearningRate", "(Tensor) Learning rate of SGD");
91
    AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
92
    AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
93 94
    AddOutput("ParamOut",
              "(Tensor or SelectedRows, same with Param) "
95
              "Output parameter, should share the same memory with Param");
96 97 98 99 100
    AddOutput("MasterParamOut",
              "The updated FP32 master weight for AMP. "
              "It shared memory with Input(MasterParam).")
        .AsDispensable();

101 102 103 104
    AddAttr<bool>(
        "use_mkldnn",
        "(bool, default false) Indicates if MKL-DNN kernel will be used")
        .SetDefault(false);
105 106 107 108 109
    AddAttr<bool>("multi_precision",
                  "(bool, default false) "
                  "Whether to use multi-precision during weight updating.")
        .SetDefault(false);

Q
Qiao Longfei 已提交
110 111
    AddComment(R"DOC(

112
SGD operator
Q
Qiao Longfei 已提交
113

114 115
This operator implements one step of the stochastic gradient descent algorithm.

116
$$param\_out = param - learning\_rate * grad$$
Q
Qiao Longfei 已提交
117 118 119 120

)DOC");
  }
};
Q
qijun 已提交
121

Q
Qiao Longfei 已提交
122 123 124
}  // namespace operators
}  // namespace paddle

D
dongzhihong 已提交
125
namespace ops = paddle::operators;
126 127
DECLARE_INFER_SHAPE_FUNCTOR(sgd,
                            SGDInferShapeFunctor,
Z
zyfncg 已提交
128
                            PD_INFER_META(phi::SgdInferMeta));
H
hong 已提交
129
REGISTER_OPERATOR(
130 131 132
    sgd,
    ops::SGDOp,
    ops::SGDOpMaker,
H
hong 已提交
133 134
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
135 136
    ops::SGDOpInferVarType,
    SGDInferShapeFunctor);