lookup_table_op.cu 6.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15 16 17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lookup_table_op.h"
#include "paddle/fluid/platform/assert.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
20
#include "paddle/fluid/platform/float16.h"
21 22 23 24

namespace paddle {
namespace operators {

25 26
template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
          bool PaddingFlag>
27
__global__ void LookupTable(T *output, const T *table, const int64_t *ids,
28 29
                            const int64_t N, const int64_t K, const int64_t D,
                            const int64_t padding_idx) {
30
  int idx = threadIdx.x;
31
  int idy = blockIdx.x + threadIdx.y * GridDimX;
32 33

  while (idy < K) {
34
    int64_t id = ids[idy];
C
chengduo 已提交
35 36
    PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
    PADDLE_ASSERT_MSG(id < N, "received id:", id);
37 38
    T *out = output + idy * D;
    const T *tab = table + id * D;
39
    for (int i = idx; i < D; i += BlockDimX) {
40
      if (PaddingFlag) {
41
        if (id == padding_idx)
42 43 44 45 46 47
          out[i] = static_cast<T>(0);
        else
          out[i] = tab[i];
      } else {
        out[i] = tab[i];
      }
48
    }
49
    idy += BlockDimY * GridDimX;
50 51 52
  }
}

53
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
54
__global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
55 56
                                const int64_t N, const int64_t K,
                                const int64_t D) {
57
  int idx = threadIdx.x;
58
  int idy = blockIdx.x + threadIdx.y * GridDimX;
59 60

  while (idy < K) {
61
    int64_t id = ids[idy];
C
chengduo 已提交
62 63
    PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
    PADDLE_ASSERT_MSG(id < N, "received id:", id);
64 65
    const T *out = output + idy * D;
    T *tab = table + id * D;
66
    for (int i = idx; i < D; i += BlockDimX) {
D
dangqingqing 已提交
67
      paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
68
    }
69
    idy += BlockDimY * GridDimX;
70 71 72 73
  }
}

template <typename T>
Y
Yu Yang 已提交
74
class LookupTableCUDAKernel : public framework::OpKernel<T> {
75
 public:
76 77 78 79
  void Compute(const framework::ExecutionContext &context) const override {
    auto *table_t = context.Input<LoDTensor>("W");
    auto *ids_t = context.Input<LoDTensor>("Ids");
    auto *output_t = context.Output<LoDTensor>("Out");
80
    int64_t padding_idx = context.Attr<int64_t>("padding_idx");
81

82 83 84
    auto id_name = context.Inputs("Ids").front();
    auto out_name = context.Outputs("Out").front();

85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
    size_t N = table_t->dims()[0];
    size_t D = table_t->dims()[1];
    size_t K = ids_t->numel();

    auto *ids = ids_t->data<int64_t>();
    auto *table = table_t->data<T>();
    auto *output = output_t->mutable_data<T>(context.GetPlace());

    dim3 threads(128, 8);
    dim3 grids(8, 1);

    if (padding_idx == -1)
      LookupTable<
          T, 128, 8, 8,
          false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
          output, table, ids, N, K, D, padding_idx);
    else
      LookupTable<
          T, 128, 8, 8,
          true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
          output, table, ids, N, K, D, padding_idx);
106 107 108 109
  }
};

template <typename T>
Y
Yu Yang 已提交
110
class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
111
 public:
112 113
  void Compute(const framework::ExecutionContext &context) const override {
    auto &dev_ctx =
Q
QI JUN 已提交
114
        context.template device_context<platform::CUDADeviceContext>();
115
    bool is_sparse = context.Attr<bool>("is_sparse");
116

117 118
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
119
    if (is_sparse) {
120 121 122 123
      auto *ids = context.Input<LoDTensor>("Ids");
      auto *table = context.Input<LoDTensor>("W");
      auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
124

125
      auto *ids_data = ids->data<int64_t>();
126
      int64_t ids_num = ids->numel();
127

Q
QI JUN 已提交
128
      auto stream = dev_ctx.stream();
129 130
      // copy GPU memory to CPU pinned memory
      framework::Vector<int64_t> new_rows;
131
      new_rows.resize(ids_num);
D
dzhwinter 已提交
132
      auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
133

Y
Yu Yang 已提交
134
      // TODO(yuyang18): Strange code here.
Y
Yu Yang 已提交
135 136
      memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()),
                   gpu_place, ids_data, ids_num * sizeof(int64_t), stream);
137 138
      d_table->set_rows(new_rows);

139
      auto *d_table_value = d_table->mutable_value();
140
      d_table_value->Resize({ids_num, table->dims()[1]});
141 142
      d_table_value->mutable_data<T>(context.GetPlace());

143 144
      auto *d_table_data = d_table_value->data<T>();
      auto *d_output_data = d_output->data<T>();
F
fengjiayi 已提交
145 146 147 148
      auto d_output_dims = d_output->dims();
      PADDLE_ENFORCE_EQ(
          d_table_value->dims(),
          framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
149
      memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
150
                   d_output->numel() * sizeof(T), stream);
151 152

    } else {
F
fengjiayi 已提交
153 154 155
      auto ids_t = context.Input<LoDTensor>("Ids");
      auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
156 157 158 159

      int N = d_table_t->dims()[0];
      int D = d_table_t->dims()[1];
      int K = ids_t->numel();
160 161 162
      const int64_t *ids = ids_t->data<int64_t>();
      const T *d_output = d_output_t->data<T>();
      T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
163 164

      auto t = framework::EigenVector<T>::Flatten(*d_table_t);
Q
QI JUN 已提交
165
      t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
166 167 168

      dim3 threads(128, 8);
      dim3 grids(8, 1);
Q
QI JUN 已提交
169
      LookupTableGrad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
T
typhoonzero 已提交
170
          d_table, d_output, ids, N, K, D);
171
    }
172 173 174 175 176 177 178
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
179
namespace plat = paddle::platform;
Q
QI JUN 已提交
180
REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
181 182
                        ops::LookupTableCUDAKernel<double>,
                        ops::LookupTableCUDAKernel<plat::float16>);
Q
QI JUN 已提交
183 184
REGISTER_OP_CUDA_KERNEL(lookup_table_grad,
                        ops::LookupTableGradCUDAKernel<float>,
185 186
                        ops::LookupTableGradCUDAKernel<double>,
                        ops::LookupTableGradCUDAKernel<plat::float16>);