提交 b8938b44 编写于 作者: C chengduoZH

refine Sum

上级 a8288392
......@@ -379,7 +379,8 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
} while (i < h);
if (dy) {
val = platform::ReduceSum(val, tid);
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
......@@ -454,7 +455,9 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
}
if (dy) {
val = platform::ReduceSum(val, threadIdx.x);
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
......
......@@ -68,19 +68,22 @@ template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
#define CREATE_SHFL_MASK(mask, predicate) unsigned mask = 0u;
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
unsigned mask = __ballot_sync(FULL_WARP_MASK, (predicate))
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template <typename T>
__device__ T ReduceSum(T val, int tid) {
__device__ T reduceSum(T val, int tid, int len) {
__shared__ T shm[32];
const int warpSize = 32;
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(-1U, val, offset);
val += __shfl_down_sync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册