未验证 提交 255bf609 编写于 作者: N niuliling123 提交者: GitHub

Add function description for Kernel Primitive API (#39884)

* Add function description for Kernel Primitive API
1. Set cumsum and sort share memory size = 1024
2.sort and cumsum api limitation : blockDim.x must be less than 512 (blockDim.x <= 512)
上级 2592805b
...@@ -136,7 +136,9 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) { ...@@ -136,7 +136,9 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
return shared_memory[threadIdx.x]; return shared_memory[threadIdx.x];
} }
// Swap data /**
* @brief Swap data
*/
template <typename T> template <typename T>
__device__ __forceinline__ void Swap(T* first_value, T* second_value) { __device__ __forceinline__ void Swap(T* first_value, T* second_value) {
T t_value; T t_value;
...@@ -145,7 +147,9 @@ __device__ __forceinline__ void Swap(T* first_value, T* second_value) { ...@@ -145,7 +147,9 @@ __device__ __forceinline__ void Swap(T* first_value, T* second_value) {
(*second_value) = t_value; (*second_value) = t_value;
} }
// swap with monotonic_type /**
* @brief Swap data according to monotonic_type.
*/
template <typename T> template <typename T>
__device__ __forceinline__ void Comparator(T* first_value, __device__ __forceinline__ void Comparator(T* first_value,
T* second_value, T* second_value,
...@@ -155,6 +159,9 @@ __device__ __forceinline__ void Comparator(T* first_value, ...@@ -155,6 +159,9 @@ __device__ __forceinline__ void Comparator(T* first_value,
} }
} }
/**
* @brief Swap data and data index according to monotonic_type.
*/
template <typename T, typename IndexType> template <typename T, typename IndexType>
__device__ __forceinline__ void ComparatorWithIndex(T* first_value, __device__ __forceinline__ void ComparatorWithIndex(T* first_value,
...@@ -170,6 +177,18 @@ __device__ __forceinline__ void ComparatorWithIndex(T* first_value, ...@@ -170,6 +177,18 @@ __device__ __forceinline__ void ComparatorWithIndex(T* first_value,
} }
} }
/**
* @brief get the last pow of 2
*/
__device__ inline int GetLastPow2(int n) {
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return std::max(1, n - (n >> 1));
}
} // namespace details } // namespace details
/** /**
...@@ -453,6 +472,29 @@ __device__ __forceinline__ void Reduce(T* out, ...@@ -453,6 +472,29 @@ __device__ __forceinline__ void Reduce(T* out,
} }
} }
/*
* @brief Fill register with a constant according to OpFunc
*
* @template paraments
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. Currently only
* GPU was supported.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE InT operator()()
* const {
* return a;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT>().
*/
template <typename InT, template <typename InT,
typename OutT, typename OutT,
int NX, int NX,
...@@ -466,6 +508,33 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) { ...@@ -466,6 +508,33 @@ __device__ __forceinline__ void ElementwiseConstant(OutT* out, OpFunc compute) {
} }
} }
/*
* @brief Get ReturnsCount random data fromm compute according to state, state
* can be curandStatePhilox4_32_10_t, hiprandStatePhilox4_32_10_t which has beed
* initialized.
*
* @template paraments
* StateType: the type of state, can be curandStatePhilox4_32_10_t or
* hiprandStatePhilox4_32_10_t.
* OutT: the type of out register.
* ReturnsCount: The number of random data generated by OpFunc.
* BlockSize: Identifies the current device thread index method. Currently only
* GPU was supported.
* OpFunc: Compute functor which has an operator() as following
* template <typename T>
* struct XxxFunctor {
* HOSTDEVICE InT operator()(StateType state)
* const {
* return ranomd(state); // Returns ReturnsCount random numbers with
* data type T
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* compute: Compute function which was declared like OpFunc<T>().
*/
template <typename StateType, template <typename StateType,
typename OutT, typename OutT,
int ReturnsCount, int ReturnsCount,
...@@ -481,131 +550,208 @@ __device__ __forceinline__ void ElementwiseRandom(OutT* out, ...@@ -481,131 +550,208 @@ __device__ __forceinline__ void ElementwiseRandom(OutT* out,
} }
} }
// attention please set share_size = blockDim.x; /*
// data and b are the register pointer * @brief Complete the prefix and in the block, each thread calculates 2 data,
#define shared_size 64 * the size of out and in is 2, and BlockDim.x must be less then 512.
template <typename InT, *
typename OutT, * @template paraments
int NX, * InT: the type of input register.
int NY, * OutT: the type of out register.
int BlockSize, * BlockSize: Identifies the current device thread index method. Currently only
class OpFunc> * GPU was supported.
* OpFunc: Compute functor which has an operator() as following
* template <typename T>
* struct XxxFunctor {
* HOSTDEVICE InT operator()(T a, T b)
* const {
* return a + b;
* }
* };
*
* @param
* out: The register pointer of out, the size is 2;
* in: The register pointer of input, the size is 2;
* compute: Compute function which was declared like OpFunc<T>().
*/
#define SHARED_SIZE_LIMIT 512
template <typename InT, typename OutT, int BlockSize, class OpFunc>
__device__ __forceinline__ void Cumsum(OutT* out, __device__ __forceinline__ void Cumsum(OutT* out,
const InT* in, const InT* in,
OpFunc compute) { OpFunc compute) {
__shared__ InT temp[shared_size * 2 + (shared_size * 2) / 32]; constexpr int kSize = SHARED_SIZE_LIMIT * 2 + (SHARED_SIZE_LIMIT * 2) / 32;
__shared__ InT temp[kSize];
int stride_size = blockDim.x;
int tidx = threadIdx.x; int tidx = threadIdx.x;
temp[tidx + tidx / 32] = in[0]; temp[tidx + tidx / 32] = in[0];
temp[shared_size + tidx + (shared_size + tidx) / 32] = in[1]; temp[stride_size + tidx + (stride_size + tidx) / 32] = in[1];
for (int stride = 1; stride <= blockDim.x; stride *= 2) { for (int stride = 1; stride <= stride_size; stride *= 2) {
__syncthreads(); __syncthreads();
int index = (tidx + 1) * 2 * stride - 1; int index = (tidx + 1) * 2 * stride - 1;
if (index < (blockDim.x * 2)) { if (index < (blockDim.x * 2)) {
temp[index + index / 32] += temp[index - stride + (index - stride) / 32]; temp[index + index / 32] =
compute(temp[index + index / 2],
temp[index - stride + (index - stride) / 32]);
} }
} }
for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) { for (int stride = (blockDim.x * 2) / 4; stride > 0; stride /= 2) {
__syncthreads(); __syncthreads();
int index = (tidx + 1) * 2 * stride - 1; int index = (tidx + 1) * 2 * stride - 1;
if ((index + stride) < (blockDim.x * 2)) { if ((index + stride) < (blockDim.x * 2)) {
temp[index + stride + (stride + index) / 32] += temp[index + stride + (stride + index) / 32] =
temp[index + (index) / 32]; compute(temp[index + stride + (stride + index) / 32],
temp[index + (index) / 32]);
} }
} }
__syncthreads(); __syncthreads();
out[0] = static_cast<OutT>(temp[tidx + tidx / 32]); out[0] = static_cast<OutT>(temp[tidx + tidx / 32]);
out[1] = out[1] =
static_cast<OutT>(temp[tidx + shared_size + (tidx + shared_size) / 32]); static_cast<OutT>(temp[tidx + stride_size + (tidx + stride_size) / 32]);
} }
#undef SHARED_SIZE_LIMIT
#define SHARED_SIZE_LIMIT \
1024 // each thread load 2 data from global memory so SHARED_SIZE_LIMIT must /*
// larger than blockDim.x * 2 * @brief Sort data in this block, each thread calculates 2 data, the size of out
// if monotonic_type = 1 then increase * and in is 2, and BlockDim.x must be less then 512.
// if gridDim.x > 1 please set monotonic_type = blockIdx.x & 1; blockIdx.x % 2 *
// == 1 the increase * @template paraments
template <typename T> * InT: the type of input register.
__device__ __forceinline__ void Sort(T* dst, * OutT: the type of out register.
const T* src_data, * BlockSize: Identifies the current device thread index method. Currently only
* GPU was supported.
*
* @param
* out: The register pointer of out, the size is 2.
* in: The register pointer of input, the size is 2.
* num: The num of this block
* monotonic_type: if monotonic_type = 1 then sorted in ascending order, eles
* sorted in escending.
*/
#define SHARED_SIZE_LIMIT 1024
// each thread load 2 data from global memory so SHARED_SIZE_LIMIT must
// larger than blockDim.x * 2
template <typename InT, typename OutT, int BlockSize>
__device__ __forceinline__ void Sort(OutT* out,
const InT* in,
int num, int num,
int monotonic_type) { int monotonic_type) {
// todo: set num = Pow2(num) int upper_bound = blockDim.x;
// update upper_bound
upper_bound = std::min(details::GetLastPow2(num), upper_bound);
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2 // shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
__shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than __shared__ InT value[SHARED_SIZE_LIMIT];
// blockDim * 2 int stride_size = blockDim.x;
// Copy value and index from src and src_index // shareMem's size must larger than blockDim * 2
value[threadIdx.x] = src_data[0]; // Copy value from in
value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1]; value[threadIdx.x] = in[0];
value[threadIdx.x + stride_size] = in[1];
// make bitonicSort // make bitonicSort
for (int size = 2; size < num; size <<= 1) { for (int size = 2; size < upper_bound; size <<= 1) {
int bitonic_type = (threadIdx.x & (size / 2)) != 0; int bitonic_type = (threadIdx.x & (size / 2)) != 0;
for (int stride = size / 2; stride > 0; stride >>= 1) { for (int stride = size / 2; stride > 0; stride >>= 1) {
__syncthreads(); __syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
details::Comparator<T>(&value[pos], &value[pos + stride], bitonic_type); details::Comparator<InT>(&value[pos], &value[pos + stride], bitonic_type);
} }
} }
// last sort // last sort
for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) { for (int stride = stride_size; stride > 0; stride >>= 1) {
__syncthreads(); __syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
// last sort when monotonic_type = 1 then increase // last sort when monotonic_type = 1 then increase
details::Comparator<T>(&value[pos], &value[pos + stride], monotonic_type); details::Comparator<InT>(&value[pos], &value[pos + stride], monotonic_type);
} }
__syncthreads(); __syncthreads();
dst[0] = value[threadIdx.x]; out[0] = static_cast<OutT>(value[threadIdx.x]);
dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)]; out[1] = static_cast<OutT>(value[threadIdx.x + stride_size]);
} }
template <typename T, typename IndexType> /*
__device__ __forceinline__ void Sort(T* dst, * @brief Sort data with data_index in this block, each thread calculates 2 data,
IndexType* dst_index, * the size of out and in is 2, and BlockDim.x must be less then 512.
const T* src_data, *
IndexType* src_index, * @template paraments
* InT: The type of input register.
* OutT: The type of out register.
* IndexType: The type of index.
* BlockSize: Identifies the current device thread index method. Currently only
* GPU was supported.
*
* @param
* out: The register pointer of out, the size is 2.
* out_index: The register pointer of out_index, the size is 2.
* in: The register pointer of input, the size is 2.
* in_index: The register pointer of in_index, the size is 2.
* num: The num of this block.
* monotonic_type: if monotonic_type = 1 then sorted in ascending order, eles
* sorted in escending.
*/
template <typename InT, typename OutT, typename IndexType, int BlockSize>
__device__ __forceinline__ void Sort(OutT* out,
IndexType* out_index,
const InT* in,
IndexType* in_index,
int num, int num,
int monotonic_type) { int monotonic_type) {
// todo: set num = Pow2(num) int upper_bound = blockDim.x;
// update upper_bound
upper_bound = std::min(details::GetLastPow2(num), upper_bound);
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2 // shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
__shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than __shared__ InT value[SHARED_SIZE_LIMIT];
// blockDim * 2 // shareMem's size must larger than blockDim * 2
__shared__ IndexType index[SHARED_SIZE_LIMIT]; __shared__ IndexType index[SHARED_SIZE_LIMIT];
// Copy value and index from src and src_index // Copy value and index from in and in_index
value[threadIdx.x] = src_data[0]; int stride_size = blockDim.x;
value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1]; value[threadIdx.x] = in[0];
value[threadIdx.x + stride_size] = in[1];
// index // index
index[threadIdx.x] = src_index[0]; index[threadIdx.x] = in_index[0];
index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_index[1]; index[threadIdx.x + stride_size] = in_index[1];
// make bitonicSort // make bitonicSort
for (int size = 2; size < num; size <<= 1) { for (int size = 2; size < upper_bound; size <<= 1) {
int bitonic_type = (threadIdx.x & (size / 2)) != 0; int bitonic_type = (threadIdx.x & (size / 2)) != 0;
for (int stride = size / 2; stride > 0; stride >>= 1) { for (int stride = size / 2; stride > 0; stride >>= 1) {
__syncthreads(); __syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
details::ComparatorWithIndex<T, IndexType>(&value[pos], details::ComparatorWithIndex<InT, IndexType>(&value[pos],
&value[pos + stride], &value[pos + stride],
&index[pos], &index[pos],
&index[pos + stride], &index[pos + stride],
bitonic_type); bitonic_type);
} }
} }
for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) { for (int stride = stride_size; stride > 0; stride >>= 1) {
__syncthreads(); __syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1)); int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
// last sort when monotonic_type = 1 then increase // last sort when monotonic_type = 1 then increase
details::ComparatorWithIndex<T, IndexType>(&value[pos], details::ComparatorWithIndex<InT, IndexType>(&value[pos],
&value[pos + stride], &value[pos + stride],
&index[pos], &index[pos],
&index[pos + stride], &index[pos + stride],
monotonic_type); monotonic_type);
} }
__syncthreads(); __syncthreads();
dst[0] = value[threadIdx.x]; out[0] = static_cast<OutT>(value[threadIdx.x]);
dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)]; out[1] = static_cast<OutT>(value[threadIdx.x + stride_size]);
dst_index[0] = index[threadIdx.x]; out_index[0] = index[threadIdx.x];
dst_index[1] = index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)]; out_index[1] = index[threadIdx.x + stride_size];
}
template <typename T1, typename T2, typename OutT, typename OpFunc>
HOSTDEVICE __forceinline__ void OperatorTernary(
OutT* out, const T1* in1, const T2* in2, OpFunc func, int num) {
func(out, in1, in2, num);
}
template <typename InT, typename OutT, typename OpFunc>
HOSTDEVICE __forceinline__ void OperatorBinary(OutT* out,
const InT* in,
OpFunc func,
int num) {
func(out, in, num);
} }
} // namespace kps } // namespace kps
......
...@@ -348,6 +348,29 @@ __device__ __forceinline__ void Reduce(T* out, ...@@ -348,6 +348,29 @@ __device__ __forceinline__ void Reduce(T* out,
} }
} }
/*
* @brief Fill register with a constant according to OpFunc
*
* @template paraments
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE InT operator()()
* const {
* return a;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT>().
*/
template <typename InT, template <typename InT,
typename OutT, typename OutT,
int NX, int NX,
......
...@@ -297,6 +297,24 @@ __device__ __forceinline__ void ReadData(T* dst, ...@@ -297,6 +297,24 @@ __device__ __forceinline__ void ReadData(T* dst,
/** /**
* @brief Read 1D data from global memory to register. The difference * @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs. * from the above function is that it supports different data types of inputs.
*
* @template paraments
* T: The type of data.
* NX: Each thread load NX data from global memory continuously.
* NY: Each thread need to load NY rows, only NY = 1 was supported.
* ArgsT: The Type if dst, ArgsT can be std::tuple<T> or std::tuple<Args>
* Index: The index of data stored in dst.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
* When the number of data processed by this block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The data pointer of the current block.
* size: The current block needs to load size data continuously.
*/ */
template <typename T, template <typename T,
int NX, int NX,
...@@ -714,6 +732,20 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -714,6 +732,20 @@ __device__ __forceinline__ void ReadDataBc(
} }
} }
/**
* @brief Initialize register with data index.
*
* @template paraments
* T: Data type of register.
* NX: Number of data to initialize.
* NY: Number of data to initialize, NY only can be 1.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index. Currently only GPU was supported.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* init_data: The register pointer of init data, the size is NX.
*/
template <typename T, int NX, int NY, int BlockSize> template <typename T, int NX, int NY, int BlockSize>
__device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) { __device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) {
int thread_offset = block_offset + threadIdx.x * NX; int thread_offset = block_offset + threadIdx.x * NX;
......
...@@ -244,6 +244,24 @@ __device__ __inline__ void ReadData(T* dst, ...@@ -244,6 +244,24 @@ __device__ __inline__ void ReadData(T* dst,
/** /**
* @brief Read 1D data from global memory to register. The difference * @brief Read 1D data from global memory to register. The difference
* from the above function is that it supports different data types of inputs. * from the above function is that it supports different data types of inputs.
*
* @template paraments
* T: The type of data.
* NX: Each thread load NX data from global memory continuously.
* NY: Each thread need to load NY rows, only NY = 1 was supported.
* ArgsT: The Type if dst, ArgsT can be std::tuple<T> or std::tuple<Args>
* Index: The index of data stored in dst.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
* IsBoundary: Whether to make an out-of-bounds judgment on access to memory.
* When the number of data processed by this block is less than
* NX x NY x blockDim.x, boundary judgment is required to avoid memory access
* crossing the boundary.
*
* @param:
* dst: The register pointer of the thread, the size is NX * NY.
* src: The data pointer of the current block.
* size: The current block needs to load size data continuously.
*/ */
template <typename T, template <typename T,
int NX, int NX,
...@@ -646,5 +664,28 @@ __device__ __inline__ void ReadDataBc( ...@@ -646,5 +664,28 @@ __device__ __inline__ void ReadDataBc(
} }
} }
/**
* @brief Initialize register with data index.
*
* @template paraments
* T: Data type of register.
* NX: Number of data to initialize.
* NY: Number of data to initialize, NY only can be 1.
* BlockSize: Identifies the current device thread index method. For xpu,
* core_id() is used as the index.
*
* @param:
* dst: The register pointer of the thread, the size is NX.
* init_data: The register pointer of init data, the size is NX.
*/
template <typename T, int NX, int NY, int BlockSize>
__device__ __forceinline__ void InitWithDataIndex(T* dst, int block_offset) {
int thread_offset = block_offset + core_id() * NX;
#pragma unroll
for (int nx = 0; nx < NX; ++nx) {
dst[nx] = static_cast<T>(thread_offset + nx);
}
}
} // namespace kps } // namespace kps
} // namespace phi } // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册