未验证 提交 39210ed0 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Refine name style and MoeKernel (#49432)

上级 c0d6ec63
......@@ -216,9 +216,9 @@ __device__ void BlockReduceMaxMinAndWrite(const T max_value,
if (max_ptr && min_ptr && mean_ptr) {
__syncthreads();
T block_max_value = phi::funcs::blockReduceMax<T>(max_value, FINAL_MASK);
T block_min_value = phi::funcs::blockReduceMin<T>(min_value, FINAL_MASK);
T block_mean_value = phi::funcs::blockReduceSum<T>(mean_value, FINAL_MASK);
T block_max_value = phi::funcs::BlockReduceMax<T>(max_value, FINAL_MASK);
T block_min_value = phi::funcs::BlockReduceMin<T>(min_value, FINAL_MASK);
T block_mean_value = phi::funcs::BlockReduceSum<T>(mean_value, FINAL_MASK);
if (threadIdx.x == 0) {
max_ptr[offset] = block_max_value;
......
......@@ -68,7 +68,7 @@ __global__ void merge_layernorm_v2(T *out,
}
}
mean = phi::funcs::blockReduceSum<float>(sum, FINAL_MASK);
mean = phi::funcs::BlockReduceSum<float>(sum, FINAL_MASK);
if (tid == 0) {
s_mean = mean / n;
}
......@@ -84,7 +84,7 @@ __global__ void merge_layernorm_v2(T *out,
}
}
variance = phi::funcs::blockReduceSum<float>(var, FINAL_MASK);
variance = phi::funcs::BlockReduceSum<float>(var, FINAL_MASK);
if (tid == 0) {
s_variance = rsqrtf(variance / n + layernorm_eps);
}
......
......@@ -26,6 +26,7 @@
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
namespace paddle {
namespace inference {
......@@ -33,41 +34,6 @@ namespace tensorrt {
namespace plugin {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
#define FINAL_MASK 0xffffffff
template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T *val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T *val) {
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T)0.0f;
}
template <int UNROLL_FACTOR>
__global__ void generalAddBiasResidualLayerNormOpt2(
......@@ -119,7 +85,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2(
float sums[2];
sums[0] = x_sum;
sums[1] = x2_sum;
blockReduceSumV2<float, 2>(sums);
phi::funcs::BlockReduceSumV2<float, 2>(sums);
if (threadIdx.x == 0) {
s_mean = sums[0] / n / 2;
......
......@@ -70,7 +70,7 @@ __global__ void merge_layernorm_v2(T *out,
}
}
mean = phi::funcs::blockReduceSum<float>(sum, FINAL_MASK);
mean = phi::funcs::BlockReduceSum<float>(sum, FINAL_MASK);
if (tid == 0) {
s_mean = mean / n;
}
......@@ -86,7 +86,7 @@ __global__ void merge_layernorm_v2(T *out,
}
}
variance = phi::funcs::blockReduceSum<float>(var, FINAL_MASK);
variance = phi::funcs::BlockReduceSum<float>(var, FINAL_MASK);
if (tid == 0) {
s_variance = rsqrtf(variance / n + layernorm_eps);
}
......
......@@ -269,10 +269,10 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_,
? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + qk_offset])
: -1e20f;
float max_val = phi::funcs::blockReduceMax<float>(tmp, mask);
float max_val = phi::funcs::BlockReduceMax<float>(tmp, mask);
float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f;
float sum_val = phi::funcs::blockReduceSum<float>(qk_tmp, mask);
float sum_val = phi::funcs::BlockReduceSum<float>(qk_tmp, mask);
if (threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / sum_val);
......@@ -295,10 +295,10 @@ __global__ void SoftmaxKernelWithEltadd<half>(half *qk_buf_,
? static_cast<float>(qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + qk_offset])
: -1e20f;
float max_val = phi::funcs::blockReduceMax<float>(tmp, mask);
float max_val = phi::funcs::BlockReduceMax<float>(tmp, mask);
float qk_tmp = threadIdx.x < seq_len ? __expf(tmp - max_val) : 0.0f;
float sum_val = phi::funcs::blockReduceSum<float>(qk_tmp, mask);
float sum_val = phi::funcs::BlockReduceSum<float>(qk_tmp, mask);
if (threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (half)(qk_tmp / sum_val);
......@@ -321,12 +321,12 @@ __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_,
? phi::funcs::ToFloat2<T>(qk_buf_[idx + qk_offset] +
bias_qk_[idx + qk_offset])
: make_float2(-1e20f, -1e20f);
float max_val = phi::funcs::blockReduceMax<float>(max(tmp.x, tmp.y), mask);
float max_val = phi::funcs::BlockReduceMax<float>(max(tmp.x, tmp.y), mask);
float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val),
__expf(tmp.y - max_val))
: make_float2(0.f, 0.f);
float sum_val =
phi::funcs::blockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
phi::funcs::BlockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
if (idx < seq_len) {
qk_buf_[idx + qk_offset] =
......@@ -353,12 +353,12 @@ __global__ void SoftmaxKernelWithEltadd2<half2>(half2 *qk_buf_,
? phi::funcs::ToFloat2<half2>(qk_buf_[idx + qk_offset] +
bias_qk_[idx + qk_offset])
: make_float2(-1e20f, -1e20f);
float max_val = phi::funcs::blockReduceMax<float>(max(tmp.x, tmp.y), mask);
float max_val = phi::funcs::BlockReduceMax<float>(max(tmp.x, tmp.y), mask);
float2 qk_tmp = idx < seq_len ? make_float2(__expf(tmp.x - max_val),
__expf(tmp.y - max_val))
: make_float2(0.f, 0.f);
float sum_val =
phi::funcs::blockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
phi::funcs::BlockReduceSum<float>(qk_tmp.x + qk_tmp.y, mask) + 1e-6f;
if (idx < seq_len) {
qk_buf_[idx + qk_offset] =
......@@ -386,14 +386,14 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
bias_qk[threadIdx.x + i + qk_offset]
: stride_max;
}
T max_val = phi::funcs::blockReduceMax<T>(stride_max, mask);
T max_val = phi::funcs::BlockReduceMax<T>(stride_max, mask);
T stride_sum = 0.f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
stride_sum += __expf(qk_buf[threadIdx.x + i + qk_offset] +
bias_qk[threadIdx.x + i + qk_offset] - max_val);
}
T sum_val = phi::funcs::blockReduceSum<T>(stride_sum, mask);
T sum_val = phi::funcs::BlockReduceSum<T>(stride_sum, mask);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
qk_buf[threadIdx.x + i + qk_offset] =
......@@ -422,7 +422,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
bias_qk[threadIdx.x + i + qk_offset]);
stride_max = tmp > stride_max ? tmp : stride_max;
}
float max_val = phi::funcs::blockReduceMax<float>(stride_max, mask);
float max_val = phi::funcs::BlockReduceMax<float>(stride_max, mask);
float stride_sum = 0.f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
......@@ -430,7 +430,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
bias_qk[threadIdx.x + i + qk_offset]);
stride_sum += __expf(tmp - max_val);
}
float sum_val = phi::funcs::blockReduceSum<float>(stride_sum, mask);
float sum_val = phi::funcs::BlockReduceSum<float>(stride_sum, mask);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
float tmp =
......@@ -461,7 +461,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
stride_max.y = max(stride_max.y, cur.y);
}
float max_val =
phi::funcs::blockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
phi::funcs::BlockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
float2 stride_sum = make_float2(0.f, 0.f);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
......@@ -472,7 +472,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
}
float sum_val =
phi::funcs::blockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
phi::funcs::BlockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
1e-6f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
......@@ -507,7 +507,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
stride_max.y = max(stride_max.y, cur.y);
}
float max_val =
phi::funcs::blockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
phi::funcs::BlockReduceMax<float>(max(stride_max.x, stride_max.y), mask);
float2 stride_sum = make_float2(0.f, 0.f);
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
......@@ -519,7 +519,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
}
float sum_val =
phi::funcs::blockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
phi::funcs::BlockReduceSum<float>(stride_sum.x + stride_sum.y, mask) +
1e-6f;
for (int i = 0; threadIdx.x + i < seq_len; i += blockDim.x) {
......@@ -573,83 +573,6 @@ inline __device__ T hadd2(T a, T b) {
return __hadd2(a, b);
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T *val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T *val) {
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T)0.0f;
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceMaxV2(T *val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceMaxV2(T *val) {
static __shared__ T shared[32][NUM];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
warpReduceMaxV2<T, NUM>(val); // get maxx in each warp
if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[wid][i] = val[i];
}
}
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[lane][i] : (T)-1e20f;
}
warpReduceMaxV2<T, NUM>(val);
return (T)0.0f;
}
template <typename T, int ITEMS_PER_THREAD, int NUM>
__global__ void softmax_kernel_with_mask(T *qk_buf_,
const T *attr_mask,
......@@ -715,9 +638,9 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_,
}
if (blockDim.x <= 32) {
warpReduceMaxV2<float, NUM>(local_max);
phi::funcs::WarpReduceMaxV2<float, NUM>(local_max);
} else {
blockReduceMaxV2<float, NUM>(local_max);
phi::funcs::BlockReduceMaxV2<float, NUM>(local_max);
}
if (threadIdx.x == 0) {
......@@ -750,9 +673,9 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_,
}
if (blockDim.x <= 32) {
warpReduceSumV2<float, NUM>(local_sum);
phi::funcs::WarpReduceSumV2<float, NUM>(local_sum);
} else {
blockReduceSumV2<float, NUM>(local_sum);
phi::funcs::BlockReduceSumV2<float, NUM>(local_sum);
}
if (threadIdx.x == 0) {
......
......@@ -187,8 +187,8 @@ __global__ void L2NormKernel(
g_tmp += (tmp1 * tmp1);
tid += grid_stride;
}
p_tmp = phi::funcs::blockReduceSum<MT>(p_tmp, FINAL_MASK);
g_tmp = phi::funcs::blockReduceSum<MT>(g_tmp, FINAL_MASK);
p_tmp = phi::funcs::BlockReduceSum<MT>(p_tmp, FINAL_MASK);
g_tmp = phi::funcs::BlockReduceSum<MT>(g_tmp, FINAL_MASK);
if (threadIdx.x == 0) {
p_buffer[blockIdx.x] = p_tmp;
......@@ -198,8 +198,8 @@ __global__ void L2NormKernel(
cg->sync(); // Grid sync for writring partial result to gloabl memory
MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0;
MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0;
MT tmp0 = phi::funcs::blockReduceSum<MT>(p_part_sum, FINAL_MASK);
MT tmp1 = phi::funcs::blockReduceSum<MT>(g_part_sum, FINAL_MASK);
MT tmp0 = phi::funcs::BlockReduceSum<MT>(p_part_sum, FINAL_MASK);
MT tmp1 = phi::funcs::BlockReduceSum<MT>(g_part_sum, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] = tmp0;
s_buffer[1] = tmp1;
......@@ -393,8 +393,8 @@ __global__ void MomentumLarsKernel(const T* param,
MT grad_part_norm = threadIdx.x < thresh ? g_buffer[threadIdx.x] : 0;
__syncthreads();
MT param_norm =
Sqrt(phi::funcs::blockReduceSum<MT>(param_part_norm, FINAL_MASK));
MT grad_norm = Sqrt(rescale_grad_pow * phi::funcs::blockReduceSum<MT>(
Sqrt(phi::funcs::BlockReduceSum<MT>(param_part_norm, FINAL_MASK));
MT grad_norm = Sqrt(rescale_grad_pow * phi::funcs::BlockReduceSum<MT>(
grad_part_norm, FINAL_MASK));
#endif
MomentumUpdate<T, MT>(param,
......
......@@ -168,7 +168,7 @@ struct KeyValuePair<half> {
#define WARP_SIZE 32
template <typename T>
__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
__inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
......@@ -180,12 +180,12 @@ __inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceSum(T val, unsigned mask) {
__inline__ __device__ T BlockReduceSum(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val, mask);
val = WarpReduceSum<T>(val, mask);
__syncthreads();
if (lane == 0) shared[wid] = val;
......@@ -195,13 +195,53 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) {
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = warpReduceSum<T>(val, mask);
val = WarpReduceSum<T>(val, mask);
return val;
}
/*
WarpReduce multi values.
*/
template <typename T, int NUM>
__inline__ __device__ T WarpReduceSumV2(T *val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T BlockReduceSumV2(T *val) {
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
WarpReduceSumV2<T, NUM>(val);
if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
}
WarpReduceSumV2<T, NUM>(val);
return (T)0.0f;
}
template <typename T>
__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
__inline__ __device__ T WarpReduceMax(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
......@@ -211,8 +251,19 @@ __inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
return val;
}
template <typename T, int NUM>
__inline__ __device__ T WarpReduceMaxV2(T *val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
}
return (T)(0.0f);
}
template <typename T>
__inline__ __device__ T warpReduceMin(T val, unsigned lane_mask) {
__inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val = min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
......@@ -246,12 +297,12 @@ __inline__ __device__ T PartialWarpReduceMin(T val, unsigned lane_mask) {
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceMax(T val, unsigned mask) {
__inline__ __device__ T BlockReduceMax(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceMax(val, mask);
val = WarpReduceMax(val, mask);
if (lane == 0) shared[wid] = val;
......@@ -260,26 +311,55 @@ __inline__ __device__ T blockReduceMax(T val, unsigned mask) {
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (lane < block_span) ? shared[lane] : -1e10f;
val = warpReduceMax(val, mask);
val = WarpReduceMax(val, mask);
return val;
}
template <typename T, int NUM>
__inline__ __device__ T BlockReduceMaxV2(T *val) {
static __shared__ T shared[32][NUM];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
WarpReduceMaxV2<T, NUM>(val); // get maxx in each warp
if (lane == 0) { // record in-warp maxx by warp Idx
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[wid][i] = val[i];
}
}
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[lane][i] : (T)-1e20f;
}
WarpReduceMaxV2<T, NUM>(val);
return (T)0.0f;
}
/* Calculate the minimum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceMin(T val, unsigned mask) {
__inline__ __device__ T BlockReduceMin(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceMin(val, mask);
val = WarpReduceMin(val, mask);
if (lane == 0) shared[wid] = val;
__syncthreads();
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (lane < block_span) ? shared[lane] : 1e10f;
val = warpReduceMin(val, mask);
val = WarpReduceMin(val, mask);
return val;
}
......
......@@ -160,30 +160,17 @@ void InitExpertChoiceRouteKernelLauncher(
<<<grid, block, 0, stream>>>(reinterpret_cast<half*>(buffer), \
(const half*)attr_mask, \
batch_size, \
head_num, \
seq_len_1, \
seq_len_2, \
(const half)scalar); \
seq_len); \
} else { \
softmax_kernel_v4_half2<__half, ITEMS_PER_THREAD> \
<<<grid, block, 0, stream>>>(reinterpret_cast<half*>(buffer), \
(const half*)attr_mask, \
batch_size, \
head_num, \
seq_len_1, \
seq_len_2, \
(const half)scalar); \
seq_len); \
} \
} else { \
softmax_kernel_v4<ITEMS_PER_THREAD, T> \
<<<grid, block, 0, stream>>>(buffer, \
buffer_src, \
attr_mask, \
batch_size, \
head_num, \
seq_len_1, \
seq_len_2, \
scalar); \
softmax_kernel_v4<ITEMS_PER_THREAD, T><<<grid, block, 0, stream>>>( \
buffer, buffer_src, attr_mask, batch_size, seq_len); \
}
template <typename T>
......@@ -191,19 +178,16 @@ void invokeMaskedSoftMax(T* buffer,
const T* buffer_src,
const T* attr_mask,
const int batch_size,
const int seq_len_1,
const int seq_len_2,
const int head_num,
const T scalar,
const int seq_len,
cudaStream_t stream) {
// NOTE: attention scores shape (batch_size, head_num, seq_len_1, seq_len_2)
dim3 grid(seq_len_1, batch_size, head_num);
if (batch_size * head_num > 360) {
grid.x = ceil(static_cast<float>(seq_len_1) / 32.0f);
// NOTE: attention scores shape (batch_size, seq_len)
dim3 grid(1, batch_size, 1);
if (batch_size > 360) {
grid.x = ceil(static_cast<float>(1) / 32.0f);
}
bool is_half2 = sizeof(T) == 2 && sizeof(T) == 2 && seq_len_2 % 2 == 0;
dim3 block((seq_len_2 / (is_half2 ? 2 : 1) + 31) / 32 * 32);
bool is_half2 = sizeof(T) == 2 && sizeof(T) == 2 && seq_len % 2 == 0;
dim3 block((seq_len / (is_half2 ? 2 : 1) + 31) / 32 * 32);
if (block.x > 2048 && block.x <= 4096) {
SOFTMAX_KERNEL(4)
......@@ -766,26 +750,19 @@ void MoeKernel(const Context& ctx,
k,
batch_size,
ctx.stream());
T scalar = (T)1.0f;
if (IS_FP16) {
invokeMaskedSoftMax<__half>(reinterpret_cast<__half*>(gating_output),
reinterpret_cast<const __half*>(gating_output),
reinterpret_cast<const __half*>(attr_mask),
/*batch_size=*/num_rows,
/*seq_len_1=*/1,
/*seq_len_2=*/num_experts,
/*head_num=*/1,
*reinterpret_cast<const __half*>(&scalar),
/*seq_len=*/num_experts,
ctx.stream());
} else {
invokeMaskedSoftMax<float>(reinterpret_cast<float*>(gating_output),
reinterpret_cast<const float*>(gating_output),
reinterpret_cast<const float*>(attr_mask),
/*batch_size=*/num_rows,
/*seq_len_1=*/1,
/*seq_len_2=*/num_experts,
/*head_num=*/1,
*reinterpret_cast<const float*>(&scalar),
/*seq_len=*/num_experts,
ctx.stream());
}
InvokeTransposeAxis01(
......
......@@ -26,87 +26,6 @@ static inline size_t AlignTo16(const size_t& input) {
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
}
/*
WarpReduce multi values.
TODO(zhengzekang): Add blocksize templates to reduce shared memory usage.
*/
template <typename T, int NUM>
__inline__ __device__ T warpReduceSumV2(T* val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32);
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceSumV2(T* val) {
static __shared__ T shared[NUM][33];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
warpReduceSumV2<T, NUM>(val);
if (lane == 0) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[i][wid] = val[i];
}
}
__syncthreads();
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[i][lane] : (T)(0.0f);
}
warpReduceSumV2<T, NUM>(val);
return (T)0.0f;
}
template <typename T, int NUM>
__inline__ __device__ T warpReduceMaxV2(T* val) {
#pragma unroll
for (int i = 0; i < NUM; i++) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32));
}
return (T)(0.0f);
}
template <typename T, int NUM>
__inline__ __device__ T blockReduceMaxV2(T* val) {
static __shared__ T shared[32][NUM];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
warpReduceMaxV2<T, NUM>(val); // get maxx in each warp
if (lane == 0) { // record in-warp maxx by warp Idx
#pragma unroll
for (int i = 0; i < NUM; i++) {
shared[wid][i] = val[i];
}
}
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool is_mask = threadIdx.x < (blockDim.x / 32.f);
#pragma unroll
for (int i = 0; i < NUM; i++) {
val[i] = is_mask ? shared[lane][i] : (T)-1e20f;
}
warpReduceMaxV2<T, NUM>(val);
return (T)0.0f;
}
class CubKeyValueSorter {
public:
CubKeyValueSorter();
......@@ -311,65 +230,57 @@ __global__ void initialize_expert_choice_route_kernel(
template <int ITEMS_PER_THREAD, typename T>
__global__ void softmax_kernel_v4(
T* qk_buf_,
const T* qk_buf_src, // shape [batch_size, head_num, seq_len_1, seq_len_2]
const T* attr_mask, // shape [batch_size, seq_len_1, seq_len_2]
const T* qk_buf_src, // shape [batch_size, seq_len]
const T* attr_mask, // shape [batch_size, seq_len]
const int batch_size,
const int head_num,
const int seq_len_1,
const int seq_len_2,
const T scalar) {
const int seq_len) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) {
float data[ITEMS_PER_THREAD];
int qk_offset;
__shared__ float s_mean, s_max;
float local_max = -1e20f;
for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) {
qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) *
seq_len_2 +
blockDim.x * i + threadIdx.x;
int mask_offset = (blockIdx.y * seq_len_1 + seq_id) * seq_len_2 +
blockDim.x * i + threadIdx.x;
float qk = static_cast<float>(qk_buf_src[qk_offset]);
float mask_val = static_cast<float>(__ldg(&attr_mask[mask_offset]));
mask_val = (1.0f - mask_val) * -10000.0f;
data[i] = qk * static_cast<float>(scalar) + mask_val;
local_max = fmax(local_max, data[i]);
}
float data[ITEMS_PER_THREAD];
int qk_offset;
__shared__ float s_mean, s_max;
float local_max = -1e20f;
for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) {
qk_offset =
((blockIdx.y + blockIdx.z)) * seq_len + blockDim.x * i + threadIdx.x;
int mask_offset = (blockIdx.y) * seq_len + blockDim.x * i + threadIdx.x;
float qk = static_cast<float>(qk_buf_src[qk_offset]);
float mask_val = static_cast<float>(__ldg(&attr_mask[mask_offset]));
mask_val = (1.0f - mask_val) * -10000.0f;
data[i] = qk + mask_val;
local_max = fmax(local_max, data[i]);
}
float max_val =
blockDim.x <= 32
? phi::funcs::warpReduceMax<float>(local_max, 0xFFFFFFFF)
: phi::funcs::blockReduceMax<float>(local_max, 0xffffffff);
if (threadIdx.x == 0) {
s_max = max_val;
}
__syncthreads();
float max_val =
blockDim.x <= 32
? phi::funcs::WarpReduceMax<float>(local_max, 0xFFFFFFFF)
: phi::funcs::BlockReduceMax<float>(local_max, 0xFFFFFFFF);
if (threadIdx.x == 0) {
s_max = max_val;
}
__syncthreads();
float local_sum = 0;
for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) {
data[i] = __expf(data[i] - s_max);
local_sum += data[i];
}
float sum_val =
blockDim.x <= 32
? phi::funcs::warpReduceSum<float>(local_sum, 0xffffffff)
: phi::funcs::blockReduceSum<float>(local_sum, 0xffffffff);
if (threadIdx.x == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
float local_sum = 0;
for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) {
data[i] = __expf(data[i] - s_max);
local_sum += data[i];
}
float sum_val =
blockDim.x <= 32
? phi::funcs::WarpReduceSum<float>(local_sum, 0xFFFFFFFF)
: phi::funcs::BlockReduceSum<float>(local_sum, 0xFFFFFFFF);
if (threadIdx.x == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
for (int i = 0; blockDim.x * i + threadIdx.x < seq_len_2; i++) {
qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) *
seq_len_2 +
blockDim.x * i + threadIdx.x;
qk_buf_[qk_offset] = (T)(data[i] * s_mean);
}
for (int i = 0; blockDim.x * i + threadIdx.x < seq_len; i++) {
qk_offset =
((blockIdx.y + blockIdx.z)) * seq_len + blockDim.x * i + threadIdx.x;
qk_buf_[qk_offset] = (T)(data[i] * s_mean);
}
#endif
}
......@@ -378,77 +289,69 @@ template <typename T, int ITEMS_PER_THREAD>
__global__ void softmax_kernel_v4_half2(T* qk_buf_,
const T* attr_mask,
const int batch_size,
const int head_num,
const int seq_len_1,
const int seq_len_2,
const T scalar) {
const int seq_len) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
using T2 = half2;
T2* qk_buf_half2 = reinterpret_cast<T2*>(qk_buf_);
const T2* attr_mask_half2 = (const T2*)attr_mask;
for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x) {
T2 data[ITEMS_PER_THREAD];
int qk_offset;
__shared__ float s_mean, s_max;
float local_max = -1e20f;
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD;
i++) {
qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) *
(seq_len_2 / 2) +
blockDim.x * i + threadIdx.x;
int mask_offset = (blockIdx.y * seq_len_1 + seq_id) * (seq_len_2 / 2) +
blockDim.x * i + threadIdx.x;
T2 qk = qk_buf_half2[qk_offset];
T2 mask_val = __ldg(&attr_mask_half2[mask_offset]);
mask_val = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val),
__float2half2_rn(-10000.0f));
data[i] = __hadd2(__hmul2(qk, __half2half2(scalar)), mask_val);
local_max = fmax(
local_max,
fmax(static_cast<float>(data[i].x), static_cast<float>(data[i].y)));
}
T2 data[ITEMS_PER_THREAD];
int qk_offset;
__shared__ float s_mean, s_max;
float local_max = -1e20f;
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
qk_offset = ((blockIdx.y + blockIdx.z)) * (seq_len / 2) + blockDim.x * i +
threadIdx.x;
int mask_offset = blockIdx.y * (seq_len / 2) + blockDim.x * i + threadIdx.x;
T2 qk = qk_buf_half2[qk_offset];
T2 mask_val = __ldg(&attr_mask_half2[mask_offset]);
mask_val = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val),
__float2half2_rn(-10000.0f));
data[i] = __hadd2(qk, mask_val);
local_max = fmax(
local_max,
fmax(static_cast<float>(data[i].x), static_cast<float>(data[i].y)));
}
float max_val =
blockDim.x <= 32
? phi::funcs::warpReduceMax<float>(local_max, 0xFFFFFFFF)
: phi::funcs::blockReduceMax<float>(local_max, 0xFFFFFFFF);
if (threadIdx.x == 0) {
s_max = max_val;
}
__syncthreads();
float local_sum = 0;
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD;
i++) {
data[i] = h2exp(__hsub2(data[i], __float2half2_rn(s_max)));
local_sum += static_cast<float>(data[i].x + data[i].y);
}
float max_val =
blockDim.x <= 32
? phi::funcs::WarpReduceMax<float>(local_max, 0xFFFFFFFF)
: phi::funcs::BlockReduceMax<float>(local_max, 0xFFFFFFFF);
if (threadIdx.x == 0) {
s_max = max_val;
}
__syncthreads();
float sum_val =
blockDim.x <= 32
? phi::funcs::warpReduceSum<float>(local_sum, 0xFFFFFFFF)
: phi::funcs::blockReduceSum<float>(local_sum, 0xFFFFFFFF);
float local_sum = 0;
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
data[i] = h2exp(__hsub2(data[i], __float2half2_rn(s_max)));
local_sum += static_cast<float>(data[i].x + data[i].y);
}
if (threadIdx.x == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD;
i++) {
qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 + seq_id) *
(seq_len_2 / 2) +
blockDim.x * i + threadIdx.x;
qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean));
}
float sum_val =
blockDim.x <= 32
? phi::funcs::WarpReduceSum<float>(local_sum, 0xFFFFFFFF)
: phi::funcs::BlockReduceSum<float>(local_sum, 0xFFFFFFFF);
if (threadIdx.x == 0) {
s_mean = sum_val + 1e-6f;
s_mean = __fdividef(1.0f, s_mean);
}
__syncthreads();
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
qk_offset = ((blockIdx.y + blockIdx.z)) * (seq_len / 2) + blockDim.x * i +
threadIdx.x;
qk_buf_half2[qk_offset] = __hmul2(data[i], __float2half2_rn(s_mean));
}
#endif
}
......@@ -457,131 +360,123 @@ template <typename T, int ITEMS_PER_THREAD, int NUM>
__global__ void softmax_kernel_v5_half2(T* qk_buf_,
const T* attr_mask,
const int batch_size,
const int head_num,
const int seq_len_1,
const int seq_len_2,
const T scalar) {
const int seq_len) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
using T2 = half2;
T2* qk_buf_half2 = reinterpret_cast<T2*>(qk_buf_);
const T2* attr_mask_half2 = (const T2*)attr_mask;
for (int seq_id = blockIdx.x; seq_id < seq_len_1; seq_id += gridDim.x * NUM) {
T2 data[NUM][ITEMS_PER_THREAD];
T2 data[NUM][ITEMS_PER_THREAD];
int qk_offset[NUM];
int qk_offset[NUM];
__shared__ float s_sum[NUM], s_max[NUM];
float local_max[NUM];
__shared__ float s_sum[NUM], s_max[NUM];
float local_max[NUM];
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_max[j] = -1e20f;
}
for (int j = 0; j < NUM; j++) {
local_max[j] = -1e20f;
}
const int MAX_NUM =
min((seq_len_1 - seq_id + gridDim.x - 1) / gridDim.x, NUM);
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD;
i++) {
int mask_offset[NUM];
const int MAX_NUM = min((1 + gridDim.x - 1) / gridDim.x, NUM);
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
int mask_offset[NUM];
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 +
seq_id + j * gridDim.x) *
(seq_len_2 / 2) +
for (int j = 0; j < MAX_NUM; j++) {
qk_offset[j] =
((blockIdx.y + blockIdx.z) + j * gridDim.x) * (seq_len / 2) +
blockDim.x * i + threadIdx.x;
mask_offset[j] = (blockIdx.y + j * gridDim.x) * (seq_len / 2) +
blockDim.x * i + threadIdx.x;
mask_offset[j] = (blockIdx.y * seq_len_1 + seq_id + j * gridDim.x) *
(seq_len_2 / 2) +
blockDim.x * i + threadIdx.x;
}
}
T2 mask_val[NUM];
T2 mask_val[NUM];
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
mask_val[j] = __ldg(&attr_mask_half2[mask_offset[j]]);
}
for (int j = 0; j < MAX_NUM; j++) {
mask_val[j] = __ldg(&attr_mask_half2[mask_offset[j]]);
}
T2 qk[NUM];
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
qk[j] = qk_buf_half2[qk_offset[j]];
}
T2 qk[NUM];
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
mask_val[j] = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val[j]),
__float2half2_rn(-10000.0f));
}
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
data[j][i] = __hadd2(__hmul2(qk[j], __half2half2(scalar)), mask_val[j]);
local_max[j] = fmax(local_max[j],
fmax(static_cast<float>(data[j][i].x),
static_cast<float>(data[j][i].y)));
}
for (int j = 0; j < MAX_NUM; j++) {
qk[j] = qk_buf_half2[qk_offset[j]];
}
if (blockDim.x <= 32) {
warpReduceMaxV2<float, NUM>(local_max);
} else {
blockReduceMaxV2<float, NUM>(local_max);
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
mask_val[j] = __hmul2(__hsub2(__float2half2_rn(1.0f), mask_val[j]),
__float2half2_rn(-10000.0f));
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
s_max[j] = local_max[j];
}
for (int j = 0; j < MAX_NUM; j++) {
data[j][i] = __hadd2(qk[j], mask_val[j]);
local_max[j] = fmax(local_max[j],
fmax(static_cast<float>(data[j][i].x),
static_cast<float>(data[j][i].y)));
}
__syncthreads();
float local_sum[NUM];
}
if (blockDim.x <= 32) {
phi::funcs::WarpReduceMaxV2<float, NUM>(local_max);
} else {
phi::funcs::BlockReduceMaxV2<float, NUM>(local_max);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
local_sum[j] = {0.f};
s_max[j] = local_max[j];
}
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD;
i++) {
}
__syncthreads();
float local_sum[NUM];
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
data[j][i] = h2exp(__hsub2(data[j][i], __float2half2_rn(s_max[j])));
}
for (int j = 0; j < NUM; j++) {
local_sum[j] = {0.f};
}
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
local_sum[j] += static_cast<float>(data[j][i].x + data[j][i].y);
}
for (int j = 0; j < MAX_NUM; j++) {
data[j][i] = h2exp(__hsub2(data[j][i], __float2half2_rn(s_max[j])));
}
if (blockDim.x <= 32) {
warpReduceSumV2<float, NUM>(local_sum);
} else {
blockReduceSumV2<float, NUM>(local_sum);
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
local_sum[j] += static_cast<float>(data[j][i].x + data[j][i].y);
}
}
if (threadIdx.x == 0) {
if (blockDim.x <= 32) {
phi::funcs::WarpReduceSumV2<float, NUM>(local_sum);
} else {
phi::funcs::BlockReduceSumV2<float, NUM>(local_sum);
}
if (threadIdx.x == 0) {
#pragma unroll
for (int j = 0; j < NUM; j++) {
s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f);
}
for (int j = 0; j < NUM; j++) {
s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f);
}
__syncthreads();
}
__syncthreads();
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len_2 / 2) && i < ITEMS_PER_THREAD;
i++) {
for (int i = 0;
blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD;
i++) {
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len_1 +
seq_id + j * gridDim.x) *
(seq_len_2 / 2) +
blockDim.x * i + threadIdx.x;
}
for (int j = 0; j < MAX_NUM; j++) {
qk_offset[j] =
((blockIdx.y + blockIdx.z) + j * gridDim.x) * (seq_len / 2) +
blockDim.x * i + threadIdx.x;
}
#pragma unroll
for (int j = 0; j < MAX_NUM; j++) {
qk_buf_half2[qk_offset[j]] =
__hmul2(data[j][i], __float2half2_rn(s_sum[j]));
}
for (int j = 0; j < MAX_NUM; j++) {
qk_buf_half2[qk_offset[j]] =
__hmul2(data[j][i], __float2half2_rn(s_sum[j]));
}
}
#endif
......
......@@ -62,7 +62,7 @@ __global__ void ReduceSumWithSubtract(
}
__syncthreads();
sum_val = phi::funcs::blockReduceSum<T>(sum_val, FULL_MASK);
sum_val = phi::funcs::BlockReduceSum<T>(sum_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = sum_val;
}
......@@ -80,7 +80,7 @@ __global__ void ReduceMaxWithSubtract(const T* x,
}
__syncthreads();
max_val = phi::funcs::blockReduceMax<T>(max_val, FULL_MASK);
max_val = phi::funcs::BlockReduceMax<T>(max_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = max_val;
}
......@@ -98,7 +98,7 @@ __global__ void ReduceMinWithSubtract(const T* x,
}
__syncthreads();
min_val = phi::funcs::blockReduceMin(min_val, FULL_MASK);
min_val = phi::funcs::BlockReduceMin(min_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = min_val;
}
......
......@@ -211,7 +211,7 @@ __inline__ __device__ T PartialBlockMin(T val,
if (threadIdx.x < threshold) {
shared_last_idx = (threshold >> 5) - 1;
val = phi::funcs::warpReduceMin(val, mask);
val = phi::funcs::WarpReduceMin(val, mask);
if (lane == 0) {
shared[wid] = val;
}
......@@ -226,7 +226,7 @@ __inline__ __device__ T PartialBlockMin(T val,
if (threadIdx.x < threshold) {
val = (lane <= shared_last_idx) ? shared[lane]
: std::numeric_limits<T>::max();
val = phi::funcs::warpReduceMin(val, mask);
val = phi::funcs::WarpReduceMin(val, mask);
shared_last_val = val;
}
__syncthreads();
......@@ -292,13 +292,13 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
s_data[1][threadIdx.x] = static_cast<MT>(0);
int remain = nthreads - (tid & (-blockDim.x));
int in_top_max_index =
phi::funcs::blockReduceMax(top_right_index, FINAL_MASK);
phi::funcs::BlockReduceMax(top_right_index, FINAL_MASK);
int in_bot_max_index =
phi::funcs::blockReduceMax(bot_right_index, FINAL_MASK);
phi::funcs::BlockReduceMax(bot_right_index, FINAL_MASK);
if (remain > blockDim.x) {
in_top_min_index = phi::funcs::blockReduceMin(input_index, FINAL_MASK);
in_bot_min_index = phi::funcs::blockReduceMin(bot_left_index, FINAL_MASK);
in_top_min_index = phi::funcs::BlockReduceMin(input_index, FINAL_MASK);
in_bot_min_index = phi::funcs::BlockReduceMin(bot_left_index, FINAL_MASK);
} else {
in_top_min_index = PartialBlockMin(input_index, remain, FINAL_MASK);
in_bot_min_index = PartialBlockMin(bot_left_index, remain, FINAL_MASK);
......
......@@ -47,7 +47,7 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows,
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
mul += out_values[row_first + idx] * dout_values[row_first + idx];
}
T mul_sum = phi::funcs::warpReduceSum<T>(mul, 0xFFFFFFFF);
T mul_sum = phi::funcs::WarpReduceSum<T>(mul, 0xFFFFFFFF);
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
dx_values[row_first + idx] = (dout_values[row_first + idx] - mul_sum) *
......
......@@ -72,7 +72,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
out_values[row_first + idx] = -std::numeric_limits<T>::infinity();
}
}
T row_max_val = phi::funcs::warpReduceMax<T>(max_val, 0xFFFFFFFF);
T row_max_val = phi::funcs::WarpReduceMax<T>(max_val, 0xFFFFFFFF);
T exp_sum = 0;
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
......@@ -81,7 +81,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
exp_sum += exp;
out_values[row_first + idx] = exp;
}
T row_exp_sum = phi::funcs::warpReduceSum<T>(exp_sum, 0xFFFFFFFF);
T row_exp_sum = phi::funcs::WarpReduceSum<T>(exp_sum, 0xFFFFFFFF);
for (int idx = threadIdx.x; idx < row_nnz; idx += blockDim.x) {
out_values[row_first + idx] = out_values[row_first + idx] / row_exp_sum;
......
......@@ -53,7 +53,7 @@ __global__ void SoftmaxGradGpuKernel(const IntT* out_crows,
mul_result += out_values[row_first + idx] * dout_values[row_first + idx];
}
T sum = phi::funcs::warpReduceSum<T>(mul_result, 0xFFFFFFFF);
T sum = phi::funcs::WarpReduceSum<T>(mul_result, 0xFFFFFFFF);
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * warpSize;
......
......@@ -57,7 +57,7 @@ __global__ void SoftmaxGpuKernel(const IntT* x_crows,
max_val = val;
}
}
T row_max_val = phi::funcs::warpReduceMax<T>(max_val, 0xFFFFFFFF);
T row_max_val = phi::funcs::WarpReduceMax<T>(max_val, 0xFFFFFFFF);
T exp_sum = 0;
for (int i = 0; i < kIteration; ++i) {
......@@ -69,7 +69,7 @@ __global__ void SoftmaxGpuKernel(const IntT* x_crows,
exp_sum += exp;
out_values[row_first + idx] = exp;
}
T row_exp_sum = phi::funcs::warpReduceSum<T>(exp_sum, 0xFFFFFFFF);
T row_exp_sum = phi::funcs::WarpReduceSum<T>(exp_sum, 0xFFFFFFFF);
for (int i = 0; i < kIteration; ++i) {
int idx = non_zero_idx + i * warpSize;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册