sgd_op.h 2.8 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
Y
Yi Wang 已提交
16 17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
Q
Qiao Longfei 已提交
19 20 21 22

namespace paddle {
namespace operators {

C
chengduoZH 已提交
23
template <typename T>
Y
Yu Yang 已提交
24
class SGDOpKernel : public framework::OpKernel<T> {
25
 public:
D
dongzhihong 已提交
26
  void Compute(const framework::ExecutionContext& ctx) const override {
Q
qijun 已提交
27 28 29
    auto* param = ctx.Input<framework::Tensor>("Param");
    auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
    auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
Q
Qiao Longfei 已提交
30

Q
qijun 已提交
31
    auto* grad_var = ctx.InputVar("Grad");
Q
qijun 已提交
32 33
    // Actually, all tensors are LoDTensor except SelectedRows.
    if (grad_var->IsType<framework::LoDTensor>()) {
Q
qijun 已提交
34 35
      param_out->mutable_data<T>(ctx.GetPlace());
      auto* grad = ctx.Input<framework::Tensor>("Grad");
Q
Qiao Longfei 已提交
36

Q
qijun 已提交
37 38 39
      auto p = framework::EigenVector<T>::Flatten(*param);
      auto g = framework::EigenVector<T>::Flatten(*grad);
      auto o = framework::EigenVector<T>::Flatten(*param_out);
C
chengduoZH 已提交
40
      auto* lr = learning_rate->data<T>();
L
liaogang 已提交
41

C
chengduoZH 已提交
42
      o = p - lr[0] * g;
Q
qijun 已提交
43 44 45 46 47 48
    } else if (grad_var->IsType<framework::SelectedRows>()) {
      // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
      // This manual optimization brings difficulty to track data dependency.
      // It's better to find a more elegant solution.
      PADDLE_ENFORCE_EQ(param, param_out);
      auto* grad = ctx.Input<framework::SelectedRows>("Grad");
C
chengduoZH 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

      auto in_height = grad->height();
      auto out_dims = param_out->dims();
      PADDLE_ENFORCE_EQ(in_height, out_dims[0]);

      auto& in_value = grad->value();
      auto& in_rows = grad->rows();

      int64_t in_row_numel = in_value.numel() / in_rows.size();
      PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);

      auto* in_data = in_value.data<T>();
      auto* out_data = param_out->data<T>();
      auto* lr = learning_rate->data<T>();

      for (size_t i = 0; i < in_rows.size(); i++) {
        for (int64_t j = 0; j < in_row_numel; j++) {
          out_data[in_rows[i] * in_row_numel + j] -=
              lr[0] * in_data[i * in_row_numel + j];
        }
      }
Q
qijun 已提交
70 71 72
    } else {
      PADDLE_THROW("Unsupported Variable Type of Grad");
    }
Q
Qiao Longfei 已提交
73 74 75 76
  }
};
}  // namespace operators
}  // namespace paddle