sgd_op.cu 2.8 KB
Newer Older
L
liaogang 已提交
1 2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
L
liaogang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
L
liaogang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
L
liaogang 已提交
14

Q
qijun 已提交
15
#define EIGEN_USE_GPU
Q
Qiao Longfei 已提交
16
#include "paddle/operators/sgd_op.h"
Q
qijun 已提交
17 18 19 20 21 22
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {

namespace {
Q
QI JUN 已提交
23
template <typename T, int block_size>
Q
qijun 已提交
24 25 26
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
                                       const int64_t* rows,
                                       const T* learning_rate, T* tensor_out,
Q
QI JUN 已提交
27
                                       int64_t row_numel) {
Q
qijun 已提交
28 29 30 31 32 33 34 35 36
  const int ty = blockIdx.y;
  int tid = threadIdx.x;

  selected_rows += ty * row_numel;
  tensor_out += rows[ty] * row_numel;

  for (int index = tid; index < row_numel; index += block_size) {
    // Since index in rows of SelectedRows can be duplicate, we have to use
    // Atomic Operation to avoid concurrent write error.
Q
qijun 已提交
37 38
    paddle::platform::CudaAtomicAdd(
        tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]);
Q
qijun 已提交
39 40 41 42 43
  }
}
}  // namespace

template <typename T>
Q
QI JUN 已提交
44 45
struct SparseSGDFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& context,
Q
qijun 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
                  const framework::SelectedRows& input,
                  const framework::Tensor& learning_rate,
                  framework::Tensor* output) {
    auto in_height = input.height();
    auto out_dims = output->dims();
    PADDLE_ENFORCE_EQ(in_height, out_dims[0]);

    auto& in_value = input.value();
    auto& in_rows = input.rows();

    int64_t in_row_numel = in_value.numel() / in_rows.size();
    PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);

    auto* in_data = in_value.data<T>();
    auto* out_data = output->data<T>();

Q
QI JUN 已提交
62
    const int block_size = 256;
Q
qijun 已提交
63 64
    dim3 threads(block_size, 1);
    dim3 grid(1, in_rows.size());
Q
QI JUN 已提交
65 66 67
    SparseSGDFunctorKernel<T, 256><<<grid, threads, 0, context.stream()>>>(
        in_data, in_rows.data(), learning_rate.data<T>(), out_data,
        in_row_numel);
Q
qijun 已提交
68 69 70
  }
};

Q
QI JUN 已提交
71 72
template struct SparseSGDFunctor<platform::CUDADeviceContext, float>;
template struct SparseSGDFunctor<platform::CUDADeviceContext, double>;
Q
qijun 已提交
73 74 75

}  // namespace operators
}  // namespace paddle
Q
Qiao Longfei 已提交
76

D
dongzhihong 已提交
77
namespace ops = paddle::operators;
Q
QI JUN 已提交
78 79 80
REGISTER_OP_CUDA_KERNEL(
    sgd, ops::SGDOpKernel<paddle::platform::CUDADeviceContext, float>,
    ops::SGDOpKernel<paddle::platform::CUDADeviceContext, double>);