sequence_erase_op.cu 4.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
Y
Yi Wang 已提交
17 18
#include "paddle/fluid/operators/sequence_erase_op.h"
#include "paddle/fluid/platform/cuda_helper.h"
19 20 21 22 23 24 25

namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using LoDTensor = framework::LoDTensor;

template <typename T>
26 27 28
__global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len,
                               const int* tokens, const size_t tokens_len,
                               size_t* num_erased) {
29 30
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < in_len) {
31
    for (size_t i = 0; i < tokens_len; ++i) {
32
      if (in_dat[index] == tokens[i]) {
33 34
        num_erased[index + 1] = 1;
        break;
35 36 37 38 39
      }
    }
  }
}

40 41
__global__ void GetOutLod(const size_t* num_erased, const size_t* in_lod,
                          const size_t lod_len, size_t* out_lod0) {
42 43 44 45 46 47 48
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < lod_len) {
    out_lod0[index] = in_lod[index] - num_erased[in_lod[index]];
  }
}

template <typename T>
49 50
__global__ void SetOutput(const T* in_dat, const int64_t in_len,
                          const size_t* num_erased, T* out_dat) {
51 52
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < in_len) {
Y
Yibing Liu 已提交
53
    if (num_erased[index] == num_erased[index + 1]) {
54 55 56 57 58 59 60 61 62 63 64 65 66 67
      out_dat[index - num_erased[index]] = in_dat[index];
    }
  }
}

template <typename T>
class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in = ctx.Input<LoDTensor>("X");
    auto* out = ctx.Output<LoDTensor>("Out");

    auto lod = in->lod();
    PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
68 69
    PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(),
                      "The actual size mismatches with the LoD information.");
70
    auto tokens = ctx.Attr<std::vector<int>>("tokens");
71 72
    auto in_len = in->numel();
    auto in_dat = in->data<T>();
73
    // Copy tokens to GPU
74
    thrust::device_vector<int> dev_tokens(tokens.begin(), tokens.end());
75
    int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data());
76

77
    // Count number of elements to be erased
78
    thrust::device_vector<size_t> num_erased(in_len + 1, 0);
79
    size_t* num_erased_ptr = thrust::raw_pointer_cast(num_erased.data());
80 81 82
    auto stream = ctx.cuda_device_context().stream();
    LabelErasedIdx<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
                     PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
83
        in_dat, in_len, dev_tokens_ptr, tokens.size(), num_erased_ptr);
84 85 86
    thrust::inclusive_scan(num_erased.begin() + 1, num_erased.end(),
                           num_erased.begin() + 1);

87 88
    // Copy LoD to GPU
    auto lod0 = lod[0];
89
    auto lod_len = lod0.size();
Y
Yu Yang 已提交
90
    const size_t* dev_in_lod_ptr = lod0.CUDAData(ctx.GetPlace());
91 92 93 94

    // Calc output LoD
    thrust::device_vector<size_t> dev_out_lod(lod_len);
    size_t* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data());
95 96 97
    GetOutLod<<<(lod_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
                PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
        num_erased_ptr, dev_in_lod_ptr, lod_len, dev_out_lod_ptr);
98
    // Set LoD for output
D
dzhwinter 已提交
99
    std::vector<size_t> out_lod0(dev_out_lod.begin(), dev_out_lod.end());
100 101
    framework::LoD out_lod;
    out_lod.push_back(out_lod0);
Y
Yibing Liu 已提交
102
    out->set_lod(out_lod);
103 104

    // Set output
105
    out->Resize({static_cast<int64_t>(out_lod0.back()), 1});
106 107 108 109 110 111 112 113 114 115 116
    auto out_dat = out->mutable_data<T>(ctx.GetPlace());
    SetOutput<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
                PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_dat, in_len,
                                                      num_erased_ptr, out_dat);
  }
};

}  // namespace operators
}  // namespace paddle

REGISTER_OP_CUDA_KERNEL(sequence_erase,
117 118
                        paddle::operators::SequenceEraseOpCUDAKernel<int32_t>,
                        paddle::operators::SequenceEraseOpCUDAKernel<int64_t>);