diff --git a/paddle/phi/kernels/funcs/select_impl.cu.h b/paddle/phi/kernels/funcs/select_impl.cu.h index 3a1d9b8ea7a7a36c31f31f7fc60dffc1f827d34e..16e00414ad772a162d33fb18c0dfafc889e53951 100644 --- a/paddle/phi/kernels/funcs/select_impl.cu.h +++ b/paddle/phi/kernels/funcs/select_impl.cu.h @@ -168,57 +168,82 @@ __global__ void CumsumOneBlock( } } +// where_index template + int MaskData> struct SelectCaller { - __device__ void inline operator()(OutT *store_data, + __device__ void inline operator()(OutT *out, const MT *mask_data, const InT *in, Functor func, - int num, - int data_offset) { - // where_index op - IdT index_reg[VecSize]; - // Set data index of global - kps::InitWithDataIndex(&index_reg[0], data_offset); + int data_offset, + int store_num, + int thread_fix, + int num) { + int64_t in_data[VecSize]; + OutT store_data[VecSize * phi::DDim::kMaxRank]; + // set index + kps::InitWithDataIndex(&in_data[0], data_offset); // Get store data according to mask_idt - kps::OperatorTernary( - store_data, mask_data, &index_reg[0], func, VecSize); + kps::OperatorTernary( + store_data, mask_data, &in_data[0], func, VecSize); + kps::details::WriteData(out + thread_fix, &store_data[0], store_num); } }; +// masked_select template -struct SelectCaller { // masked_select - __device__ void inline operator()(OutT *store_data, +struct SelectCaller { + __device__ void inline operator()(OutT *out, const MT *mask_data, const InT *in, Functor func, - int num, - int data_offset) { + int data_offset, + int store_num, + int thread_fix, + int num) { InT in_data[VecSize]; + OutT store_data[VecSize * phi::DDim::kMaxRank]; kps::ReadData(&in_data[0], in, num); // Get store data according to mask_idt kps::OperatorTernary( store_data, mask_data, &in_data[0], func, VecSize); + kps::details::WriteData(out + thread_fix, &store_data[0], store_num); + } +}; + +// masked_select_grad +template +struct SelectCaller { + __device__ void inline operator()(OutT *out, + const MT *mask_data, + const InT *in, + Functor func, + int data_offset, + int store_num, + int thread_fix, + int num) { + InT in_data[VecSize]; + OutT store_data[VecSize * phi::DDim::kMaxRank]; + kps::details::ReadData(&in_data[0], in + thread_fix, store_num); + kps::OperatorTernary( + store_data, mask_data, &in_data[0], func, VecSize); + kps::WriteData(out, &store_data[0], num); } }; @@ -253,7 +278,6 @@ SelectKernelImpl(OutT *out, IdT num_thread[kCVecSize]; IdT cumsum_thread[kCVecSize]; - OutT store_data[VecSize * phi::DDim::kMaxRank]; MT mask_data[VecSize]; IdT mask_idt[VecSize]; // init data_pr @@ -271,17 +295,15 @@ SelectKernelImpl(OutT *out, // Get cumsum_thread cumsum from 0 to num_thread cumsum_thread[0] is the // thread_fix kps::Cumsum(&cumsum_thread[0], &num_thread[0], Add()); - // Get store data(index) according to mask_idt - SelectCaller - compute; - compute(&store_data[0], &mask_data[0], in, func, num, data_offset); // get thread_fix int thread_fix = (static_cast(cumsum_thread[0] - num_thread[0]) * store_rank); // get how many data need to store int store_num = static_cast(num_thread[0]) * store_rank; // thread store num data, each thread may has different num - kps::details::WriteData(out + thread_fix, &store_data[0], store_num); + // Get store data(index) according to mask_idt + SelectCaller select; + select(out, mask_data, in, func, data_offset, store_num, thread_fix, num); } template (&block_store_offset, cumsum + idx_cumsum, 1); + int out_fix = MaskData < 2 ? block_store_offset * store_rank : data_offset; + int in_fix = MaskData < 2 ? data_offset : block_store_offset * store_rank; SelectKernelImpl( - out + block_store_offset * store_rank, + out + out_fix, mask + data_offset, - in + data_offset, + in + in_fix, func, size, data_offset, @@ -323,12 +347,13 @@ __global__ void SelectKernel(OutT *out, if (num > 0) { // Cumsum index int idx_cumsum = repeat * GRID_NUM_X + BLOCK_ID_X; - // niuliling todo: us ReadData API - int block_store_offset = static_cast(cumsum[idx_cumsum]); + kps::details::ReadData(&block_store_offset, cumsum + idx_cumsum, 1); + int out_fix = MaskData < 2 ? block_store_offset * store_rank : data_offset; + int in_fix = MaskData < 2 ? data_offset : block_store_offset * store_rank; SelectKernelImpl( - out + block_store_offset * store_rank, + out + out_fix, mask + data_offset, - in + data_offset, + in + in_fix, func, num, data_offset, @@ -402,6 +427,7 @@ void SelectKernel(const KPDevice &dev_ctx, const int kCumVesize = 2; const int block_c = 256; const int main_offset_c = Floor(size_count_block, (kCumVesize * block_c)); + using Add = kps::AddFunctor; CumsumOneBlock<<<1, block_c, 0, stream>>>( count_data, cumsum_data, size_count_block, main_offset_c, Add()); @@ -418,10 +444,13 @@ void SelectKernel(const KPDevice &dev_ctx, dev_ctx.Wait(); // 3.1.2 allock for out with total_true_num std::vector out_dim = {static_cast(total_true_num)}; - if (SelectData == 0) { // where_index + + if (SelectData == 1) { + out->Resize(phi::make_ddim(out_dim)); + } else if (SelectData == 0) { // == 0 where_index out_dim.push_back(rank); + out->Resize(phi::make_ddim(out_dim)); } - out->Resize(phi::make_ddim(out_dim)); auto out_data = out->mutable_data(cuda_place); // 3.2 get true data's index according to cond_data and cumsum_data if (total_true_num <= 0) return; diff --git a/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu b/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu index 71b7cd8750462fdf0dad20b2b221bd18cc6dbbe6..5d0097af2ca9abe4c7a4feb2d312068a5150ae1b 100644 --- a/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_select_grad_kernel.cu @@ -17,38 +17,31 @@ #include #include -#include "paddle/phi/backends/cpu/cpu_context.h" -#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/select_impl.cu.h" #include "paddle/phi/kernels/masked_select_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" namespace phi { -__global__ void SetMaskArrayT(const bool* mask, int32_t* mask_array, int size) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - for (; idx < size; idx += blockDim.x * gridDim.x) { - if (mask[idx]) - mask_array[idx] = 1; - else - mask_array[idx] = 0; - } -} +template +struct MaskedSelectGradFunctor { + HOSTDEVICE MaskedSelectGradFunctor() {} -template -__global__ void SelectGradWithPrefixMask(const int32_t* mask_prefix_sum, - const bool* mask, - const T* input, - T* out, - int size) { - int idx = blockDim.x * blockIdx.x + threadIdx.x; - for (; idx < size; idx += blockDim.x * gridDim.x) { - if (mask[idx]) { - int index = mask_prefix_sum[idx]; - out[idx] = input[index]; - } else { - out[idx] = 0; + HOSTDEVICE inline void operator()(OutT* out, + const MT* mask, + const InT* value, + int num) { + int read_fix = 0; + for (int idx = 0; idx < num; idx++) { + if (mask[idx]) { + out[idx] = value[read_fix++]; + } else { + out[idx] = 0; + } } } -} +}; template void MaskedSelectGradKernel(const Context& dev_ctx, @@ -56,42 +49,12 @@ void MaskedSelectGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& mask, DenseTensor* x_grad) { - auto* mask_data = mask.data(); - auto* input_data = out_grad.data(); - auto* out_data = x_grad->mutable_data(dev_ctx.GetPlace()); - - auto input_size = out_grad.numel(); auto mask_size = mask.numel(); - auto mask_dim = mask.dims(); - - auto out_size = mask_size; - - DenseTensor mask_array; - DenseTensor mask_prefix_sum; - mask_array.Resize(mask_dim); - mask_prefix_sum.Resize(mask_dim); - - int32_t* mask_array_data = - mask_array.mutable_data(dev_ctx.GetPlace()); - int32_t* mask_prefix_sum_data = - mask_prefix_sum.mutable_data(dev_ctx.GetPlace()); - int threads = 512; - int grid = (mask_size + threads - 1) / threads; - auto stream = dev_ctx.stream(); - SetMaskArrayT<<>>( - mask_data, mask_array_data, mask_size); - - thrust::device_ptr mask_array_dev_ptr = - thrust::device_pointer_cast(mask_array_data); - thrust::device_vector mask_array_vec(mask_array_dev_ptr, - mask_array_dev_ptr + mask_size); - thrust::exclusive_scan(thrust::device, - mask_array_vec.begin(), - mask_array_vec.end(), - mask_prefix_sum_data); - - SelectGradWithPrefixMask<<>>( - mask_prefix_sum_data, mask_data, input_data, out_data, mask_size); + auto* out_data = x_grad->mutable_data(dev_ctx.GetPlace()); + if (mask_size <= 0) return; + using Functor = MaskedSelectGradFunctor; + phi::funcs::SelectKernel( + dev_ctx, mask, out_grad, x_grad, Functor()); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/masked_select_kernel.cu b/paddle/phi/kernels/gpu/masked_select_kernel.cu index b443ae6b8fb5e6c3bf5264a50d25205a419f22ad..8986c97583e2086eb9ba53aee34745533dbc91c5 100644 --- a/paddle/phi/kernels/gpu/masked_select_kernel.cu +++ b/paddle/phi/kernels/gpu/masked_select_kernel.cu @@ -17,11 +17,12 @@ #include #include -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h" #include "paddle/phi/kernels/masked_select_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + namespace phi { template diff --git a/paddle/phi/kernels/gpu/where_index_kernel.cu b/paddle/phi/kernels/gpu/where_index_kernel.cu index 9538533f70d597e21b393d2650d56bebd823c360..616679057ffce29b8d911d56d5cf428801138589 100644 --- a/paddle/phi/kernels/gpu/where_index_kernel.cu +++ b/paddle/phi/kernels/gpu/where_index_kernel.cu @@ -20,13 +20,14 @@ namespace cub = hipcub; #endif -#include "paddle/phi/core/ddim.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/select_impl.cu.h" #include "paddle/phi/kernels/where_index_kernel.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + namespace phi { template struct IndexFunctor { diff --git a/paddle/phi/kernels/primitive/datamover_primitives.h b/paddle/phi/kernels/primitive/datamover_primitives.h index 1d4181f3b9a89509ada2a8fe27d584a9b5aa039c..993349f2d9e14112e2acbd0bc098f3a00351232b 100644 --- a/paddle/phi/kernels/primitive/datamover_primitives.h +++ b/paddle/phi/kernels/primitive/datamover_primitives.h @@ -123,6 +123,15 @@ __device__ __forceinline__ void WriteData(T* dst, dst[i] = src[i]; } } + +template +__device__ __forceinline__ void ReadData(T* dst, + const T* __restrict__ src, + int num) { + for (int i = 0; i < num; i++) { + dst[i] = src[i]; + } +} #undef INT_BITS } // namespace details