lookup_table_op.cu 7.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15 16 17
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lookup_table_op.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
19
#include "paddle/fluid/platform/float16.h"
20 21 22 23

namespace paddle {
namespace operators {

24 25
template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
          bool PaddingFlag>
26
__global__ void LookupTable(T *output, const T *table, const int64_t *ids,
27 28
                            const int64_t N, const int64_t K, const int64_t D,
                            const int64_t padding_idx) {
29
  int idx = threadIdx.x;
30
  int idy = blockIdx.x + threadIdx.y * GridDimX;
31 32

  while (idy < K) {
33
    int64_t id = ids[idy];
34
    PADDLE_ENFORCE(
35 36 37 38
        id >= 0,
        "Variable value (input) of OP(fluid.layers.embedding) "
        "expected >= 0 and < %ld, but got %ld. Please check input value.",
        N, id);
39
    PADDLE_ENFORCE(
40 41 42 43
        id < N,
        "Variable value (input) of OP(fluid.layers.embedding) "
        "expected >= 0 and < %ld, but got %ld. Please check input value.",
        N, id);
44 45
    T *out = output + idy * D;
    const T *tab = table + id * D;
46
    for (int i = idx; i < D; i += BlockDimX) {
47
      if (PaddingFlag) {
48
        if (id == padding_idx)
49 50 51 52 53 54
          out[i] = static_cast<T>(0);
        else
          out[i] = tab[i];
      } else {
        out[i] = tab[i];
      }
55
    }
56
    idy += BlockDimY * GridDimX;
57 58 59
  }
}

60
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
61
__global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
62 63
                                const int64_t N, const int64_t K,
                                const int64_t D) {
64
  int idx = threadIdx.x;
65
  int idy = blockIdx.x + threadIdx.y * GridDimX;
66 67

  while (idy < K) {
68
    int64_t id = ids[idy];
69
    PADDLE_ENFORCE(
70 71 72 73
        id >= 0,
        "Variable value (input) of OP(fluid.layers.embedding) "
        "expected >= 0 and < %ld, but got %ld. Please check input value.",
        N, id);
74
    PADDLE_ENFORCE(
75 76 77 78
        id < N,
        "Variable value (input) of OP(fluid.layers.embedding) "
        "expected >= 0 and < %ld, but got %ld. Please check input value.",
        N, id);
79 80
    const T *out = output + idy * D;
    T *tab = table + id * D;
81
    for (int i = idx; i < D; i += BlockDimX) {
D
dangqingqing 已提交
82
      paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
83
    }
84
    idy += BlockDimY * GridDimX;
85 86 87 88
  }
}

template <typename T>
Y
Yu Yang 已提交
89
class LookupTableCUDAKernel : public framework::OpKernel<T> {
90
 public:
91 92 93 94
  void Compute(const framework::ExecutionContext &context) const override {
    auto *table_t = context.Input<LoDTensor>("W");
    auto *ids_t = context.Input<LoDTensor>("Ids");
    auto *output_t = context.Output<LoDTensor>("Out");
95
    int64_t padding_idx = context.Attr<int64_t>("padding_idx");
96

97 98 99
    auto id_name = context.Inputs("Ids").front();
    auto out_name = context.Outputs("Out").front();

100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
    size_t N = table_t->dims()[0];
    size_t D = table_t->dims()[1];
    size_t K = ids_t->numel();

    auto *ids = ids_t->data<int64_t>();
    auto *table = table_t->data<T>();
    auto *output = output_t->mutable_data<T>(context.GetPlace());

    dim3 threads(128, 8);
    dim3 grids(8, 1);

    if (padding_idx == -1)
      LookupTable<
          T, 128, 8, 8,
          false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
          output, table, ids, N, K, D, padding_idx);
    else
      LookupTable<
          T, 128, 8, 8,
          true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
          output, table, ids, N, K, D, padding_idx);
121 122 123 124
  }
};

template <typename T>
Y
Yu Yang 已提交
125
class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
126
 public:
127 128
  void Compute(const framework::ExecutionContext &context) const override {
    auto &dev_ctx =
Q
QI JUN 已提交
129
        context.template device_context<platform::CUDADeviceContext>();
130
    bool is_sparse = context.Attr<bool>("is_sparse");
131

132 133
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
134
    if (is_sparse) {
135 136 137 138
      auto *ids = context.Input<LoDTensor>("Ids");
      auto *table = context.Input<LoDTensor>("W");
      auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
139

140
      auto *ids_data = ids->data<int64_t>();
141
      int64_t ids_num = ids->numel();
142

Q
QI JUN 已提交
143
      auto stream = dev_ctx.stream();
144 145
      // copy GPU memory to CPU pinned memory
      framework::Vector<int64_t> new_rows;
146
      new_rows.resize(ids_num);
D
dzhwinter 已提交
147
      auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
148

Y
Yu Yang 已提交
149
      // TODO(yuyang18): Strange code here.
Y
Yu Yang 已提交
150 151
      memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()),
                   gpu_place, ids_data, ids_num * sizeof(int64_t), stream);
152 153
      d_table->set_rows(new_rows);

154
      auto *d_table_value = d_table->mutable_value();
155
      d_table_value->Resize({ids_num, table->dims()[1]});
156 157
      d_table_value->mutable_data<T>(context.GetPlace());

158 159
      auto *d_table_data = d_table_value->data<T>();
      auto *d_output_data = d_output->data<T>();
F
fengjiayi 已提交
160 161 162 163
      auto d_output_dims = d_output->dims();
      PADDLE_ENFORCE_EQ(
          d_table_value->dims(),
          framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
164
      memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
165
                   d_output->numel() * sizeof(T), stream);
166 167

    } else {
F
fengjiayi 已提交
168 169 170
      auto ids_t = context.Input<LoDTensor>("Ids");
      auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
171 172 173 174

      int N = d_table_t->dims()[0];
      int D = d_table_t->dims()[1];
      int K = ids_t->numel();
175 176 177
      const int64_t *ids = ids_t->data<int64_t>();
      const T *d_output = d_output_t->data<T>();
      T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
178 179

      auto t = framework::EigenVector<T>::Flatten(*d_table_t);
Q
QI JUN 已提交
180
      t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
181 182 183

      dim3 threads(128, 8);
      dim3 grids(8, 1);
Q
QI JUN 已提交
184
      LookupTableGrad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
T
typhoonzero 已提交
185
          d_table, d_output, ids, N, K, D);
186
    }
187 188 189 190 191 192 193
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
194
namespace plat = paddle::platform;
Q
QI JUN 已提交
195
REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
196 197
                        ops::LookupTableCUDAKernel<double>,
                        ops::LookupTableCUDAKernel<plat::float16>);
Q
QI JUN 已提交
198 199
REGISTER_OP_CUDA_KERNEL(lookup_table_grad,
                        ops::LookupTableGradCUDAKernel<float>,
200 201
                        ops::LookupTableGradCUDAKernel<double>,
                        ops::LookupTableGradCUDAKernel<plat::float16>);