masked_select_kernel.cu 4.2 KB
Newer Older
H
hong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/reverse.h>
#include <thrust/scan.h>

20 21 22
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/masked_select_kernel.h"
H
hong 已提交
23

24
namespace phi {
H
hong 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

__global__ void SetMaskArray(const bool* mask, int32_t* mask_array, int size) {
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  for (; idx < size; idx += blockDim.x * gridDim.x) {
    if (mask[idx])
      mask_array[idx] = 1;
    else
      mask_array[idx] = 0;
  }
}

template <typename T>
__global__ void SelectWithPrefixMask(const int32_t* mask_prefix_sum,
                                     const bool* mask,
                                     const T* input,
                                     T* out,
                                     int size) {
  int idx = blockDim.x * blockIdx.x + threadIdx.x;
  for (; idx < size; idx += blockDim.x * gridDim.x) {
    if (mask[idx]) {
      int index = mask_prefix_sum[idx];
      out[index] = input[idx];
    }
  }
}

template <typename T, typename Context>
void MaskedSelectKernel(const Context& dev_ctx,
                        const DenseTensor& x,
                        const DenseTensor& mask,
                        DenseTensor* out) {
  auto* mask_data = mask.data<bool>();
  auto input_data = x.data<T>();

  auto mask_size = mask.numel();
  auto input_dim = x.dims();
  auto mask_dim = mask.dims();
  PADDLE_ENFORCE_EQ(input_dim,
                    mask_dim,
64
                    phi::errors::InvalidArgument(
H
hong 已提交
65 66 67 68 69 70 71 72 73 74 75 76
                        "The dim size of input and mask in OP(masked_selected) "
                        "must be equal, but got input dim:(%ld), mask dim: "
                        "(%ld). Please check input "
                        "value.",
                        input_dim,
                        mask_dim));

  thrust::device_ptr<const bool> mask_dev_ptr =
      thrust::device_pointer_cast(mask_data);
  thrust::device_vector<T> mask_vec(mask_dev_ptr, mask_dev_ptr + mask_size);
  auto out_size = thrust::count(mask_vec.begin(), mask_vec.end(), true);

77
  DDim out_dim{out_size};
H
hong 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
  out->Resize(out_dim);
  auto out_data = out->mutable_data<T>(dev_ctx.GetPlace());

  DenseTensor mask_array;
  DenseTensor mask_prefix_sum;
  mask_array.Resize(mask_dim);
  mask_prefix_sum.Resize(mask_dim);

  int32_t* mask_array_data =
      mask_array.mutable_data<int32_t>(dev_ctx.GetPlace());
  int32_t* mask_prefix_sum_data =
      mask_prefix_sum.mutable_data<int32_t>(dev_ctx.GetPlace());
  int threads = 512;
  int grid = (mask_size + threads - 1) / threads;
  auto stream = dev_ctx.stream();
  SetMaskArray<<<grid, threads, 0, stream>>>(
      mask_data, mask_array_data, mask_size);

  thrust::device_ptr<int32_t> mask_array_dev_ptr =
      thrust::device_pointer_cast(mask_array_data);
  thrust::device_vector<int32_t> mask_array_vec(mask_array_dev_ptr,
                                                mask_array_dev_ptr + mask_size);
  thrust::exclusive_scan(thrust::device,
                         mask_array_vec.begin(),
                         mask_array_vec.end(),
                         mask_prefix_sum_data);

  SelectWithPrefixMask<T><<<grid, threads, 0, stream>>>(
      mask_prefix_sum_data, mask_data, input_data, out_data, mask_size);
}

109
}  // namespace phi
H
hong 已提交
110

111
PD_REGISTER_KERNEL(masked_select,
H
hong 已提交
112 113
                   GPU,
                   ALL_LAYOUT,
114
                   phi::MaskedSelectKernel,
H
hong 已提交
115 116 117 118
                   float,
                   double,
                   int,
                   int64_t) {
119
  kernel->InputAt(1).SetDataType(phi::DataType::BOOL);
H
hong 已提交
120
}