提交 b8938b44 编写于 作者: C chengduoZH

refine Sum

上级 a8288392
...@@ -379,7 +379,8 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( ...@@ -379,7 +379,8 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
} while (i < h); } while (i < h);
if (dy) { if (dy) {
val = platform::ReduceSum(val, tid); h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
dy[j] = val; dy[j] = val;
} }
...@@ -454,7 +455,9 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( ...@@ -454,7 +455,9 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
} }
if (dy) { if (dy) {
val = platform::ReduceSum(val, threadIdx.x); int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
dy[j] = val; dy[j] = val;
} }
......
...@@ -68,19 +68,22 @@ template <typename T> ...@@ -68,19 +68,22 @@ template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) { __forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta); return __shfl_down(val, delta);
} }
#define CREATE_SHFL_MASK(mask, predicate) unsigned mask = 0u; #define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else #else
#define FULL_WARP_MASK 0xFFFFFFFF #define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \ #define CREATE_SHFL_MASK(mask, predicate) \
unsigned mask = __ballot_sync(FULL_WARP_MASK, (predicate)) mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif #endif
template <typename T> template <typename T>
__device__ T ReduceSum(T val, int tid) { __device__ T reduceSum(T val, int tid, int len) {
__shared__ T shm[32]; __shared__ T shm[32];
const int warpSize = 32; const int warpSize = 32;
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2) for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(-1U, val, offset); val += __shfl_down_sync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0; if (tid < warpSize) shm[tid] = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册