提交 22b9ab05 编写于 作者: C chengduoZH

refine Sum

上级 d4dabe3e
......@@ -357,6 +357,14 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
}
}
#ifdef __NVCC__
// __shfl_down has been deprecated as of CUDA 9.0
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
#endif
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
const T* x, const T* y, const T* out, const T* dout, int h, int w,
......@@ -381,15 +389,22 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
} while (i < h);
if (dy) {
T val = shm[threadIdx.x];
int warpSize = 32;
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(0, val, offset);
__syncthreads();
shm[tid] = 0;
if (threadIdx.x % 32 == 0) {
shm[threadIdx.x / 32] = val;
}
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = shm[threadIdx.x];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(0, val, offset);
// Sum, could be optimized
if (threadIdx.x == 0) {
for (int k = 1; k < h; ++k) {
shm[0] += shm[k];
}
dy[j] = shm[0];
}
}
......@@ -468,15 +483,22 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
}
if (dy) {
__syncthreads();
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
T val = shm[threadIdx.x];
int warpSize = 32;
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(0, val, offset);
// Sum, could be optimized
if (tid == 0) {
for (int i = 1; i < h; ++i) {
shm[0] += shm[i];
__syncthreads();
shm[tid] = 0;
if (threadIdx.x % 32 == 0) {
shm[threadIdx.x / 32] = val;
}
val = shm[threadIdx.x];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(0, val, offset);
if (threadIdx.x == 0) {
dy[j] = shm[0];
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册