From bc385b537494620972162b98581f8e997923ebd4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Feb 2022 18:28:06 +0800 Subject: [PATCH] feat(cuda): support float16 depthwise large kernel conv GitOrigin-RevId: fdc1b15fbcb3968e695601bff6b6a953bf66f115 --- .../chanwise/depthwise_large_filter_algo.inl | 248 +++++++++++++++--- .../conv_bias/chanwise/fwd_large_filter.cu | 11 +- .../cuda/conv_bias/depthwise_large_filter.cpp | 15 +- .../backward_data/depthwise_large_filter.cpp | 14 +- .../convolution/chanwise/bwd_large_filter.cu | 11 +- dnn/test/cuda/conv_bias.cpp | 160 +++++++---- dnn/test/cuda/convolution.cpp | 51 +++- 7 files changed, 421 insertions(+), 89 deletions(-) diff --git a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl index 4eb54607c..a86689f98 100644 --- a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl +++ b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl @@ -98,9 +98,11 @@ struct ConvTrait { static int const smem_buff_h = FilterTileConfig::unroll_h; static int const smem_load_h = smem_src_h + smem_buff_h; static int const smem_h = smem_load_h + smem_buff_h; - static int const smem_w = OutTileConfig::block_w + - FilterTileConfig::unroll_w * ThreadConfig::thread_x - - 1; + static int const smem_w = + DIVUP(OutTileConfig::block_w + + FilterTileConfig::unroll_w * ThreadConfig::thread_x - 1, + 2) * + 2; static int const smem_size = smem_h * smem_w; static int const load_w = smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w; @@ -266,9 +268,11 @@ __device__ __forceinline__ void Global2SharedMem< // one each in the lower and upper half of a tile. // Backprop input direction is the same as forward direction with the filter // rotated by 180°. -template +template __global__ void DepthwiseConv2dGPUKernelNCHWSmall( - const Param param, const T* input, const T* filter, T* output) { + const Param param, const __half* input, const __half* filter, __half* output) { + using T = __half; + using T2 = __half2; using ThreadConfig = typename ConvTrait::ThreadConfig; using SrcTileConfig = typename ConvTrait::SrcTileConfig; using FilterTileConfig = typename ConvTrait::FilterTileConfig; @@ -282,6 +286,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, off_oh = threadIdx.y, off_ow = threadIdx.x; + const int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2; + const int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2; + const int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2; + extern __shared__ __align__(8) unsigned char smem[]; static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); T* smem_src = reinterpret_cast(smem); @@ -315,10 +323,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( __syncthreads(); - T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], - reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; + T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w], + reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; - T sum[OutTileConfig::unroll_size] = {0.0}; + T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { gl2sh_src.copy(); @@ -326,23 +334,34 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( #pragma unroll for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { #pragma unroll - for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { - reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr - [(off_oh + fh + s_h) % SrcTileCount::smem_h * - SrcTileCount::smem_w + - s_w]; + for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { + int src_offset = (off_oh + fh + s_h) % SrcTileCount::smem_h * + SrcTileCount::smem_w + + s_w * 2; + reg_src[s_h * t2_src_unroll_w + s_w] = + *reinterpret_cast(smem_src_ptr + src_offset); } } #pragma unroll for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { #pragma unroll - for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { - reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr - [(fh + f_h) % FilterTileCount::smem_h * - FilterTileCount::smem_w + - f_w]; + for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { + int flt_offset = + (fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + + f_w * 2; + reg_flt[0][f_h * t2_flt_unroll_w + f_w] = + *reinterpret_cast(smem_flt_ptr + flt_offset); + reg_flt[1][f_h * t2_flt_unroll_w + f_w] = { + f_w > 0 ? reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y + : static_cast(0.0), + reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; } + reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = { + static_cast(0.0), static_cast(0.0)}; + reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = { + reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, + static_cast(0.0)}; } #pragma unroll @@ -350,13 +369,14 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( #pragma unroll for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { #pragma unroll - for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { + for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { #pragma unroll for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { - sum[oh * OutTileConfig::unroll_w + ow] += - reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] * - reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + - ow]; + sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( + reg_flt[ow % 2][inner_fh * t2_flt_unroll_w + fw], + reg_src[(inner_fh + oh) * t2_src_unroll_w + fw + + ow / 2], + sum[oh * t2_out_unroll_w + ow]); } } } @@ -387,7 +407,156 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( if (out_w_idx >= param.out_w) return; out_base_ptr[out_h_idx * param.out_w + out_w_idx] = - sum[i * OutTileConfig::unroll_w + j]; + sum[i * OutTileConfig::unroll_w + j].x + + sum[i * OutTileConfig::unroll_w + j].y; + } + } + } + } +} + +template +__global__ void DepthwiseConv2dGPUKernelNCHWSmall( + const Param param, const float* input, const float* filter, float* output) { + using T = float; + using T2 = float2; + using ThreadConfig = typename ConvTrait::ThreadConfig; + using SrcTileConfig = typename ConvTrait::SrcTileConfig; + using FilterTileConfig = typename ConvTrait::FilterTileConfig; + using OutTileConfig = typename ConvTrait::OutTileConfig; + using SrcTileCount = typename ConvTrait::SrcTileCount; + using FilterTileCount = typename ConvTrait::FilterTileCount; + using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; + using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; + const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); + + int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, + off_oh = threadIdx.y, off_ow = threadIdx.x; + + const int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2; + const int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2; + const int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2; + + extern __shared__ __align__(8) unsigned char smem[]; + static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); + T* smem_src = reinterpret_cast(smem); + T* smem_flt = reinterpret_cast(&smem_src[SrcTileCount::smem_size]); + + int off_ichannel = off_ochannel / param.chl_mul, + off_fchannel = off_ichannel % param.src_chl, + out_start_h = off_obh * OutTileConfig::block_h, + out_start_w = off_obw * OutTileConfig::block_w, + src_start_h = out_start_h - param.pad_h, + src_start_w = out_start_w - param.pad_w, + out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; + + T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; + T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w; + + T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; + + SrcGlobal2ShareVisitor gl2sh_src( + smem_src, param.src_w, src_start_h, src_start_w, param.src_h, param.src_w); + + FilterGlobal2ShareVisitor gl2sh_flt = { + smem_flt, param.flt_w, is_fwd ? 0 : param.flt_h - 2, + 0, param.flt_h, param.flt_w}; + + gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; + gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; + + gl2sh_src.first_copy(); + gl2sh_flt.first_copy(); + + __syncthreads(); + + T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w], + reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w]; + + T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}}; + + for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { + gl2sh_src.copy(); + gl2sh_flt.copy(); +#pragma unroll + for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { +#pragma unroll + for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) { + int src_offset = (off_oh + fh + s_h) % SrcTileCount::smem_h * + SrcTileCount::smem_w + + s_w * 2; + reg_src[s_h * t2_src_unroll_w + s_w] = + *reinterpret_cast(smem_src_ptr + src_offset); + } + } + +#pragma unroll + for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { +#pragma unroll + for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) { + int flt_offset = + (fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + + f_w * 2; + reg_flt[0][f_h * t2_flt_unroll_w + f_w] = + *reinterpret_cast(smem_flt_ptr + flt_offset); + reg_flt[1][f_h * t2_flt_unroll_w + f_w] = { + f_w > 0 ? reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y + : static_cast(0.0), + reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; + } + reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = { + static_cast(0.0), static_cast(0.0)}; + reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = { + reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, + static_cast(0.0)}; + } + +#pragma unroll + for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { +#pragma unroll + for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { +#pragma unroll + for (int fw = 0; fw < t2_flt_unroll_w; ++fw) { +#pragma unroll + for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { + sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2( + reg_flt[ow % 2][inner_fh * t2_flt_unroll_w + fw], + reg_src[(inner_fh + oh) * t2_src_unroll_w + fw + + ow / 2], + sum[oh * t2_out_unroll_w + ow]); + } + } + } + } + + __syncthreads(); + gl2sh_src.commit(); + gl2sh_flt.commit(); + gl2sh_src.iter_forward(); + gl2sh_flt.iter_forward(); + __syncthreads(); + } + + for (int o = 0; o < OutTileConfig::unroll_size; ++o) { + for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { + sum[o].x += __shfl_xor(sum[o].x, i, 32); + sum[o].y += __shfl_xor(sum[o].y, i, 32); + } + } + + if (threadIdx.x == 0) { +#pragma unroll + for (int i = 0; i < OutTileConfig::unroll_h; ++i) { + int out_h_idx = out_base_h_idx + i; + if (out_h_idx < param.out_h) { +#pragma unroll + for (int j = 0; j < OutTileConfig::unroll_w; ++j) { + int out_w_idx = out_start_w + j; + if (out_w_idx >= param.out_w) + return; + out_base_ptr[out_h_idx * param.out_w + out_w_idx] = + sum[i * OutTileConfig::unroll_w + j].x + + sum[i * OutTileConfig::unroll_w + j].y; } } } @@ -419,28 +588,27 @@ void LaunchDepthwiseConv2dGPUSmall( (SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T); void (*kernel)(const Param, const T*, const T*, T*); - kernel = DepthwiseConv2dGPUKernelNCHWSmall; + kernel = DepthwiseConv2dGPUKernelNCHWSmall; kernel<<>>(param, input, filter, output); after_kernel_launch(); } -#define INSTANCE_AB(a, b, direction) \ - if (param.out_w > b * 4) { \ - LaunchDepthwiseConv2dGPUSmall( \ - param, src, flt, dst, stream); \ +#define INSTANCE_AB(type1, type2, a, b, direction) \ + if (param.out_w > b * 4) { \ + LaunchDepthwiseConv2dGPUSmall( \ + param, src, flt, dst, stream); \ } -#define INSTANCE_A(a, direction) \ - if (param.flt_w > 0) { \ - INSTANCE_AB(a, 15, direction) \ - else INSTANCE_AB(a, 14, direction) else INSTANCE_AB(a, 13, direction) else INSTANCE_AB( \ - a, 12, direction) else INSTANCE_AB(a, 11, direction) else INSTANCE_AB(a, 10, direction) else INSTANCE_AB(a, 9, direction) else INSTANCE_AB(a, 8, direction) else INSTANCE_AB(a, 7, direction) else INSTANCE_AB(a, 6, direction) else INSTANCE_AB(a, 5, direction) else INSTANCE_AB(a, 4, direction) else INSTANCE_AB(a, 3, direction) else INSTANCE_AB(a, 2, direction) else INSTANCE_AB(a, 1, direction) else INSTANCE_AB(a, 0, direction) \ +#define INSTANCE_A(type1, type2, a, direction) \ + if (param.flt_w > a * 4) { \ + INSTANCE_AB(type1, type2, a, 15, direction) \ + else INSTANCE_AB(type1, type2, a, 14, direction) else INSTANCE_AB(type1, type2, a, 13, direction) else INSTANCE_AB(type1, type2, a, 12, direction) else INSTANCE_AB(type1, type2, a, 11, direction) else INSTANCE_AB(type1, type2, a, 10, direction) else INSTANCE_AB( \ + type1, type2, \ + a, 9, direction) else INSTANCE_AB(type1, type2, a, 8, direction) else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB(type1, type2, a, 6, direction) else INSTANCE_AB(type1, type2, a, 5, direction) else INSTANCE_AB(type1, type2, a, 4, direction) else INSTANCE_AB(type1, type2, a, 3, direction) else INSTANCE_AB(type1, type2, a, 2, direction) else INSTANCE_AB(type1, type2, a, 1, direction) else INSTANCE_AB(type1, type2, a, 0, direction) \ } -#define INSTANCE(direction) \ - INSTANCE_A(7, direction) \ - else INSTANCE_A(6, direction) else INSTANCE_A(5, direction) else INSTANCE_A(4, direction) else INSTANCE_A( \ - 3, \ - direction) else INSTANCE_A(2, direction) else INSTANCE_A(1, direction) else INSTANCE_A(0, direction) - +#define INSTANCE(type1, type2, direction) \ + INSTANCE_A(type1, type2, 6, direction) \ + else INSTANCE_A(type1, type2, 4, direction) else INSTANCE_A( \ + type1, type2, 2, direction) else INSTANCE_A(type1, type2, 0, direction) } // anonymous namespace diff --git a/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu b/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu index 35c142ee9..debe3910b 100644 --- a/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu +++ b/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu @@ -37,9 +37,18 @@ template <> void run_fwd_depthwise_large_filter( float* dst, const float* src, const float* flt, const Param& param, cudaStream_t stream) { - INSTANCE(DepthwiseConv2dDirection::DIRECTION_FORWARD) + INSTANCE(float, float2, DepthwiseConv2dDirection::DIRECTION_FORWARD) } +#if CUDA_VERSION >= 9000 +template <> +void run_fwd_depthwise_large_filter( + __half* dst, const __half* src, const __half* flt, const Param& param, + cudaStream_t stream) { + INSTANCE(__half, __half2, DepthwiseConv2dDirection::DIRECTION_FORWARD) +} +#endif + } // namespace chanwise } // namespace conv_bias } // namespace cuda diff --git a/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp b/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp index 766254b8a..5c14d2a8f 100644 --- a/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp +++ b/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp @@ -50,7 +50,11 @@ bool ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::is_available( return false; } if (args.src_layout->dtype != args.filter_layout->dtype && - args.src_layout->dtype != dtype::Float32()) { + (args.src_layout->dtype != dtype::Float32() +#if CUDA_VERSION >= 9000 + || args.src_layout->dtype != dtype::Float16() +#endif + )) { return false; } if (args.z_layout->ndim > 0) @@ -97,6 +101,15 @@ void ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::exec(const ExecArgs& args) c conv_dst_tensor.ptr(), args.src_tensor->ptr(), args.filter_tensor->ptr(), kparam, stream); break; +#if CUDA_VERSION >= 9000 + case DTypeEnum::Float16: + chanwise::run_fwd_depthwise_large_filter( + static_cast(conv_dst_tensor.raw_ptr()), + static_cast(args.src_tensor->raw_ptr()), + static_cast(args.filter_tensor->raw_ptr()), kparam, + stream); + break; +#endif default: megdnn_assert_internal(0); } diff --git a/dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp b/dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp index f9a6f998d..1f24af62c 100644 --- a/dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp +++ b/dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp @@ -49,7 +49,11 @@ bool ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::is_available( return false; } if (args.diff_layout->dtype != args.filter_layout->dtype && - args.diff_layout->dtype != dtype::Float32()) { + (args.diff_layout->dtype != dtype::Float32() +#if CUDA_VERSION >= 9000 + || args.diff_layout->dtype != dtype::Float16() +#endif + )) { return false; } @@ -78,6 +82,14 @@ void ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::exec( args.grad_tensor->ptr(), args.diff_tensor->ptr(), args.filter_tensor->ptr(), kparam, stream); break; +#if CUDA_VERSION >= 9000 + case DTypeEnum::Float16: + chanwise::run_bwd_depthwise_large_filter( + static_cast(args.grad_tensor->raw_ptr()), + static_cast(args.diff_tensor->raw_ptr()), + static_cast(args.filter_tensor->raw_ptr()), kparam, stream); + break; +#endif default: megdnn_assert_internal(0); } diff --git a/dnn/src/cuda/convolution/chanwise/bwd_large_filter.cu b/dnn/src/cuda/convolution/chanwise/bwd_large_filter.cu index ff2ad37ef..acbb7b9a8 100644 --- a/dnn/src/cuda/convolution/chanwise/bwd_large_filter.cu +++ b/dnn/src/cuda/convolution/chanwise/bwd_large_filter.cu @@ -34,9 +34,18 @@ template <> void run_bwd_depthwise_large_filter( float* dst, const float* src, const float* flt, const Param& param, cudaStream_t stream) { - INSTANCE(DepthwiseConv2dDirection::DIRECTION_BACKWARD) + INSTANCE(float, float2, DepthwiseConv2dDirection::DIRECTION_BACKWARD) } +#if CUDA_VERSION >= 9000 +template <> +void run_bwd_depthwise_large_filter( + __half* dst, const __half* src, const __half* flt, const Param& param, + cudaStream_t stream) { + INSTANCE(__half, __half2, DepthwiseConv2dDirection::DIRECTION_BACKWARD) +} +#endif + } // namespace chanwise } // namespace convolution } // namespace cuda diff --git a/dnn/test/cuda/conv_bias.cpp b/dnn/test/cuda/conv_bias.cpp index eceecc5e0..1c6dc08ab 100644 --- a/dnn/test/cuda/conv_bias.cpp +++ b/dnn/test/cuda/conv_bias.cpp @@ -701,51 +701,53 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { ConvBiasForward::algo_name( "DEPTHWISE_LARGE_FILTER", {}) .c_str())); - auto run = [&checker](size_t n, size_t g, size_t h, size_t fh) { - param::ConvBias cur_param; - cur_param.mode = param::ConvBias::Mode::CROSS_CORRELATION; - cur_param.sparse = ConvBias::Param::Sparse::GROUP; - checker.set_dtype(0, dtype::Float32()) - .set_dtype(1, dtype::Float32()) - .set_dtype(2, dtype::Float32()) - .set_dtype(3, dtype::Float32()) - .set_dtype(4, dtype::Float32()); + for (auto dtype : std::vector{dtype::Float16()}) { + auto run = [&checker, &dtype](size_t n, size_t g, size_t h, size_t fh) { + param::ConvBias cur_param; + cur_param.mode = param::ConvBias::Mode::CROSS_CORRELATION; + cur_param.sparse = ConvBias::Param::Sparse::GROUP; + checker.set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype) + .set_dtype(3, dtype) + .set_dtype(4, dtype); - cur_param.pad_h = cur_param.pad_w = fh / 2; - cur_param.stride_h = cur_param.stride_w = 1; - checker.set_param(cur_param).execs( - {{n, g, h, h}, {g, 1, 1, fh, fh}, {}, {}, {}}); - }; - run(4, 8, 32, 5); - run(4, 8, 32, 7); - run(4, 8, 32, 9); - run(4, 8, 32, 11); - run(4, 8, 32, 13); - run(4, 8, 32, 15); - run(4, 8, 32, 17); - run(4, 8, 32, 19); - run(4, 8, 32, 21); - run(4, 8, 32, 23); - run(4, 8, 32, 25); - run(4, 8, 32, 27); - run(4, 8, 32, 29); - run(4, 8, 32, 31); - run(4, 8, 64, 5); - run(4, 8, 64, 7); - run(4, 8, 64, 9); - run(4, 8, 64, 11); - run(4, 8, 64, 13); - run(4, 8, 64, 15); - run(4, 8, 64, 17); - run(4, 8, 64, 19); - run(4, 8, 64, 21); - run(4, 8, 64, 23); - run(4, 8, 64, 25); - run(4, 8, 64, 27); - run(4, 8, 64, 29); - run(4, 8, 64, 31); - run(1, 2, 128, 31); - run(1, 2, 256, 31); + cur_param.pad_h = cur_param.pad_w = fh / 2; + cur_param.stride_h = cur_param.stride_w = 1; + checker.set_param(cur_param).execs( + {{n, g, h, h}, {g, 1, 1, fh, fh}, {}, {}, {}}); + }; + run(4, 8, 32, 5); + run(4, 8, 32, 7); + run(4, 8, 32, 9); + run(4, 8, 32, 11); + run(4, 8, 32, 13); + run(4, 8, 32, 15); + run(4, 8, 32, 17); + run(4, 8, 32, 19); + run(4, 8, 32, 21); + run(4, 8, 32, 23); + run(4, 8, 32, 25); + run(4, 8, 32, 27); + run(4, 8, 32, 29); + run(4, 8, 32, 31); + run(4, 8, 64, 5); + run(4, 8, 64, 7); + run(4, 8, 64, 9); + run(4, 8, 64, 11); + run(4, 8, 64, 13); + run(4, 8, 64, 15); + run(4, 8, 64, 17); + run(4, 8, 64, 19); + run(4, 8, 64, 21); + run(4, 8, 64, 23); + run(4, 8, 64, 25); + run(4, 8, 64, 27); + run(4, 8, 64, 29); + run(4, 8, 64, 31); + run(1, 2, 128, 31); + run(1, 2, 256, 31); + } } TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE_8x8x32) { @@ -1550,11 +1552,81 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { param.stride_h = sh; param.stride_w = sw; + bencher.set_times(nr_times); + size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); + size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); + TensorShape inp{batch, g, hi, wi}, kern{g, 1, 1, fh, fw}, out{batch, g, ho, wo}; + + float bandwith = static_cast( + inp.total_nr_elems() + kern.total_nr_elems() + + out.total_nr_elems()) / + (1024 * 1024 * 1024) * 1e3; + bencher.set_param(param) .set_dtype(0, dtype::Float32()) .set_dtype(1, dtype::Float32()) .set_dtype(2, dtype::Float32()) .set_dtype(4, dtype::Float32()); + auto fp32_time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times; + bencher.set_param(param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .set_dtype(4, dtype::Float16()); + auto fp16_time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times; + printf("chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, fp32_time: " + "%.2fms, fp16_time: %.2fms, speedup: %0.2f (fp16/fp32) " + "fp32_bandwidth: %.2fGB/s fp16_bandwidth: %.2fGB/s.\n", + inp.to_string().c_str(), kern.to_string().c_str(), + out.to_string().c_str(), fp32_time_in_ms, fp16_time_in_ms, + fp32_time_in_ms / fp16_time_in_ms, bandwith * 4 / fp32_time_in_ms, + bandwith * 2 / fp16_time_in_ms); + }; + + run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10); + run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10); + run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10); + run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10); + run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10); + run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10); + run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10); + run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10); + run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10); + run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10); + run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10); + run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10); + run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10); + run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10); + run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); +} + +TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP16) { + require_compute_capability(7, 5); + Benchmarker bencher(handle_cuda()); + bencher.set_display(false); + bencher.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker( + ConvBiasForward::algo_name( + "DEPTHWISE_LARGE_FILTER", {}) + .c_str())); + + ConvBias::Param param; + param.format = ConvBias::Param::Format::NCHW; + + using NonlineMode = ConvBias::Param::NonlineMode; + param.nonlineMode = NonlineMode::IDENTITY; + param.sparse = ConvBias::Param::Sparse::GROUP; + auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh, + size_t fw, size_t sh, size_t sw, size_t nr_times) { + param.pad_h = fh / 2; + param.pad_w = fw / 2; + param.stride_h = sh; + param.stride_w = sw; + + bencher.set_param(param) + .set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()) + .set_dtype(4, dtype::Float16()); bencher.set_times(nr_times); size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); diff --git a/dnn/test/cuda/convolution.cpp b/dnn/test/cuda/convolution.cpp index 88cf77dcc..b29d6ef32 100644 --- a/dnn/test/cuda/convolution.cpp +++ b/dnn/test/cuda/convolution.cpp @@ -728,7 +728,7 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) { Checker checker(handle_cuda()); checker.set_before_exec_callback( AlgoChecker("DEPTHWISE_LARGE_FILTER")); - for (auto dtype : std::vector{dtype::Float32()}) { + for (auto dtype : std::vector{dtype::Float16()}) { auto run = [&checker, &dtype](size_t n, size_t g, size_t h, size_t fh) { param::Convolution param; param.stride_h = param.stride_w = 1; @@ -999,6 +999,55 @@ TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER) { run(64, 384, 384, 32, 32, 31, 1, 10); } +TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER_FP16) { + CUBenchmarker bencher{handle_cuda()}; + bencher.set_display(false); + bencher.set_before_exec_callback( + AlgoChecker("DEPTHWISE_LARGE_FILTER")); + + auto run = [&](size_t N, size_t OC, size_t g, size_t IH, size_t IW, size_t FH, + size_t SH, size_t nr_times) { + bencher.set_dtype(0, dtype::Float16()) + .set_dtype(1, dtype::Float16()) + .set_dtype(2, dtype::Float16()); + param::Convolution param; + param.stride_h = param.stride_w = SH; + param.pad_h = param.pad_w = FH / 2; + param.sparse = param::Convolution::Sparse::GROUP; + bencher.set_param(param); + bencher.set_times(nr_times); + TensorLayout src{{N, g, IH, IW}, dtype::Float16()}, + filter{{g, 1, 1, FH, FH}, dtype::Float16()}; + TensorLayout dst; + { + auto&& opr = handle_cuda()->create_operator(); + opr->param() = param; + opr->deduce_layout(src, filter, dst); + } + auto time_ms_fp16 = bencher.execl({filter, dst, src}) / nr_times; + float flo = 2.0 * N * g * dst[2] * dst[3] * FH * FH; + printf("inp=%s, kern=%s, dst=%s ", src.to_string().c_str(), + filter.to_string().c_str(), dst.to_string().c_str()); + printf("time_fp16=%.2fms, flops=%.3fTFLOPS\n", time_ms_fp16, + (flo / (time_ms_fp16 * 1e9))); + }; + run(64, 384, 384, 32, 32, 3, 1, 10); + run(64, 384, 384, 32, 32, 5, 1, 10); + run(64, 384, 384, 32, 32, 7, 1, 10); + run(64, 384, 384, 32, 32, 9, 1, 10); + run(64, 384, 384, 32, 32, 11, 1, 10); + run(64, 384, 384, 32, 32, 13, 1, 10); + run(64, 384, 384, 32, 32, 15, 1, 10); + run(64, 384, 384, 32, 32, 17, 1, 10); + run(64, 384, 384, 32, 32, 19, 1, 10); + run(64, 384, 384, 32, 32, 21, 1, 10); + run(64, 384, 384, 32, 32, 23, 1, 10); + run(64, 384, 384, 32, 32, 25, 1, 10); + run(64, 384, 384, 32, 32, 27, 1, 10); + run(64, 384, 384, 32, 32, 29, 1, 10); + run(64, 384, 384, 32, 32, 31, 1, 10); +} + TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_BF16) { CUBenchmarker bench{handle_cuda()}; std::unique_ptr> proxy{ -- GitLab