From a8148ce0385ef4b3fdcf97681bdb8cc98b8432e5 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Tue, 19 Jul 2022 14:31:15 +0800 Subject: [PATCH] Add mfence for XPU2 KP (#44258) --- .../primitive/datamover_primitives_xpu2.h | 44 +++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) mode change 100755 => 100644 paddle/phi/kernels/primitive/datamover_primitives_xpu2.h diff --git a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h old mode 100755 new mode 100644 index 68eb11bd6d..2915463f5f --- a/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h +++ b/paddle/phi/kernels/primitive/datamover_primitives_xpu2.h @@ -320,6 +320,7 @@ __device__ __forceinline__ void WriteData(T _global_ptr_* dst, T* src, int num) { if (num > 0) { + mfence_local(); LM2GM(src, dst, num * sizeof(T)); } } @@ -387,6 +388,7 @@ __device__ __inline__ void ReadData(Ty* dst, break; } } + mfence_local(); GM2LM(src + thread_offset + idy * stride_ny, in_temp, sizeof(Tx)); dst[idy] = static_cast(in_temp[0]); } @@ -398,6 +400,7 @@ __device__ __inline__ void ReadData(Ty* dst, break; } } + mfence_local(); GM2LM(src + thread_offset + idx * stride_nx, in_temp, sizeof(Tx)); dst[idx] = static_cast(in_temp[0]); } @@ -412,6 +415,7 @@ __device__ __inline__ void ReadData(Ty* dst, } } int fix = thread_offset + idx * stride_nx + idy * stride_ny; + mfence_local(); GM2LM(src + fix, in_temp, sizeof(Tx)); dst[idy * NX + idx] = static_cast(in_temp[0]); } @@ -484,14 +488,13 @@ template __device__ __inline__ void ReadData(T* dst, const T _global_ptr_* src, int num) { + mfence_local(); int thread_offset = core_id() * NX; - __local__ T in_temp[1]; if (IsBoundary) { // core_num() * NX > num #pragma unroll for (int idx = 0; idx < NX; ++idx) { if (idx + thread_offset < num) { - GM2LM(src + thread_offset + idx, in_temp, sizeof(T)); - dst[idx] = in_temp[0]; + GM2LM(src + thread_offset + idx, dst + idx, sizeof(T)); } } } else { // core_num() * NX < num @@ -505,13 +508,12 @@ __device__ __inline__ void ReadData(T* dst, int num, int read_lens) { int thread_offset = core_id() * read_lens; - __local__ T in_temp[1]; + mfence_local(); if (IsBoundary) { // core_num() * read_lens > num #pragma unroll for (int idx = 0; idx < read_lens; ++idx) { if (idx + thread_offset < num) { - GM2LM(src + thread_offset + idx, in_temp, sizeof(T)); - dst[idx] = in_temp[0]; + GM2LM(src + thread_offset + idx, dst + idx, sizeof(T)); } } } else { // core_num() * read_lens < num @@ -607,8 +609,7 @@ __device__ __inline__ void ReadDataBc(T* dst, int stride_ny) { uint32_t thread_offset = block_offset + core_id(); uint32_t index_src = 0; - __local__ T in_temp[1]; - + mfence_local(); #pragma unroll for (int ny = 0; ny < NY; ++ny) { #pragma unroll @@ -621,8 +622,7 @@ __device__ __inline__ void ReadDataBc(T* dst, } } index_src = config(index_output); - GM2LM(src + index_src, in_temp, sizeof(T)); - dst[nx + ny * NX] = in_temp[0]; + GM2LM(src + index_src, dst + nx + ny * NX, sizeof(T)); } } } @@ -698,8 +698,10 @@ __device__ __forceinline__ void ReadDataReduce( } } uint32_t index_src = index_cal(thread_offset + block_offset); + mfence_local(); GM2LM(src + index_src, in_temp, sizeof(Tx)); dst[ny] = static_cast(func(in_temp[0])); + thread_offset += stride_ny; } } else { @@ -714,6 +716,7 @@ __device__ __forceinline__ void ReadDataReduce( } } uint32_t index_src = index_cal(thread_offset + block_offset); + mfence_local(); GM2LM(src + index_src, in_temp, sizeof(Tx)); dst[nx + ny * NX] = static_cast(func(in_temp[0])); thread_offset += stride_ny; @@ -749,19 +752,16 @@ __device__ void WriteData(T _global_ptr_* dst, int num, int read_lens) { int thread_offset = core_id() * read_lens; - __local__ T in_temp[1]; + mfence_local(); if (IsBoundary) { // core_num() * read_lens > num #pragma unroll for (int idx = 0; idx < read_lens; ++idx) { if (idx + thread_offset < num) { - in_temp[0] = src[idx]; - mfence(); - LM2GM(in_temp, dst + idx + thread_offset, sizeof(T)); + LM2GM(src + idx, dst + idx + thread_offset, sizeof(T)); } } } else { // core_num() * read_lens < num - mfence(); LM2GM(src, dst + thread_offset, read_lens * sizeof(T)); } } @@ -769,17 +769,17 @@ __device__ void WriteData(T _global_ptr_* dst, template __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) { int thread_offset = core_id() * NX; - __local__ T in_temp[1]; + mfence_local(); if (IsBoundary) { // core_num() * NX > num #pragma unroll for (int idx = 0; idx < NX; ++idx) { if (idx + thread_offset < num) { - in_temp[0] = src[idx]; - LM2GM(in_temp, dst + idx + thread_offset, sizeof(T)); + LM2GM(src + idx, dst + idx + thread_offset, sizeof(T)); } } } else { // core_num() * NX < num + mfence_local(); LM2GM(src, dst + thread_offset, NX * sizeof(T)); } } @@ -831,10 +831,12 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst, if (IsBoundary) { if (left_size_nx > 0) { in_temp[0] = static_cast(src[0]); + mfence_local(); LM2GM(in_temp, dst + thread_offset, sizeof(Ty)); } } else { in_temp[0] = static_cast(src[0]); + mfence_local(); LM2GM(in_temp, dst + thread_offset, sizeof(Ty)); } } else if (NX == 1) { @@ -847,6 +849,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst, } in_temp[0] = static_cast(src[idy]); + mfence_local(); LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(Ty)); } } else if (NY == 1) { // for NY == 1 and NX != 1 @@ -859,6 +862,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst, } in_temp[0] = static_cast(src[idx]); + mfence_local(); LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(Ty)); } } else { // for NX != 1 and NY != 1 @@ -877,6 +881,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst, } } in_temp[0] = static_cast(src[idx + idy * NX]); + mfence_local(); LM2GM(in_temp, dst + thread_offset + idx * stride_nx + idy * stride_ny, sizeof(Ty)); @@ -1029,6 +1034,7 @@ __device__ __inline__ void ReadDataBc1NMn( for (int i = 0; i < last_col; i++) { dst[i] = in_temp; } + mfence_local(); GM2LM(src + index_base + 1, &in_temp, sizeof(T)); for (int i = 0; i < read_lens - last_col; i++) { dst[last_col + i] = in_temp; @@ -1083,6 +1089,7 @@ __device__ __inline__ void ReadDataBc1N1Mnk( } else { next_part_index = 0; } + mfence_local(); GM2LM(src + next_part_index, &in_temp, sizeof(T)); for (int i = 0; i < read_lens - last_col; i++) { dst[last_col + i] = in_temp; @@ -1169,6 +1176,7 @@ __device__ __inline__ void ReadDataBcCanNotCmp( if (index_src >= index_base && index_src < index_base + cache_size) { in_temp = src_temp[index_src - index_base]; } else { + mfence_local(); GM2LM(src + index_src, &in_temp, sizeof(T)); } dst[nx] = in_temp; -- GitLab