未验证 提交 82b33be3 编写于 作者: N niuliling123 提交者: GitHub

Modify the reduce op according to the kernel primitive api (#35282)

上级 7aa4d879
...@@ -202,9 +202,9 @@ void SetConfigForColumnReduce(const int max_threads, const int reduce_num, ...@@ -202,9 +202,9 @@ void SetConfigForColumnReduce(const int max_threads, const int reduce_num,
int num_block = (max_threads / left_num); int num_block = (max_threads / left_num);
if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) { if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
*blocking_size = detail::GetLastPow2(reduce_num / num_block); *blocking_size = details::GetLastPow2(reduce_num / num_block);
if (*blocking_size <= 1) { if (*blocking_size <= 1) {
*blocking_size = detail::GetLastPow2(sqrt(reduce_num)); *blocking_size = details::GetLastPow2(sqrt(reduce_num));
} else if (*blocking_size * 2 < reduce_num) { } else if (*blocking_size * 2 < reduce_num) {
*blocking_size *= 2; *blocking_size *= 2;
} }
......
...@@ -31,13 +31,15 @@ namespace kernel_primitives { ...@@ -31,13 +31,15 @@ namespace kernel_primitives {
namespace details { namespace details {
#ifdef __HIPCC__ #ifdef __HIPCC__
constexpr int kMaxThread = 256; constexpr int kReduceMaxThread = 256;
constexpr int kWarpSize = 64; constexpr int kWarpSize = 64;
#else #else
constexpr int kMaxThread = 128; constexpr int kReduceMaxThread = 128;
constexpr int kWarpSize = 32; constexpr int kWarpSize = 32;
#endif #endif
// kGlobalMode: block reduce, each block gets an output;
// kLocalMode: thread reduce, each thread gets an output;
enum ReduceMode { kGlobalMode, kLocalMode }; enum ReduceMode { kGlobalMode, kLocalMode };
template <typename T> template <typename T>
...@@ -118,7 +120,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { ...@@ -118,7 +120,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
*/ */
template <typename T, typename ReduceOp> template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) { __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
__shared__ T shared_memory[details::kMaxThread]; __shared__ T shared_memory[details::kReduceMaxThread];
shared_memory[SharedMemoryIndex(0)] = val; shared_memory[SharedMemoryIndex(0)] = val;
for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) { for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
__syncthreads(); __syncthreads();
......
...@@ -124,36 +124,36 @@ struct BroadcastConfig { ...@@ -124,36 +124,36 @@ struct BroadcastConfig {
template <typename Tx, typename Ty, int NX, int NY, int BlockSize> template <typename Tx, typename Ty, int NX, int NY, int BlockSize>
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
int stride_nx, int stride_ny) { int stride_nx, int stride_ny) {
int thread_offset = threadIdx.x * NX;
if (NY == 1 && NX == 1) { if (NY == 1 && NX == 1) {
dst[0] = static_cast<Ty>(src[threadIdx.x]); dst[0] = static_cast<Ty>(src[thread_offset]);
} else if (NX == 1) { } else if (NX == 1) {
int dx = threadIdx.x;
#pragma unroll #pragma unroll
for (int idy = 0; idy < NY; ++idy) { for (int idy = 0; idy < NY; ++idy) {
dst[idy] = static_cast<Ty>(src[dx + idy * stride_ny]); dst[idy] = static_cast<Ty>(src[thread_offset + idy * stride_ny]);
} }
} else if (NY == 1) { } else if (NY == 1) {
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
dst[idx] = static_cast<Ty>(src[idx * stride_nx]); dst[idx] = static_cast<Ty>(src[thread_offset + idx * stride_nx]);
} }
} else { } else {
int dx = threadIdx.x * NX;
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
#pragma unroll #pragma unroll
for (int idy = 0; idy < NY; ++idy) { for (int idy = 0; idy < NY; ++idy) {
dst[idy * NX + idx] = dst[idy * NX + idx] = static_cast<Ty>(
static_cast<Ty>(src[idx * stride_nx + dx + idy * stride_ny]); src[thread_offset + idx * stride_nx + idy * stride_ny]);
} }
} }
} }
} }
/** /**
* @brief load data from src to dst, src can be 1D data or 2D data. When * @brief load data from src to dst with stride, src can be 1D data or 2D data.
* boundary judgment is required, you need to set a to true, and a is false by * When boundary judgment is required, you need to set a to true, and a is false
* default. * by default.
* @typename: * @typename:
* Tx: data type of src * Tx: data type of src
* Ty: data type of dstt * Ty: data type of dstt
...@@ -172,17 +172,17 @@ template <typename Tx, typename Ty, int NX, int NY, int BlockSize, ...@@ -172,17 +172,17 @@ template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
int size_nx, int size_ny, int size_nx, int size_ny,
int stride_nx, int stride_ny) { int stride_nx, int stride_ny) {
int dx = threadIdx.x * NX; int thread_offset = threadIdx.x * NX;
int size = size_nx - dx; int left_size_nx = size_nx - thread_offset;
// Each branch is added for better performance // Each branch is added for better performance
if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1 if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1
if (IsBoundary) { if (IsBoundary) {
if (dx < size_nx) { if (left_size_nx > 0) {
dst[0] = static_cast<Ty>(src[dx]); dst[0] = static_cast<Ty>(src[thread_offset]);
} }
} else { } else {
dst[0] = static_cast<Ty>(src[dx]); dst[0] = static_cast<Ty>(src[thread_offset]);
} }
} else if (NX == 1) { // for NX == 1 and NY != 1 } else if (NX == 1) { // for NX == 1 and NY != 1
#pragma unroll #pragma unroll
...@@ -192,23 +192,23 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, ...@@ -192,23 +192,23 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
break; break;
} }
} }
dst[idy] = static_cast<Ty>(src[dx + idy * stride_ny]); dst[idy] = static_cast<Ty>(src[thread_offset + idy * stride_ny]);
} }
} else if (NY == 1) { // for NY == 1 and NX != 1 } else if (NY == 1) { // for NY == 1 and NX != 1
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
if (IsBoundary) { if (IsBoundary) {
if (idx >= size) { if (idx >= left_size_nx) {
break; break;
} }
} }
dst[idx] = static_cast<Ty>(src[idx * stride_nx + dx]); dst[idx] = static_cast<Ty>(src[thread_offset + idx * stride_nx]);
} }
} else { // for NX != 1 and NY != 1 } else { // for NX != 1 and NY != 1
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
if (IsBoundary) { if (IsBoundary) {
if (idx >= size) { if (idx >= left_size_nx) {
break; break;
} }
} }
...@@ -219,8 +219,8 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, ...@@ -219,8 +219,8 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
break; break;
} }
} }
dst[idy * NX + idx] = dst[idy * NX + idx] = static_cast<Ty>(
static_cast<Ty>(src[idx * stride_nx + dx + idy * stride_ny]); src[thread_offset + idx * stride_nx + idy * stride_ny]);
} }
} }
} }
...@@ -251,17 +251,17 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false> ...@@ -251,17 +251,17 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
int num) { int num) {
if (IsBoundary) { // blockDim.x * NX > num if (IsBoundary) { // blockDim.x * NX > num
int dx = threadIdx.x * NX; int thread_offset = threadIdx.x * NX;
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
if (idx + dx < num) { if (idx + thread_offset < num) {
dst[idx] = src[idx + dx]; dst[idx] = src[thread_offset + idx];
} }
} }
} else { // blockDim,x * NX < num } else { // blockDim,x * NX < num
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int kVectorsPerThread = NX / kVectorSize; const int kVectorsPerThread = NX / kVectorSize;
int tid = threadIdx.x * kVectorsPerThread; int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>; using VecType = details::VectorType<T, kVectorSize>;
const VecType* vec_input = reinterpret_cast<const VecType*>(src); const VecType* vec_input = reinterpret_cast<const VecType*>(src);
...@@ -269,7 +269,7 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, ...@@ -269,7 +269,7 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
#pragma unroll #pragma unroll
for (int i = 0; i < kVectorsPerThread; ++i) { for (int i = 0; i < kVectorsPerThread; ++i) {
vec_temp[i] = vec_input[i + tid]; vec_temp[i] = vec_input[thread_offset + i];
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx); dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
...@@ -289,39 +289,39 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, ...@@ -289,39 +289,39 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
* is 2 * is 2
* IsBoundary: whether to make boundary judgment * IsBoundary: whether to make boundary judgment
* @param: * @param:
* fix: data offset of this block, blockDim.x * blockIdx.x * NX; * block_offset: data offset of this block, blockDim.x * blockIdx.x * NX;
* config: get the global index in src, attention config was declared in host; * config: get the global index in src, attention config was declared in host;
* num: the num of out * total_num_output: total num of output
* stride_nx: the stride of cols * stride_nx: the stride of cols
* stride_ny: the stride of rows * stride_ny: the stride of rows
*/ */
template <typename T, int NX, int NY, int BlockSize, int ShapeSize, template <typename T, int NX, int NY, int BlockSize, int ShapeSize,
bool IsBoundary = false> bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc( __device__ __forceinline__ void ReadDataBc(
T* dst, const T* __restrict__ src, uint32_t fix, T* dst, const T* __restrict__ src, uint32_t block_offset,
details::BroadcastConfig<ShapeSize> config, int num, int stride_nx, details::BroadcastConfig<ShapeSize> config, int total_num_output,
int stride_ny) { int stride_nx, int stride_ny) {
uint32_t base_offset = fix + threadIdx.x * NX; uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t offset = 0; uint32_t index_src = 0;
#pragma unroll #pragma unroll
for (int ny = 0; ny < NY; ++ny) { for (int ny = 0; ny < NY; ++ny) {
#pragma unroll #pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) { for (uint32_t nx = 0; nx < NX; ++nx) {
uint32_t idx = base_offset + ny * stride_ny + nx * stride_nx; uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx;
index_src = 0;
if (IsBoundary) { if (IsBoundary) {
if (idx >= num) { if (index_output >= total_num_output) {
break; break;
} }
} }
offset = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < ShapeSize; ++i) { for (int i = 0; i < ShapeSize; ++i) {
auto fast_divmoder = config.divmoders[i].Divmod(idx); auto fast_divmoder = config.divmoders[i].Divmod(index_output);
idx = fast_divmoder.val[0]; index_output = fast_divmoder.val[0];
offset += fast_divmoder.val[1] * config.strides[i]; index_src += fast_divmoder.val[1] * config.strides[i];
} }
dst[nx + ny * NX] = src[offset]; dst[nx + ny * NX] = src[index_src];
} }
} }
} }
...@@ -338,7 +338,7 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -338,7 +338,7 @@ __device__ __forceinline__ void ReadDataBc(
* IndexCal: get the global index in src, attention config was declared in host; * IndexCal: get the global index in src, attention config was declared in host;
* IsBoundary: whether to make boundary judgment * IsBoundary: whether to make boundary judgment
* @param: * @param:
* fix: data offset of this block, blockDim.x * blockIdx.x * NX; * block_offset: data offset of this block, blockDim.x * blockIdx.x * NX;
* index_cal: get the global index in src, attention config was declared in * index_cal: get the global index in src, attention config was declared in
* host; * host;
* size_nx: number of columns to be processed by the current block * size_nx: number of columns to be processed by the current block
...@@ -350,27 +350,27 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -350,27 +350,27 @@ __device__ __forceinline__ void ReadDataBc(
template <typename T, int NX, int NY, int BlockSize, int ShapeSize, template <typename T, int NX, int NY, int BlockSize, int ShapeSize,
typename IndexCal, bool IsBoundary = false> typename IndexCal, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataReduce( __device__ __forceinline__ void ReadDataReduce(
T* dst, const T* __restrict__ src, int fix, const IndexCal& index_cal, T* dst, const T* __restrict__ src, int block_offset,
int size_nx, int size_ny, int stride_nx, int stride_ny, const IndexCal& index_cal, int size_nx, int size_ny, int stride_nx,
bool reduce_last_dim) { int stride_ny, bool reduce_last_dim) {
int base_offset = fix; int thread_offset = 0;
if (reduce_last_dim) { if (reduce_last_dim) {
base_offset += threadIdx.x; thread_offset = block_offset + threadIdx.x;
} else { } else {
base_offset += threadIdx.y; thread_offset = block_offset + threadIdx.y;
} }
if (NX == 1) { if (NX == 1) {
#pragma unroll #pragma unroll
for (int ny = 0; ny < NY; ++ny) { for (int ny = 0; ny < NY; ++ny) {
if (IsBoundary) { if (IsBoundary) {
if (base_offset >= size_ny) { if (thread_offset >= size_ny) {
break; break;
} }
} }
uint32_t offset = index_cal(base_offset); uint32_t index_src = index_cal(thread_offset);
dst[ny] = src[offset]; dst[ny] = src[index_src];
base_offset += stride_ny; thread_offset += stride_ny;
} }
} else { } else {
#pragma unroll #pragma unroll
...@@ -387,15 +387,16 @@ __device__ __forceinline__ void ReadDataReduce( ...@@ -387,15 +387,16 @@ __device__ __forceinline__ void ReadDataReduce(
break; break;
} }
} }
uint32_t offset = index_cal(base_offset); uint32_t index_src = index_cal(thread_offset);
dst[nx + ny * NX] = src[offset]; dst[nx + ny * NX] = src[index_src];
base_offset += stride_ny; thread_offset += stride_ny;
} }
thread_offset += stride_nx;
} }
} }
} }
/** @brief: WriteData /**
* @brief store data from src to dst, src can be 1D data, you should set NY = 1. * @brief store data from src to dst, src can be 1D data, you should set NY = 1.
* When boundary judgment is required, you need to set a to true, and a is false * When boundary judgment is required, you need to set a to true, and a is false
* by default. * by default.
...@@ -412,11 +413,11 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false> ...@@ -412,11 +413,11 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
int num) { int num) {
if (IsBoundary) { if (IsBoundary) {
int dx = threadIdx.x * NX; int thread_offset = threadIdx.x * NX;
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
if ((idx + dx) < num) { if ((thread_offset + idx) < num) {
dst[idx + dx] = src[idx]; dst[thread_offset + idx] = src[idx];
} }
} }
} else { } else {
...@@ -424,14 +425,14 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, ...@@ -424,14 +425,14 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int kVectorsPerThread = NX / kVectorSize; const int kVectorsPerThread = NX / kVectorSize;
int dx = threadIdx.x * kVectorsPerThread; int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>; using VecType = details::VectorType<T, kVectorSize>;
VecType* vec_dst = reinterpret_cast<VecType*>(dst); VecType* vec_dst = reinterpret_cast<VecType*>(dst);
VecType vec_temp[kVectorsPerThread]; VecType vec_temp[kVectorsPerThread];
#pragma unroll #pragma unroll
for (int idx = 0; idx < kVectorsPerThread; ++idx) { for (int idx = 0; idx < kVectorsPerThread; ++idx) {
vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx); vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
vec_dst[dx + idx] = vec_temp[idx]; vec_dst[thread_offset + idx] = vec_temp[idx];
} }
} }
} }
......
...@@ -32,7 +32,6 @@ static __device__ __forceinline__ platform::float16 LogFunctor( ...@@ -32,7 +32,6 @@ static __device__ __forceinline__ platform::float16 LogFunctor(
static __device__ __forceinline__ float LogFunctor(float x) { return logf(x); } static __device__ __forceinline__ float LogFunctor(float x) { return logf(x); }
static __device__ __forceinline__ double LogFunctor(double x) { return log(x); } static __device__ __forceinline__ double LogFunctor(double x) { return log(x); }
} // namespace details
/*************************** Compute Functor****************************/ /*************************** Compute Functor****************************/
// for margin_cross_entropy // for margin_cross_entropy
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
...@@ -75,6 +74,7 @@ struct DivideFunctor { ...@@ -75,6 +74,7 @@ struct DivideFunctor {
T n_inv; T n_inv;
}; };
} // namespace details
} // namespace kernel_primitives } // namespace kernel_primitives
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -128,39 +128,9 @@ __global__ void AddMarginToPositiveLogitsKernel( ...@@ -128,39 +128,9 @@ __global__ void AddMarginToPositiveLogitsKernel(
} }
} }
static __device__ __forceinline__ platform::float16 exp_on_device(
platform::float16 x) {
return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float exp_on_device(float x) {
return expf(x);
}
static __device__ __forceinline__ double exp_on_device(double x) {
return exp(x);
}
static __device__ __forceinline__ platform::float16 log_on_device(
platform::float16 x) {
return ::Eigen::numext::log(x);
}
static __device__ __forceinline__ float log_on_device(float x) {
return logf(x);
}
static __device__ __forceinline__ double log_on_device(double x) {
return log(x);
}
template <typename Tx, typename Ty = Tx>
struct ExpLogitTransformer {
HOSTDEVICE explicit inline ExpLogitTransformer(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(exp_on_device(x));
}
};
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct ExpAndSum { struct ExpAndSum {
using Transformer = ExpLogitTransformer<Tx>; using Transformer = kpds::ExpLogitTransformer<Tx>;
inline Ty initial() { return static_cast<Ty>(0.0f); } inline Ty initial() { return static_cast<Ty>(0.0f); }
...@@ -189,7 +159,7 @@ __global__ void LogitsMinusLogSumKernel(T* logits, const T* logits_sum_per_row, ...@@ -189,7 +159,7 @@ __global__ void LogitsMinusLogSumKernel(T* logits, const T* logits_sum_per_row,
const int64_t N, const int64_t D) { const int64_t N, const int64_t D) {
CUDA_KERNEL_LOOP(i, N * D) { CUDA_KERNEL_LOOP(i, N * D) {
auto row = i / D; auto row = i / D;
logits[i] -= log_on_device(logits_sum_per_row[row]); logits[i] -= kpds::LogFunctor(logits_sum_per_row[row]);
} }
} }
...@@ -204,9 +174,9 @@ __global__ void HardLabelSoftmaxWithCrossEntropyKernel( ...@@ -204,9 +174,9 @@ __global__ void HardLabelSoftmaxWithCrossEntropyKernel(
if ((col + start_index) == labels[row]) { if ((col + start_index) == labels[row]) {
auto softmax = log_softmax[i]; auto softmax = log_softmax[i];
loss[row] = -softmax; loss[row] = -softmax;
log_softmax[i] = exp_on_device(softmax); log_softmax[i] = kpds::ExpFunctor(softmax);
} else { } else {
log_softmax[i] = exp_on_device(log_softmax[i]); log_softmax[i] = kpds::ExpFunctor(log_softmax[i]);
} }
} }
} }
......
...@@ -24,9 +24,11 @@ limitations under the License. */ ...@@ -24,9 +24,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace kpds = paddle::operators::kernel_primitives::details;
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomMin { struct CustomMin {
using Transformer = detail::IdentityFunctor<Tx>; using Transformer = kpds::IdentityFunctor<Tx>;
inline Ty initial() { inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<Ty>::max()); return static_cast<Ty>(std::numeric_limits<Ty>::max());
...@@ -39,7 +41,7 @@ struct CustomMin { ...@@ -39,7 +41,7 @@ struct CustomMin {
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomMax { struct CustomMax {
using Transformer = detail::IdentityFunctor<Tx>; using Transformer = kpds::IdentityFunctor<Tx>;
inline Ty initial() { inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<Ty>::lowest()); return static_cast<Ty>(std::numeric_limits<Ty>::lowest());
...@@ -53,7 +55,7 @@ struct CustomMax { ...@@ -53,7 +55,7 @@ struct CustomMax {
// for cub::Reduce // for cub::Reduce
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomSum { struct CustomSum {
using Transformer = detail::IdentityFunctor<Tx, Ty>; using Transformer = kpds::IdentityFunctor<Tx, Ty>;
inline Ty initial() { return static_cast<Ty>(0.0f); } inline Ty initial() { return static_cast<Ty>(0.0f); }
...@@ -64,7 +66,7 @@ struct CustomSum { ...@@ -64,7 +66,7 @@ struct CustomSum {
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomMean { struct CustomMean {
using Transformer = detail::DivideFunctor<Tx>; using Transformer = kpds::DivideFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(0.0f); } inline Ty initial() { return static_cast<Ty>(0.0f); }
...@@ -75,7 +77,7 @@ struct CustomMean { ...@@ -75,7 +77,7 @@ struct CustomMean {
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomMul { struct CustomMul {
using Transformer = detail::IdentityFunctor<Tx>; using Transformer = kpds::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(1.0f); } inline Ty initial() { return static_cast<Ty>(1.0f); }
...@@ -86,7 +88,7 @@ struct CustomMul { ...@@ -86,7 +88,7 @@ struct CustomMul {
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomLogicalOr { struct CustomLogicalOr {
using Transformer = detail::IdentityFunctor<Tx>; using Transformer = kpds::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(false); } inline Ty initial() { return static_cast<Ty>(false); }
...@@ -97,7 +99,7 @@ struct CustomLogicalOr { ...@@ -97,7 +99,7 @@ struct CustomLogicalOr {
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomLogicalAnd { struct CustomLogicalAnd {
using Transformer = detail::IdentityFunctor<Tx>; using Transformer = kpds::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(true); } inline Ty initial() { return static_cast<Ty>(true); }
......
...@@ -34,6 +34,7 @@ namespace cub = hipcub; ...@@ -34,6 +34,7 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/fast_divmod.h" #include "paddle/fluid/platform/fast_divmod.h"
...@@ -43,28 +44,10 @@ namespace cub = hipcub; ...@@ -43,28 +44,10 @@ namespace cub = hipcub;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace detail {
// Post processing function for sum, max, min, prod, any namespace kps = paddle::operators::kernel_primitives;
template <typename Tx, typename Ty = Tx>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const { namespace details {
return static_cast<Ty>(x);
}
};
// Post processing function for mean
template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
private:
T n_inv;
};
static inline int GetLastPow2(int n) { static inline int GetLastPow2(int n) {
n |= (n >> 1); n |= (n >> 1);
...@@ -90,17 +73,11 @@ static inline std::vector<int> GetDimStrides(const std::vector<int>& dims, ...@@ -90,17 +73,11 @@ static inline std::vector<int> GetDimStrides(const std::vector<int>& dims,
return strides; return strides;
} }
#ifdef __HIPCC__
constexpr int kMaxThread = 256;
constexpr int kWarpSize = 64;
#else
constexpr int kMaxThread = 128;
constexpr int kWarpSize = 32;
#endif
// get blockDim for reduceLastDim and reduceAny // get blockDim for reduceLastDim and reduceAny
static inline int GetBlockDim(int block_dim) { static inline int GetBlockDim(int block_dim) {
return block_dim >= kMaxThread ? kMaxThread : GetLastPow2(block_dim); return block_dim >= kps::details::kReduceMaxThread
? kps::details::kReduceMaxThread
: GetLastPow2(block_dim);
} }
// check reduce rand is valid // check reduce rand is valid
...@@ -140,7 +117,7 @@ static inline paddle::framework::Array<T, ElementCount> VectorToArray( ...@@ -140,7 +117,7 @@ static inline paddle::framework::Array<T, ElementCount> VectorToArray(
return ret; return ret;
} }
} // namespace detail } // namespace details
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
constexpr int kMaxRank = framework::DDim::kMaxRank; constexpr int kMaxRank = framework::DDim::kMaxRank;
...@@ -156,18 +133,18 @@ struct IndexCalculator { ...@@ -156,18 +133,18 @@ struct IndexCalculator {
const std::vector<int>& cal_strides, const std::vector<int>& cal_strides,
const std::vector<int>& full_strides) const std::vector<int>& full_strides)
: dim(dim) { : dim(dim) {
dims = detail::VectorToArray<int, kMaxRank>(cal_dims); dims = details::VectorToArray<int, kMaxRank>(cal_dims);
strides = detail::VectorToArray<int, kMaxRank>(full_strides); strides = details::VectorToArray<int, kMaxRank>(full_strides);
std::vector<platform::FastDivMod> cal_divmoders; std::vector<platform::FastDivMod> cal_divmoders;
// fast divmod // fast divmod
for (auto i : cal_strides) { for (auto i : cal_strides) {
cal_divmoders.push_back(platform::FastDivMod(i)); cal_divmoders.push_back(platform::FastDivMod(i));
} }
divmoders = divmoders =
detail::VectorToArray<platform::FastDivMod, kMaxRank>(cal_divmoders); details::VectorToArray<platform::FastDivMod, kMaxRank>(cal_divmoders);
} }
__device__ inline int Get(int offset) const { __device__ inline int operator()(int offset) const {
int index = 0; int index = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < kMaxRank; ++i) { for (int i = 0; i < kMaxRank; ++i) {
...@@ -187,6 +164,15 @@ struct IndexCalculator { ...@@ -187,6 +164,15 @@ struct IndexCalculator {
framework::Array<platform::FastDivMod, kMaxRank> divmoders; framework::Array<platform::FastDivMod, kMaxRank> divmoders;
}; };
// when reduce_type == kReduceLastDim this struct will be used
// for higher performance
struct LastDimIndexCal {
explicit LastDimIndexCal(int num) : stride(num) {}
__device__ inline int operator()(int index) const { return index * stride; }
int stride;
};
// reduce config // reduce config
template <typename Ty> template <typename Ty>
struct ReduceConfig { struct ReduceConfig {
...@@ -307,7 +293,7 @@ struct ReduceConfig { ...@@ -307,7 +293,7 @@ struct ReduceConfig {
left_dim.assign(left_set.begin(), left_set.end()); left_dim.assign(left_set.begin(), left_set.end());
// if the last dim gets involved in reduction // if the last dim gets involved in reduction
reduce_lastdim = (reduce_dim.back() == x_dim.size() - 1); reduce_last_dim = (reduce_dim.back() == x_dim.size() - 1);
} }
// set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny // set x_strides, reduce_strides, left_strides for reduceLastDim and reduceAny
...@@ -320,9 +306,9 @@ struct ReduceConfig { ...@@ -320,9 +306,9 @@ struct ReduceConfig {
idx_dim.push_back(i); idx_dim.push_back(i);
} }
x_strides = detail::GetDimStrides(x_dim, idx_dim); x_strides = details::GetDimStrides(x_dim, idx_dim);
reduce_strides = detail::GetDimStrides(x_dim, reduce_dim); reduce_strides = details::GetDimStrides(x_dim, reduce_dim);
left_strides = detail::GetDimStrides(x_dim, left_dim); left_strides = details::GetDimStrides(x_dim, left_dim);
reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]]; reduce_num = reduce_strides[0] * x_dim[reduce_dim[0]];
left_num = 1; left_num = 1;
...@@ -354,36 +340,36 @@ struct ReduceConfig { ...@@ -354,36 +340,36 @@ struct ReduceConfig {
void SetBlockDimForReduceAny(dim3* block_dim, dim3* grid_dim) { void SetBlockDimForReduceAny(dim3* block_dim, dim3* grid_dim) {
constexpr int min_reduce_num_per_thread = 16; constexpr int min_reduce_num_per_thread = 16;
constexpr int max_reduce_num_per_thread = 256; constexpr int max_reduce_num_per_thread = 256;
constexpr int max_num_threads = detail::kMaxThread; constexpr int max_num_threads = kps::details::kReduceMaxThread;
// set block size. // set block size.
// 1. If reduce_lastdim == true, all the threads whose threadIdx.y are same // 1. If reduce_last_dim == true, all the threads whose threadIdx.y are same
// will process the reduction for one output. // will process the reduction for one output.
// The number of output for one block is blockDim.y; // The number of output for one block is blockDim.y;
// 2. If reduce_lastdim == false, different threadIdx.x will process // 2. If reduce_last_dim == false, different threadIdx.x will process
// different reduction and gets the output separately. If it is // different reduction and gets the output separately. If it is
// necessary, it should reduce in block y. // necessary, it should reduce in block y.
// The number of output for one block is blockDim.x; // The number of output for one block is blockDim.x;
int block_x, block_y; int block_x, block_y;
int grid_num, reduce_num_per_thread; int grid_num, reduce_num_per_thread;
if (reduce_lastdim) { if (reduce_last_dim) {
block_x = detail::GetBlockDim(reduce_num); block_x = details::GetBlockDim(reduce_num);
block_y = detail::GetBlockDim(left_num); block_y = details::GetBlockDim(left_num);
block_dim->x = block_x; block_dim->x = block_x;
block_dim->y = block_dim->y =
std::min(block_y, static_cast<int>(max_num_threads / block_dim->x)); std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
grid_num = detail::AlignUp(left_num, block_dim->y); grid_num = details::AlignUp(left_num, block_dim->y);
reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->x); reduce_num_per_thread = details::AlignUp(reduce_num, block_dim->x);
} else { } else {
block_x = detail::GetBlockDim(left_num); block_x = details::GetBlockDim(left_num);
block_y = detail::GetBlockDim(reduce_num); block_y = details::GetBlockDim(reduce_num);
block_dim->x = std::min(block_x, 32); block_dim->x = std::min(block_x, 32);
block_dim->y = block_dim->y =
std::min(block_y, static_cast<int>(max_num_threads / block_dim->x)); std::min(block_y, static_cast<int>(max_num_threads / block_dim->x));
block_dim->x = block_dim->x =
std::min(block_x, static_cast<int>(max_num_threads / block_dim->y)); std::min(block_x, static_cast<int>(max_num_threads / block_dim->y));
grid_num = detail::AlignUp(left_num, block_dim->x); grid_num = details::AlignUp(left_num, block_dim->x);
reduce_num_per_thread = detail::AlignUp(reduce_num, block_dim->y); reduce_num_per_thread = details::AlignUp(reduce_num, block_dim->y);
} }
int device_id = platform::GetCurrentDeviceId(); int device_id = platform::GetCurrentDeviceId();
int max_mp = platform::GetCUDAMultiProcessors(device_id); int max_mp = platform::GetCUDAMultiProcessors(device_id);
...@@ -403,10 +389,10 @@ struct ReduceConfig { ...@@ -403,10 +389,10 @@ struct ReduceConfig {
// the number cannot be larger than max_reduce_num_per_thread, so we // the number cannot be larger than max_reduce_num_per_thread, so we
// choose the maximum between the result above and input_split_num_2. // choose the maximum between the result above and input_split_num_2.
int input_split_num_1 = int input_split_num_1 =
detail::AlignUp(reduce_num_per_thread, min_reduce_num_per_thread); details::AlignUp(reduce_num_per_thread, min_reduce_num_per_thread);
int input_split_num_2 = int input_split_num_2 =
detail::AlignUp(reduce_num_per_thread, max_reduce_num_per_thread); details::AlignUp(reduce_num_per_thread, max_reduce_num_per_thread);
int input_split_num_3 = detail::AlignUp(max_num_blocks, grid_num); int input_split_num_3 = details::AlignUp(max_num_blocks, grid_num);
grid_dim->x = grid_num; grid_dim->x = grid_num;
grid_dim->y = std::max(std::min(input_split_num_1, input_split_num_3), grid_dim->y = std::max(std::min(input_split_num_1, input_split_num_3),
...@@ -423,7 +409,7 @@ struct ReduceConfig { ...@@ -423,7 +409,7 @@ struct ReduceConfig {
// for others: block(block_num, 1) , grid(left_num, 1) // for others: block(block_num, 1) , grid(left_num, 1)
void SetBlockDim() { void SetBlockDim() {
// init // init
int block_num = detail::GetBlockDim(reduce_num); int block_num = details::GetBlockDim(reduce_num);
should_reduce_again = false; should_reduce_again = false;
dim3 block_dim(block_num, 1); dim3 block_dim(block_num, 1);
...@@ -449,23 +435,23 @@ struct ReduceConfig { ...@@ -449,23 +435,23 @@ struct ReduceConfig {
int num_block = (max_threads / left_num); int num_block = (max_threads / left_num);
if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) { if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
blocking_size = detail::GetLastPow2(reduce_num / num_block); blocking_size = details::GetLastPow2(reduce_num / num_block);
if (blocking_size <= 1) { if (blocking_size <= 1) {
blocking_size = detail::GetLastPow2(sqrt(reduce_num)); blocking_size = details::GetLastPow2(sqrt(reduce_num));
} else if (blocking_size * 2 < reduce_num) { } else if (blocking_size * 2 < reduce_num) {
blocking_size *= 2; blocking_size *= 2;
} }
should_reduce_again = true; should_reduce_again = true;
block_dim.x = 32; block_dim.x = details::GetBlockDim(left_num);
block_dim.y = 1; block_dim.y = 1;
grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x; grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
grid_dim.y = (reduce_num + blocking_size - 1) / blocking_size; grid_dim.y = (reduce_num + blocking_size - 1) / blocking_size;
} else { } else {
block_dim.x = 32; block_dim.x = details::GetBlockDim(left_num);
block_dim.y = 1; block_dim.y = 1;
blocking_size = reduce_num; blocking_size = reduce_num;
grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x; grid_dim.x = (left_num + block_dim.x - 1) / block_dim.x;
...@@ -493,272 +479,210 @@ struct ReduceConfig { ...@@ -493,272 +479,210 @@ struct ReduceConfig {
int left_num; int left_num;
int blocking_size; int blocking_size;
bool should_reduce_again; bool should_reduce_again;
bool reduce_lastdim; bool reduce_last_dim;
Ty* output_data; Ty* output_data;
dim3 block; dim3 block;
dim3 grid; dim3 grid;
}; };
/* size : how many colonms left have to be reduced
static __device__ int SharedMemoryIndex(int index) { * loop : how many rows data have to be reduced
return (threadIdx.y + index) * blockDim.x + threadIdx.x; * block_size: max rows this block to reduce
}
template <typename T, typename ReduceOp>
static __device__ T WarpReduce(T val, ReduceOp reducer) {
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int stride = detail::kWarpSize / 2; stride > 0; stride >>= 1) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
val = reducer(val, temp);
}
return val;
}
/* e.g.
* |---------block---------|
* |warp0|warp1|warp2|warp3|
* |0~31|32~63|64~95|96~127| ---->blockDim.x = 128
* \|/ \|/ \|/ \|/ ---->1. First WarpReduce in each warp
* res0 res1 res2 res3 ---->2. Store result of each warp to shared memory
* \ \ / / ---->3. Load the result above from shared memory
* res to warp0 and process the second WarpReduce
*/ */
template <typename T, typename ReduceOp>
static __device__ T BlockXReduce(T val, ReduceOp reducer) {
using detail::kWarpSize;
__shared__ T shared[2 * kWarpSize];
int block_dim_x = blockDim.x;
if (blockDim.x > kWarpSize) {
block_dim_x = blockDim.x / kWarpSize;
int lane = threadIdx.x % kWarpSize;
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int wid = tid / kWarpSize;
int bid = threadIdx.y;
val = WarpReduce(val, reducer);
if (lane == 0) {
shared[wid] = val;
}
__syncthreads();
val = shared[bid * block_dim_x + lane];
}
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
for (int stride = 1; stride < block_dim_x; stride <<= 1) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
val = reducer(val, temp);
}
return val;
}
template <typename T, typename ReduceOp>
static __device__ T BlockYReduce(T val, ReduceOp reducer) {
__shared__ T shared_memory[detail::kMaxThread];
shared_memory[SharedMemoryIndex(0)] = val;
for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
__syncthreads();
if (threadIdx.y < stride && threadIdx.y + stride < blockDim.y) {
T temp = shared_memory[SharedMemoryIndex(stride)];
val = reducer(val, temp);
}
shared_memory[SharedMemoryIndex(0)] = val;
}
return val;
}
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx / 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
template <typename Tx, typename Ty, typename MPType, typename ReduceOp, template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp> typename TransformOp, bool IsBoundary = false>
__device__ void ReduceHigherDim(const Tx* x, Ty* y, ReduceOp reducer, __device__ void HigherDimDealSegment(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, MPType init, TransformOp transformer, MPType init,
int reduce_num, int left_num, int block_size) { int reduce_num, int left_num,
int idx = blockIdx.x * blockDim.x + threadIdx.x; int block_size) {
const int NY = 1;
int idx = blockIdx.x * blockDim.x;
int idy = blockIdx.y * block_size; int idy = blockIdx.y * block_size;
// block_offset of rows
MPType reduce_var = init; Tx reduce_input[NY];
MPType reduce_compute[NY];
if (idx < left_num) { MPType result = init;
// the offset of this block
int block_offset = idy * left_num + idx + blockIdx.z * reduce_num * left_num;
const Tx* input = x + block_offset;
int store_offset =
blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num + idx;
// how many columns left
int size = left_num - idx;
// how many rows have to be reduced
int loop = reduce_num - idy; int loop = reduce_num - idy;
loop = loop > block_size ? block_size : loop; loop = loop > block_size ? block_size : loop;
for (int iy = 0; iy < loop; iy++) { for (int loop_index = 0; loop_index < loop; loop_index += NY) {
int id = (idy + iy) * left_num + idx + blockIdx.z * reduce_num * left_num; kps::ReadData<Tx, Tx, 1, NY, 1, IsBoundary>(
reduce_var = reducer(reduce_var, static_cast<MPType>(transformer(x[id]))); &reduce_input[0], input + loop_index * left_num, size, NY, 1, left_num);
kps::ElementwiseUnary<Tx, MPType, REDUCE_VEC_SIZE, 1, 1, TransformOp>(
&reduce_compute[0], &reduce_input[0], transformer);
kps::Reduce<MPType, NY, 1, 1, ReduceOp,
kps::details::ReduceMode::kLocalMode>(
&result, &reduce_compute[0], reducer, false);
} }
y[idx + blockIdx.y * left_num + blockIdx.z * gridDim.y * left_num] = Ty temp_data = static_cast<Ty>(result);
static_cast<Ty>(reduce_var); kps::WriteData<Ty, 1, 1, 1, IsBoundary>(y + store_offset, &temp_data, size);
}
} }
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or // when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this // when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used // function will be used
template <typename Tx, typename Ty, typename MPType, typename ReduceOp, template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp> typename TransformOp, typename Calculator>
__device__ void ReduceAny(const Tx* x, Ty* y, ReduceOp reducer, __global__ void ReduceAnyKernel(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, MPType init, int reduce_num, TransformOp transformer, MPType init,
int left_num, bool reduce_lastdim, int reduce_num, int left_num,
const IndexCalculator& reduce_index_calculator, bool reduce_last_dim,
const IndexCalculator& left_index_calculator) { const Calculator reduce_index_calculator,
const Calculator left_index_calculator) {
int input_idx, left_idx, stride; int input_idx, left_idx, stride;
int block_size = 0;
bool need_store = true;
int tid = 0;
// the last dim gets involved in reduction // the last dim gets involved in reduction
if (reduce_lastdim) { if (reduce_last_dim) {
input_idx = blockIdx.y * blockDim.x + threadIdx.x; input_idx = blockIdx.y * blockDim.x;
left_idx = blockIdx.x * blockDim.y + threadIdx.y; left_idx = blockIdx.x * blockDim.y + threadIdx.y;
stride = gridDim.y * blockDim.x; stride = gridDim.y * blockDim.x;
block_size = blockDim.x;
need_store = (threadIdx.x == 0) && (left_idx < left_num);
tid = threadIdx.x;
} else { } else {
input_idx = blockIdx.y * blockDim.y + threadIdx.y; input_idx = blockIdx.y * blockDim.y;
left_idx = blockIdx.x * blockDim.x + threadIdx.x; left_idx = blockIdx.x * blockDim.x + threadIdx.x;
stride = gridDim.y * blockDim.y; stride = gridDim.y * blockDim.y;
block_size = blockDim.y;
need_store = (threadIdx.y == 0) && (left_idx < left_num);
tid = threadIdx.y;
} }
int store_offset = blockIdx.y * left_num + left_idx;
// calculate the offset, means the addr where each thread really start. // calculate the offset, means the addr where each thread really start.
int input_offset = left_index_calculator.Get(left_idx); int input_offset = left_index_calculator(left_idx);
const Tx* input = x + input_offset; const Tx* input = x + input_offset;
MPType reduce_var = init; MPType reduce_var = init;
Ty store_data;
// 1. reduce for each thread // 1. reduce for each thread
if (left_idx < left_num) { if (left_idx < left_num) {
// load REDUCE_VEC_SIZE data once, and then compute // load REDUCE_VEC_SIZE data once, and then compute
Tx input_reg[REDUCE_VEC_SIZE]; Tx input_reg[REDUCE_VEC_SIZE];
MPType input_compute[REDUCE_VEC_SIZE];
int bound = reduce_num - (REDUCE_VEC_SIZE - 1) * stride; int bound = reduce_num - (REDUCE_VEC_SIZE - 1) * stride;
while (input_idx < bound) { for (; input_idx + block_size < bound;
#pragma unroll input_idx += REDUCE_VEC_SIZE * stride) {
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { kps::ReadDataReduce<Tx, 1, REDUCE_VEC_SIZE, 1, 1, Calculator>(
int reduce_idx = input_idx + i * stride; &input_reg[0], input, input_idx, reduce_index_calculator, 1,
int idx_x = reduce_index_calculator.Get(reduce_idx); reduce_num, 1, stride, reduce_last_dim);
input_reg[i] = input[idx_x]; kps::ElementwiseUnary<Tx, MPType, REDUCE_VEC_SIZE, 1, 1, TransformOp>(
} &input_compute[0], &input_reg[0], transformer);
#pragma unroll kps::Reduce<MPType, REDUCE_VEC_SIZE, 1, 1, ReduceOp,
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { kps::details::ReduceMode::kLocalMode>(
reduce_var = &reduce_var, &input_compute[0], reducer, reduce_last_dim);
reducer(reduce_var, static_cast<MPType>(transformer(input_reg[i]))); }
}
input_idx += REDUCE_VEC_SIZE * stride; kps::Init<MPType, REDUCE_VEC_SIZE>(&input_compute[0], init);
} kps::ReadDataReduce<Tx, 1, REDUCE_VEC_SIZE, 1, 1, Calculator, true>(
&input_reg[0], input, input_idx, reduce_index_calculator, 1, reduce_num,
// deal with the remain part 1, stride, reduce_last_dim);
int input_idx_tmp = input_idx; input_idx += tid;
#pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
if (input_idx >= reduce_num) {
break;
}
int reduce_idx = input_idx;
int idx_x = reduce_index_calculator.Get(reduce_idx);
input_reg[i] = input[idx_x];
input_idx += stride;
}
input_idx = input_idx_tmp;
#pragma unroll #pragma unroll
for (int i = 0; i < REDUCE_VEC_SIZE; ++i) { for (int i = 0; i < REDUCE_VEC_SIZE; ++i) {
if (input_idx >= reduce_num) { if (input_idx >= reduce_num) {
break; break;
} }
reduce_var = input_compute[i] = static_cast<MPType>(transformer(input_reg[i]));
reducer(reduce_var, static_cast<MPType>(transformer(input_reg[i])));
input_idx += stride; input_idx += stride;
} }
kps::Reduce<MPType, REDUCE_VEC_SIZE, 1, 1, ReduceOp,
kps::details::ReduceMode::kLocalMode>(
&reduce_var, &input_compute[0], reducer, reduce_last_dim);
} }
// 2. reduce in block y kps::Reduce<MPType, 1, 1, 1, ReduceOp, kps::details::kGlobalMode>(
if (!reduce_lastdim && blockDim.y > 1) { &reduce_var, &reduce_var, reducer, reduce_last_dim);
reduce_var = BlockYReduce(reduce_var, reducer); if (need_store) {
} y[store_offset] = static_cast<Ty>(reduce_var);
__syncthreads();
if (reduce_lastdim) {
// 3. reduce in block x
reduce_var = BlockXReduce(reduce_var, reducer);
if (left_idx < left_num && threadIdx.x == 0) {
y[blockIdx.y * left_num + left_idx] = static_cast<Ty>(reduce_var);
}
} else {
if (left_idx < left_num && threadIdx.y == 0) {
y[blockIdx.y * left_num + left_idx] = static_cast<Ty>(reduce_var);
}
} }
} }
// module function designed for global function
template <typename Tx, typename Ty, typename MPType, typename ReduceOp, template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp> typename TransformOp>
__device__ void ReduceModule(const Tx* x, Ty* y, ReduceOp reducer, __global__ void ReduceHigherDimKernel(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, MPType init, TransformOp transformer, MPType init,
int reduce_num, int left_num, int blocking_size, int reduce_num, int left_num,
int reduce_type, bool reduce_lastdim, int blocking_size) {
const IndexCalculator& reduce_index_calculator, // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
const IndexCalculator& left_index_calculator) { // function will be used
if (reduce_type == ReduceType::kReduceLastDim || // eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
reduce_type == ReduceType::kReduceAny) { // if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx /
ReduceAny<Tx, Ty, MPType, ReduceOp, TransformOp>( // 32
x, y, reducer, transformer, init, reduce_num, left_num, reduce_lastdim, // else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
reduce_index_calculator, left_index_calculator); int idx = blockIdx.x * blockDim.x;
// reduce_rank == 1 && reduce_dim[0] != x_dim.size() - 1 int size = left_num - idx;
} else if (reduce_type == ReduceType::kReduceHigherDim) { if (size >= blockDim.x) { // complete segment
ReduceHigherDim<Tx, Ty, MPType, ReduceOp, TransformOp>( HigherDimDealSegment<Tx, Ty, MPType, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);
} else {
HigherDimDealSegment<Tx, Ty, MPType, ReduceOp, TransformOp, true>(
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size); x, y, reducer, transformer, init, reduce_num, left_num, blocking_size);
} }
} }
template <typename Tx, typename Ty, typename MPType, typename ReduceOp,
typename TransformOp>
__global__ void ReduceKernelFunction(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, MPType init,
int reduce_num, int left_num,
int blocking_size, int reduce_type,
bool reduce_lastdim,
IndexCalculator reduce_index_calculator,
IndexCalculator left_index_calculator) {
ReduceModule<Tx, Ty, MPType, ReduceOp, TransformOp>(
x, y, reducer, transformer, init, reduce_num, left_num, blocking_size,
reduce_type, reduce_lastdim, reduce_index_calculator,
left_index_calculator);
}
template <typename Tx, typename Ty, typename MPType, typename ReduceOp> template <typename Tx, typename Ty, typename MPType, typename ReduceOp>
static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, static void LaunchReduceKernel(const Tx* x_data, Ty* y_data,
const ReduceOp& reducer, MPType init, const ReduceOp& reducer, MPType init,
gpuStream_t stream, ReduceConfig<Ty> config) { gpuStream_t stream, ReduceConfig<Ty> config) {
using TransformOp = typename ReduceOp::Transformer; using TransformOp = typename ReduceOp::Transformer;
if (config.reduce_type == kReduceLastDim) {
int stride_reduce = 1;
int stride_left = config.reduce_num;
// for higher performance
auto reduce_index_calculator = LastDimIndexCal(stride_reduce);
auto left_index_calculator = LastDimIndexCal(stride_left);
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
LastDimIndexCal><<<config.grid, config.block, 0, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num),
init, config.reduce_num, config.left_num, config.reduce_last_dim,
reduce_index_calculator, left_index_calculator);
} else {
int reduce_rank = config.reduce_strides.size(); int reduce_rank = config.reduce_strides.size();
int left_rank = config.left_strides.size(); int left_rank = config.left_strides.size();
auto reduce_index_calculator = IndexCalculator( auto reduce_index_calculator =
reduce_rank, config.reduce_dim, config.reduce_strides, config.x_strides); IndexCalculator(reduce_rank, config.reduce_dim, config.reduce_strides,
config.x_strides);
auto left_index_calculator = IndexCalculator( auto left_index_calculator = IndexCalculator(
left_rank, config.left_dim, config.left_strides, config.x_strides); left_rank, config.left_dim, config.left_strides, config.x_strides);
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp,
ReduceKernelFunction<Tx, Ty, MPType, ReduceOp, IndexCalculator><<<config.grid, config.block, 0, stream>>>(
TransformOp><<<config.grid, config.block, 0, stream>>>( x_data, config.output_data, reducer, TransformOp(config.reduce_num),
x_data, config.output_data, reducer, TransformOp(config.reduce_num), init, init, config.reduce_num, config.left_num, config.reduce_last_dim,
config.reduce_num, config.left_num, config.blocking_size, reduce_index_calculator, left_index_calculator);
config.reduce_type, config.reduce_lastdim, reduce_index_calculator, }
left_index_calculator);
if (config.should_reduce_again) { if (config.should_reduce_again) {
dim3 block; dim3 block;
dim3 grid; dim3 grid;
if (config.reduce_lastdim) { if (config.reduce_last_dim) {
block = dim3(32, 1, 1); block = dim3(32, 1, 1);
grid = dim3(detail::AlignUp(config.left_num, 32), 1, 1); grid = dim3(details::AlignUp(config.left_num, 32), 1, 1);
} else { } else {
block = dim3(config.block.x, 1, 1); block = dim3(config.block.x, 1, 1);
grid = dim3(config.grid.x, 1, config.grid.z); grid = dim3(config.grid.x, 1, config.grid.z);
} }
ReduceKernelFunction< ReduceHigherDimKernel<
Ty, Ty, MPType, ReduceOp, Ty, Ty, MPType, ReduceOp,
detail::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>( kps::details::IdentityFunctor<Ty, MPType>><<<grid, block, 0, stream>>>(
config.output_data, y_data, reducer, config.output_data, y_data, reducer,
detail::IdentityFunctor<Ty, MPType>(config.grid.y), init, config.grid.y, kps::details::IdentityFunctor<Ty, MPType>(config.grid.y), init,
config.left_num, config.grid.y, ReduceType::kReduceHigherDim, config.grid.y, config.left_num, config.grid.y);
config.reduce_lastdim, reduce_index_calculator, left_index_calculator);
} }
} }
...@@ -812,6 +736,39 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, ...@@ -812,6 +736,39 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y,
using MPType = typename details::MPTypeTrait<Ty>::Type; using MPType = typename details::MPTypeTrait<Ty>::Type;
auto reducer = ReduceOp<Tx, MPType>(); auto reducer = ReduceOp<Tx, MPType>();
// launch ReduceHigherDimKernel
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
// eg: x_dim = {nz, ny, nx}, nx != 1, axis can be 0 or 1
// if axis = 1 then grid.z = nz, grid.y = ny / block_size, grid.x = nx /
// 32
// else grid.z = 1, grid.y = ny / block_size, grid.x = nx /32
if (config.reduce_type == ReduceType::kReduceHigherDim) {
using TransformOp = typename ReduceOp<Tx, MPType>::Transformer;
ReduceHigherDimKernel<
Tx, Ty, MPType, ReduceOp<Tx, MPType>,
TransformOp><<<config.grid, config.block, 0, stream>>>(
x_data, config.output_data, reducer, TransformOp(config.reduce_num),
reducer.initial(), config.reduce_num, config.left_num,
config.blocking_size);
if (config.should_reduce_again) {
dim3 block = dim3(config.block.x, 1, 1);
dim3 grid = dim3(config.grid.x, 1, config.grid.z);
ReduceHigherDimKernel<Ty, Ty, MPType, ReduceOp<Tx, MPType>,
kps::details::IdentityFunctor<
Ty, MPType>><<<grid, block, 0, stream>>>(
config.output_data, y_data, reducer,
kps::details::IdentityFunctor<Ty, MPType>(config.grid.y),
reducer.initial(), config.grid.y, config.left_num, config.grid.y);
}
return;
}
// when reduce_dim.size() == 1 and reduce_dim[0] == x_dim.size() - 1, or
// when reduce_dim.size() != 1 and reduce_dim.size() != x_dim.size(), this
// function will be used
LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>>( LaunchReduceKernel<Tx, Ty, MPType, ReduceOp<Tx, MPType>>(
x_data, y_data, reducer, reducer.initial(), stream, config); x_data, y_data, reducer, reducer.initial(), stream, config);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册