未验证 提交 a8148ce0 编写于 作者: N niuliling123 提交者: GitHub

Add mfence for XPU2 KP (#44258)

上级 fea05f1f
...@@ -320,6 +320,7 @@ __device__ __forceinline__ void WriteData(T _global_ptr_* dst, ...@@ -320,6 +320,7 @@ __device__ __forceinline__ void WriteData(T _global_ptr_* dst,
T* src, T* src,
int num) { int num) {
if (num > 0) { if (num > 0) {
mfence_local();
LM2GM(src, dst, num * sizeof(T)); LM2GM(src, dst, num * sizeof(T));
} }
} }
...@@ -387,6 +388,7 @@ __device__ __inline__ void ReadData(Ty* dst, ...@@ -387,6 +388,7 @@ __device__ __inline__ void ReadData(Ty* dst,
break; break;
} }
} }
mfence_local();
GM2LM(src + thread_offset + idy * stride_ny, in_temp, sizeof(Tx)); GM2LM(src + thread_offset + idy * stride_ny, in_temp, sizeof(Tx));
dst[idy] = static_cast<Ty>(in_temp[0]); dst[idy] = static_cast<Ty>(in_temp[0]);
} }
...@@ -398,6 +400,7 @@ __device__ __inline__ void ReadData(Ty* dst, ...@@ -398,6 +400,7 @@ __device__ __inline__ void ReadData(Ty* dst,
break; break;
} }
} }
mfence_local();
GM2LM(src + thread_offset + idx * stride_nx, in_temp, sizeof(Tx)); GM2LM(src + thread_offset + idx * stride_nx, in_temp, sizeof(Tx));
dst[idx] = static_cast<Ty>(in_temp[0]); dst[idx] = static_cast<Ty>(in_temp[0]);
} }
...@@ -412,6 +415,7 @@ __device__ __inline__ void ReadData(Ty* dst, ...@@ -412,6 +415,7 @@ __device__ __inline__ void ReadData(Ty* dst,
} }
} }
int fix = thread_offset + idx * stride_nx + idy * stride_ny; int fix = thread_offset + idx * stride_nx + idy * stride_ny;
mfence_local();
GM2LM(src + fix, in_temp, sizeof(Tx)); GM2LM(src + fix, in_temp, sizeof(Tx));
dst[idy * NX + idx] = static_cast<Ty>(in_temp[0]); dst[idy * NX + idx] = static_cast<Ty>(in_temp[0]);
} }
...@@ -484,14 +488,13 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary> ...@@ -484,14 +488,13 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ __inline__ void ReadData(T* dst, __device__ __inline__ void ReadData(T* dst,
const T _global_ptr_* src, const T _global_ptr_* src,
int num) { int num) {
mfence_local();
int thread_offset = core_id() * NX; int thread_offset = core_id() * NX;
__local__ T in_temp[1];
if (IsBoundary) { // core_num() * NX > num if (IsBoundary) { // core_num() * NX > num
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) { if (idx + thread_offset < num) {
GM2LM(src + thread_offset + idx, in_temp, sizeof(T)); GM2LM(src + thread_offset + idx, dst + idx, sizeof(T));
dst[idx] = in_temp[0];
} }
} }
} else { // core_num() * NX < num } else { // core_num() * NX < num
...@@ -505,13 +508,12 @@ __device__ __inline__ void ReadData(T* dst, ...@@ -505,13 +508,12 @@ __device__ __inline__ void ReadData(T* dst,
int num, int num,
int read_lens) { int read_lens) {
int thread_offset = core_id() * read_lens; int thread_offset = core_id() * read_lens;
__local__ T in_temp[1]; mfence_local();
if (IsBoundary) { // core_num() * read_lens > num if (IsBoundary) { // core_num() * read_lens > num
#pragma unroll #pragma unroll
for (int idx = 0; idx < read_lens; ++idx) { for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) { if (idx + thread_offset < num) {
GM2LM(src + thread_offset + idx, in_temp, sizeof(T)); GM2LM(src + thread_offset + idx, dst + idx, sizeof(T));
dst[idx] = in_temp[0];
} }
} }
} else { // core_num() * read_lens < num } else { // core_num() * read_lens < num
...@@ -607,8 +609,7 @@ __device__ __inline__ void ReadDataBc(T* dst, ...@@ -607,8 +609,7 @@ __device__ __inline__ void ReadDataBc(T* dst,
int stride_ny) { int stride_ny) {
uint32_t thread_offset = block_offset + core_id(); uint32_t thread_offset = block_offset + core_id();
uint32_t index_src = 0; uint32_t index_src = 0;
__local__ T in_temp[1]; mfence_local();
#pragma unroll #pragma unroll
for (int ny = 0; ny < NY; ++ny) { for (int ny = 0; ny < NY; ++ny) {
#pragma unroll #pragma unroll
...@@ -621,8 +622,7 @@ __device__ __inline__ void ReadDataBc(T* dst, ...@@ -621,8 +622,7 @@ __device__ __inline__ void ReadDataBc(T* dst,
} }
} }
index_src = config(index_output); index_src = config(index_output);
GM2LM(src + index_src, in_temp, sizeof(T)); GM2LM(src + index_src, dst + nx + ny * NX, sizeof(T));
dst[nx + ny * NX] = in_temp[0];
} }
} }
} }
...@@ -698,8 +698,10 @@ __device__ __forceinline__ void ReadDataReduce( ...@@ -698,8 +698,10 @@ __device__ __forceinline__ void ReadDataReduce(
} }
} }
uint32_t index_src = index_cal(thread_offset + block_offset); uint32_t index_src = index_cal(thread_offset + block_offset);
mfence_local();
GM2LM(src + index_src, in_temp, sizeof(Tx)); GM2LM(src + index_src, in_temp, sizeof(Tx));
dst[ny] = static_cast<Ty>(func(in_temp[0])); dst[ny] = static_cast<Ty>(func(in_temp[0]));
thread_offset += stride_ny; thread_offset += stride_ny;
} }
} else { } else {
...@@ -714,6 +716,7 @@ __device__ __forceinline__ void ReadDataReduce( ...@@ -714,6 +716,7 @@ __device__ __forceinline__ void ReadDataReduce(
} }
} }
uint32_t index_src = index_cal(thread_offset + block_offset); uint32_t index_src = index_cal(thread_offset + block_offset);
mfence_local();
GM2LM(src + index_src, in_temp, sizeof(Tx)); GM2LM(src + index_src, in_temp, sizeof(Tx));
dst[nx + ny * NX] = static_cast<Ty>(func(in_temp[0])); dst[nx + ny * NX] = static_cast<Ty>(func(in_temp[0]));
thread_offset += stride_ny; thread_offset += stride_ny;
...@@ -749,19 +752,16 @@ __device__ void WriteData(T _global_ptr_* dst, ...@@ -749,19 +752,16 @@ __device__ void WriteData(T _global_ptr_* dst,
int num, int num,
int read_lens) { int read_lens) {
int thread_offset = core_id() * read_lens; int thread_offset = core_id() * read_lens;
__local__ T in_temp[1]; mfence_local();
if (IsBoundary) { // core_num() * read_lens > num if (IsBoundary) { // core_num() * read_lens > num
#pragma unroll #pragma unroll
for (int idx = 0; idx < read_lens; ++idx) { for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) { if (idx + thread_offset < num) {
in_temp[0] = src[idx]; LM2GM(src + idx, dst + idx + thread_offset, sizeof(T));
mfence();
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
} }
} }
} else { // core_num() * read_lens < num } else { // core_num() * read_lens < num
mfence();
LM2GM(src, dst + thread_offset, read_lens * sizeof(T)); LM2GM(src, dst + thread_offset, read_lens * sizeof(T));
} }
} }
...@@ -769,17 +769,17 @@ __device__ void WriteData(T _global_ptr_* dst, ...@@ -769,17 +769,17 @@ __device__ void WriteData(T _global_ptr_* dst,
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary> template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ void WriteData(T _global_ptr_* dst, const T* src, int num) { __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) {
int thread_offset = core_id() * NX; int thread_offset = core_id() * NX;
__local__ T in_temp[1]; mfence_local();
if (IsBoundary) { // core_num() * NX > num if (IsBoundary) { // core_num() * NX > num
#pragma unroll #pragma unroll
for (int idx = 0; idx < NX; ++idx) { for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) { if (idx + thread_offset < num) {
in_temp[0] = src[idx]; LM2GM(src + idx, dst + idx + thread_offset, sizeof(T));
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
} }
} }
} else { // core_num() * NX < num } else { // core_num() * NX < num
mfence_local();
LM2GM(src, dst + thread_offset, NX * sizeof(T)); LM2GM(src, dst + thread_offset, NX * sizeof(T));
} }
} }
...@@ -831,10 +831,12 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst, ...@@ -831,10 +831,12 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
if (IsBoundary) { if (IsBoundary) {
if (left_size_nx > 0) { if (left_size_nx > 0) {
in_temp[0] = static_cast<Ty>(src[0]); in_temp[0] = static_cast<Ty>(src[0]);
mfence_local();
LM2GM(in_temp, dst + thread_offset, sizeof(Ty)); LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
} }
} else { } else {
in_temp[0] = static_cast<Ty>(src[0]); in_temp[0] = static_cast<Ty>(src[0]);
mfence_local();
LM2GM(in_temp, dst + thread_offset, sizeof(Ty)); LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
} }
} else if (NX == 1) { } else if (NX == 1) {
...@@ -847,6 +849,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst, ...@@ -847,6 +849,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
} }
in_temp[0] = static_cast<Ty>(src[idy]); in_temp[0] = static_cast<Ty>(src[idy]);
mfence_local();
LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(Ty)); LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(Ty));
} }
} else if (NY == 1) { // for NY == 1 and NX != 1 } else if (NY == 1) { // for NY == 1 and NX != 1
...@@ -859,6 +862,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst, ...@@ -859,6 +862,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
} }
in_temp[0] = static_cast<Ty>(src[idx]); in_temp[0] = static_cast<Ty>(src[idx]);
mfence_local();
LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(Ty)); LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(Ty));
} }
} else { // for NX != 1 and NY != 1 } else { // for NX != 1 and NY != 1
...@@ -877,6 +881,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst, ...@@ -877,6 +881,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
} }
} }
in_temp[0] = static_cast<Ty>(src[idx + idy * NX]); in_temp[0] = static_cast<Ty>(src[idx + idy * NX]);
mfence_local();
LM2GM(in_temp, LM2GM(in_temp,
dst + thread_offset + idx * stride_nx + idy * stride_ny, dst + thread_offset + idx * stride_nx + idy * stride_ny,
sizeof(Ty)); sizeof(Ty));
...@@ -1029,6 +1034,7 @@ __device__ __inline__ void ReadDataBc1NMn( ...@@ -1029,6 +1034,7 @@ __device__ __inline__ void ReadDataBc1NMn(
for (int i = 0; i < last_col; i++) { for (int i = 0; i < last_col; i++) {
dst[i] = in_temp; dst[i] = in_temp;
} }
mfence_local();
GM2LM(src + index_base + 1, &in_temp, sizeof(T)); GM2LM(src + index_base + 1, &in_temp, sizeof(T));
for (int i = 0; i < read_lens - last_col; i++) { for (int i = 0; i < read_lens - last_col; i++) {
dst[last_col + i] = in_temp; dst[last_col + i] = in_temp;
...@@ -1083,6 +1089,7 @@ __device__ __inline__ void ReadDataBc1N1Mnk( ...@@ -1083,6 +1089,7 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
} else { } else {
next_part_index = 0; next_part_index = 0;
} }
mfence_local();
GM2LM(src + next_part_index, &in_temp, sizeof(T)); GM2LM(src + next_part_index, &in_temp, sizeof(T));
for (int i = 0; i < read_lens - last_col; i++) { for (int i = 0; i < read_lens - last_col; i++) {
dst[last_col + i] = in_temp; dst[last_col + i] = in_temp;
...@@ -1169,6 +1176,7 @@ __device__ __inline__ void ReadDataBcCanNotCmp( ...@@ -1169,6 +1176,7 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
if (index_src >= index_base && index_src < index_base + cache_size) { if (index_src >= index_base && index_src < index_base + cache_size) {
in_temp = src_temp[index_src - index_base]; in_temp = src_temp[index_src - index_base];
} else { } else {
mfence_local();
GM2LM(src + index_src, &in_temp, sizeof(T)); GM2LM(src + index_src, &in_temp, sizeof(T));
} }
dst[nx] = in_temp; dst[nx] = in_temp;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册