未验证 提交 f4e74887 编写于 作者: N niuliling123 提交者: GitHub

Add Sort API for Kernel Primitive API (#39734)

* Add Sort API for Kernel Primitive API

* update & -> ptr
上级 de760d2c
...@@ -132,6 +132,40 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) { ...@@ -132,6 +132,40 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
return shared_memory[threadIdx.x]; return shared_memory[threadIdx.x];
} }
// Swap data
template <typename T>
__device__ __forceinline__ void Swap(T* first_value, T* second_value) {
T t_value;
t_value = (*first_value);
(*first_value) = (*second_value);
(*second_value) = t_value;
}
// swap with monotonic_type
template <typename T>
__device__ __forceinline__ void Comparator(T* first_value,
T* second_value,
int monotonic_type) {
if (((*first_value) > (*second_value)) == monotonic_type) {
Swap<T>(first_value, second_value);
}
}
template <typename T, typename IndexType>
__device__ __forceinline__ void ComparatorWithIndex(T* first_value,
T* second_value,
IndexType* first_index,
IndexType* second_index,
int monotonic_type) {
if ((*first_value > (*second_value)) == monotonic_type) {
// swap value
Swap<T>(first_value, second_value);
// swap index
Swap<IndexType>(first_index, second_index);
}
}
} // namespace details } // namespace details
/** /**
...@@ -481,5 +515,94 @@ __device__ __forceinline__ void Cumsum(OutT* out, ...@@ -481,5 +515,94 @@ __device__ __forceinline__ void Cumsum(OutT* out,
static_cast<OutT>(temp[tidx + shared_size + (tidx + shared_size) / 32]); static_cast<OutT>(temp[tidx + shared_size + (tidx + shared_size) / 32]);
} }
#define SHARED_SIZE_LIMIT \
1024 // each thread load 2 data from global memory so SHARED_SIZE_LIMIT must
// larger than blockDim.x * 2
// if monotonic_type = 1 then increase
// if gridDim.x > 1 please set monotonic_type = blockIdx.x & 1; blockIdx.x % 2
// == 1 the increase
template <typename T>
__device__ __forceinline__ void Sort(T* dst,
const T* src_data,
int num,
int monotonic_type) {
// todo: set num = Pow2(num)
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
__shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than
// blockDim * 2
// Copy value and index from src and src_index
value[threadIdx.x] = src_data[0];
value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1];
// make bitonicSort
for (int size = 2; size < num; size <<= 1) {
int bitonic_type = (threadIdx.x & (size / 2)) != 0;
for (int stride = size / 2; stride > 0; stride >>= 1) {
__syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
details::Comparator<T>(&value[pos], &value[pos + stride], bitonic_type);
}
}
// last sort
for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) {
__syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
// last sort when monotonic_type = 1 then increase
details::Comparator<T>(&value[pos], &value[pos + stride], monotonic_type);
}
__syncthreads();
dst[0] = value[threadIdx.x];
dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)];
}
template <typename T, typename IndexType>
__device__ __forceinline__ void Sort(T* dst,
IndexType* dst_index,
const T* src_data,
IndexType* src_index,
int num,
int monotonic_type) {
// todo: set num = Pow2(num)
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
__shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than
// blockDim * 2
__shared__ IndexType index[SHARED_SIZE_LIMIT];
// Copy value and index from src and src_index
value[threadIdx.x] = src_data[0];
value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1];
// index
index[threadIdx.x] = src_index[0];
index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_index[1];
// make bitonicSort
for (int size = 2; size < num; size <<= 1) {
int bitonic_type = (threadIdx.x & (size / 2)) != 0;
for (int stride = size / 2; stride > 0; stride >>= 1) {
__syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
details::ComparatorWithIndex<T, IndexType>(&value[pos],
&value[pos + stride],
&index[pos],
&index[pos + stride],
bitonic_type);
}
}
for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) {
__syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
// last sort when monotonic_type = 1 then increase
details::ComparatorWithIndex<T, IndexType>(&value[pos],
&value[pos + stride],
&index[pos],
&index[pos + stride],
monotonic_type);
}
__syncthreads();
dst[0] = value[threadIdx.x];
dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)];
dst_index[0] = index[threadIdx.x];
dst_index[1] = index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)];
}
} // namespace kps } // namespace kps
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册