未验证 提交 82b33be3 编写于 作者: N niuliling123 提交者: GitHub

Modify the reduce op according to the kernel primitive api (#35282)

上级 7aa4d879
...@@ -202,9 +202,9 @@ void SetConfigForColumnReduce(const int max_threads, const int reduce_num, ...@@ -202,9 +202,9 @@ void SetConfigForColumnReduce(const int max_threads, const int reduce_num,
int num_block = (max_threads / left_num); int num_block = (max_threads / left_num);
if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) { if (num_block > 1 && reduce_num >= REDUCE_SPLIT_BOUNDARY) {
*blocking_size = detail::GetLastPow2(reduce_num / num_block); *blocking_size = details::GetLastPow2(reduce_num / num_block);
if (*blocking_size <= 1) { if (*blocking_size <= 1) {
*blocking_size = detail::GetLastPow2(sqrt(reduce_num)); *blocking_size = details::GetLastPow2(sqrt(reduce_num));
} else if (*blocking_size * 2 < reduce_num) { } else if (*blocking_size * 2 < reduce_num) {
*blocking_size *= 2; *blocking_size *= 2;
} }
......
...@@ -31,13 +31,15 @@ namespace kernel_primitives { ...@@ -31,13 +31,15 @@ namespace kernel_primitives {
namespace details { namespace details {
#ifdef __HIPCC__ #ifdef __HIPCC__
constexpr int kMaxThread = 256; constexpr int kReduceMaxThread = 256;
constexpr int kWarpSize = 64; constexpr int kWarpSize = 64;
#else #else
constexpr int kMaxThread = 128; constexpr int kReduceMaxThread = 128;
constexpr int kWarpSize = 32; constexpr int kWarpSize = 32;
#endif #endif
// kGlobalMode: block reduce, each block gets an output;
// kLocalMode: thread reduce, each thread gets an output;
enum ReduceMode { kGlobalMode, kLocalMode }; enum ReduceMode { kGlobalMode, kLocalMode };
template <typename T> template <typename T>
...@@ -118,7 +120,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { ...@@ -118,7 +120,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
*/ */
template <typename T, typename ReduceOp> template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) { __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
__shared__ T shared_memory[details::kMaxThread]; __shared__ T shared_memory[details::kReduceMaxThread];
shared_memory[SharedMemoryIndex(0)] = val; shared_memory[SharedMemoryIndex(0)] = val;
for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) { for (int stride = blockDim.y / 2; stride > 0; stride >>= 1) {
__syncthreads(); __syncthreads();
......
...@@ -124,36 +124,36 @@ struct BroadcastConfig { ...@@ -124,36 +124,36 @@ struct BroadcastConfig {
template <typename Tx, typename Ty, int NX, int NY, int BlockSize> template <typename Tx, typename Ty, int NX, int NY, int BlockSize>
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
int stride_nx, int stride_ny) { int stride_nx, int stride_ny) {
int thread_offset = threadIdx.x * NX;
if (NY == 1 && NX == 1) { if (NY == 1 && NX == 1) {
dst[0] = static_cast<Ty>(src[threadIdx.x]); dst[0] = static_cast<Ty>(src[thread_offset]);
} else if (NX == 1) { } else if (NX == 1) {
int dx = threadIdx.x;
#pragma unroll #pragma unroll
for (int idy = 0; idy < NY; ++idy) { for (int idy = 0; idy < NY; ++idy) {
dst[idy] = static_cast<Ty>(src[dx + idy * stride_ny]); dst[idy] = static_cast<Ty>(src[thread_offset + idy * stride_ny]);
} }
} else if (NY == 1) { } else if (NY == 1) {
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
dst[idx] = static_cast<Ty>(src[idx * stride_nx]); dst[idx] = static_cast<Ty>(src[thread_offset + idx * stride_nx]);
} }
} else { } else {
int dx = threadIdx.x * NX;
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
#pragma unroll #pragma unroll
for (int idy = 0; idy < NY; ++idy) { for (int idy = 0; idy < NY; ++idy) {
dst[idy * NX + idx] = dst[idy * NX + idx] = static_cast<Ty>(
static_cast<Ty>(src[idx * stride_nx + dx + idy * stride_ny]); src[thread_offset + idx * stride_nx + idy * stride_ny]);
} }
} }
} }
} }
/** /**
* @brief load data from src to dst, src can be 1D data or 2D data. When * @brief load data from src to dst with stride, src can be 1D data or 2D data.
* boundary judgment is required, you need to set a to true, and a is false by * When boundary judgment is required, you need to set a to true, and a is false
* default. * by default.
* @typename: * @typename:
* Tx: data type of src * Tx: data type of src
* Ty: data type of dstt * Ty: data type of dstt
...@@ -172,17 +172,17 @@ template <typename Tx, typename Ty, int NX, int NY, int BlockSize, ...@@ -172,17 +172,17 @@ template <typename Tx, typename Ty, int NX, int NY, int BlockSize,
__device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
int size_nx, int size_ny, int size_nx, int size_ny,
int stride_nx, int stride_ny) { int stride_nx, int stride_ny) {
int dx = threadIdx.x * NX; int thread_offset = threadIdx.x * NX;
int size = size_nx - dx; int left_size_nx = size_nx - thread_offset;
// Each branch is added for better performance // Each branch is added for better performance
if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1 if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1
if (IsBoundary) { if (IsBoundary) {
if (dx < size_nx) { if (left_size_nx > 0) {
dst[0] = static_cast<Ty>(src[dx]); dst[0] = static_cast<Ty>(src[thread_offset]);
} }
} else { } else {
dst[0] = static_cast<Ty>(src[dx]); dst[0] = static_cast<Ty>(src[thread_offset]);
} }
} else if (NX == 1) { // for NX == 1 and NY != 1 } else if (NX == 1) { // for NX == 1 and NY != 1
#pragma unroll #pragma unroll
...@@ -192,23 +192,23 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, ...@@ -192,23 +192,23 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
break; break;
} }
} }
dst[idy] = static_cast<Ty>(src[dx + idy * stride_ny]); dst[idy] = static_cast<Ty>(src[thread_offset + idy * stride_ny]);
} }
} else if (NY == 1) { // for NY == 1 and NX != 1 } else if (NY == 1) { // for NY == 1 and NX != 1
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
if (IsBoundary) { if (IsBoundary) {
if (idx >= size) { if (idx >= left_size_nx) {
break; break;
} }
} }
dst[idx] = static_cast<Ty>(src[idx * stride_nx + dx]); dst[idx] = static_cast<Ty>(src[thread_offset + idx * stride_nx]);
} }
} else { // for NX != 1 and NY != 1 } else { // for NX != 1 and NY != 1
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
if (IsBoundary) { if (IsBoundary) {
if (idx >= size) { if (idx >= left_size_nx) {
break; break;
} }
} }
...@@ -219,8 +219,8 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src, ...@@ -219,8 +219,8 @@ __device__ __forceinline__ void ReadData(Ty* dst, const Tx* __restrict__ src,
break; break;
} }
} }
dst[idy * NX + idx] = dst[idy * NX + idx] = static_cast<Ty>(
static_cast<Ty>(src[idx * stride_nx + dx + idy * stride_ny]); src[thread_offset + idx * stride_nx + idy * stride_ny]);
} }
} }
} }
...@@ -251,17 +251,17 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false> ...@@ -251,17 +251,17 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
int num) { int num) {
if (IsBoundary) { // blockDim.x * NX > num if (IsBoundary) { // blockDim.x * NX > num
int dx = threadIdx.x * NX; int thread_offset = threadIdx.x * NX;
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
if (idx + dx < num) { if (idx + thread_offset < num) {
dst[idx] = src[idx + dx]; dst[idx] = src[thread_offset + idx];
} }
} }
} else { // blockDim,x * NX < num } else { // blockDim,x * NX < num
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int kVectorsPerThread = NX / kVectorSize; const int kVectorsPerThread = NX / kVectorSize;
int tid = threadIdx.x * kVectorsPerThread; int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>; using VecType = details::VectorType<T, kVectorSize>;
const VecType* vec_input = reinterpret_cast<const VecType*>(src); const VecType* vec_input = reinterpret_cast<const VecType*>(src);
...@@ -269,7 +269,7 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, ...@@ -269,7 +269,7 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
#pragma unroll #pragma unroll
for (int i = 0; i < kVectorsPerThread; ++i) { for (int i = 0; i < kVectorsPerThread; ++i) {
vec_temp[i] = vec_input[i + tid]; vec_temp[i] = vec_input[thread_offset + i];
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx); dst[idx] = *(reinterpret_cast<T*>(vec_temp) + idx);
...@@ -289,39 +289,39 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src, ...@@ -289,39 +289,39 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
* is 2 * is 2
* IsBoundary: whether to make boundary judgment * IsBoundary: whether to make boundary judgment
* @param: * @param:
* fix: data offset of this block, blockDim.x * blockIdx.x * NX; * block_offset: data offset of this block, blockDim.x * blockIdx.x * NX;
* config: get the global index in src, attention config was declared in host; * config: get the global index in src, attention config was declared in host;
* num: the num of out * total_num_output: total num of output
* stride_nx: the stride of cols * stride_nx: the stride of cols
* stride_ny: the stride of rows * stride_ny: the stride of rows
*/ */
template <typename T, int NX, int NY, int BlockSize, int ShapeSize, template <typename T, int NX, int NY, int BlockSize, int ShapeSize,
bool IsBoundary = false> bool IsBoundary = false>
__device__ __forceinline__ void ReadDataBc( __device__ __forceinline__ void ReadDataBc(
T* dst, const T* __restrict__ src, uint32_t fix, T* dst, const T* __restrict__ src, uint32_t block_offset,
details::BroadcastConfig<ShapeSize> config, int num, int stride_nx, details::BroadcastConfig<ShapeSize> config, int total_num_output,
int stride_ny) { int stride_nx, int stride_ny) {
uint32_t base_offset = fix + threadIdx.x * NX; uint32_t thread_offset = block_offset + threadIdx.x * NX;
uint32_t offset = 0; uint32_t index_src = 0;
#pragma unroll #pragma unroll
for (int ny = 0; ny < NY; ++ny) { for (int ny = 0; ny < NY; ++ny) {
#pragma unroll #pragma unroll
for (uint32_t nx = 0; nx < NX; ++nx) { for (uint32_t nx = 0; nx < NX; ++nx) {
uint32_t idx = base_offset + ny * stride_ny + nx * stride_nx; uint32_t index_output = thread_offset + ny * stride_ny + nx * stride_nx;
index_src = 0;
if (IsBoundary) { if (IsBoundary) {
if (idx >= num) { if (index_output >= total_num_output) {
break; break;
} }
} }
offset = 0;
#pragma unroll #pragma unroll
for (int i = 0; i < ShapeSize; ++i) { for (int i = 0; i < ShapeSize; ++i) {
auto fast_divmoder = config.divmoders[i].Divmod(idx); auto fast_divmoder = config.divmoders[i].Divmod(index_output);
idx = fast_divmoder.val[0]; index_output = fast_divmoder.val[0];
offset += fast_divmoder.val[1] * config.strides[i]; index_src += fast_divmoder.val[1] * config.strides[i];
} }
dst[nx + ny * NX] = src[offset]; dst[nx + ny * NX] = src[index_src];
} }
} }
} }
...@@ -338,7 +338,7 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -338,7 +338,7 @@ __device__ __forceinline__ void ReadDataBc(
* IndexCal: get the global index in src, attention config was declared in host; * IndexCal: get the global index in src, attention config was declared in host;
* IsBoundary: whether to make boundary judgment * IsBoundary: whether to make boundary judgment
* @param: * @param:
* fix: data offset of this block, blockDim.x * blockIdx.x * NX; * block_offset: data offset of this block, blockDim.x * blockIdx.x * NX;
* index_cal: get the global index in src, attention config was declared in * index_cal: get the global index in src, attention config was declared in
* host; * host;
* size_nx: number of columns to be processed by the current block * size_nx: number of columns to be processed by the current block
...@@ -350,27 +350,27 @@ __device__ __forceinline__ void ReadDataBc( ...@@ -350,27 +350,27 @@ __device__ __forceinline__ void ReadDataBc(
template <typename T, int NX, int NY, int BlockSize, int ShapeSize, template <typename T, int NX, int NY, int BlockSize, int ShapeSize,
typename IndexCal, bool IsBoundary = false> typename IndexCal, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataReduce( __device__ __forceinline__ void ReadDataReduce(
T* dst, const T* __restrict__ src, int fix, const IndexCal& index_cal, T* dst, const T* __restrict__ src, int block_offset,
int size_nx, int size_ny, int stride_nx, int stride_ny, const IndexCal& index_cal, int size_nx, int size_ny, int stride_nx,
bool reduce_last_dim) { int stride_ny, bool reduce_last_dim) {
int base_offset = fix; int thread_offset = 0;
if (reduce_last_dim) { if (reduce_last_dim) {
base_offset += threadIdx.x; thread_offset = block_offset + threadIdx.x;
} else { } else {
base_offset += threadIdx.y; thread_offset = block_offset + threadIdx.y;
} }
if (NX == 1) { if (NX == 1) {
#pragma unroll #pragma unroll
for (int ny = 0; ny < NY; ++ny) { for (int ny = 0; ny < NY; ++ny) {
if (IsBoundary) { if (IsBoundary) {
if (base_offset >= size_ny) { if (thread_offset >= size_ny) {
break; break;
} }
} }
uint32_t offset = index_cal(base_offset); uint32_t index_src = index_cal(thread_offset);
dst[ny] = src[offset]; dst[ny] = src[index_src];
base_offset += stride_ny; thread_offset += stride_ny;
} }
} else { } else {
#pragma unroll #pragma unroll
...@@ -387,15 +387,16 @@ __device__ __forceinline__ void ReadDataReduce( ...@@ -387,15 +387,16 @@ __device__ __forceinline__ void ReadDataReduce(
break; break;
} }
} }
uint32_t offset = index_cal(base_offset); uint32_t index_src = index_cal(thread_offset);
dst[nx + ny * NX] = src[offset]; dst[nx + ny * NX] = src[index_src];
base_offset += stride_ny; thread_offset += stride_ny;
} }
thread_offset += stride_nx;
} }
} }
} }
/** @brief: WriteData /**
* @brief store data from src to dst, src can be 1D data, you should set NY = 1. * @brief store data from src to dst, src can be 1D data, you should set NY = 1.
* When boundary judgment is required, you need to set a to true, and a is false * When boundary judgment is required, you need to set a to true, and a is false
* by default. * by default.
...@@ -412,11 +413,11 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false> ...@@ -412,11 +413,11 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary = false>
__device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
int num) { int num) {
if (IsBoundary) { if (IsBoundary) {
int dx = threadIdx.x * NX; int thread_offset = threadIdx.x * NX;
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
if ((idx + dx) < num) { if ((thread_offset + idx) < num) {
dst[idx + dx] = src[idx]; dst[thread_offset + idx] = src[idx];
} }
} }
} else { } else {
...@@ -424,14 +425,14 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src, ...@@ -424,14 +425,14 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1; const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
const int kVectorsPerThread = NX / kVectorSize; const int kVectorsPerThread = NX / kVectorSize;
int dx = threadIdx.x * kVectorsPerThread; int thread_offset = threadIdx.x * kVectorsPerThread;
using VecType = details::VectorType<T, kVectorSize>; using VecType = details::VectorType<T, kVectorSize>;
VecType* vec_dst = reinterpret_cast<VecType*>(dst); VecType* vec_dst = reinterpret_cast<VecType*>(dst);
VecType vec_temp[kVectorsPerThread]; VecType vec_temp[kVectorsPerThread];
#pragma unroll #pragma unroll
for (int idx = 0; idx < kVectorsPerThread; ++idx) { for (int idx = 0; idx < kVectorsPerThread; ++idx) {
vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx); vec_temp[idx] = *(reinterpret_cast<VecType*>(src) + idx);
vec_dst[dx + idx] = vec_temp[idx]; vec_dst[thread_offset + idx] = vec_temp[idx];
} }
} }
} }
......
...@@ -32,7 +32,6 @@ static __device__ __forceinline__ platform::float16 LogFunctor( ...@@ -32,7 +32,6 @@ static __device__ __forceinline__ platform::float16 LogFunctor(
static __device__ __forceinline__ float LogFunctor(float x) { return logf(x); } static __device__ __forceinline__ float LogFunctor(float x) { return logf(x); }
static __device__ __forceinline__ double LogFunctor(double x) { return log(x); } static __device__ __forceinline__ double LogFunctor(double x) { return log(x); }
} // namespace details
/*************************** Compute Functor****************************/ /*************************** Compute Functor****************************/
// for margin_cross_entropy // for margin_cross_entropy
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
...@@ -75,6 +74,7 @@ struct DivideFunctor { ...@@ -75,6 +74,7 @@ struct DivideFunctor {
T n_inv; T n_inv;
}; };
} // namespace details
} // namespace kernel_primitives } // namespace kernel_primitives
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -128,39 +128,9 @@ __global__ void AddMarginToPositiveLogitsKernel( ...@@ -128,39 +128,9 @@ __global__ void AddMarginToPositiveLogitsKernel(
} }
} }
static __device__ __forceinline__ platform::float16 exp_on_device(
platform::float16 x) {
return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float exp_on_device(float x) {
return expf(x);
}
static __device__ __forceinline__ double exp_on_device(double x) {
return exp(x);
}
static __device__ __forceinline__ platform::float16 log_on_device(
platform::float16 x) {
return ::Eigen::numext::log(x);
}
static __device__ __forceinline__ float log_on_device(float x) {
return logf(x);
}
static __device__ __forceinline__ double log_on_device(double x) {
return log(x);
}
template <typename Tx, typename Ty = Tx>
struct ExpLogitTransformer {
HOSTDEVICE explicit inline ExpLogitTransformer(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(exp_on_device(x));
}
};
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct ExpAndSum { struct ExpAndSum {
using Transformer = ExpLogitTransformer<Tx>; using Transformer = kpds::ExpLogitTransformer<Tx>;
inline Ty initial() { return static_cast<Ty>(0.0f); } inline Ty initial() { return static_cast<Ty>(0.0f); }
...@@ -189,7 +159,7 @@ __global__ void LogitsMinusLogSumKernel(T* logits, const T* logits_sum_per_row, ...@@ -189,7 +159,7 @@ __global__ void LogitsMinusLogSumKernel(T* logits, const T* logits_sum_per_row,
const int64_t N, const int64_t D) { const int64_t N, const int64_t D) {
CUDA_KERNEL_LOOP(i, N * D) { CUDA_KERNEL_LOOP(i, N * D) {
auto row = i / D; auto row = i / D;
logits[i] -= log_on_device(logits_sum_per_row[row]); logits[i] -= kpds::LogFunctor(logits_sum_per_row[row]);
} }
} }
...@@ -204,9 +174,9 @@ __global__ void HardLabelSoftmaxWithCrossEntropyKernel( ...@@ -204,9 +174,9 @@ __global__ void HardLabelSoftmaxWithCrossEntropyKernel(
if ((col + start_index) == labels[row]) { if ((col + start_index) == labels[row]) {
auto softmax = log_softmax[i]; auto softmax = log_softmax[i];
loss[row] = -softmax; loss[row] = -softmax;
log_softmax[i] = exp_on_device(softmax); log_softmax[i] = kpds::ExpFunctor(softmax);
} else { } else {
log_softmax[i] = exp_on_device(log_softmax[i]); log_softmax[i] = kpds::ExpFunctor(log_softmax[i]);
} }
} }
} }
......
...@@ -24,9 +24,11 @@ limitations under the License. */ ...@@ -24,9 +24,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace kpds = paddle::operators::kernel_primitives::details;
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomMin { struct CustomMin {
using Transformer = detail::IdentityFunctor<Tx>; using Transformer = kpds::IdentityFunctor<Tx>;
inline Ty initial() { inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<Ty>::max()); return static_cast<Ty>(std::numeric_limits<Ty>::max());
...@@ -39,7 +41,7 @@ struct CustomMin { ...@@ -39,7 +41,7 @@ struct CustomMin {
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomMax { struct CustomMax {
using Transformer = detail::IdentityFunctor<Tx>; using Transformer = kpds::IdentityFunctor<Tx>;
inline Ty initial() { inline Ty initial() {
return static_cast<Ty>(std::numeric_limits<Ty>::lowest()); return static_cast<Ty>(std::numeric_limits<Ty>::lowest());
...@@ -53,7 +55,7 @@ struct CustomMax { ...@@ -53,7 +55,7 @@ struct CustomMax {
// for cub::Reduce // for cub::Reduce
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomSum { struct CustomSum {
using Transformer = detail::IdentityFunctor<Tx, Ty>; using Transformer = kpds::IdentityFunctor<Tx, Ty>;
inline Ty initial() { return static_cast<Ty>(0.0f); } inline Ty initial() { return static_cast<Ty>(0.0f); }
...@@ -64,7 +66,7 @@ struct CustomSum { ...@@ -64,7 +66,7 @@ struct CustomSum {
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomMean { struct CustomMean {
using Transformer = detail::DivideFunctor<Tx>; using Transformer = kpds::DivideFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(0.0f); } inline Ty initial() { return static_cast<Ty>(0.0f); }
...@@ -75,7 +77,7 @@ struct CustomMean { ...@@ -75,7 +77,7 @@ struct CustomMean {
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomMul { struct CustomMul {
using Transformer = detail::IdentityFunctor<Tx>; using Transformer = kpds::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(1.0f); } inline Ty initial() { return static_cast<Ty>(1.0f); }
...@@ -86,7 +88,7 @@ struct CustomMul { ...@@ -86,7 +88,7 @@ struct CustomMul {
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomLogicalOr { struct CustomLogicalOr {
using Transformer = detail::IdentityFunctor<Tx>; using Transformer = kpds::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(false); } inline Ty initial() { return static_cast<Ty>(false); }
...@@ -97,7 +99,7 @@ struct CustomLogicalOr { ...@@ -97,7 +99,7 @@ struct CustomLogicalOr {
template <typename Tx, typename Ty = Tx> template <typename Tx, typename Ty = Tx>
struct CustomLogicalAnd { struct CustomLogicalAnd {
using Transformer = detail::IdentityFunctor<Tx>; using Transformer = kpds::IdentityFunctor<Tx>;
inline Ty initial() { return static_cast<Ty>(true); } inline Ty initial() { return static_cast<Ty>(true); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册