未验证 提交 3c21f26b 编写于 作者: W wangguanzhong 提交者: GitHub

Stablize depthwise conv (#35161)

* stablize depthwise conv

* clean commend
上级 7ca28bb6
......@@ -31,18 +31,43 @@ namespace operators {
namespace math {
template <typename T>
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
static __forceinline__ __device__ T WarpReduceSum(T val, int warp_size) {
typedef cub::WarpReduce<T> WarpReduce;
typename WarpReduce::TempStorage temp_storage;
val = WarpReduce(temp_storage).Sum(val, warp_size);
return val;
}
#ifdef __HIPCC__
int block_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
value = WarpReduce(temp_storage).Sum(value, block_size);
#else
value = WarpReduce(temp_storage).Sum(value);
#endif
template <typename T>
__forceinline__ __device__ T BlockReduceSum(T val) {
static __shared__ T shared[32];
int thread_id = threadIdx.x + threadIdx.y * blockDim.x +
threadIdx.z * blockDim.x * blockDim.y;
int warp_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
int lane = thread_id % warp_size;
int wid = thread_id / warp_size;
val = WarpReduceSum(val, warp_size); // Each warp performs partial reduction
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
__syncthreads(); // Wait for all partial reductions
// read from shared memory only if that warp existed
int block_size = blockDim.x * blockDim.y * blockDim.z;
if (thread_id < (block_size - 1) / warp_size + 1) {
val = shared[lane];
} else {
val = static_cast<T>(0);
}
if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
if (wid == 0) {
val = WarpReduceSum(val, warp_size); // Final reduce within first warp
}
__syncthreads();
if (thread_id != 0) {
val = static_cast<T>(0);
}
return val;
}
#define ARG_DEFINE_KernelDepthwiseConv \
......@@ -665,7 +690,9 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
}
}
}
CudaAtomicAddWithWarp(&filter_grad_data[gbid], s);
T val = BlockReduceSum(s);
platform::CudaAtomicAdd(&filter_grad_data[gbid], val);
}
template <typename T, bool fuse_relu_before_conv>
......@@ -892,6 +919,7 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
int blocks;
dim3 threads;
dim3 grid;
if (data_layout != DataLayout::kNHWC) {
if (output_width > 1024 && output_width <= 2048)
thread = (output_width - 1) / 2 + 1;
......@@ -1034,6 +1062,7 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
int blocks;
dim3 threads;
dim3 grid;
if (data_layout != DataLayout::kNHWC) {
if (input_width > 1024 && input_width <= 2048) {
thread = (input_width - 1) / 2 + 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册