sequence_mask_compute.cu 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <thrust/device_ptr.h>
16
#include <thrust/functional.h>
17 18 19 20
#include <thrust/reduce.h>

#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
21
#include "lite/kernels/cuda/sequence_mask_compute.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {

template <typename T>
__global__ void SequenceMaskKernel(T* dst,
                                   const int64_t* src,
                                   int count,
                                   int maxlen) {
  CUDA_KERNEL_LOOP(index, count) {
    int src_idx = index / maxlen;
    int inner_idx = index % maxlen;
    dst[index] = static_cast<T>(inner_idx < src[src_idx] ? 1 : 0);
  }
}

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
template <typename T>
__global__ void VecMaxKernel(const T* in_data, T* out, const int count) {
  extern __shared__ T cache[];

  int i = blockDim.x * blockIdx.x + threadIdx.x;
  int cache_index = threadIdx.x;
  T tmp = -1;

  while (i < count) {
    if (in_data[i] > tmp) {
      tmp = in_data[i];
    }
    i += blockDim.x * gridDim.x;
  }
  cache[cache_index] = tmp;

  __syncthreads();

  // perform parallel reduction, blockDim.x must be 2^n
  int ib = blockDim.x / 2;
  while (ib != 0) {
    if (cache_index < ib && cache[cache_index + ib] > cache[cache_index]) {
      cache[cache_index] = cache[cache_index + ib];
    }

    __syncthreads();

    ib /= 2;
  }
  if (cache_index == 0) {
    out[blockIdx.x] = cache[0];
  }
}

74 75 76 77 78 79 80
template <typename T, PrecisionType Ptype>
void SequenceMaskCompute<T, Ptype>::Run() {
  auto& param = this->template Param<param_t>();
  auto& ctx = this->ctx_->template As<CUDAContext>();
  auto stream = ctx.exec_stream();

  const auto* x = param.X;
81
  const int64_t* x_data = x->template data<int64_t>();
82 83 84 85 86 87 88 89 90 91 92 93
  auto* y = param.Y;
  int maxlen = param.maxlen;

  if (param.MaxLenTensor) {
    auto* len_tensor_data = param.MaxLenTensor->template data<int32_t>();
    int32_t len_data{0};
    TargetWrapperCuda::MemcpySync(
        &len_data, len_tensor_data, sizeof(int32_t), IoDirection::DtoH);
    maxlen = len_data;
  }

  if (maxlen < 0) {
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
    // choose algorithm according to magic_num.
    const int magic_num = 256;
    std::vector<int64_t> h_max_data;
    if (x->numel() < magic_num) {
      h_max_data.resize(x->numel());
      TargetWrapperCuda::MemcpySync(h_max_data.data(),
                                    x_data,
                                    x->numel() * sizeof(int64_t),
                                    IoDirection::DtoH);
    } else {
      const int threads = 256;
      const int blocks = (x->numel() + threads - 1) / threads;
      max_tensor_.Resize({blocks});
      auto* max_data = max_tensor_.mutable_data<int64_t>(TARGET(kCUDA));
      VecMaxKernel<
          int64_t><<<blocks, threads, threads * sizeof(int64_t), stream>>>(
          x_data, max_data, x->numel());
      h_max_data.resize(blocks);
      TargetWrapperCuda::MemcpyAsync(h_max_data.data(),
                                     max_data,
                                     sizeof(int64_t) * blocks,
                                     IoDirection::DtoH,
                                     stream);
      TargetWrapperCuda::StreamSync(stream);
    }
    auto maxlen_iterator =
        std::max_element(h_max_data.begin(), h_max_data.end());
    maxlen = h_max_data[std::distance(h_max_data.begin(), maxlen_iterator)];
122 123 124 125 126 127
  }

  auto y_dim = x->dims().Vectorize();
  y_dim.push_back(maxlen);
  y->Resize(y_dim);
  const int count = y->numel();
128
  auto* dst_data = y->template mutable_data<T>(TARGET(kCUDA));
129 130
  if (param.out_dtype == 5) {
    SequenceMaskKernel<
131
        T><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
        dst_data, x_data, count, maxlen);
  } else {
    LOG(FATAL) << "not supported out_dtype: " << param.out_dtype;
  }
  CUDA_POST_KERNEL_CHECK;
}

}  // namespace cuda
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

using SeqMaskFp32 =
    paddle::lite::kernels::cuda::SequenceMaskCompute<float, PRECISION(kFloat)>;

147 148 149
using SeqMaskFp16 =
    paddle::lite::kernels::cuda::SequenceMaskCompute<half, PRECISION(kFP16)>;

150 151
REGISTER_LITE_KERNEL(sequence_mask, kCUDA, kFloat, kNCHW, SeqMaskFp32, def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
152 153
    .BindInput("MaxLenTensor",
               {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt32))})
154 155
    .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
    .Finalize();
156 157 158 159 160 161 162

REGISTER_LITE_KERNEL(sequence_mask, kCUDA, kFP16, kNCHW, SeqMaskFp16, def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
    .BindInput("MaxLenTensor",
               {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt32))})
    .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
    .Finalize();