sgd_op.cu 5.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
L
liaogang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
L
liaogang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
L
liaogang 已提交
14

C
chengduo 已提交
15
#include <algorithm>
W
Wu Yi 已提交
16
#include "paddle/fluid/operators/optimizers/sgd_op.h"
D
dzhwinter 已提交
17
#include "paddle/fluid/platform/cuda_primitives.h"
Q
qijun 已提交
18 19 20 21 22

namespace paddle {
namespace operators {

namespace {
C
chengduoZH 已提交
23 24 25 26 27

template <typename T>
__global__ void SGDKernel(const T* g, const T* p, const T* learning_rate,
                          const int num, T* p_out) {
  T lr = learning_rate[0];
28
  CUDA_KERNEL_LOOP(i, num) {
C
chengduoZH 已提交
29 30 31 32 33 34
    T g_data = g[i];
    T p_data = p[i];
    p_out[i] = p_data - lr * g_data;
  }
}

C
chengduo 已提交
35
template <typename T>
Q
qijun 已提交
36 37 38
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
                                       const int64_t* rows,
                                       const T* learning_rate, T* tensor_out,
C
chengduo 已提交
39 40 41 42 43 44 45 46 47
                                       int64_t row_numel, int64_t limit) {
  for (int64_t i = blockIdx.x; i < limit; i += gridDim.x) {
    const T* selected_rows_ptr = selected_rows + i * row_numel;
    T* tensor_out_ptr = tensor_out + rows[i] * row_numel;
    for (int64_t index = threadIdx.x; index < row_numel; index += blockDim.x) {
      // Since index in rows of SelectedRows can be duplicate, we have to use
      // Atomic Operation to avoid concurrent write error.
      paddle::platform::CudaAtomicAdd(
          tensor_out_ptr + index,
48
          -static_cast<T>(1.0) * learning_rate[0] * selected_rows_ptr[index]);
C
chengduo 已提交
49
    }
Q
qijun 已提交
50 51 52 53 54
  }
}
}  // namespace

template <typename T>
55 56
class SGDOpKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
C
chengduoZH 已提交
57 58
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
C
chengduo 已提交
59 60 61 62
    const auto* param_var = ctx.InputVar("Param");
    PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
                   "The Var(%s)'s type should be LoDTensor, "
                   "but the received is %s",
H
hong 已提交
63
                   ctx.InputNames("Param").front(),
S
sneaxiy 已提交
64
                   framework::ToTypeName(param_var->Type()));
C
chengduo 已提交
65

C
chengduoZH 已提交
66 67 68 69 70 71 72 73 74
    auto* param = ctx.Input<framework::Tensor>("Param");
    auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
    auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");

    auto* grad_var = ctx.InputVar("Grad");
    // Actually, all tensors are LoDTensor except SelectedRows.
    if (grad_var->IsType<framework::LoDTensor>()) {
      param_out->mutable_data<T>(ctx.GetPlace());
      auto* grad = ctx.Input<framework::Tensor>("Grad");
H
hong 已提交
75 76
      // LOG(ERROR) << "grad";
      // LOG(ERROR) << ctx.op().Input("Grad");
C
chengduoZH 已提交
77
      auto* grad_data = grad->data<T>();
H
hong 已提交
78
      // LOG(ERROR) << "param";
C
chengduoZH 已提交
79
      auto* param_data = param->data<T>();
H
hong 已提交
80
      // LOG(ERROR) << "fin";
C
chengduoZH 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
      auto* param_out_data = param_out->data<T>();

      int block = 512;
      int grid = (param->numel() + block - 1) / block;

      SGDKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
          grad_data, param_data, learning_rate->data<T>(), param->numel(),
          param_out_data);

    } else if (grad_var->IsType<framework::SelectedRows>()) {
      // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
      // This manual optimization brings difficulty to track data dependency.
      // It's better to find a more elegant solution.
      PADDLE_ENFORCE_EQ(param, param_out);
      auto* grad = ctx.Input<framework::SelectedRows>("Grad");

      auto in_height = grad->height();
      auto out_dims = param_out->dims();
      PADDLE_ENFORCE_EQ(in_height, out_dims[0]);

      auto& in_value = grad->value();
Y
Yu Yang 已提交
102
      auto& in_rows = grad->rows();
C
chengduoZH 已提交
103 104 105 106 107 108 109

      int64_t in_row_numel = in_value.numel() / in_rows.size();
      PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);

      auto* in_data = in_value.data<T>();
      auto* out_data = param_out->data<T>();

C
chengduo 已提交
110 111 112 113 114 115 116
      const int kThreadsPerBlock = 256;
      int thread_x = kThreadsPerBlock;
      int max_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount();
      int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);

      SparseSGDFunctorKernel<<<max_blocks, thread_x, 0,
                               ctx.cuda_device_context().stream()>>>(
Y
Yu Yang 已提交
117
          in_data, in_rows.CUDAData(ctx.GetPlace()), learning_rate->data<T>(),
C
chengduo 已提交
118
          out_data, in_row_numel, in_rows.size());
C
chengduoZH 已提交
119 120 121 122

    } else {
      PADDLE_THROW("Unsupported Variable Type of Grad");
    }
Q
qijun 已提交
123 124 125 126
  }
};
}  // namespace operators
}  // namespace paddle
Q
Qiao Longfei 已提交
127

D
dongzhihong 已提交
128
namespace ops = paddle::operators;
129
namespace plat = paddle::platform;
130 131 132 133
REGISTER_OP_CUDA_KERNEL(
    sgd, ops::SGDOpKernel<paddle::platform::CUDADeviceContext, float>,
    ops::SGDOpKernel<paddle::platform::CUDADeviceContext, double>,
    ops::SGDOpKernel<paddle::platform::CUDADeviceContext, plat::float16>);