c_embedding_op.cu 8.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/operators/collective/c_embedding_op.h"
16
#include "paddle/fluid/framework/convert_utils.h"
17 18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
20
#include "paddle/phi/backends/gpu/gpu_primitives.h"
21 22 23 24 25 26 27 28 29 30 31 32 33

namespace paddle {
namespace operators {

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
  return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
                  kNumMaxinumNumBlocks);
}

template <typename T, typename IndexT>
34 35 36 37 38 39 40 41
__global__ void CEmbedding(T *out,
                           const T *table,
                           const IndexT *ids,
                           const int rows,
                           const int columns,
                           const int64_t N,
                           const int64_t start_idx,
                           const int64_t end_idx,
42 43 44 45 46 47 48 49 50 51 52 53 54
                           const int64_t limit) {
  CUDA_KERNEL_LOOP(i, limit) {
    size_t row = i / columns;
    size_t col = i % columns;
    auto id = ids[row];

    if (id >= start_idx && id < end_idx) {
      auto real_idx = id - start_idx;
      PADDLE_ENFORCE(real_idx < N,
                     "The index is out of bounds, "
                     "please check whether the dimensions of index and "
                     "input meet the requirements. It should "
                     "be less than [%d], but received [%d]",
55 56
                     N,
                     real_idx);
57 58 59 60 61 62 63 64
      out[i] = table[real_idx * columns + col];
    } else {
      out[i] = static_cast<T>(0);
    }
  }
}

template <typename T, typename IndexT>
65 66 67 68 69 70 71 72 73
__global__ void CEmbeddingGrad(T *table,
                               const T *output,
                               const IndexT *ids,
                               const int rows,
                               const int columns,
                               const int64_t N,
                               const int64_t start_idx,
                               const int64_t end_idx,
                               const int64_t limit) {
74 75 76 77 78 79
  CUDA_KERNEL_LOOP(i, limit) {
    size_t row = i / columns;
    size_t col = i % columns;
    auto id = ids[row];
    if (id >= start_idx && id < end_idx) {
      auto real_idx = id - start_idx;
80
      phi::CudaAtomicAdd(&table[real_idx * columns + col], output[i]);
81 82 83 84
    }
  }
}

85
template <typename T, typename DeviceContext>
86 87 88
class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
89 90 91
    auto *table_t = context.Input<phi::DenseTensor>("W");
    auto *ids_t = context.Input<phi::DenseTensor>("Ids");
    auto *output_t = context.Output<phi::DenseTensor>("Out");
92

L
Leo Chen 已提交
93
    const auto &dev_ctx = context.template device_context<phi::GPUContext>();
94 95 96 97 98 99 100 101 102 103 104 105 106 107
    const int64_t start_idx = context.Attr<int64_t>("start_index");
    size_t N = table_t->dims()[0];
    size_t D = table_t->dims()[1];
    size_t K = ids_t->numel();

    const int64_t end_idx = start_idx + N;

    auto *table = table_t->data<T>();
    auto *output = output_t->mutable_data<T>(context.GetPlace());

    auto limit = K * D;
    int blocks = NumBlocks(limit);
    int threads = kNumCUDAThreads;

108
    const auto &index_type = framework::TransToProtoVarType(ids_t->dtype());
109
    if (index_type == framework::proto::VarType::INT32) {
110 111 112 113 114 115 116 117 118 119
      CEmbedding<T, int32_t>
          <<<blocks, threads, 0, dev_ctx.stream()>>>(output,
                                                     table,
                                                     ids_t->data<int32_t>(),
                                                     K,
                                                     D,
                                                     N,
                                                     start_idx,
                                                     end_idx,
                                                     limit);
120 121

    } else if (index_type == framework::proto::VarType::INT64) {
122 123 124 125 126 127 128 129 130 131
      CEmbedding<T, int64_t>
          <<<blocks, threads, 0, dev_ctx.stream()>>>(output,
                                                     table,
                                                     ids_t->data<int64_t>(),
                                                     K,
                                                     D,
                                                     N,
                                                     start_idx,
                                                     end_idx,
                                                     limit);
B
Baibaifan 已提交
132 133 134
    } else {
      PADDLE_THROW(platform::errors::Unavailable(
          "GPU c_embedding ids only support int32 or int64."));
135 136 137 138
    }
  }
};

139
template <typename T, typename DeviceContext>
140 141 142
class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
L
Leo Chen 已提交
143
    const auto &dev_ctx = context.template device_context<phi::GPUContext>();
144
    const int64_t start_idx = context.Attr<int64_t>("start_index");
145 146 147 148 149
    auto ids_t = context.Input<phi::DenseTensor>("Ids");
    auto d_output_t =
        context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
    auto d_table_t =
        context.Output<phi::DenseTensor>(framework::GradVarName("W"));
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165

    int N = d_table_t->dims()[0];
    int D = d_table_t->dims()[1];
    int K = ids_t->numel();

    const int64_t end_idx = start_idx + N;
    auto limit = K * D;
    int blocks = NumBlocks(limit);
    int threads = kNumCUDAThreads;

    const T *d_output = d_output_t->data<T>();
    T *d_table = d_table_t->mutable_data<T>(context.GetPlace());

    auto t = framework::EigenVector<T>::Flatten(*d_table_t);
    t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));

166
    const auto &index_type = framework::TransToProtoVarType(ids_t->dtype());
167
    if (index_type == framework::proto::VarType::INT32) {
168 169 170 171 172 173 174 175 176 177
      CEmbeddingGrad<T, int32_t>
          <<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
                                                     d_output,
                                                     ids_t->data<int32_t>(),
                                                     K,
                                                     D,
                                                     N,
                                                     start_idx,
                                                     end_idx,
                                                     limit);
178
    } else if (index_type == framework::proto::VarType::INT64) {
179 180 181 182 183 184 185 186 187 188
      CEmbeddingGrad<T, int64_t>
          <<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
                                                     d_output,
                                                     ids_t->data<int64_t>(),
                                                     K,
                                                     D,
                                                     N,
                                                     start_idx,
                                                     end_idx,
                                                     limit);
189 190 191 192 193 194 195 196 197
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
198 199 200 201 202 203 204

PD_REGISTER_STRUCT_KERNEL(c_embedding,
                          GPU,
                          ALL_LAYOUT,
                          ops::CEmbeddingCUDAKernel,
                          float,
                          double,
205
#if NCCL_VERSION_CODE >= 21000
206
                          plat::bfloat16,
207
#endif
208 209 210 211 212 213 214 215 216
                          plat::float16) {
}

PD_REGISTER_STRUCT_KERNEL(c_embedding_grad,
                          GPU,
                          ALL_LAYOUT,
                          ops::CEmbeddingGradCUDAKernel,
                          float,
                          double,
217
#if NCCL_VERSION_CODE >= 21000
218
                          plat::bfloat16,
219
#endif
220 221
                          plat::float16) {
}