lookup_table_op.cu 6.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15 16 17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lookup_table_op.h"
#include "paddle/fluid/platform/assert.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
20 21 22 23

namespace paddle {
namespace operators {

24 25
template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
          bool PaddingFlag>
26
__global__ void LookupTable(T *output, const T *table, const int64_t *ids,
27 28
                            const int64_t N, const int64_t K, const int64_t D,
                            const int64_t padding_idx) {
29
  int idx = threadIdx.x;
30
  int idy = blockIdx.x + threadIdx.y * GridDimX;
31 32

  while (idy < K) {
33
    int64_t id = ids[idy];
34 35
    PADDLE_ASSERT(id >= 0);
    PADDLE_ASSERT(id < N);
36 37
    T *out = output + idy * D;
    const T *tab = table + id * D;
38
    for (int i = idx; i < D; i += BlockDimX) {
39
      if (PaddingFlag) {
40
        if (id == padding_idx)
41 42 43 44 45 46
          out[i] = static_cast<T>(0);
        else
          out[i] = tab[i];
      } else {
        out[i] = tab[i];
      }
47
    }
48
    idy += BlockDimY * GridDimX;
49 50 51
  }
}

52
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
53
__global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
54 55
                                const int64_t N, const int64_t K,
                                const int64_t D) {
56
  int idx = threadIdx.x;
57
  int idy = blockIdx.x + threadIdx.y * GridDimX;
58 59 60 61 62

  while (idy < K) {
    int id = ids[idy];
    PADDLE_ASSERT(id >= 0);
    PADDLE_ASSERT(id < N);
63 64
    const T *out = output + idy * D;
    T *tab = table + id * D;
65
    for (int i = idx; i < D; i += BlockDimX) {
D
dangqingqing 已提交
66
      paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
67
    }
68
    idy += BlockDimY * GridDimX;
69 70 71 72
  }
}

template <typename T>
Y
Yu Yang 已提交
73
class LookupTableCUDAKernel : public framework::OpKernel<T> {
74
 public:
75 76 77 78
  void Compute(const framework::ExecutionContext &context) const override {
    auto *table_t = context.Input<LoDTensor>("W");
    auto *ids_t = context.Input<LoDTensor>("Ids");
    auto *output_t = context.Output<LoDTensor>("Out");
79
    int64_t padding_idx = context.Attr<int64_t>("padding_idx");
80 81 82

    size_t N = table_t->dims()[0];
    size_t D = table_t->dims()[1];
83 84 85 86 87
    size_t K = ids_t->numel();

    auto *ids = ids_t->data<int64_t>();
    auto *table = table_t->data<T>();
    auto *output = output_t->mutable_data<T>(context.GetPlace());
88 89 90

    dim3 threads(128, 8);
    dim3 grids(8, 1);
91 92 93 94 95 96 97 98 99 100 101

    if (padding_idx == -1)
      LookupTable<
          T, 128, 8, 8,
          false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
          output, table, ids, N, K, D, padding_idx);
    else
      LookupTable<
          T, 128, 8, 8,
          true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
          output, table, ids, N, K, D, padding_idx);
102 103 104 105
  }
};

template <typename T>
Y
Yu Yang 已提交
106
class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
107
 public:
108 109
  void Compute(const framework::ExecutionContext &context) const override {
    auto &dev_ctx =
Q
QI JUN 已提交
110
        context.template device_context<platform::CUDADeviceContext>();
111
    bool is_sparse = context.Attr<bool>("is_sparse");
112 113
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
114
    if (is_sparse) {
115 116 117 118
      auto *ids = context.Input<LoDTensor>("Ids");
      auto *table = context.Input<LoDTensor>("W");
      auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
119

120
      auto *ids_data = ids->data<int64_t>();
121
      int64_t ids_num = ids->numel();
122

Q
QI JUN 已提交
123
      auto stream = dev_ctx.stream();
124 125
      // copy GPU memory to CPU pinned memory
      framework::Vector<int64_t> new_rows;
126
      new_rows.resize(ids_num);
D
dzhwinter 已提交
127
      auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
128

Y
Yu Yang 已提交
129 130 131
      // TODO(yuyang18): Strange code here.
      memory::Copy(platform::CPUPlace(),
                   new_rows.CUDAMutableData(context.GetPlace()), gpu_place,
132
                   ids_data, ids_num * sizeof(int64_t), stream);
133 134 135

      d_table->set_rows(new_rows);

136
      auto *d_table_value = d_table->mutable_value();
137
      d_table_value->Resize({ids_num, table->dims()[1]});
138 139
      d_table_value->mutable_data<T>(context.GetPlace());

140 141
      auto *d_table_data = d_table_value->data<T>();
      auto *d_output_data = d_output->data<T>();
F
fengjiayi 已提交
142 143 144 145
      auto d_output_dims = d_output->dims();
      PADDLE_ENFORCE_EQ(
          d_table_value->dims(),
          framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
146
      memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
147
                   d_output->numel() * sizeof(T), stream);
148 149

    } else {
F
fengjiayi 已提交
150 151 152
      auto ids_t = context.Input<LoDTensor>("Ids");
      auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
153 154 155 156

      int N = d_table_t->dims()[0];
      int D = d_table_t->dims()[1];
      int K = ids_t->numel();
157 158 159
      const int64_t *ids = ids_t->data<int64_t>();
      const T *d_output = d_output_t->data<T>();
      T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
160 161

      auto t = framework::EigenVector<T>::Flatten(*d_table_t);
Q
QI JUN 已提交
162
      t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
163 164 165

      dim3 threads(128, 8);
      dim3 grids(8, 1);
Q
QI JUN 已提交
166
      LookupTableGrad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
T
typhoonzero 已提交
167
          d_table, d_output, ids, N, K, D);
168
    }
169 170 171 172 173 174 175
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Q
QI JUN 已提交
176 177 178 179 180
REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
                        ops::LookupTableCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(lookup_table_grad,
                        ops::LookupTableGradCUDAKernel<float>,
                        ops::LookupTableGradCUDAKernel<double>);