sequence_mask_compute.cu 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <thrust/device_ptr.h>
16
#include <thrust/functional.h>
17 18 19 20
#include <thrust/reduce.h>

#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
21
#include "lite/kernels/cuda/sequence_mask_compute.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46

namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {

template <typename T>
__global__ void SequenceMaskKernel(T* dst,
                                   const int64_t* src,
                                   int count,
                                   int maxlen) {
  CUDA_KERNEL_LOOP(index, count) {
    int src_idx = index / maxlen;
    int inner_idx = index % maxlen;
    dst[index] = static_cast<T>(inner_idx < src[src_idx] ? 1 : 0);
  }
}

template <typename T, PrecisionType Ptype>
void SequenceMaskCompute<T, Ptype>::Run() {
  auto& param = this->template Param<param_t>();
  auto& ctx = this->ctx_->template As<CUDAContext>();
  auto stream = ctx.exec_stream();

  const auto* x = param.X;
47
  const int64_t* x_data = x->template data<int64_t>();
48 49 50 51 52 53 54 55 56 57 58 59
  auto* y = param.Y;
  int maxlen = param.maxlen;

  if (param.MaxLenTensor) {
    auto* len_tensor_data = param.MaxLenTensor->template data<int32_t>();
    int32_t len_data{0};
    TargetWrapperCuda::MemcpySync(
        &len_data, len_tensor_data, sizeof(int32_t), IoDirection::DtoH);
    maxlen = len_data;
  }

  if (maxlen < 0) {
60 61 62 63 64
    maxlen = static_cast<int>(
        thrust::reduce(thrust::device_pointer_cast(x_data),
                       thrust::device_pointer_cast(x_data) + x->numel(),
                       static_cast<int64_t>(0),
                       thrust::maximum<int64_t>()));
65 66 67 68 69 70
  }

  auto y_dim = x->dims().Vectorize();
  y_dim.push_back(maxlen);
  y->Resize(y_dim);
  const int count = y->numel();
71
  auto* dst_data = y->template mutable_data<T>(TARGET(kCUDA));
72 73
  if (param.out_dtype == 5) {
    SequenceMaskKernel<
74
        T><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
        dst_data, x_data, count, maxlen);
  } else {
    LOG(FATAL) << "not supported out_dtype: " << param.out_dtype;
  }
  CUDA_POST_KERNEL_CHECK;
}

}  // namespace cuda
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

using SeqMaskFp32 =
    paddle::lite::kernels::cuda::SequenceMaskCompute<float, PRECISION(kFloat)>;

90 91 92
using SeqMaskFp16 =
    paddle::lite::kernels::cuda::SequenceMaskCompute<half, PRECISION(kFP16)>;

93 94
REGISTER_LITE_KERNEL(sequence_mask, kCUDA, kFloat, kNCHW, SeqMaskFp32, def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
95 96
    .BindInput("MaxLenTensor",
               {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt32))})
97 98
    .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
    .Finalize();
99 100 101 102 103 104 105

REGISTER_LITE_KERNEL(sequence_mask, kCUDA, kFP16, kNCHW, SeqMaskFp16, def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
    .BindInput("MaxLenTensor",
               {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt32))})
    .BindOutput("Y", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
    .Finalize();