未验证 提交 3222cf16 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #10325 from chengduoZH/fix_shfl_sync

Fix shfl_sync for CUDA8.0
...@@ -228,6 +228,21 @@ extern __thread cudaStream_t default_stream; ...@@ -228,6 +228,21 @@ extern __thread cudaStream_t default_stream;
<< "CUDA error: " << hl_get_device_error_string((size_t)err); \ << "CUDA error: " << hl_get_device_error_string((size_t)err); \
} }
// __shfl has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T
__shfl_sync(unsigned, T val, int src_line, int width) {
return __shfl(val, src_line, width);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
#endif /* __NVCC__ */ #endif /* __NVCC__ */
#endif /* HL_BASE_H_ */ #endif /* HL_BASE_H_ */
...@@ -341,12 +341,15 @@ void hl_lstm_parallel_forward(real *gateValue, ...@@ -341,12 +341,15 @@ void hl_lstm_parallel_forward(real *gateValue,
} }
__device__ __forceinline__ void transpose_32x32(real a[], const int idx) { __device__ __forceinline__ void transpose_32x32(real a[], const int idx) {
int addr = idx % 32; const int warp_size = 32;
int addr = idx % warp_size;
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, addr < warp_size);
#pragma unroll #pragma unroll
for (int k = 1; k < 32; k++) { for (int k = 1; k < 32; k++) {
// rSrc[k] = __shfl_sync(rSrc[k], (threadIdx.x + k) % 32, 32); // rSrc[k] = __shfl_sync(rSrc[k], (threadIdx.x + k) % 32, 32);
addr = __shfl_sync(addr, (idx + 1) % 32, 32); addr = __shfl_sync(mask, addr, (idx + 1) % 32, 32);
a[k] = __shfl_sync(a[k], addr, 32); a[k] = __shfl_sync(mask, a[k], addr, 32);
} }
#pragma unroll #pragma unroll
...@@ -360,10 +363,11 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) { ...@@ -360,10 +363,11 @@ __device__ __forceinline__ void transpose_32x32(real a[], const int idx) {
} }
addr = (32 - idx) % 32; addr = (32 - idx) % 32;
CREATE_SHFL_MASK(mask, idx % 32 < warp_size);
#pragma unroll #pragma unroll
for (int k = 0; k < 32; k++) { for (int k = 0; k < 32; k++) {
a[k] = __shfl_sync(a[k], addr, 32); a[k] = __shfl_sync(mask, a[k], addr, 32);
addr = __shfl_sync(addr, (idx + 31) % 32, 32); addr = __shfl_sync(mask, addr, (idx + 31) % 32, 32);
} }
} }
......
...@@ -244,13 +244,16 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK, ...@@ -244,13 +244,16 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
if (--beamSize == 0) break; if (--beamSize == 0) break;
__syncthreads(); __syncthreads();
unsigned mask = 0u;
// CREATE_SHFL_MASK(mask, tid < len);
if (tid == maxId[0]) { if (tid == maxId[0]) {
if (beam < maxLength) { if (beam < maxLength) {
shTopK[tid] = topK[beam]; shTopK[tid] = topK[beam];
} }
} }
if (maxId[0] / 32 == warp) { if (maxId[0] / 32 == warp) {
if (__shfl_sync(beam, (maxId[0]) % 32, 32) == maxLength) break; if (__shfl_sync(mask, beam, (maxId[0]) % 32, 32) == maxLength) break;
} }
} }
} }
......
...@@ -74,10 +74,6 @@ __forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { ...@@ -74,10 +74,6 @@ __forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
} }
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u; #define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else #else
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned mask, T val, int delta) {
return __shfl_down(mask, val, delta);
}
#define FULL_WARP_MASK 0xFFFFFFFF #define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \ #define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate)) mask = __ballot_sync(FULL_WARP_MASK, (predicate))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册