lookup_table_op.cu 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

12
#include "paddle/framework/eigen.h"
13
#include "paddle/framework/op_registry.h"
14
#include "paddle/operators/lookup_table_op.h"
15 16 17 18 19 20
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {

21
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
22 23
__global__ void LookupTable(T* output, const T* table, const int64_t* ids,
                            const int64_t N, const int64_t K, const int64_t D) {
24
  int idx = threadIdx.x;
25
  int idy = blockIdx.x + threadIdx.y * GridDimX;
26 27

  while (idy < K) {
28
    int64_t id = ids[idy];
29 30
    PADDLE_ASSERT(id >= 0);
    PADDLE_ASSERT(id < N);
D
dangqingqing 已提交
31 32
    T* out = output + idy * D;
    const T* tab = table + id * D;
33
    for (int i = idx; i < D; i += BlockDimX) {
34 35
      out[i] = tab[i];
    }
36
    idy += BlockDimY * GridDimX;
37 38 39
  }
}

40
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
41 42 43
__global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
                                const int64_t N, const int64_t K,
                                const int64_t D) {
44
  int idx = threadIdx.x;
45
  int idy = blockIdx.x + threadIdx.y * GridDimX;
46 47 48 49 50

  while (idy < K) {
    int id = ids[idy];
    PADDLE_ASSERT(id >= 0);
    PADDLE_ASSERT(id < N);
D
dangqingqing 已提交
51 52
    const T* out = output + idy * D;
    T* tab = table + id * D;
53
    for (int i = idx; i < D; i += BlockDimX) {
D
dangqingqing 已提交
54
      paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
55
    }
56
    idy += BlockDimY * GridDimX;
57 58 59 60
  }
}

template <typename T>
Y
Yu Yang 已提交
61
class LookupTableCUDAKernel : public framework::OpKernel<T> {
62 63
 public:
  void Compute(const framework::ExecutionContext& context) const override {
F
fengjiayi 已提交
64 65 66
    auto* table_t = context.Input<LoDTensor>("W");
    auto* ids_t = context.Input<LoDTensor>("Ids");
    auto* output_t = context.Output<LoDTensor>("Out");
67 68 69

    size_t N = table_t->dims()[0];
    size_t D = table_t->dims()[1];
70
    size_t K = ids_t->numel();
F
fengjiayi 已提交
71 72 73
    auto* ids = ids_t->data<int64_t>();
    auto* table = table_t->data<T>();
    auto* output = output_t->mutable_data<T>(context.GetPlace());
74 75 76

    dim3 threads(128, 8);
    dim3 grids(8, 1);
C
caoying03 已提交
77 78 79 80
    LookupTable<T, 128, 8, 8><<<
        grids, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
                               context.device_context())
                               .stream()>>>(output, table, ids, N, K, D);
81 82 83 84
  }
};

template <typename T>
Y
Yu Yang 已提交
85
class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
86 87
 public:
  void Compute(const framework::ExecutionContext& context) const override {
88 89
    bool is_sparse = context.Attr<bool>("is_sparse");
    if (is_sparse) {
F
fengjiayi 已提交
90 91 92
      auto* ids = context.Input<LoDTensor>("Ids");
      auto* table = context.Input<LoDTensor>("W");
      auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
      auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));

      auto* ids_data = ids->data<int64_t>();
      auto ids_dim = ids->dims();

      auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
                        context.device_context())
                        .stream();
      // copy GPU memory to CPU pinned memory
      framework::Vector<int64_t> new_rows;
      new_rows.resize(ids_dim[0]);
      auto gpu_place = boost::get<platform::GPUPlace>(context.GetPlace());

      memory::Copy(platform::CPUPlace(), new_rows.data(), gpu_place, ids_data,
                   ids_dim[0] * sizeof(int64_t), stream);

      d_table->set_rows(new_rows);

      auto* d_table_value = d_table->mutable_value();
      d_table_value->Resize({ids_dim[0], table->dims()[1]});
      d_table_value->mutable_data<T>(context.GetPlace());

      auto* d_table_data = d_table_value->data<T>();
      auto* d_output_data = d_output->data<T>();
      PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
      memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
119
                   d_output->numel() * sizeof(T), stream);
120 121

    } else {
F
fengjiayi 已提交
122 123 124
      auto ids_t = context.Input<LoDTensor>("Ids");
      auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141

      int N = d_table_t->dims()[0];
      int D = d_table_t->dims()[1];
      int K = ids_t->numel();
      const int64_t* ids = ids_t->data<int64_t>();
      const T* d_output = d_output_t->data<T>();
      T* d_table = d_table_t->mutable_data<T>(context.GetPlace());

      auto t = framework::EigenVector<T>::Flatten(*d_table_t);
      t.device(context.GetEigenDevice<platform::GPUPlace>()) =
          t.constant(static_cast<T>(0));

      dim3 threads(128, 8);
      dim3 grids(8, 1);
      LookupTableGrad<T, 128, 8,
                      8><<<grids, threads, 0,
                           reinterpret_cast<const platform::CUDADeviceContext&>(
C
caoying03 已提交
142 143
                               context.device_context())
                               .stream()>>>(d_table, d_output, ids, N, K, D);
144
    }
145 146 147 148 149 150 151
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
152 153 154 155
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
                       ops::LookupTableCUDAKernel<double>);
REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGradCUDAKernel<float>,
                       ops::LookupTableGradCUDAKernel<double>);