提交 22b9ab05 编写于 作者: C chengduoZH

refine Sum

上级 d4dabe3e
...@@ -357,6 +357,14 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, ...@@ -357,6 +357,14 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
} }
} }
#ifdef __NVCC__ #ifdef __NVCC__
// __shfl_down has been deprecated as of CUDA 9.0
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
#endif
template <typename T, typename DX_OP, typename DY_OP> template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel( static __global__ void ElemwiseGradBroadcast1CUDAKernel(
const T* x, const T* y, const T* out, const T* dout, int h, int w, const T* x, const T* y, const T* out, const T* dout, int h, int w,
...@@ -381,15 +389,22 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( ...@@ -381,15 +389,22 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
} while (i < h); } while (i < h);
if (dy) { if (dy) {
T val = shm[threadIdx.x];
int warpSize = 32;
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(0, val, offset);
__syncthreads(); __syncthreads();
shm[tid] = 0;
if (threadIdx.x % 32 == 0) {
shm[threadIdx.x / 32] = val;
}
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; val = shm[threadIdx.x];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(0, val, offset);
// Sum, could be optimized
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
for (int k = 1; k < h; ++k) {
shm[0] += shm[k];
}
dy[j] = shm[0]; dy[j] = shm[0];
} }
} }
...@@ -468,15 +483,22 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( ...@@ -468,15 +483,22 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
} }
if (dy) { if (dy) {
T val = shm[threadIdx.x];
int warpSize = 32;
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += __shfl_down_sync(0, val, offset);
__syncthreads(); __syncthreads();
int h = pre * post; shm[tid] = 0;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; if (threadIdx.x % 32 == 0) {
shm[threadIdx.x / 32] = val;
}
// Sum, could be optimized val = shm[threadIdx.x];
if (tid == 0) { for (int offset = warpSize / 2; offset > 0; offset /= 2)
for (int i = 1; i < h; ++i) { val += __shfl_down_sync(0, val, offset);
shm[0] += shm[i];
} if (threadIdx.x == 0) {
dy[j] = shm[0]; dy[j] = shm[0];
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册