lookup_table_op.cu 8.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
Luo Tao 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

Y
Yi Wang 已提交
15 16 17
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lookup_table_op.h"
18
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
19
#include "paddle/fluid/platform/float16.h"
20 21 22 23

namespace paddle {
namespace operators {

24 25
template <typename T, int BlockDimX, int BlockDimY, int GridDimX,
          bool PaddingFlag>
26
__global__ void LookupTable(T *output, const T *table, const int64_t *ids,
27 28
                            const int64_t N, const int64_t K, const int64_t D,
                            const int64_t padding_idx) {
29
  int idx = threadIdx.x;
30
  int idy = blockIdx.x + threadIdx.y * GridDimX;
31 32

  while (idy < K) {
33
    int64_t id = ids[idy];
34
    PADDLE_ENFORCE(
35 36 37 38
        id >= 0,
        "Variable value (input) of OP(fluid.layers.embedding) "
        "expected >= 0 and < %ld, but got %ld. Please check input value.",
        N, id);
39
    PADDLE_ENFORCE(
40 41 42 43
        id < N,
        "Variable value (input) of OP(fluid.layers.embedding) "
        "expected >= 0 and < %ld, but got %ld. Please check input value.",
        N, id);
44 45
    T *out = output + idy * D;
    const T *tab = table + id * D;
46
    for (int i = idx; i < D; i += BlockDimX) {
47
      if (PaddingFlag) {
48
        if (id == padding_idx)
49 50 51 52 53 54
          out[i] = static_cast<T>(0);
        else
          out[i] = tab[i];
      } else {
        out[i] = tab[i];
      }
55
    }
56
    idy += BlockDimY * GridDimX;
57 58 59
  }
}

60
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
61
__global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
62 63
                                const int64_t N, const int64_t K,
                                const int64_t D) {
64
  int idx = threadIdx.x;
65
  int idy = blockIdx.x + threadIdx.y * GridDimX;
66 67

  while (idy < K) {
68
    int64_t id = ids[idy];
69
    PADDLE_ENFORCE(
70 71 72 73
        id >= 0,
        "Variable value (input) of OP(fluid.layers.embedding) "
        "expected >= 0 and < %ld, but got %ld. Please check input value.",
        N, id);
74
    PADDLE_ENFORCE(
75 76 77 78
        id < N,
        "Variable value (input) of OP(fluid.layers.embedding) "
        "expected >= 0 and < %ld, but got %ld. Please check input value.",
        N, id);
79 80
    const T *out = output + idy * D;
    T *tab = table + id * D;
81
    for (int i = idx; i < D; i += BlockDimX) {
D
dangqingqing 已提交
82
      paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
83
    }
84
    idy += BlockDimY * GridDimX;
85 86 87 88
  }
}

template <typename T>
Y
Yu Yang 已提交
89
class LookupTableCUDAKernel : public framework::OpKernel<T> {
90
 public:
91 92 93 94
  void Compute(const framework::ExecutionContext &context) const override {
    auto *table_t = context.Input<LoDTensor>("W");
    auto *ids_t = context.Input<LoDTensor>("Ids");
    auto *output_t = context.Output<LoDTensor>("Out");
95
    int64_t padding_idx = context.Attr<int64_t>("padding_idx");
96

H
hong 已提交
97 98
    auto id_name = context.InputNames("Ids").front();
    auto out_name = context.OutputNames("Out").front();
99

100 101 102 103 104 105 106 107
    size_t N = table_t->dims()[0];
    size_t D = table_t->dims()[1];
    size_t K = ids_t->numel();

    auto *ids = ids_t->data<int64_t>();
    auto *table = table_t->data<T>();
    auto *output = output_t->mutable_data<T>(context.GetPlace());

F
furnace 已提交
108 109 110
#ifdef PADDLE_WITH_HIP
    dim3 threads(64, 4);
#else
111
    dim3 threads(128, 8);
F
furnace 已提交
112
#endif  // PADDLE_WITH_HIP
113
    dim3 grids(8, 1);
F
furnace 已提交
114 115 116 117 118 119 120 121 122 123 124 125
#ifdef PADDLE_WITH_HIP
    if (padding_idx == -1)
      LookupTable<
          T, 64, 4, 8,
          false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
          output, table, ids, N, K, D, padding_idx);
    else
      LookupTable<
          T, 64, 4, 8,
          true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
          output, table, ids, N, K, D, padding_idx);
#else
126 127 128 129 130 131 132 133 134 135
    if (padding_idx == -1)
      LookupTable<
          T, 128, 8, 8,
          false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
          output, table, ids, N, K, D, padding_idx);
    else
      LookupTable<
          T, 128, 8, 8,
          true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
          output, table, ids, N, K, D, padding_idx);
F
furnace 已提交
136
#endif  // PADDLE_WITH_HIP
137 138 139 140
  }
};

template <typename T>
Y
Yu Yang 已提交
141
class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
142
 public:
143 144
  void Compute(const framework::ExecutionContext &context) const override {
    auto &dev_ctx =
Q
QI JUN 已提交
145
        context.template device_context<platform::CUDADeviceContext>();
146
    bool is_sparse = context.Attr<bool>("is_sparse");
147

148 149
    // Since paddings are not trainable and fixed in forward, the gradient of
    // paddings makes no sense and we don't deal with it in backward.
150
    if (is_sparse) {
151 152 153
      auto *ids = context.Input<LoDTensor>("Ids");
      auto *table = context.Input<LoDTensor>("W");
      auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
154
      auto *d_table =
155
          context.Output<phi::SelectedRows>(framework::GradVarName("W"));
156

157
      auto *ids_data = ids->data<int64_t>();
158
      int64_t ids_num = ids->numel();
159

Q
QI JUN 已提交
160
      auto stream = dev_ctx.stream();
161 162
      // copy GPU memory to CPU pinned memory
      framework::Vector<int64_t> new_rows;
163
      new_rows.resize(ids_num);
164
      auto gpu_place = context.GetPlace();
165

Y
Yu Yang 已提交
166
      // TODO(yuyang18): Strange code here.
167 168
      paddle::framework::MixVector<int64_t> mixv_new_rows(&new_rows);
      memory::Copy(gpu_place, mixv_new_rows.CUDAMutableData(context.GetPlace()),
Y
Yu Yang 已提交
169
                   gpu_place, ids_data, ids_num * sizeof(int64_t), stream);
170
      mixv_new_rows.CopyToCPU();
171 172
      d_table->set_rows(new_rows);

173
      auto *d_table_value = d_table->mutable_value();
174
      d_table_value->Resize({ids_num, table->dims()[1]});
175 176
      d_table_value->mutable_data<T>(context.GetPlace());

177 178
      auto *d_table_data = d_table_value->data<T>();
      auto *d_output_data = d_output->data<T>();
F
fengjiayi 已提交
179
      auto d_output_dims = d_output->dims();
180
      auto d_output_dims_2d =
181
          phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
182
      PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output_dims_2d,
183 184 185 186 187 188
                        platform::errors::InvalidArgument(
                            "ShapeError: The shape of lookup_table@Grad and "
                            "output@Grad should be same. "
                            "But received lookup_table@Grad's shape = [%s], "
                            "output@Grad's shape = [%s].",
                            d_table_value->dims(), d_output_dims_2d));
189
      memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
190
                   d_output->numel() * sizeof(T), stream);
191 192

    } else {
F
fengjiayi 已提交
193 194 195
      auto ids_t = context.Input<LoDTensor>("Ids");
      auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
      auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
196 197 198 199

      int N = d_table_t->dims()[0];
      int D = d_table_t->dims()[1];
      int K = ids_t->numel();
200 201 202
      const int64_t *ids = ids_t->data<int64_t>();
      const T *d_output = d_output_t->data<T>();
      T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
203 204

      auto t = framework::EigenVector<T>::Flatten(*d_table_t);
Q
QI JUN 已提交
205
      t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
206

F
furnace 已提交
207 208 209
#ifdef PADDLE_WITH_HIP
      dim3 threads(64, 4);
#else
210
      dim3 threads(128, 8);
F
furnace 已提交
211
#endif  // PADDLE_WITH_HIP
212
      dim3 grids(8, 1);
F
furnace 已提交
213 214 215 216 217

#ifdef PADDLE_WITH_HIP
      LookupTableGrad<T, 64, 4, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
          d_table, d_output, ids, N, K, D);
#else
Q
QI JUN 已提交
218
      LookupTableGrad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
T
typhoonzero 已提交
219
          d_table, d_output, ids, N, K, D);
F
furnace 已提交
220
#endif  // PADDLE_WITH_HIP
221
    }
222 223 224 225 226 227 228
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
229
namespace plat = paddle::platform;
Q
QI JUN 已提交
230
REGISTER_OP_CUDA_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
231
                        ops::LookupTableCUDAKernel<double>,
232
                        ops::LookupTableCUDAKernel<plat::float16>,
233 234
                        ops::LookupTableCUDAKernel<int8_t>,
                        ops::LookupTableCUDAKernel<int16_t>);
Q
QI JUN 已提交
235 236
REGISTER_OP_CUDA_KERNEL(lookup_table_grad,
                        ops::LookupTableGradCUDAKernel<float>,
237 238
                        ops::LookupTableGradCUDAKernel<double>,
                        ops::LookupTableGradCUDAKernel<plat::float16>);