未验证 提交 a8148ce0 编写于 作者: N niuliling123 提交者: GitHub

Add mfence for XPU2 KP (#44258)

上级 fea05f1f
......@@ -320,6 +320,7 @@ __device__ __forceinline__ void WriteData(T _global_ptr_* dst,
T* src,
int num) {
if (num > 0) {
mfence_local();
LM2GM(src, dst, num * sizeof(T));
}
}
......@@ -387,6 +388,7 @@ __device__ __inline__ void ReadData(Ty* dst,
break;
}
}
mfence_local();
GM2LM(src + thread_offset + idy * stride_ny, in_temp, sizeof(Tx));
dst[idy] = static_cast<Ty>(in_temp[0]);
}
......@@ -398,6 +400,7 @@ __device__ __inline__ void ReadData(Ty* dst,
break;
}
}
mfence_local();
GM2LM(src + thread_offset + idx * stride_nx, in_temp, sizeof(Tx));
dst[idx] = static_cast<Ty>(in_temp[0]);
}
......@@ -412,6 +415,7 @@ __device__ __inline__ void ReadData(Ty* dst,
}
}
int fix = thread_offset + idx * stride_nx + idy * stride_ny;
mfence_local();
GM2LM(src + fix, in_temp, sizeof(Tx));
dst[idy * NX + idx] = static_cast<Ty>(in_temp[0]);
}
......@@ -484,14 +488,13 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ __inline__ void ReadData(T* dst,
const T _global_ptr_* src,
int num) {
mfence_local();
int thread_offset = core_id() * NX;
__local__ T in_temp[1];
if (IsBoundary) { // core_num() * NX > num
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
GM2LM(src + thread_offset + idx, in_temp, sizeof(T));
dst[idx] = in_temp[0];
GM2LM(src + thread_offset + idx, dst + idx, sizeof(T));
}
}
} else { // core_num() * NX < num
......@@ -505,13 +508,12 @@ __device__ __inline__ void ReadData(T* dst,
int num,
int read_lens) {
int thread_offset = core_id() * read_lens;
__local__ T in_temp[1];
mfence_local();
if (IsBoundary) { // core_num() * read_lens > num
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) {
GM2LM(src + thread_offset + idx, in_temp, sizeof(T));
dst[idx] = in_temp[0];
GM2LM(src + thread_offset + idx, dst + idx, sizeof(T));
}
}
} else { // core_num() * read_lens < num
......@@ -607,8 +609,7 @@ __device__ __inline__ void ReadDataBc(T* dst,
int stride_ny) {
uint32_t thread_offset = block_offset + core_id();
uint32_t index_src = 0;
__local__ T in_temp[1];
mfence_local();
#pragma unroll
for (int ny = 0; ny < NY; ++ny) {
#pragma unroll
......@@ -621,8 +622,7 @@ __device__ __inline__ void ReadDataBc(T* dst,
}
}
index_src = config(index_output);
GM2LM(src + index_src, in_temp, sizeof(T));
dst[nx + ny * NX] = in_temp[0];
GM2LM(src + index_src, dst + nx + ny * NX, sizeof(T));
}
}
}
......@@ -698,8 +698,10 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t index_src = index_cal(thread_offset + block_offset);
mfence_local();
GM2LM(src + index_src, in_temp, sizeof(Tx));
dst[ny] = static_cast<Ty>(func(in_temp[0]));
thread_offset += stride_ny;
}
} else {
......@@ -714,6 +716,7 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t index_src = index_cal(thread_offset + block_offset);
mfence_local();
GM2LM(src + index_src, in_temp, sizeof(Tx));
dst[nx + ny * NX] = static_cast<Ty>(func(in_temp[0]));
thread_offset += stride_ny;
......@@ -749,19 +752,16 @@ __device__ void WriteData(T _global_ptr_* dst,
int num,
int read_lens) {
int thread_offset = core_id() * read_lens;
__local__ T in_temp[1];
mfence_local();
if (IsBoundary) { // core_num() * read_lens > num
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) {
in_temp[0] = src[idx];
mfence();
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
LM2GM(src + idx, dst + idx + thread_offset, sizeof(T));
}
}
} else { // core_num() * read_lens < num
mfence();
LM2GM(src, dst + thread_offset, read_lens * sizeof(T));
}
}
......@@ -769,17 +769,17 @@ __device__ void WriteData(T _global_ptr_* dst,
template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ void WriteData(T _global_ptr_* dst, const T* src, int num) {
int thread_offset = core_id() * NX;
__local__ T in_temp[1];
mfence_local();
if (IsBoundary) { // core_num() * NX > num
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
in_temp[0] = src[idx];
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
LM2GM(src + idx, dst + idx + thread_offset, sizeof(T));
}
}
} else { // core_num() * NX < num
mfence_local();
LM2GM(src, dst + thread_offset, NX * sizeof(T));
}
}
......@@ -831,10 +831,12 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
if (IsBoundary) {
if (left_size_nx > 0) {
in_temp[0] = static_cast<Ty>(src[0]);
mfence_local();
LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
}
} else {
in_temp[0] = static_cast<Ty>(src[0]);
mfence_local();
LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
}
} else if (NX == 1) {
......@@ -847,6 +849,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
}
in_temp[0] = static_cast<Ty>(src[idy]);
mfence_local();
LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(Ty));
}
} else if (NY == 1) { // for NY == 1 and NX != 1
......@@ -859,6 +862,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
}
in_temp[0] = static_cast<Ty>(src[idx]);
mfence_local();
LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(Ty));
}
} else { // for NX != 1 and NY != 1
......@@ -877,6 +881,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
}
}
in_temp[0] = static_cast<Ty>(src[idx + idy * NX]);
mfence_local();
LM2GM(in_temp,
dst + thread_offset + idx * stride_nx + idy * stride_ny,
sizeof(Ty));
......@@ -1029,6 +1034,7 @@ __device__ __inline__ void ReadDataBc1NMn(
for (int i = 0; i < last_col; i++) {
dst[i] = in_temp;
}
mfence_local();
GM2LM(src + index_base + 1, &in_temp, sizeof(T));
for (int i = 0; i < read_lens - last_col; i++) {
dst[last_col + i] = in_temp;
......@@ -1083,6 +1089,7 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
} else {
next_part_index = 0;
}
mfence_local();
GM2LM(src + next_part_index, &in_temp, sizeof(T));
for (int i = 0; i < read_lens - last_col; i++) {
dst[last_col + i] = in_temp;
......@@ -1169,6 +1176,7 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
if (index_src >= index_base && index_src < index_base + cache_size) {
in_temp = src_temp[index_src - index_base];
} else {
mfence_local();
GM2LM(src + index_src, &in_temp, sizeof(T));
}
dst[nx] = in_temp;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册