// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/masked_select_kernel.h" namespace phi { __global__ void SetMaskArray(const bool* mask, int32_t* mask_array, int size) { int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < size; idx += blockDim.x * gridDim.x) { if (mask[idx]) mask_array[idx] = 1; else mask_array[idx] = 0; } } template __global__ void SelectWithPrefixMask(const int32_t* mask_prefix_sum, const bool* mask, const T* input, T* out, int size) { int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < size; idx += blockDim.x * gridDim.x) { if (mask[idx]) { int index = mask_prefix_sum[idx]; out[index] = input[idx]; } } } template void MaskedSelectKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& mask, DenseTensor* out) { auto* mask_data = mask.data(); auto input_data = x.data(); auto mask_size = mask.numel(); auto input_dim = x.dims(); auto mask_dim = mask.dims(); PADDLE_ENFORCE_EQ(input_dim, mask_dim, phi::errors::InvalidArgument( "The dim size of input and mask in OP(masked_selected) " "must be equal, but got input dim:(%ld), mask dim: " "(%ld). Please check input " "value.", input_dim, mask_dim)); thrust::device_ptr mask_dev_ptr = thrust::device_pointer_cast(mask_data); thrust::device_vector mask_vec(mask_dev_ptr, mask_dev_ptr + mask_size); auto out_size = thrust::count(mask_vec.begin(), mask_vec.end(), true); DDim out_dim{out_size}; out->Resize(out_dim); auto out_data = out->mutable_data(dev_ctx.GetPlace()); DenseTensor mask_array; DenseTensor mask_prefix_sum; mask_array.Resize(mask_dim); mask_prefix_sum.Resize(mask_dim); int32_t* mask_array_data = mask_array.mutable_data(dev_ctx.GetPlace()); int32_t* mask_prefix_sum_data = mask_prefix_sum.mutable_data(dev_ctx.GetPlace()); int threads = 512; int grid = (mask_size + threads - 1) / threads; auto stream = dev_ctx.stream(); SetMaskArray<<>>( mask_data, mask_array_data, mask_size); thrust::device_ptr mask_array_dev_ptr = thrust::device_pointer_cast(mask_array_data); thrust::device_vector mask_array_vec(mask_array_dev_ptr, mask_array_dev_ptr + mask_size); thrust::exclusive_scan(thrust::device, mask_array_vec.begin(), mask_array_vec.end(), mask_prefix_sum_data); SelectWithPrefixMask<<>>( mask_prefix_sum_data, mask_data, input_data, out_data, mask_size); } } // namespace phi PD_REGISTER_KERNEL(masked_select, GPU, ALL_LAYOUT, phi::MaskedSelectKernel, float, double, int, int64_t) { kernel->InputAt(1).SetDataType(phi::DataType::BOOL); }