未验证 提交 77dbb318 编写于 作者: B Bo Zhang 提交者: GitHub

fix reduce_any kernel data race on sharedMem (#47233)

* fix reduce_any kernel data race on sharedMem

* use bit operation instead of div & mod

* unbranch

* modified according to PR comments
上级 cb746665
...@@ -91,10 +91,13 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { ...@@ -91,10 +91,13 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
__shared__ T shared[2 * kWarpSize]; __shared__ T shared[2 * kWarpSize];
int block_dim_x = blockDim.x; int block_dim_x = blockDim.x;
if (blockDim.x > kWarpSize) { if (blockDim.x > kWarpSize) {
block_dim_x = blockDim.x / kWarpSize; // Bit operation can be used when kWarpSize is 32 or 64 now
int lane = threadIdx.x % kWarpSize; constexpr int rshift_val =
(kWarpSize != 32) ? ((kWarpSize == 64) ? 6 : 5) : 5;
block_dim_x = blockDim.x >> rshift_val;
int lane = threadIdx.x & (kWarpSize - 1);
int tid = threadIdx.y * blockDim.x + threadIdx.x; int tid = threadIdx.y * blockDim.x + threadIdx.x;
int wid = tid / kWarpSize; int wid = tid >> rshift_val;
int bid = threadIdx.y; int bid = threadIdx.y;
val = WarpReduce(val, reducer); val = WarpReduce(val, reducer);
if (lane == 0) { if (lane == 0) {
...@@ -110,6 +113,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { ...@@ -110,6 +113,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride); T temp = paddle::platform::CudaShuffleDownSync(mask, val, stride);
val = reducer(val, temp); val = reducer(val, temp);
} }
__syncthreads();
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
shared[threadIdx.y] = val; shared[threadIdx.y] = val;
} }
...@@ -385,8 +389,8 @@ __device__ __forceinline__ void CycleBinary(OutT* out, ...@@ -385,8 +389,8 @@ __device__ __forceinline__ void CycleBinary(OutT* out,
/** /**
* @brief The Reduce provides collective methods for computing a parallel * @brief The Reduce provides collective methods for computing a parallel
* reduction of items partitioned across a CUDA block and intra thread. When * reduction of items partitioned across a CUDA block and intra thread. When
* ReduceMode == kLocalMode, thread reduce along nx. When ReduceMode == * ReduceMode == kLocalMode, use shared memory to reduce between threads.When
* kGlobalMode, use shared memory to reduce between threads. * ReduceMode == kGlobalMode, thread reduce along nx.
* *
* @template paraments * @template paraments
* T: The type of data. * T: The type of data.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册