lookup_table_op.cu 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License. */

#include "paddle/framework/op_registry.h"
#include "paddle/operators/functor/math_functor.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T, int blockDimX, int blockDimY, int gridDimX>
__global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
                            const int N, const int K, const int D) {
  int idx = threadIdx.x;
  int idy = blockIdx.x + threadIdx.y * gridDimX;

  while (idy < K) {
    int id = ids[idy];
    PADDLE_ASSERT(id >= 0);
    PADDLE_ASSERT(id < N);
    T* out = output + idy;
    const T* tab = table + id;
    for (int i = idx; i < D; i += blockDimX) {
      out[i] = tab[i];
    }
    idy += blockDimY * gridDimX;
  }
}

template <typename T, int blockDimX, int blockDimY, int gridDimX>
__global__ void LookupTableGradKernel(T* table, const T* output,
                                      const uint32_t* ids, const int N,
                                      const int K, const int D) {
  int idx = threadIdx.x;
  int idy = blockIdx.x + threadIdx.y * gridDimX;

  while (idy < K) {
    int id = ids[idy];
    PADDLE_ASSERT(id >= 0);
    PADDLE_ASSERT(id < N);
    const T* out = output + idy;
    T* tab = table + id;
    for (int i = idx; i < D; i += blockDimX) {
      paddle::platform::CudaAtomicAdd(tab + i, out[i]);
    }
    idy += blockDimY * gridDimX;
  }
}

template <typename T>
class LookupTableCUDAKernel : public framework::OpKernel {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto table_t = context.Input<Tensor>("W");
    auto ids_t = context.Input<Tensor>("Ids");
    auto output_t = context.Output<Tensor>("Out");

    size_t N = table_t->dims()[0];
    size_t D = table_t->dims()[1];
    size_t K = product(ids_t->dims());
    auto ids = ids_t->data<uint32_t>();
    auto table = table_t->data<T>();
    auto output = output_t->mutable_data<T>(context.GetPlace());

    dim3 threads(128, 8);
    dim3 grids(8, 1);
    LookupTable<T, 128, 8, 8><<<grids, threads>>>(output, table, ids, N, K, D);
  }
};

template <typename T>
class LookupTableGrad : public framework::OpKernel {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto ids_t = context.Input<Tensor>("Ids");
    auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
    auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));

    int N = d_table_t->dims()[0];
    int D = d_table_t->dims()[1];
    int K = product(ids_t->dims());
    const uint32_t* ids = ids_t->data<uint32_t>();
    T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
    const T* d_output = d_output_t->data<T>();

    auto* device_context =
        const_cast<platform::DeviceContext*>(context.device_context_);
    functor::Set<paddle::platform::GPUPlace, T>()(static_cast<T>(0), d_table_t,
                                                  device_context);
    dim3 threads(128, 8);
    dim3 grids(8, 1);
    LookupTableGradKernel<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output,
                                                            ids, N, K, D);
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGrad<float>);