sequence_erase_op.cu 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/operators/sequence_erase_op.h"
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using LoDTensor = framework::LoDTensor;

template <typename T>
__global__ void LabelErasedIdx(const T* in_dat, const int in_len,
                               const T* tokens, const int tokens_len,
                               int* num_erased) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < in_len) {
    int erased = 0;
    for (int i = 0; i < tokens_len; ++i) {
      if (in_dat[index] == tokens[i]) {
        erased = 1;
      }
    }
    num_erased[index + 1] = erased;
    if (index == 0) {
      num_erased[0] = 0;
    }
  }
}

template <typename T>
45 46
__global__ void GetOutLod(const T* num_erased, const size_t* in_lod,
                          const int lod_len, size_t* out_lod0) {
47 48 49 50 51 52 53 54 55 56 57
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < lod_len) {
    out_lod0[index] = in_lod[index] - num_erased[in_lod[index]];
  }
}

template <typename T>
__global__ void SetOutput(const T* in_dat, const int in_len,
                          const int* num_erased, T* out_dat) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < in_len) {
Y
Yibing Liu 已提交
58
    if (num_erased[index] == num_erased[index + 1]) {
59 60 61 62 63
      out_dat[index - num_erased[index]] = in_dat[index];
    }
  }
}

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
template <typename T, typename Vector>
thrust::device_vector<T> set_device_vector(Vector& vector) {
  thrust::host_vector<T> host_vec(vector.size());
  for (size_t i = 0; i < vector.size(); ++i) {
    host_vec[i] = vector[i];
  }
  thrust::device_vector<T> dev_vec = host_vec;
  return dev_vec;
}

template <typename T>
std::vector<T> get_std_vector(thrust::device_vector<T>& dev_vec) {
  thrust::host_vector<T> host_vec = dev_vec;
  std::vector<T> std_vec(host_vec.size(), 0);
  for (size_t i = 0; i < host_vec.size(); ++i) {
    std_vec[i] = host_vec[i];
  }
  return std_vec;
}

84 85 86 87 88 89 90 91 92
template <typename T>
class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<LoDTensor>("X");
    auto* out = ctx.Output<LoDTensor>("Out");

    auto lod = in->lod();
    PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
93 94
    PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
                      "The actual size mismatches with the LoD information.");
95 96 97
    auto tokens = ctx.Attr<std::vector<T>>("tokens");
    auto in_len = in->numel();
    auto in_dat = in->data<T>();
98 99 100 101
    // Copy tokens to GPU
    thrust::device_vector<T> dev_tokens =
        set_device_vector<T, std::vector<T>>(tokens);
    T* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());
102

103
    // Count number of elements to be erased
104 105 106 107 108
    thrust::device_vector<int> num_erased(in_len + 1);
    int* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data());
    auto stream = ctx.cuda_device_context().stream();
    LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
                     PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
109
        in_dat, in_len, dev_tokens_ptr, tokens.size(), num_erased_ptr);
110 111 112
    thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(),
                           num_erased.begin() + 1);

113 114
    // Copy LoD to GPU
    auto lod0 = lod[0];
115
    auto lod_len = lod0.size();
116 117 118 119 120 121 122
    thrust::device_vector<size_t> dev_in_lod =
        set_device_vector<size_t, paddle::framework::Vector<size_t>>(lod0);
    size_t* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data());

    // Calc output LoD
    thrust::device_vector<size_t> dev_out_lod(lod_len);
    size_t* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
123 124 125
    GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
                PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
        num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
126 127 128

    // Set LoD for output
    std::vector<size_t> out_lod0 = get_std_vector<size_t>(dev_out_lod);
129 130
    framework::LoD out_lod;
    out_lod.push_back(out_lod0);
Y
Yibing Liu 已提交
131
    out->set_lod(out_lod);
132 133

    // Set output
134
    out->Resize({static_cast<int64_t>(out_lod0.back()), 1});
135 136 137 138 139 140 141 142 143 144 145 146
    auto out_dat = out->mutable_data<T>(ctx.GetPlace());
    SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
                PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len,
                                                      num_erased_ptr, out_dat);
  }
};

}  // namespace operators
}  // namespace paddle

REGISTER_OP_CUDA_KERNEL(sequence_erase,
                        paddle::operators::SequenceEraseOpCUDAKernel<int32_t>);