sgd_op.h 7.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
19
#include "paddle/fluid/operators/jit/kernels.h"
Q
Qiao Longfei 已提交
20 21 22 23

namespace paddle {
namespace operators {

24
template <typename DeviceContext, typename T>
Y
Yu Yang 已提交
25
class SGDOpKernel : public framework::OpKernel<T> {
26 27 28 29 30 31 32
 public:
  void Compute(const framework::ExecutionContext &ctx) const override;
};

template <typename T>
class SGDOpKernel<platform::CPUDeviceContext, T>
    : public framework::OpKernel<T> {
33
 public:
34 35 36 37 38 39 40 41 42 43 44 45
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");

    const auto *param_var = ctx.InputVar("Param");
    const auto *grad_var = ctx.InputVar("Grad");

    if (param_var->IsType<framework::LoDTensor>()) {
      const auto *param = ctx.Input<framework::Tensor>("Param");
      auto *param_out = ctx.Output<framework::Tensor>("ParamOut");
      // Actually, all tensors are LoDTensor except SelectedRows.
      if (grad_var->IsType<framework::LoDTensor>()) {
        const auto *grad = ctx.Input<framework::Tensor>("Grad");
46
        auto sz = param_out->numel();
C
Chengmo 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60
        PADDLE_ENFORCE_EQ(param->numel(), sz,
                          platform::errors::InvalidArgument(
                              "The input tensor Param's numel of SgdOp "
                              "should be equal with ParamOut's numel. "
                              "But received Param's "
                              "numel = [%s], ParamOut's numel = [%s]",
                              param->numel(), sz));
        PADDLE_ENFORCE_EQ(grad->numel(), sz,
                          platform::errors::InvalidArgument(
                              "The input tensor Grad's numel of SgdOp "
                              "should be equal with ParamOut's numel. "
                              "But received Grad's "
                              "numel = [%s], ParamOut's numel = [%s]",
                              grad->numel(), sz));
61 62 63 64 65 66 67 68

        jit::sgd_attr_t attr(1, sz, 1, sz, 1);
        const T *lr = learning_rate->data<T>();
        const T *param_data = param->data<T>();
        const T *grad_data = grad->data<T>();
        int64_t rows_idx = 0;
        T *out_data = param_out->mutable_data<T>(ctx.GetPlace());

69 70 71
        auto sgd =
            jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
                attr);
72
        sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
73 74 75 76
      } else if (grad_var->IsType<framework::SelectedRows>()) {
        // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
        // This manual optimization brings difficulty to track data dependency.
        // It's better to find a more elegant solution.
C
Chengmo 已提交
77 78 79 80 81
        PADDLE_ENFORCE_EQ(param, param_out,
                          platform::errors::InvalidArgument(
                              "The input tensor Param of SgdOp "
                              "should be equal with ParamOut if variable's "
                              "type is SelectedRows. "));
82
        const auto *grad = ctx.Input<framework::SelectedRows>("Grad");
83
        auto &grad_rows = grad->rows();
84 85 86

        // for distributed training, a sparse var may be empty,
        // just skip updating.
87
        if (grad_rows.size() == 0) {
88 89 90 91
          return;
        }

        auto out_dims = param_out->dims();
C
Chengmo 已提交
92 93 94 95 96 97 98
        PADDLE_ENFORCE_EQ(
            grad->height(), out_dims[0],
            platform::errors::InvalidArgument(
                "The input tensor Grad's height of SgdOp "
                "should be equal with ParamOut's dims. But received  Grad's "
                "height [%s] and ParamOut's dims [%s]",
                grad->height(), out_dims[0]));
99
        auto &grad_value = grad->value();
100 101 102 103 104 105 106 107 108 109 110 111
        const T *param_data = param->data<T>();
        const T *grad_data = grad_value.data<T>();
        const T *lr = learning_rate->data<T>();
        const int64_t *rows_data = grad_rows.data();
        T *out_data = param_out->mutable_data<T>(ctx.GetPlace());

        jit::sgd_attr_t attr;
        attr.param_height = out_dims[0];
        attr.param_width = param_out->numel() / attr.param_height;
        attr.grad_height = grad_rows.size();  // note: it is not grad->height()
        attr.grad_width = grad_value.numel() / attr.grad_height;
        attr.selected_rows_size = grad_rows.size();
C
Chengmo 已提交
112 113 114 115 116 117 118
        PADDLE_ENFORCE_EQ(
            attr.grad_width, attr.param_width,
            platform::errors::InvalidArgument(
                "The grad_value's numel of SgdOp "
                "should be equal with param_out's numel. But received "
                "grad_value's numel [%s] and param_out's numel [%s]",
                attr.grad_width, attr.param_width));
119

120 121 122
        auto sgd =
            jit::KernelFuncs<jit::SgdTuple<T>, platform::CPUPlace>::Cache().At(
                attr);
123
        sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
124
      } else {
C
Chengmo 已提交
125 126 127 128 129 130
        PADDLE_ENFORCE_EQ(
            false, true,
            platform::errors::PermissionDenied(
                "Unsupported Variable Type of Grad in SgdOp. Excepted "
                "LodTensor or SelectedRows, But received [%s]",
                paddle::framework::ToTypeName(grad_var->Type())));
131 132
      }
    } else if (param_var->IsType<framework::SelectedRows>()) {
C
Chengmo 已提交
133 134 135 136
      PADDLE_ENFORCE_EQ(grad_var->IsType<framework::SelectedRows>(), true,
                        platform::errors::InvalidArgument(
                            "when param is SelectedRows, "
                            "gradient should also be SelectedRows"));
137 138 139
      const auto &param = param_var->Get<framework::SelectedRows>();
      auto *param_out = ctx.Output<framework::SelectedRows>("ParamOut");
      const auto &grad = grad_var->Get<framework::SelectedRows>();
C
chengduoZH 已提交
140

141 142
      // for distributed training, a sparse var may be empty,
      // just skip updating.
143
      if (grad.rows().size() == 0) {
144 145 146
        return;
      }

Q
qiaolongfei 已提交
147 148
      auto param_row_width = param.value().dims()[1];
      auto grad_row_width = grad.value().dims()[1];
C
Chengmo 已提交
149 150 151 152 153 154 155
      PADDLE_ENFORCE_EQ(
          param_row_width, grad_row_width,
          platform::errors::InvalidArgument(
              "The param_row in SgdOP should have the same size with grad_row. "
              "But received param_row's width is [%s], and grad_row's width is "
              "[%s]",
              param_row_width, grad_row_width));
C
chengduoZH 已提交
156

157 158 159 160
      const auto *lr = learning_rate->data<T>();
      const auto *grad_data = grad.value().data<T>();
      auto *out_data = param_out->mutable_value()->data<T>();
      for (size_t i = 0; i < grad.rows().size(); i++) {
161
        int64_t id_index = param_out->AutoGrownIndex(grad.rows()[i], false);
C
Chengmo 已提交
162 163 164 165 166
        PADDLE_ENFORCE_GE(
            id_index, static_cast<int64_t>(0),
            platform::errors::InvalidArgument(
                "The id in SgdOp should be >= 0. But recevied id_index is [%s]",
                id_index));
167
        for (int64_t j = 0; j < grad_row_width; j++) {
168 169
          out_data[id_index * grad_row_width + j] -=
              lr[0] * grad_data[i * grad_row_width + j];
C
chengduoZH 已提交
170 171
        }
      }
Q
qijun 已提交
172
    } else {
C
Chengmo 已提交
173 174 175 176 177 178
      PADDLE_ENFORCE_EQ(
          false, true,
          platform::errors::PermissionDenied(
              "Unsupported Variable Type of Parameter in SgdOp. Excepted "
              "LodTensor or SelectedRows, But received [%s]",
              paddle::framework::ToTypeName(param_var->Type())));
Q
qijun 已提交
179
    }
Q
Qiao Longfei 已提交
180 181 182 183
  }
};
}  // namespace operators
}  // namespace paddle