diff --git a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl index a86689f9809d08be790a9d72794bca19e9a5b152..b9aa47ad340de290d0d943f0b916d2caa0463646 100644 --- a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl +++ b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl @@ -57,10 +57,13 @@ struct Global2SharedMem { T* smem; int stride; int start_h, start_w, bound_h, bound_w, ring_smem_h, ring_src_h; + // just used in backward src data + int stride_h, stride_w; const T* g_ptr; - __device__ __forceinline__ - Global2SharedMem(T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w); + __device__ __forceinline__ Global2SharedMem( + T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w, int stride_h_, + int stride_w_); __device__ __forceinline__ void first_copy(); __device__ __forceinline__ void copy(); @@ -77,7 +80,7 @@ struct Global2SharedMem { template < typename ldg_dtype, DepthwiseConv2dDirection kDirection, typename ThreadConfig_, - typename OutTileConfig_, typename FilterTileConfig_> + typename OutTileConfig_, typename FilterTileConfig_, int stride_w, int stride_h> struct ConvTrait { using ThreadConfig = ThreadConfig_; using OutTileConfig = OutTileConfig_; @@ -88,19 +91,19 @@ struct ConvTrait { static int const unroll_h = OutTileConfig::unroll_h + FilterTileConfig::unroll_h - 1; static int const unroll_w = - OutTileConfig::unroll_w + FilterTileConfig::unroll_w - 1; + (OutTileConfig::unroll_w - 1) * stride_w + FilterTileConfig::unroll_w; static int const unroll_size = unroll_h * unroll_w; }; struct SrcTileCount { static int const smem_src_h = - OutTileConfig::block_h + FilterTileConfig::unroll_h - 1; + (OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h; static int const smem_buff_h = FilterTileConfig::unroll_h; static int const smem_load_h = smem_src_h + smem_buff_h; static int const smem_h = smem_load_h + smem_buff_h; static int const smem_w = - DIVUP(OutTileConfig::block_w + - FilterTileConfig::unroll_w * ThreadConfig::thread_x - 1, + DIVUP((OutTileConfig::block_w - 1) * stride_w + + FilterTileConfig::unroll_w * ThreadConfig::thread_x, 2) * 2; static int const smem_size = smem_h * smem_w; @@ -140,20 +143,25 @@ template < typename TileCount_> __device__ __forceinline__ Global2SharedMem::Global2SharedMem( - T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w) + T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w, int stride_h_, + int stride_w_) : smem(smem_), stride(stride_), start_h(s_h), start_w(s_w), bound_h(b_h), bound_w(b_w), - ring_smem_h(TileCount::smem_load_h) { + ring_smem_h(TileCount::smem_load_h), + stride_h(stride_h_), + stride_w(stride_w_) { if (is_fwd) { ring_src_h = s_h + TileCount::smem_load_h; w_offset = 0; } else { ring_src_h = s_h - 1; w_offset = TileCount::smem_w - b_w; + // stride_h and stride_w just used in backward src data. + stride_h = stride_w = 1; } } @@ -195,9 +203,10 @@ __device__ __forceinline__ void Global2SharedMem< T val = 0.0f; if (src_h_idx >= 0 && src_h_idx < bound_h && src_w_idx >= 0 && src_w_idx < bound_w && - (is_fwd || (TileCount::smem_load_h - smem_h_idx - 1 >= 0 && - TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) { - val = g_ptr[src_h_idx * stride + src_w_idx]; + ((is_fwd && src_h_idx % stride_h == 0 && src_w_idx % stride_w == 0) || + (!is_fwd && TileCount::smem_load_h - smem_h_idx - 1 >= 0 && + TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) { + val = g_ptr[src_h_idx / stride_h * stride + src_w_idx / stride_w]; } *(sh_ptr_as_copy_t(smem_h_idx, smem_w_idx)) = val; } @@ -223,8 +232,9 @@ __device__ __forceinline__ void Global2SharedMem< T val = 0.0f; if (ring_src_h >= 0 && ring_src_h < bound_h && src_w_idx >= 0 && src_w_idx < bound_w && - (is_fwd || TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0)) { - val = g_ptr[ring_src_h * stride + src_w_idx]; + ((is_fwd && ring_src_h % stride_h == 0 && src_w_idx % stride_w == 0) || + (!is_fwd && TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) { + val = g_ptr[ring_src_h / stride_h * stride + src_w_idx / stride_w]; } reg[j] = val; } @@ -286,21 +296,23 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, off_oh = threadIdx.y, off_ow = threadIdx.x; - const int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2; - const int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2; - const int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2; + constexpr int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2; + constexpr int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2; + constexpr int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2; extern __shared__ __align__(8) unsigned char smem[]; static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); T* smem_src = reinterpret_cast(smem); T* smem_flt = reinterpret_cast(&smem_src[SrcTileCount::smem_size]); + int stride_h = is_fwd ? param.stride_h : 1; + int stride_w = is_fwd ? param.stride_w : 1; int off_ichannel = off_ochannel / param.chl_mul, off_fchannel = off_ichannel % param.src_chl, out_start_h = off_obh * OutTileConfig::block_h, out_start_w = off_obw * OutTileConfig::block_w, - src_start_h = out_start_h - param.pad_h, - src_start_w = out_start_w - param.pad_w, + src_start_h = out_start_h * stride_h - param.pad_h, + src_start_w = out_start_w * stride_w - param.pad_w, out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; @@ -308,12 +320,28 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; - SrcGlobal2ShareVisitor gl2sh_src( - smem_src, param.src_w, src_start_h, src_start_w, param.src_h, param.src_w); - - FilterGlobal2ShareVisitor gl2sh_flt = { - smem_flt, param.flt_w, is_fwd ? 0 : param.flt_h - 2, - 0, param.flt_h, param.flt_w}; + SrcGlobal2ShareVisitor gl2sh_src = { + smem_src, + param.src_w, + is_fwd ? src_start_h + : src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - + param.src_h * param.stride_h / 2), + is_fwd ? src_start_w + : src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - + param.src_w * param.stride_w / 2), + is_fwd ? param.src_h : param.src_h * param.stride_h, + is_fwd ? param.src_w : param.src_w * param.stride_w, + is_fwd ? 1 : param.stride_h, + is_fwd ? 1 : param.stride_w}; + + FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, + param.flt_w, + is_fwd ? 0 : param.flt_h - 2, + 0, + param.flt_h, + param.flt_w, + 1, + 1}; gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; @@ -326,7 +354,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w], reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; - T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; + float2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { gl2sh_src.copy(); @@ -335,7 +363,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { #pragma unroll for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { - int src_offset = (off_oh + fh + s_h) % SrcTileCount::smem_h * + int src_offset = (off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * SrcTileCount::smem_w + s_w * 2; reg_src[s_h * t2_src_unroll_w + s_w] = @@ -373,9 +401,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( #pragma unroll for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( - reg_flt[ow % 2][inner_fh * t2_flt_unroll_w + fw], + reg_flt[ow * stride_w % 2] + [inner_fh * t2_flt_unroll_w + fw], reg_src[(inner_fh + oh) * t2_src_unroll_w + fw + - ow / 2], + ow * stride_w / 2], sum[oh * t2_out_unroll_w + ow]); } } @@ -392,7 +421,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( for (int o = 0; o < OutTileConfig::unroll_size; ++o) { for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { - sum[o] += __shfl_xor(sum[o], i, 32); + sum[o].x += __shfl_xor(sum[o].x, i, 32); + sum[o].y += __shfl_xor(sum[o].y, i, 32); } } @@ -406,9 +436,9 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( int out_w_idx = out_start_w + j; if (out_w_idx >= param.out_w) return; - out_base_ptr[out_h_idx * param.out_w + out_w_idx] = + out_base_ptr[out_h_idx * param.out_w + out_w_idx] = __float2half( sum[i * OutTileConfig::unroll_w + j].x + - sum[i * OutTileConfig::unroll_w + j].y; + sum[i * OutTileConfig::unroll_w + j].y); } } } @@ -433,21 +463,19 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, off_oh = threadIdx.y, off_ow = threadIdx.x; - const int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2; - const int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2; - const int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2; - extern __shared__ __align__(8) unsigned char smem[]; static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); T* smem_src = reinterpret_cast(smem); T* smem_flt = reinterpret_cast(&smem_src[SrcTileCount::smem_size]); + int stride_h = is_fwd ? param.stride_h : 1; + int stride_w = is_fwd ? param.stride_w : 1; int off_ichannel = off_ochannel / param.chl_mul, off_fchannel = off_ichannel % param.src_chl, out_start_h = off_obh * OutTileConfig::block_h, out_start_w = off_obw * OutTileConfig::block_w, - src_start_h = out_start_h - param.pad_h, - src_start_w = out_start_w - param.pad_w, + src_start_h = out_start_h * stride_h - param.pad_h, + src_start_w = out_start_w * stride_w - param.pad_w, out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; @@ -455,12 +483,28 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; - SrcGlobal2ShareVisitor gl2sh_src( - smem_src, param.src_w, src_start_h, src_start_w, param.src_h, param.src_w); - - FilterGlobal2ShareVisitor gl2sh_flt = { - smem_flt, param.flt_w, is_fwd ? 0 : param.flt_h - 2, - 0, param.flt_h, param.flt_w}; + SrcGlobal2ShareVisitor gl2sh_src = { + smem_src, + param.src_w, + is_fwd ? src_start_h + : src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h - + param.src_h * param.stride_h / 2), + is_fwd ? src_start_w + : src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w - + param.src_w * param.stride_w / 2), + is_fwd ? param.src_h : param.src_h * param.stride_h, + is_fwd ? param.src_w : param.src_w * param.stride_w, + is_fwd ? 1 : param.stride_h, + is_fwd ? 1 : param.stride_w}; + + FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt, + param.flt_w, + is_fwd ? 0 : param.flt_h - 2, + 0, + param.flt_h, + param.flt_w, + 1, + 1}; gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; @@ -470,10 +514,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( __syncthreads(); - T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w], - reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; + T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], + reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; - T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; + T sum[OutTileConfig::unroll_size] = {0.0}; for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { gl2sh_src.copy(); @@ -481,34 +525,28 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( #pragma unroll for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { #pragma unroll - for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { - int src_offset = (off_oh + fh + s_h) % SrcTileCount::smem_h * - SrcTileCount::smem_w + - s_w * 2; - reg_src[s_h * t2_src_unroll_w + s_w] = - *reinterpret_cast(smem_src_ptr + src_offset); + for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { + reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr + [(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * + SrcTileCount::smem_w + + s_w]; + if (off_ochannel == 0 && off_obw == 0 && off_obh == 0 && off_oh == 30 && + off_ow == 0) { + printf("reg_src[%d] = %f\n", s_h * SrcTileConfig::unroll_w + s_w, + reg_src[s_h * SrcTileConfig::unroll_w + s_w]); + } } } #pragma unroll for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { #pragma unroll - for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { - int flt_offset = - (fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + - f_w * 2; - reg_flt[0][f_h * t2_flt_unroll_w + f_w] = - *reinterpret_cast(smem_flt_ptr + flt_offset); - reg_flt[1][f_h * t2_flt_unroll_w + f_w] = { - f_w > 0 ? reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y - : static_cast(0.0), - reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; + for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { + reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr + [(fh + f_h) % FilterTileCount::smem_h * + FilterTileCount::smem_w + + f_w]; } - reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = { - static_cast(0.0), static_cast(0.0)}; - reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = { - reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, - static_cast(0.0)}; } #pragma unroll @@ -516,14 +554,22 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( #pragma unroll for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { #pragma unroll - for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { + for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { #pragma unroll for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { - sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( - reg_flt[ow % 2][inner_fh * t2_flt_unroll_w + fw], - reg_src[(inner_fh + oh) * t2_src_unroll_w + fw + - ow / 2], - sum[oh * t2_out_unroll_w + ow]); + sum[oh * OutTileConfig::unroll_w + ow] += + reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] * + reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + + ow * stride_w]; + if (off_ochannel == 0 && off_obw == 0 && off_obh == 0 && + off_oh == 30) { + printf("sum[%d] += %f * %f\nsum = %f\n", + oh * OutTileConfig::unroll_w + ow, + reg_flt[inner_fh * FilterTileConfig::unroll_w + fw], + reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + + fw + ow * stride_w], + sum[oh * OutTileConfig::unroll_w + ow]); + } } } } @@ -539,8 +585,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( for (int o = 0; o < OutTileConfig::unroll_size; ++o) { for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { - sum[o].x += __shfl_xor(sum[o].x, i, 32); - sum[o].y += __shfl_xor(sum[o].y, i, 32); + sum[o] += __shfl_xor(sum[o], i, 32); } } @@ -555,8 +600,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( if (out_w_idx >= param.out_w) return; out_base_ptr[out_h_idx * param.out_w + out_w_idx] = - sum[i * OutTileConfig::unroll_w + j].x + - sum[i * OutTileConfig::unroll_w + j].y; + sum[i * OutTileConfig::unroll_w + j]; } } } @@ -565,7 +609,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( template < typename T, typename T2, DepthwiseConv2dDirection kDirection, int unroll_fw, - int unroll_ow> + int unroll_ow, int stride> void LaunchDepthwiseConv2dGPUSmall( const Param& param, const T* input, const T* filter, T* output, cudaStream_t stream) { @@ -574,8 +618,9 @@ void LaunchDepthwiseConv2dGPUSmall( using FilterTileConfig = FilterTileConfig; using ThreadConfig = ThreadConfig<4, 32>; using OutTileConfig = OutTileConfig; - using IConvTrait = - ConvTrait; + using IConvTrait = ConvTrait< + T, kDirection, ThreadConfig, OutTileConfig, FilterTileConfig, stride, + stride>; using SrcTileCount = typename IConvTrait::SrcTileCount; using FilterTileCount = typename IConvTrait::FilterTileCount; @@ -593,10 +638,17 @@ void LaunchDepthwiseConv2dGPUSmall( after_kernel_launch(); } -#define INSTANCE_AB(type1, type2, a, b, direction) \ - if (param.out_w > b * 4) { \ - LaunchDepthwiseConv2dGPUSmall( \ - param, src, flt, dst, stream); \ +#define INSTANCE_AB(type1, type2, a, b, direction) \ + if (param.out_w > b * 4) { \ + printf("param.out_w = %d, b = %d\n", param.out_w, b); \ + if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \ + (param.stride_h == 1 && param.stride_w == 1)) { \ + LaunchDepthwiseConv2dGPUSmall( \ + param, src, flt, dst, stream); \ + } else if (param.stride_h == 2 && param.stride_w == 2) { \ + LaunchDepthwiseConv2dGPUSmall( \ + param, src, flt, dst, stream); \ + } \ } #define INSTANCE_A(type1, type2, a, direction) \ diff --git a/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu b/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu index debe3910b289eb3c63a261091fa492c57b789b39..d4570b88815627f136d121d81e9063123f129ed5 100644 --- a/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu +++ b/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu @@ -11,7 +11,6 @@ #include "cuda.h" #include "cuda_fp16.h" -// #include "src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cuh" #include "src/cuda/conv_bias/chanwise/kern.cuh" #include "src/cuda/conv_bias/chanwise/kern_helper.cuh" #include "src/cuda/conv_bias/chanwise/launch_config.cuh" diff --git a/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp b/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp index 5c14d2a8f046c02fe7a3555985292a5c3824c40c..2f28e87ad0db995f05ab7ccfccd5aee4404a462b 100644 --- a/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp +++ b/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp @@ -32,15 +32,14 @@ inline bool is_available_depthwise_large_filter(const chanwise::Param& param) { : 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1; int out_reg_per_thread = (ow + 3) / 4 * 4; if (device_prop.regsPerBlock < 4 * 32 * - (flt_reg_per_thread + src_reg_per_thread + - out_reg_per_thread) || + (flt_reg_per_thread * 2 + + src_reg_per_thread + out_reg_per_thread) || device_prop.sharedMemPerBlock < static_cast( - flt_smem_w * flt_smem_h + src_smem_w * src_smem_h)) { + flt_smem_w * flt_smem_h * 2 + src_smem_w * src_smem_h)) { return false; } - return param.stride_h == 1 && param.stride_w == 1 && param.src_h == param.out_h && - param.src_w == param.out_w; + return true; } } // anonymous namespace diff --git a/dnn/src/cuda/convolution/backward_data/algo.h b/dnn/src/cuda/convolution/backward_data/algo.h index a4f5536cc7e98b10487ce2cfc0fadcf2a97e9af8..a9009c79b8a2ab766b85dfc72fe1aadcb9317a6c 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.h +++ b/dnn/src/cuda/convolution/backward_data/algo.h @@ -68,7 +68,7 @@ public: const TensorLayout& grad); convolution::ForwardSizeArgs as_fwd_args() const { - return {handle, grad_layout, filter_layout, filter_meta, diff_layout}; + return {handle, diff_layout, filter_layout, filter_meta, grad_layout}; } }; struct ExecArgs : public SizeArgs { diff --git a/dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp b/dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp index 1f24af62c3d6fcfb5c95b865ea8f8e12ed03732c..5ebcd66d17f904217b84babe052e6b9ed53fd73e 100644 --- a/dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp +++ b/dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp @@ -31,15 +31,17 @@ inline bool is_available_depthwise_large_filter(const chanwise::Param& param) { : 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1; int out_reg_per_thread = (ow + 3) / 4 * 4; if (device_prop.regsPerBlock < 4 * 32 * - (flt_reg_per_thread + src_reg_per_thread + - out_reg_per_thread) || + (flt_reg_per_thread * 2 + + src_reg_per_thread + out_reg_per_thread) || device_prop.sharedMemPerBlock < static_cast( - flt_smem_w * flt_smem_h + src_smem_w * src_smem_h)) { + flt_smem_w * flt_smem_h * 2 + src_smem_w * src_smem_h)) { return false; } - return param.stride_h == 1 && param.stride_w == 1 && param.src_h == param.out_h && - param.src_w == param.out_w; + printf("param.src_w = %d, param.src_h = %d, param.out_w = %d, param.out_h = %d\n", + param.src_w, param.src_h, param.out_w, param.out_h); + return (param.stride_h == 1 && param.stride_w == 1) || + (param.stride_h == 2 && param.stride_w == 2); } } // anonymous namespace diff --git a/dnn/src/cuda/fp16_help.cuh b/dnn/src/cuda/fp16_help.cuh index f85a82df7622fa4e05cbbc6f2045d27e5ae56c79..afebd19a17aab7708669af974f239a98136ac090 100644 --- a/dnn/src/cuda/fp16_help.cuh +++ b/dnn/src/cuda/fp16_help.cuh @@ -45,6 +45,12 @@ fma2(const __half2 a, const __half2 b, const __half2 c) { #endif } +__device__ __forceinline__ float2 +fma2(const __half2 a, const __half2 b, const float2 c) { + return {__half2float(a.x) * __half2float(b.x) + c.x, + __half2float(a.y) * __half2float(b.y) + c.y}; +} + #endif // CUDA_VERSION >= 9000 } // namespace cuda diff --git a/dnn/test/cuda/conv_bias.cpp b/dnn/test/cuda/conv_bias.cpp index 1c6dc08ab993f733a1a940b4535e9a6da07043aa..f76304d16048ecd0af9003b1702090cb4bc395de 100644 --- a/dnn/test/cuda/conv_bias.cpp +++ b/dnn/test/cuda/conv_bias.cpp @@ -701,8 +701,10 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { ConvBiasForward::algo_name( "DEPTHWISE_LARGE_FILTER", {}) .c_str())); - for (auto dtype : std::vector{dtype::Float16()}) { - auto run = [&checker, &dtype](size_t n, size_t g, size_t h, size_t fh) { + for (auto dtype : std::vector{dtype::Float32(), dtype::Float16()}) { + auto run = [&checker, &dtype]( + size_t n, size_t g, size_t h, size_t fh, size_t padding, + size_t stride) { param::ConvBias cur_param; cur_param.mode = param::ConvBias::Mode::CROSS_CORRELATION; cur_param.sparse = ConvBias::Param::Sparse::GROUP; @@ -711,42 +713,52 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { .set_dtype(2, dtype) .set_dtype(3, dtype) .set_dtype(4, dtype); + float scale = 64.f / sqrt(fh * fh); + UniformFloatRNG rng(scale, 2 * scale); + checker.set_rng(0, &rng) + .set_rng(1, &rng) + .set_rng(2, &rng) + .set_rng(3, &rng) + .set_rng(4, &rng); + if (dtype.enumv() == DTypeEnum::Float16) { + checker.set_epsilon(1e-1); + } - cur_param.pad_h = cur_param.pad_w = fh / 2; - cur_param.stride_h = cur_param.stride_w = 1; + cur_param.pad_h = cur_param.pad_w = padding; + cur_param.stride_h = cur_param.stride_w = stride; checker.set_param(cur_param).execs( {{n, g, h, h}, {g, 1, 1, fh, fh}, {}, {}, {}}); }; - run(4, 8, 32, 5); - run(4, 8, 32, 7); - run(4, 8, 32, 9); - run(4, 8, 32, 11); - run(4, 8, 32, 13); - run(4, 8, 32, 15); - run(4, 8, 32, 17); - run(4, 8, 32, 19); - run(4, 8, 32, 21); - run(4, 8, 32, 23); - run(4, 8, 32, 25); - run(4, 8, 32, 27); - run(4, 8, 32, 29); - run(4, 8, 32, 31); - run(4, 8, 64, 5); - run(4, 8, 64, 7); - run(4, 8, 64, 9); - run(4, 8, 64, 11); - run(4, 8, 64, 13); - run(4, 8, 64, 15); - run(4, 8, 64, 17); - run(4, 8, 64, 19); - run(4, 8, 64, 21); - run(4, 8, 64, 23); - run(4, 8, 64, 25); - run(4, 8, 64, 27); - run(4, 8, 64, 29); - run(4, 8, 64, 31); - run(1, 2, 128, 31); - run(1, 2, 256, 31); + run(4, 8, 32, 5, 5 / 2, 1); + run(4, 8, 32, 7, 7 / 2, 1); + run(4, 8, 32, 9, 9 / 2, 1); + run(4, 8, 32, 11, 11 / 2, 1); + run(4, 8, 32, 13, 13 / 2, 1); + run(4, 8, 32, 15, 15 / 2, 1); + run(4, 8, 32, 17, 17 / 2, 1); + run(4, 8, 32, 19, 19 / 2, 1); + run(4, 8, 32, 21, 21 / 2, 1); + run(4, 8, 32, 23, 23 / 2, 1); + run(4, 8, 32, 25, 25 / 2, 1); + run(4, 8, 32, 27, 27 / 2, 1); + run(4, 8, 32, 29, 29 / 2, 1); + run(4, 8, 32, 31, 31 / 2, 1); + run(4, 8, 64, 5, 5 / 3, 2); + run(4, 8, 64, 7, 7 / 3, 2); + run(4, 8, 64, 9, 9 / 3, 2); + run(4, 8, 64, 11, 11 / 3, 2); + run(4, 8, 64, 13, 13 / 3, 2); + run(4, 8, 64, 15, 15 / 3, 2); + run(4, 8, 64, 17, 17 / 3, 2); + run(4, 8, 64, 19, 19 / 3, 2); + run(4, 8, 64, 21, 21 / 3, 2); + run(4, 8, 64, 23, 23 / 3, 2); + run(4, 8, 64, 25, 25 / 3, 2); + run(4, 8, 64, 27, 27 / 3, 2); + run(4, 8, 64, 29, 29 / 3, 2); + run(4, 8, 64, 31, 31 / 3, 2); + run(1, 2, 128, 31, 10, 2); + run(1, 2, 256, 31, 10, 2); } } @@ -1530,7 +1542,7 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_TENSORCORE_INT8) { run_bench(256, 512, 7, 7, 2048, 1, 1, 1, 1, 1000); } -TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { +TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP16) { require_compute_capability(7, 5); Benchmarker bencher(handle_cuda()); bencher.set_display(false); @@ -1552,6 +1564,11 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { param.stride_h = sh; param.stride_w = sw; + bencher.set_param(param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .set_dtype(4, dtype::Float16()); bencher.set_times(nr_times); size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); @@ -1562,25 +1579,13 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { out.total_nr_elems()) / (1024 * 1024 * 1024) * 1e3; - bencher.set_param(param) - .set_dtype(0, dtype::Float32()) - .set_dtype(1, dtype::Float32()) - .set_dtype(2, dtype::Float32()) - .set_dtype(4, dtype::Float32()); - auto fp32_time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times; - bencher.set_param(param) - .set_dtype(0, dtype::Float16()) - .set_dtype(1, dtype::Float16()) - .set_dtype(2, dtype::Float16()) - .set_dtype(4, dtype::Float16()); - auto fp16_time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times; - printf("chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, fp32_time: " - "%.2fms, fp16_time: %.2fms, speedup: %0.2f (fp16/fp32) " - "fp32_bandwidth: %.2fGB/s fp16_bandwidth: %.2fGB/s.\n", + auto time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times; + auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12; + printf("chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, time: " + "%.2fms, " + "perf: %.2f Tops bandwidth: %.2fGB/s.\n", inp.to_string().c_str(), kern.to_string().c_str(), - out.to_string().c_str(), fp32_time_in_ms, fp16_time_in_ms, - fp32_time_in_ms / fp16_time_in_ms, bandwith * 4 / fp32_time_in_ms, - bandwith * 2 / fp16_time_in_ms); + out.to_string().c_str(), time_in_ms, ops, bandwith * 4 / time_in_ms); }; run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10); @@ -1600,7 +1605,7 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); } -TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP16) { +TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP32) { require_compute_capability(7, 5); Benchmarker bencher(handle_cuda()); bencher.set_display(false); @@ -1623,10 +1628,10 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP16) { param.stride_w = sw; bencher.set_param(param) - .set_dtype(0, dtype::Float16()) - .set_dtype(1, dtype::Float16()) - .set_dtype(2, dtype::Float16()) - .set_dtype(4, dtype::Float16()); + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_dtype(4, dtype::Float32()); bencher.set_times(nr_times); size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); diff --git a/dnn/test/cuda/convolution.cpp b/dnn/test/cuda/convolution.cpp index b29d6ef325175b29df864d1bce398ae3aa8d93d9..697b784fc2f44175ee17623df1cfde535a03c2f2 100644 --- a/dnn/test/cuda/convolution.cpp +++ b/dnn/test/cuda/convolution.cpp @@ -728,48 +728,58 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) { Checker checker(handle_cuda()); checker.set_before_exec_callback( AlgoChecker("DEPTHWISE_LARGE_FILTER")); - for (auto dtype : std::vector{dtype::Float16()}) { - auto run = [&checker, &dtype](size_t n, size_t g, size_t h, size_t fh) { + for (auto dtype : std::vector{dtype::Float32(), dtype::Float16()}) { + auto run = [&checker, &dtype]( + size_t n, size_t g, size_t h, size_t fh, size_t padding, + size_t stride) { param::Convolution param; - param.stride_h = param.stride_w = 1; - param.pad_h = param.pad_w = fh / 2; + param.stride_h = param.stride_w = stride; + param.pad_h = param.pad_w = padding; param.mode = Convolution::Mode::CROSS_CORRELATION; param.sparse = param::Convolution::Sparse::GROUP; checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype); + float scale = 64.f / sqrt(fh * fh); + UniformFloatRNG rng(1.0, 1.0); + checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &rng); + if (dtype.enumv() == DTypeEnum::Float16) + checker.set_epsilon(1e-1); checker.set_param(param).execs( - {{g, 1, 1, fh, fh}, {n, g, h, h}, {n, g, h, h}}); + {{g, 1, 1, fh, fh}, + {n, g, (h + 2 * padding - fh + 1) / stride, + (h + 2 * padding - fh + 1) / stride}, + {n, g, h, h}}); }; - run(4, 8, 32, 5); - run(4, 8, 32, 7); - run(4, 8, 32, 9); - run(4, 8, 32, 11); - run(4, 8, 32, 13); - run(4, 8, 32, 15); - run(4, 8, 32, 17); - run(4, 8, 32, 19); - run(4, 8, 32, 21); - run(4, 8, 32, 23); - run(4, 8, 32, 25); - run(4, 8, 32, 27); - run(4, 8, 32, 29); - run(4, 8, 32, 31); - run(4, 8, 64, 7); - run(4, 8, 64, 5); - run(4, 8, 64, 9); - run(4, 8, 64, 11); - run(4, 8, 64, 13); - run(4, 8, 64, 15); - run(4, 8, 64, 17); - run(4, 8, 64, 19); - run(4, 8, 64, 21); - run(4, 8, 64, 23); - run(4, 8, 64, 25); - run(4, 8, 64, 27); - run(4, 8, 64, 29); - run(4, 8, 64, 31); - run(1, 2, 128, 31); - run(1, 2, 256, 31); + run(4, 8, 32, 5, 5 / 2, 1); + run(4, 8, 32, 7, 7/2, 1); + run(4, 8, 32, 9, 9/2, 1); + run(4, 8, 32, 11, 11/2, 1); + run(4, 8, 32, 13, 13/2, 1); + run(4, 8, 32, 15, 15/2, 1); + run(4, 8, 32, 17, 17/2, 1); + run(4, 8, 32, 19, 19/2, 1); + run(4, 8, 32, 21, 21/2, 1); + run(4, 8, 32, 23, 23/2, 1); + run(4, 8, 32, 25, 25/2, 1); + run(4, 8, 32, 27, 27/2, 1); + run(4, 8, 32, 29, 29/2, 1); + run(4, 8, 32, 31, 31/2, 1); + run(4, 8, 64, 5, 5 / 2, 2); + run(4, 8, 64, 7, 7/3, 2); + run(4, 8, 64, 9, 9/3, 2); + run(4, 8, 64, 11, 11/3, 2); + run(4, 8, 64, 13, 13/3, 2); + run(4, 8, 64, 15, 15/3, 2); + run(4, 8, 64, 17, 17/3, 2); + run(4, 8, 64, 19, 19/3, 2); + run(4, 8, 64, 21, 21/3, 2); + run(4, 8, 64, 23, 23/3, 2); + run(4, 8, 64, 25, 25/3, 2); + run(4, 8, 64, 27, 27/3, 2); + run(4, 8, 64, 29, 29/3, 2); + run(4, 8, 64, 31, 31/3, 2); + run(1, 2, 128, 31, 31/3, 2); + run(1, 2, 256, 31, 31/3, 2); } } @@ -950,7 +960,7 @@ TEST_F(CUDA, CONVOLUTION_BWD_DATA_BENCHMARK) { run(32, 64, 64, 56, 56, 1, 1, 0); } -TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER) { +TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER_FP32) { CUBenchmarker bencher{handle_cuda()}; bencher.set_display(false); bencher.set_before_exec_callback(