diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index a9146c8aa58959bdbd6c994ffc2d9a3b08b55099..2d9a7522515d0109a4bb4f268e48ab2f6285154b 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -132,6 +132,40 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) { return shared_memory[threadIdx.x]; } +// Swap data +template +__device__ __forceinline__ void Swap(T* first_value, T* second_value) { + T t_value; + t_value = (*first_value); + (*first_value) = (*second_value); + (*second_value) = t_value; +} + +// swap with monotonic_type +template +__device__ __forceinline__ void Comparator(T* first_value, + T* second_value, + int monotonic_type) { + if (((*first_value) > (*second_value)) == monotonic_type) { + Swap(first_value, second_value); + } +} + +template +__device__ __forceinline__ void ComparatorWithIndex(T* first_value, + + T* second_value, + IndexType* first_index, + IndexType* second_index, + int monotonic_type) { + if ((*first_value > (*second_value)) == monotonic_type) { + // swap value + Swap(first_value, second_value); + // swap index + Swap(first_index, second_index); + } +} + } // namespace details /** @@ -481,5 +515,94 @@ __device__ __forceinline__ void Cumsum(OutT* out, static_cast(temp[tidx + shared_size + (tidx + shared_size) / 32]); } +#define SHARED_SIZE_LIMIT \ + 1024 // each thread load 2 data from global memory so SHARED_SIZE_LIMIT must + // larger than blockDim.x * 2 +// if monotonic_type = 1 then increase +// if gridDim.x > 1 please set monotonic_type = blockIdx.x & 1; blockIdx.x % 2 +// == 1 the increase +template +__device__ __forceinline__ void Sort(T* dst, + const T* src_data, + int num, + int monotonic_type) { + // todo: set num = Pow2(num) + // shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2 + __shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than + // blockDim * 2 + // Copy value and index from src and src_index + value[threadIdx.x] = src_data[0]; + value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1]; + // make bitonicSort + for (int size = 2; size < num; size <<= 1) { + int bitonic_type = (threadIdx.x & (size / 2)) != 0; + for (int stride = size / 2; stride > 0; stride >>= 1) { + __syncthreads(); + int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + details::Comparator(&value[pos], &value[pos + stride], bitonic_type); + } + } + // last sort + for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) { + __syncthreads(); + int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + // last sort when monotonic_type = 1 then increase + details::Comparator(&value[pos], &value[pos + stride], monotonic_type); + } + __syncthreads(); + dst[0] = value[threadIdx.x]; + dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)]; +} + +template +__device__ __forceinline__ void Sort(T* dst, + IndexType* dst_index, + const T* src_data, + IndexType* src_index, + int num, + int monotonic_type) { + // todo: set num = Pow2(num) + // shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2 + __shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than + // blockDim * 2 + __shared__ IndexType index[SHARED_SIZE_LIMIT]; + // Copy value and index from src and src_index + value[threadIdx.x] = src_data[0]; + value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1]; + // index + index[threadIdx.x] = src_index[0]; + index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_index[1]; + // make bitonicSort + for (int size = 2; size < num; size <<= 1) { + int bitonic_type = (threadIdx.x & (size / 2)) != 0; + for (int stride = size / 2; stride > 0; stride >>= 1) { + __syncthreads(); + int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + details::ComparatorWithIndex(&value[pos], + &value[pos + stride], + &index[pos], + &index[pos + stride], + bitonic_type); + } + } + + for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) { + __syncthreads(); + int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); + // last sort when monotonic_type = 1 then increase + details::ComparatorWithIndex(&value[pos], + &value[pos + stride], + &index[pos], + &index[pos + stride], + monotonic_type); + } + + __syncthreads(); + dst[0] = value[threadIdx.x]; + dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)]; + dst_index[0] = index[threadIdx.x]; + dst_index[1] = index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)]; +} + } // namespace kps } // namespace phi