lookup_table_op.cu 7.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15 16 17 18
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lookup_table_op.h"
#include "paddle/fluid/platform/assert.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_primitives.h"
20
#include "paddle/fluid/platform/float16.h"
21 22 23 24

namespace paddle {
namespace operators {

25 26
template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
          bool PaddingFlag>
27
__global__ void LookupTable(T *output, const T *table, const int64_t *ids,
28 29
                            const int64_t N, const int64_t K, const int64_t D,
                            const int64_t padding_idx) {
30
  int idx = threadIdx.x;
31
  int idy = blockIdx.x + threadIdx.y * GridDimX;
32 33

  while (idy < K) {
34
    int64_t id = ids[idy];
35 36 37 38 39 40 41 42 43 44
    PADDLE_ASSERT_MSG(
        id >= 0,
        "Variable value (input) of OP(fluid.layers.embedding) "
        "expected >= 0 and < %ld, but got %ld. Please check input value.",
        N, id);
    PADDLE_ASSERT_MSG(
        id < N,
        "Variable value (input) of OP(fluid.layers.embedding) "
        "expected >= 0 and < %ld, but got %ld. Please check input value.",
        N, id);
45 46
    T *out = output + idy * D;
    const T *tab = table + id * D;
47
    for (int i = idx; i < D; i += BlockDimX) {
48
      if (PaddingFlag) {
49
        if (id == padding_idx)
50 51 52 53 54 55
          out[i] = static_cast<T>(0);
        else
          out[i] = tab[i];
      } else {
        out[i] = tab[i];
      }
56
    }
57
    idy += BlockDimY * GridDimX;
58 59 60
  }
}

61
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
62
__global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
63 64
                                const int64_t N, const int64_t K,
                                const int64_t D) {
65
  int idx = threadIdx.x;
66
  int idy = blockIdx.x + threadIdx.y * GridDimX;
67 68

  while (idy < K) {
69
    int64_t id = ids[idy];
70 71 72 73 74 75 76 77 78 79
    PADDLE_ASSERT_MSG(
        id >= 0,
        "Variable value (input) of OP(fluid.layers.embedding) "
        "expected >= 0 and < %ld, but got %ld. Please check input value.",
        N, id);
    PADDLE_ASSERT_MSG(
        id < N,
        "Variable value (input) of OP(fluid.layers.embedding) "
        "expected >= 0 and < %ld, but got %ld. Please check input value.",
        N, id);
80 81
    const T *out = output + idy * D;
    T *tab = table + id * D;
82
    for (int i = idx; i < D; i += BlockDimX) {
D
dangqingqing 已提交
83
      paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
84
    }
85
    idy += BlockDimY * GridDimX;
86 87 88 89
  }
}

template <typename T>
Y
Yu Yang 已提交
90
class LookupTableCUDAKernel : public framework::OpKernel<T> {
91
 public:
92 93 94 95
  void Compute(const framework::ExecutionContext &context) const override {
    auto *table_t = context.Input<LoDTensor>("W");
    auto *ids_t = context.Input<LoDTensor>("Ids");
    auto *output_t = context.Output<LoDTensor>("Out");
96
    int64_t padding_idx = context.Attr<int64_t>("padding_idx");
97

98 99 100
    auto id_name = context.Inputs("Ids").front();
    auto out_name = context.Outputs("Out").front();

101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
    size_t N = table_t->dims()[0];
    size_t D = table_t->dims()[1];
    size_t K = ids_t->numel();

    auto *ids = ids_t->data<int64_t>();
    auto *table = table_t->data<T>();
    auto *output = output_t->mutable_data<T>(context.GetPlace());

    dim3 threads(128, 8);
    dim3 grids(8, 1);

    if (padding_idx == -1)
      LookupTable<
          T, 128, 8, 8,
          false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
          output, table, ids, N, K, D, padding_idx);
    else
      LookupTable<
          T, 128, 8, 8,
          true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
          output, table, ids, N, K, D, padding_idx);
122 123 124 125
  }
};

template <typename T>
Y
Yu Yang 已提交
126
class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
127
 public:
128 129
  void Compute(const framework::ExecutionContext &context) const override {
    auto &dev_ctx =
Q
QI JUN 已提交
130
        context.template device_context<platform::CUDADeviceContext>();
131
    bool is_sparse = context.Attr<bool>("is_sparse");
132

133 134
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
135
    if (is_sparse) {
136 137 138 139
      auto *ids = context.Input<LoDTensor>("Ids");
      auto *table = context.Input<LoDTensor>("W");
      auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
140

141
      auto *ids_data = ids->data<int64_t>();
142
      int64_t ids_num = ids->numel();
143

Q
QI JUN 已提交
144
      auto stream = dev_ctx.stream();
145 146
      // copy GPU memory to CPU pinned memory
      framework::Vector<int64_t> new_rows;
147
      new_rows.resize(ids_num);
D
dzhwinter 已提交
148
      auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
149

Y
Yu Yang 已提交
150
      // TODO(yuyang18): Strange code here.
Y
Yu Yang 已提交
151 152
      memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()),
                   gpu_place, ids_data, ids_num * sizeof(int64_t), stream);
153 154
      d_table->set_rows(new_rows);

155
      auto *d_table_value = d_table->mutable_value();
156
      d_table_value->Resize({ids_num, table->dims()[1]});
157 158
      d_table_value->mutable_data<T>(context.GetPlace());

159 160
      auto *d_table_data = d_table_value->data<T>();
      auto *d_output_data = d_output->data<T>();
F
fengjiayi 已提交
161 162 163 164
      auto d_output_dims = d_output->dims();
      PADDLE_ENFORCE_EQ(
          d_table_value->dims(),
          framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
165
      memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
166
                   d_output->numel() * sizeof(T), stream);
167 168

    } else {
F
fengjiayi 已提交
169 170 171
      auto ids_t = context.Input<LoDTensor>("Ids");
      auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
172 173 174 175

      int N = d_table_t->dims()[0];
      int D = d_table_t->dims()[1];
      int K = ids_t->numel();
176 177 178
      const int64_t *ids = ids_t->data<int64_t>();
      const T *d_output = d_output_t->data<T>();
      T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
179 180

      auto t = framework::EigenVector<T>::Flatten(*d_table_t);
Q
QI JUN 已提交
181
      t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
182 183 184

      dim3 threads(128, 8);
      dim3 grids(8, 1);
Q
QI JUN 已提交
185
      LookupTableGrad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
T
typhoonzero 已提交
186
          d_table, d_output, ids, N, K, D);
187
    }
188 189 190 191 192 193 194
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
195
namespace plat = paddle::platform;
Q
QI JUN 已提交
196
REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
197 198
                        ops::LookupTableCUDAKernel<double>,
                        ops::LookupTableCUDAKernel<plat::float16>);
Q
QI JUN 已提交
199 200
REGISTER_OP_CUDA_KERNEL(lookup_table_grad,
                        ops::LookupTableGradCUDAKernel<float>,
201 202
                        ops::LookupTableGradCUDAKernel<double>,
                        ops::LookupTableGradCUDAKernel<plat::float16>);