未验证 提交 afb13484 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Fix ernie python infer diff (#21311)

* Fix ernie pythoin infer diff
* Refine mask

test=develop
上级 b6ce4f8b
......@@ -28,10 +28,10 @@ namespace operators {
#define WARP_SIZE 32
template <typename T>
__inline__ __device__ T warpReduceSum(T val) {
__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val += __shfl_xor_sync(FINAL_MASK, val, mask, warpSize);
val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
#else
val += __shfl_xor(val, mask, warpSize);
#endif
......@@ -40,28 +40,30 @@ __inline__ __device__ T warpReduceSum(T val) {
/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceSum(T val) {
__inline__ __device__ T blockReduceSum(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
val = warpReduceSum<T>(val, mask);
if (lane == 0) shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val, mask);
return val;
}
template <typename T>
__inline__ __device__ T warpReduceMax(T val) {
__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, warpSize));
val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else
val = max(val, __shfl_xor(val, mask, warpSize));
#endif
......@@ -70,19 +72,21 @@ __inline__ __device__ T warpReduceMax(T val) {
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceMax(T val) {
__inline__ __device__ T blockReduceMax(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceMax(val);
val = warpReduceMax(val, mask);
if (lane == 0) shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : -1e10f;
val = warpReduceMax(val);
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : -1e10f;
val = warpReduceMax(val, mask);
return val;
}
......@@ -190,7 +194,8 @@ template <typename T>
__global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len) {
const int seq_len,
const unsigned mask) {
int seq_id = blockIdx.x % seq_len;
int qk_offset = blockIdx.x * seq_len;
int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len;
......@@ -202,13 +207,15 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
bias_qk_[threadIdx.x + bias_offset]))
: 0.0f;
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
float max_val = blockReduceMax<float>(tmp);
float max_val = blockReduceMax<float>(tmp, mask);
if (threadIdx.x == 0) s_max = max_val;
__syncthreads();
float qk_tmp =
threadIdx.x < seq_len ? __expf(static_cast<float>(tmp - s_max)) : 0.0f;
float sum_val = blockReduceSum<float>(qk_tmp);
float sum_val = blockReduceSum<float>(qk_tmp, mask);
if (threadIdx.x == 0) {
s_sum = sum_val + 1e-6f;
......@@ -258,8 +265,9 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
int grid = m;
int block = k;
unsigned mask = block < 32 ? (((unsigned)1 << block) - 1) : FINAL_MASK;
softmax_kernel_with_eltadd<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len);
qk_buf_, bias_qk, batch_size, head_num, seq_len, mask);
}
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册