未验证 提交 3c21f26b 编写于 作者: W wangguanzhong 提交者: GitHub

Stablize depthwise conv (#35161)

* stablize depthwise conv

* clean commend
上级 7ca28bb6
...@@ -31,18 +31,43 @@ namespace operators { ...@@ -31,18 +31,43 @@ namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { static __forceinline__ __device__ T WarpReduceSum(T val, int warp_size) {
typedef cub::WarpReduce<T> WarpReduce; typedef cub::WarpReduce<T> WarpReduce;
typename WarpReduce::TempStorage temp_storage; typename WarpReduce::TempStorage temp_storage;
val = WarpReduce(temp_storage).Sum(val, warp_size);
return val;
}
#ifdef __HIPCC__ template <typename T>
int block_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize); __forceinline__ __device__ T BlockReduceSum(T val) {
value = WarpReduce(temp_storage).Sum(value, block_size); static __shared__ T shared[32];
#else int thread_id = threadIdx.x + threadIdx.y * blockDim.x +
value = WarpReduce(temp_storage).Sum(value); threadIdx.z * blockDim.x * blockDim.y;
#endif int warp_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize);
int lane = thread_id % warp_size;
int wid = thread_id / warp_size;
val = WarpReduceSum(val, warp_size); // Each warp performs partial reduction
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
__syncthreads(); // Wait for all partial reductions
// read from shared memory only if that warp existed
int block_size = blockDim.x * blockDim.y * blockDim.z;
if (thread_id < (block_size - 1) / warp_size + 1) {
val = shared[lane];
} else {
val = static_cast<T>(0);
}
if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value); if (wid == 0) {
val = WarpReduceSum(val, warp_size); // Final reduce within first warp
}
__syncthreads();
if (thread_id != 0) {
val = static_cast<T>(0);
}
return val;
} }
#define ARG_DEFINE_KernelDepthwiseConv \ #define ARG_DEFINE_KernelDepthwiseConv \
...@@ -665,7 +690,9 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( ...@@ -665,7 +690,9 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
} }
} }
} }
CudaAtomicAddWithWarp(&filter_grad_data[gbid], s);
T val = BlockReduceSum(s);
platform::CudaAtomicAdd(&filter_grad_data[gbid], val);
} }
template <typename T, bool fuse_relu_before_conv> template <typename T, bool fuse_relu_before_conv>
...@@ -892,6 +919,7 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T, ...@@ -892,6 +919,7 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
int blocks; int blocks;
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (data_layout != DataLayout::kNHWC) { if (data_layout != DataLayout::kNHWC) {
if (output_width > 1024 && output_width <= 2048) if (output_width > 1024 && output_width <= 2048)
thread = (output_width - 1) / 2 + 1; thread = (output_width - 1) / 2 + 1;
...@@ -1034,6 +1062,7 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T, ...@@ -1034,6 +1062,7 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
int blocks; int blocks;
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (data_layout != DataLayout::kNHWC) { if (data_layout != DataLayout::kNHWC) {
if (input_width > 1024 && input_width <= 2048) { if (input_width > 1024 && input_width <= 2048) {
thread = (input_width - 1) / 2 + 1; thread = (input_width - 1) / 2 + 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册