提交 d36af62c 编写于 作者: C chengduoZH

wrap_shfl_x_sync

上级 0285a2b9
......@@ -224,7 +224,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout,
for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32.
val += platform::__shfl_down_sync(mask, val, offset);
val += platform::CudaShuffleDownSync(mask, val, offset);
}
__syncthreads();
......@@ -284,7 +284,7 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32.
val += platform::__shfl_down_sync(mask, val, offset);
val += platform::CudaShuffleDownSync(mask, val, offset);
}
__syncthreads();
......
......@@ -241,7 +241,8 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
CREATE_SHFL_MASK(mask, true);
if (maxid[0] / 32 == warp) {
if (platform::__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength)
if (platform::CudaShuffleSync(mask, *beam, (maxid[0]) % 32, 32) ==
MaxLength)
break;
}
}
......
......@@ -18,34 +18,33 @@ limitations under the License. */
namespace paddle {
namespace platform {
// __shfl_down and __shfl have been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
template <typename T>
__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line,
int width) {
return __shfl(val, src_line, width);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned mask, T val, int delta) {
return __shfl_down_sync(mask, val, delta);
__forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
int delta, int width = 32) {
#if CUDA_VERSION < 9000
return __shfl_down(val, delta, width);
#else
return __shfl_down_sync(mask, val, delta, width);
#endif
}
template <typename T>
__forceinline__ __device__ T __shfl_sync(unsigned mask, T val, int src_line,
int width) {
__forceinline__ __device__ T CudaShuffleSync(unsigned mask, T val, int src_line,
int width = 32) {
#if CUDA_VERSION < 9000
return __shfl(val, src_line, width);
#else
return __shfl_sync(mask, val, src_line, width);
}
#endif
}
template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
......@@ -61,7 +60,7 @@ __device__ T reduceSum(T val, int tid, int len) {
CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::__shfl_down_sync(mask, val, offset);
val += platform::CudaShuffleDownSync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0;
......@@ -75,7 +74,7 @@ __device__ T reduceSum(T val, int tid, int len) {
if (tid < warpSize) {
val = shm[tid];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::__shfl_down_sync(mask, val, offset);
val += platform::CudaShuffleDownSync(mask, val, offset);
}
return val;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册