diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index 2cc0c327e85b2ca08558ff63d837646cb9fdca5c..f1e139d6967d7a438102aa642a0fa45b3d7f105f 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -19,6 +19,7 @@ using namespace cuda; ConvBiasForwardImpl::AlgoPack::AlgoPack() { non_cudnn_algos.push_back(&chanwise); non_cudnn_algos.push_back(&chanwise_small); + non_cudnn_algos.push_back(&depthwise_large_filter); non_cudnn_algos.push_back(&inplace_matmul); non_cudnn_algos.push_back(&matmul); @@ -34,6 +35,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { std::vector conv_algos; conv_algos.push_back(&chanwise); conv_algos.push_back(&chanwise_small); + conv_algos.push_back(&depthwise_large_filter); conv_algos.push_back(&chanwise8x8x32); for (auto&& algo : cudnn_convs) { conv_algos.push_back(&algo); diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index ea0e60531f87413af0574a658937dbd9cb8be798..18daa829d5967e55e618e091fd81549acdde16b4 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -22,7 +22,6 @@ #include "src/cuda/conv_bias/opr_impl.h" #include "src/cuda/convolution_helper/parameter.cuh" #include "src/cuda/cudnn_wrapper.h" -#include "src/cuda/handle.h" #include #include @@ -57,6 +56,7 @@ public: CUDA_CUDNN_CONVBIAS, CUDA_CHANWISE, CUDA_CHANWISE_SMALL, + CUDA_DEPTHWISE_LARGE_FILTER, CUDA_CHANWISE_INT8X8X32, CUDA_CUDNN_CONV, CUDA_INPLACE_MATMUL, @@ -257,6 +257,26 @@ private: mutable std::string m_name; }; +class ConvBiasForwardImpl::AlgoDepthwiseLargeFilter final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasForward::algo_name( + "DEPTHWISE_LARGE_FILTER", {}); + } + return m_name.c_str(); + } + MEGDNN_DECL_ALGO_TYPE(CUDA_DEPTHWISE_LARGE_FILTER) + AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } + +private: + mutable std::string m_name; +}; + class ConvBiasForwardImpl::AlgoChanwise8x8x32 final : public AlgoBase { public: bool is_available(const SizeArgs& args) const override; @@ -1084,6 +1104,7 @@ public: AlgoFallbackNCHWQS8 fallback_nchw_qs8; AlgoChanwise chanwise; AlgoChanwiseSmall chanwise_small; + AlgoDepthwiseLargeFilter depthwise_large_filter; AlgoChanwise8x8x32 chanwise8x8x32; AlgoInplaceMatmul inplace_matmul; AlgoMatmul matmul; diff --git a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl new file mode 100644 index 0000000000000000000000000000000000000000..4eb54607c5c46041dece33878dfe2be563a0d655 --- /dev/null +++ b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl @@ -0,0 +1,446 @@ +/** + * \file dnn/src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.inl + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "src/cuda/cuda_shfl_compat.cuh" +namespace { + +enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD }; + +template +struct OutTileConfig { + using ThreadConfig = ThreadConfig_; + static int const unroll_h = oh_; + static int const unroll_w = ThreadConfig::thread_x * ow_; + static int const unroll_size = unroll_h * unroll_w; + static int const block_h = unroll_h * ThreadConfig::thread_y; + static int const block_w = unroll_w; +}; + +template +struct FilterTileConfig { + static int const unroll_h = fh_; + static int const unroll_w = fw_; + static int const unroll_size = unroll_h * unroll_w; +}; + +template +struct ThreadConfig { + static int const thread_x = x_; + static_assert((thread_x & (thread_x - 1)) == 0, "thread_x must be pow of 2!"); + static int const thread_y = y_; + static int const nr_threads = x_ * y_; +}; + +template < + typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_, + typename TileCount_> +struct Global2SharedMem { + using TileCount = TileCount_; + using ThreadConfig = ThreadConfig_; + T reg[TileCount::reg_w]; + const int tidx = threadIdx.x; + const int tidy = threadIdx.y; + const int tid = tidy * ThreadConfig::thread_x + tidx; + const int gl_load_y = tid / TileCount::load_w; + const int gl_load_x = tid - gl_load_y * TileCount::load_w; + const bool is_fwd = (kDirection == DIRECTION_FORWARD); + int w_offset; + + T* smem; + int stride; + int start_h, start_w, bound_h, bound_w, ring_smem_h, ring_src_h; + const T* g_ptr; + + __device__ __forceinline__ + Global2SharedMem(T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w); + + __device__ __forceinline__ void first_copy(); + __device__ __forceinline__ void copy(); + __device__ __forceinline__ void commit(); + __device__ __forceinline__ void iter_forward(); + __device__ __forceinline__ T* sh_ptr(int y, int x) { + return &smem[y * TileCount::smem_w + x]; + } + + __device__ __forceinline__ T* sh_ptr_as_copy_t(int y, int x) { + return reinterpret_cast(sh_ptr(y, x)); + } +}; + +template < + typename ldg_dtype, DepthwiseConv2dDirection kDirection, typename ThreadConfig_, + typename OutTileConfig_, typename FilterTileConfig_> +struct ConvTrait { + using ThreadConfig = ThreadConfig_; + using OutTileConfig = OutTileConfig_; + using FilterTileConfig = FilterTileConfig_; + using CompType = ldg_dtype; + + struct SrcTileConfig { + static int const unroll_h = + OutTileConfig::unroll_h + FilterTileConfig::unroll_h - 1; + static int const unroll_w = + OutTileConfig::unroll_w + FilterTileConfig::unroll_w - 1; + static int const unroll_size = unroll_h * unroll_w; + }; + + struct SrcTileCount { + static int const smem_src_h = + OutTileConfig::block_h + FilterTileConfig::unroll_h - 1; + static int const smem_buff_h = FilterTileConfig::unroll_h; + static int const smem_load_h = smem_src_h + smem_buff_h; + static int const smem_h = smem_load_h + smem_buff_h; + static int const smem_w = OutTileConfig::block_w + + FilterTileConfig::unroll_w * ThreadConfig::thread_x - + 1; + static int const smem_size = smem_h * smem_w; + static int const load_w = + smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w; + static int const load_h = 1; + static int const reg_h = 1; + static int const reg_w = DIVUP(smem_w, load_w); + static bool constexpr check_bounds_h = smem_h % load_h != 0; + static bool constexpr check_bounds_w = smem_w % load_w != 0; + }; + + struct FilterTileCount { + static int const smem_flt_h = FilterTileConfig::unroll_h; + static int const smem_buff_h = FilterTileConfig::unroll_h; + static int const smem_load_h = smem_flt_h + smem_buff_h; + static int const smem_h = smem_load_h + smem_buff_h; + static int const smem_w = FilterTileConfig::unroll_w * ThreadConfig::thread_x; + static int const smem_size = smem_h * smem_w; + static int const load_w = smem_w > 32 ? 32 : smem_w; + static int const load_h = ThreadConfig::nr_threads / load_w; + static int const reg_h = 1; + static int const reg_w = DIVUP(smem_w, load_w); + static bool constexpr check_bounds_h = smem_h % load_h != 0; + static bool constexpr check_bounds_w = smem_w % load_w != 0; + }; + + using SrcGlobal2ShareVisitor = Global2SharedMem< + CompType, DepthwiseConv2dDirection::DIRECTION_FORWARD, ThreadConfig, + SrcTileCount>; + using FilterGlobal2ShareVisitor = + Global2SharedMem; +}; + +template < + typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_, + typename TileCount_> +__device__ __forceinline__ +Global2SharedMem::Global2SharedMem( + T* smem_, int stride_, int s_h, int s_w, int b_h, int b_w) + : smem(smem_), + stride(stride_), + start_h(s_h), + start_w(s_w), + bound_h(b_h), + bound_w(b_w), + ring_smem_h(TileCount::smem_load_h) { + if (is_fwd) { + ring_src_h = s_h + TileCount::smem_load_h; + w_offset = 0; + } else { + ring_src_h = s_h - 1; + w_offset = TileCount::smem_w - b_w; + } +} + +template < + typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_, + typename TileCount_> +__device__ __forceinline__ void Global2SharedMem< + T, kDirection, ThreadConfig_, TileCount_>::first_copy() { + static int const load_w = TileCount::smem_w > 32 ? 32 : TileCount::smem_w; + static int const load_h = ThreadConfig::nr_threads / load_w; + static int const h_per_thread = DIVUP(TileCount::smem_load_h, load_h); + static int const w_per_thread = DIVUP(TileCount::smem_w, load_w); + static bool constexpr check_bounds_h = TileCount::smem_load_h % load_h != 0; + static bool constexpr check_bounds_w = TileCount::smem_w % load_w != 0; + const int y_base_idx = tid / load_w; + const int x_base_idx = tid - y_base_idx * load_w; +#pragma unroll + for (int i = 0; i < h_per_thread; ++i) { + int smem_h_idx = y_base_idx + i * load_h; + int src_h_idx; + if (is_fwd) { + src_h_idx = start_h + smem_h_idx; + } else { + src_h_idx = start_h + TileCount::smem_load_h - smem_h_idx - 1; + } + if (check_bounds_h && smem_h_idx >= TileCount::smem_load_h) + continue; +#pragma unroll + for (int j = 0; j < w_per_thread; ++j) { + int smem_w_idx = x_base_idx + j * load_w; + int src_w_idx; + if (is_fwd) { + src_w_idx = start_w + smem_w_idx; + } else { + src_w_idx = start_w + TileCount::smem_w - w_offset - smem_w_idx - 1; + } + if (check_bounds_w && smem_w_idx >= TileCount::smem_w) + continue; + T val = 0.0f; + if (src_h_idx >= 0 && src_h_idx < bound_h && src_w_idx >= 0 && + src_w_idx < bound_w && + (is_fwd || (TileCount::smem_load_h - smem_h_idx - 1 >= 0 && + TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) { + val = g_ptr[src_h_idx * stride + src_w_idx]; + } + *(sh_ptr_as_copy_t(smem_h_idx, smem_w_idx)) = val; + } + } +} + +template < + typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_, + typename TileCount_> +__device__ __forceinline__ void Global2SharedMem< + T, kDirection, ThreadConfig_, TileCount_>::copy() { +#pragma unroll + for (int j = 0; j < TileCount::reg_w; ++j) { + int smem_w_idx = gl_load_x + j * TileCount::load_w; + int src_w_idx; + if (is_fwd) { + src_w_idx = start_w + smem_w_idx; + } else { + src_w_idx = start_w + TileCount::smem_w - w_offset - smem_w_idx - 1; + } + if (TileCount::check_bounds_w && smem_w_idx >= TileCount::smem_w) + continue; + T val = 0.0f; + if (ring_src_h >= 0 && ring_src_h < bound_h && src_w_idx >= 0 && + src_w_idx < bound_w && + (is_fwd || TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0)) { + val = g_ptr[ring_src_h * stride + src_w_idx]; + } + reg[j] = val; + } +} + +template < + typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_, + typename TileCount_> +__device__ __forceinline__ void Global2SharedMem< + T, kDirection, ThreadConfig_, TileCount_>::commit() { +#pragma unroll + for (int j = 0; j < TileCount::reg_w; ++j) { + int smem_w_idx = gl_load_x + j * TileCount::load_w; + + if (TileCount::check_bounds_w && smem_w_idx >= TileCount::smem_w) + continue; + + *(sh_ptr_as_copy_t(ring_smem_h, smem_w_idx)) = reg[j]; + } +} + +template < + typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_, + typename TileCount_> +__device__ __forceinline__ void Global2SharedMem< + T, kDirection, ThreadConfig_, TileCount_>::iter_forward() { + if (is_fwd) { + ring_src_h++; + } else { + ring_src_h--; + } + ring_smem_h = (ring_smem_h + 1) % TileCount::smem_h; +} + +// CUDA kernel to compute the depthwise convolution forward pass in NCHW format, +// tailored for small images up to 32x32. Stride and depth multiplier must be 1. +// Padding must be 'SAME', which allows to reuse the index computation. Only +// use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) returns true. +// Tiles of the input and filter tensors are loaded into shared memory before +// performing the convolution. Each thread handles two elements per iteration, +// one each in the lower and upper half of a tile. +// Backprop input direction is the same as forward direction with the filter +// rotated by 180°. +template +__global__ void DepthwiseConv2dGPUKernelNCHWSmall( + const Param param, const T* input, const T* filter, T* output) { + using ThreadConfig = typename ConvTrait::ThreadConfig; + using SrcTileConfig = typename ConvTrait::SrcTileConfig; + using FilterTileConfig = typename ConvTrait::FilterTileConfig; + using OutTileConfig = typename ConvTrait::OutTileConfig; + using SrcTileCount = typename ConvTrait::SrcTileCount; + using FilterTileCount = typename ConvTrait::FilterTileCount; + using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; + using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; + const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); + + int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, + off_oh = threadIdx.y, off_ow = threadIdx.x; + + extern __shared__ __align__(8) unsigned char smem[]; + static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); + T* smem_src = reinterpret_cast(smem); + T* smem_flt = reinterpret_cast(&smem_src[SrcTileCount::smem_size]); + + int off_ichannel = off_ochannel / param.chl_mul, + off_fchannel = off_ichannel % param.src_chl, + out_start_h = off_obh * OutTileConfig::block_h, + out_start_w = off_obw * OutTileConfig::block_w, + src_start_h = out_start_h - param.pad_h, + src_start_w = out_start_w - param.pad_w, + out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h; + + T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w; + T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w; + + T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w; + + SrcGlobal2ShareVisitor gl2sh_src( + smem_src, param.src_w, src_start_h, src_start_w, param.src_h, param.src_w); + + FilterGlobal2ShareVisitor gl2sh_flt = { + smem_flt, param.flt_w, is_fwd ? 0 : param.flt_h - 2, + 0, param.flt_h, param.flt_w}; + + gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w; + gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w; + + gl2sh_src.first_copy(); + gl2sh_flt.first_copy(); + + __syncthreads(); + + T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w], + reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w]; + + T sum[OutTileConfig::unroll_size] = {0.0}; + + for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) { + gl2sh_src.copy(); + gl2sh_flt.copy(); +#pragma unroll + for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) { +#pragma unroll + for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) { + reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr + [(off_oh + fh + s_h) % SrcTileCount::smem_h * + SrcTileCount::smem_w + + s_w]; + } + } + +#pragma unroll + for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) { +#pragma unroll + for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) { + reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr + [(fh + f_h) % FilterTileCount::smem_h * + FilterTileCount::smem_w + + f_w]; + } + } + +#pragma unroll + for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) { +#pragma unroll + for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) { +#pragma unroll + for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) { +#pragma unroll + for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) { + sum[oh * OutTileConfig::unroll_w + ow] += + reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] * + reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + + ow]; + } + } + } + } + + __syncthreads(); + gl2sh_src.commit(); + gl2sh_flt.commit(); + gl2sh_src.iter_forward(); + gl2sh_flt.iter_forward(); + __syncthreads(); + } + + for (int o = 0; o < OutTileConfig::unroll_size; ++o) { + for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) { + sum[o] += __shfl_xor(sum[o], i, 32); + } + } + + if (threadIdx.x == 0) { +#pragma unroll + for (int i = 0; i < OutTileConfig::unroll_h; ++i) { + int out_h_idx = out_base_h_idx + i; + if (out_h_idx < param.out_h) { +#pragma unroll + for (int j = 0; j < OutTileConfig::unroll_w; ++j) { + int out_w_idx = out_start_w + j; + if (out_w_idx >= param.out_w) + return; + out_base_ptr[out_h_idx * param.out_w + out_w_idx] = + sum[i * OutTileConfig::unroll_w + j]; + } + } + } + } +} + +template < + typename T, typename T2, DepthwiseConv2dDirection kDirection, int unroll_fw, + int unroll_ow> +void LaunchDepthwiseConv2dGPUSmall( + const Param& param, const T* input, const T* filter, T* output, + cudaStream_t stream) { + static int const unroll_oh = 1, unroll_fh = 1; + + using FilterTileConfig = FilterTileConfig; + using ThreadConfig = ThreadConfig<4, 32>; + using OutTileConfig = OutTileConfig; + using IConvTrait = + ConvTrait; + using SrcTileCount = typename IConvTrait::SrcTileCount; + using FilterTileCount = typename IConvTrait::FilterTileCount; + + dim3 block(ThreadConfig::thread_x, ThreadConfig::thread_y); + dim3 grid; + grid.x = param.batch * param.src_chl * param.chl_mul; + grid.y = DIVUP(param.out_w, OutTileConfig::block_w); + grid.z = DIVUP(param.out_h, OutTileConfig::block_h); + const int shared_storage = + (SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T); + + void (*kernel)(const Param, const T*, const T*, T*); + kernel = DepthwiseConv2dGPUKernelNCHWSmall; + kernel<<>>(param, input, filter, output); + after_kernel_launch(); +} + +#define INSTANCE_AB(a, b, direction) \ + if (param.out_w > b * 4) { \ + LaunchDepthwiseConv2dGPUSmall( \ + param, src, flt, dst, stream); \ + } + +#define INSTANCE_A(a, direction) \ + if (param.flt_w > 0) { \ + INSTANCE_AB(a, 15, direction) \ + else INSTANCE_AB(a, 14, direction) else INSTANCE_AB(a, 13, direction) else INSTANCE_AB( \ + a, 12, direction) else INSTANCE_AB(a, 11, direction) else INSTANCE_AB(a, 10, direction) else INSTANCE_AB(a, 9, direction) else INSTANCE_AB(a, 8, direction) else INSTANCE_AB(a, 7, direction) else INSTANCE_AB(a, 6, direction) else INSTANCE_AB(a, 5, direction) else INSTANCE_AB(a, 4, direction) else INSTANCE_AB(a, 3, direction) else INSTANCE_AB(a, 2, direction) else INSTANCE_AB(a, 1, direction) else INSTANCE_AB(a, 0, direction) \ + } + +#define INSTANCE(direction) \ + INSTANCE_A(7, direction) \ + else INSTANCE_A(6, direction) else INSTANCE_A(5, direction) else INSTANCE_A(4, direction) else INSTANCE_A( \ + 3, \ + direction) else INSTANCE_A(2, direction) else INSTANCE_A(1, direction) else INSTANCE_A(0, direction) + +} // anonymous namespace diff --git a/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu b/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu new file mode 100644 index 0000000000000000000000000000000000000000..35c142ee9d37ddcae74898ce9a50b7ced0b9f0ca --- /dev/null +++ b/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu @@ -0,0 +1,48 @@ +/** + * \file dnn/src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "cuda.h" +#include "cuda_fp16.h" +// #include "src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cuh" +#include "src/cuda/conv_bias/chanwise/kern.cuh" +#include "src/cuda/conv_bias/chanwise/kern_helper.cuh" +#include "src/cuda/conv_bias/chanwise/launch_config.cuh" +#include "src/cuda/fp16_help.cuh" + +using namespace megdnn; +using namespace cuda; +using namespace conv_bias; +using namespace chanwise; + +#include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl" + +namespace megdnn { +namespace cuda { +namespace conv_bias { +namespace chanwise { + +// =====================================fwd===================================== + +#define check + +template <> +void run_fwd_depthwise_large_filter( + float* dst, const float* src, const float* flt, const Param& param, + cudaStream_t stream) { + INSTANCE(DepthwiseConv2dDirection::DIRECTION_FORWARD) +} + +} // namespace chanwise +} // namespace conv_bias +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/chanwise/kern.cuh b/dnn/src/cuda/conv_bias/chanwise/kern.cuh index 4b5a60d044e7bedd51ac457193dc76f1f216e12b..b346291c1e81c237ea62bd5bdffb4279432740ee 100644 --- a/dnn/src/cuda/conv_bias/chanwise/kern.cuh +++ b/dnn/src/cuda/conv_bias/chanwise/kern.cuh @@ -61,6 +61,10 @@ template void run_fwd_small( T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream); +template +void run_fwd_depthwise_large_filter( + T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream); + // implemented in fwd_8x8x32.cu void run_fwd_8x8x32( int32_t* dst, const int8_t* src, const int8_t* flt, const Param& param, diff --git a/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp b/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..766254b8a279c74a408a301bc1ebdad8d0794269 --- /dev/null +++ b/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp @@ -0,0 +1,109 @@ +/** + * \file dnn/src/cuda/conv_bias/depthwise_large_filter.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/common/conv_bias.h" +#include "src/cuda/conv_bias/algo.h" +#include "src/cuda/conv_bias/chanwise/kern.cuh" +#include "src/cuda/utils.h" + +using namespace megdnn; +using namespace cuda; +using namespace conv_bias; + +namespace { +inline bool is_available_depthwise_large_filter(const chanwise::Param& param) { + auto&& device_prop = cuda::current_device_prop(); + int flt_smem_w = (param.flt_w + 3) / 4 * 4; + int flt_smem_h = 3; + int flt_reg_per_thread = + flt_smem_w > 32 ? (flt_smem_w + 31) / 32 : 1 + flt_smem_w / 4; + int ow = param.out_w > 64 ? 64 : param.out_w; + int src_smem_w = ow + flt_smem_w - 1; + int src_smem_h = flt_smem_h + param.flt_h - 1; + int src_reg_per_thread = src_smem_w > 128 ? (flt_smem_w + 127) / 128 + : 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1; + int out_reg_per_thread = (ow + 3) / 4 * 4; + if (device_prop.regsPerBlock < 4 * 32 * + (flt_reg_per_thread + src_reg_per_thread + + out_reg_per_thread) || + device_prop.sharedMemPerBlock < + static_cast( + flt_smem_w * flt_smem_h + src_smem_w * src_smem_h)) { + return false; + } + return param.stride_h == 1 && param.stride_w == 1 && param.src_h == param.out_h && + param.src_w == param.out_w; +} +} // anonymous namespace + +bool ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::is_available( + const SizeArgs& args) const { + if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { + return false; + } + if (args.src_layout->dtype != args.filter_layout->dtype && + args.src_layout->dtype != dtype::Float32()) { + return false; + } + if (args.z_layout->ndim > 0) + return false; + + auto param = chanwise::Param::from_fwd_args(args); + auto&& fm = args.filter_meta; + return fm.group > 1 && args.filter_meta.format == Param::Format::NCHW && + args.src_layout->dtype.category() == DTypeCategory::FLOAT && + args.opr->param().compute_mode == Param::ComputeMode::DEFAULT && + fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && !fm.should_flip && + is_available_depthwise_large_filter(param); +} + +size_t ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::get_workspace_in_bytes( + const SizeArgs& args) const { + auto dst_layout = *args.dst_layout; + if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { + dst_layout.dtype = DType(); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); + return dst_layout.span().dist_byte(); + } + return 0; +} + +void ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::exec(const ExecArgs& args) const { + WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; + TensorND conv_dst_tensor = *args.dst_tensor; + if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { + conv_dst_tensor = TensorND{bundle.get(0), conv_dst_tensor.layout}; + conv_dst_tensor.layout.dtype = DType(); + args.opr->check_or_deduce_dtype_fwd( + args.src_layout->dtype, args.filter_layout->dtype, + conv_dst_tensor.layout.dtype); + } + { + auto kparam = chanwise::Param::from_fwd_args(args); + auto stream = cuda_stream(args.handle); + switch (args.src_layout->dtype.enumv()) { + case DTypeEnum::Float32: + chanwise::run_fwd_depthwise_large_filter( + conv_dst_tensor.ptr(), args.src_tensor->ptr(), + args.filter_tensor->ptr(), kparam, stream); + break; + default: + megdnn_assert_internal(0); + } + } + handle_bias_and_nonlinear( + args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, + args.bias_tensor); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/conv_bias/opr_impl.h b/dnn/src/cuda/conv_bias/opr_impl.h index a7184243da38653d49896b09d64fa03951b4f293..e1a30eccb0a17295cef344649993c32caefa8c40 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.h +++ b/dnn/src/cuda/conv_bias/opr_impl.h @@ -45,6 +45,7 @@ public: class AlgoCUDNNConvBiasActivation; class AlgoChanwise; class AlgoChanwiseSmall; + class AlgoDepthwiseLargeFilter; class AlgoChanwise8x8x32; class AlgoCUDNNConv; class AlgoFallbackNCHWQS8; diff --git a/dnn/test/cuda/conv_bias.cpp b/dnn/test/cuda/conv_bias.cpp index d4eff9adf2d6e60c0c8cb0676a439d2e94a5b2be..eceecc5e0ec484513f7c8b9790df29e3fe6d3850 100644 --- a/dnn/test/cuda/conv_bias.cpp +++ b/dnn/test/cuda/conv_bias.cpp @@ -695,6 +695,59 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE_SMALL) { } } +TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { + Checker checker(handle_cuda()); + checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker( + ConvBiasForward::algo_name( + "DEPTHWISE_LARGE_FILTER", {}) + .c_str())); + auto run = [&checker](size_t n, size_t g, size_t h, size_t fh) { + param::ConvBias cur_param; + cur_param.mode = param::ConvBias::Mode::CROSS_CORRELATION; + cur_param.sparse = ConvBias::Param::Sparse::GROUP; + checker.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_dtype(3, dtype::Float32()) + .set_dtype(4, dtype::Float32()); + + cur_param.pad_h = cur_param.pad_w = fh / 2; + cur_param.stride_h = cur_param.stride_w = 1; + checker.set_param(cur_param).execs( + {{n, g, h, h}, {g, 1, 1, fh, fh}, {}, {}, {}}); + }; + run(4, 8, 32, 5); + run(4, 8, 32, 7); + run(4, 8, 32, 9); + run(4, 8, 32, 11); + run(4, 8, 32, 13); + run(4, 8, 32, 15); + run(4, 8, 32, 17); + run(4, 8, 32, 19); + run(4, 8, 32, 21); + run(4, 8, 32, 23); + run(4, 8, 32, 25); + run(4, 8, 32, 27); + run(4, 8, 32, 29); + run(4, 8, 32, 31); + run(4, 8, 64, 5); + run(4, 8, 64, 7); + run(4, 8, 64, 9); + run(4, 8, 64, 11); + run(4, 8, 64, 13); + run(4, 8, 64, 15); + run(4, 8, 64, 17); + run(4, 8, 64, 19); + run(4, 8, 64, 21); + run(4, 8, 64, 23); + run(4, 8, 64, 25); + run(4, 8, 64, 27); + run(4, 8, 64, 29); + run(4, 8, 64, 31); + run(1, 2, 128, 31); + run(1, 2, 256, 31); +} + TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE_8x8x32) { require_compute_capability(6, 1); Checker checker(handle_cuda()); @@ -1474,6 +1527,69 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_TENSORCORE_INT8) { run_bench(256, 512, 7, 7, 512, 3, 3, 1, 1, 1000); run_bench(256, 512, 7, 7, 2048, 1, 1, 1, 1, 1000); } + +TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { + require_compute_capability(7, 5); + Benchmarker bencher(handle_cuda()); + bencher.set_display(false); + bencher.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker( + ConvBiasForward::algo_name( + "DEPTHWISE_LARGE_FILTER", {}) + .c_str())); + + ConvBias::Param param; + param.format = ConvBias::Param::Format::NCHW; + + using NonlineMode = ConvBias::Param::NonlineMode; + param.nonlineMode = NonlineMode::IDENTITY; + param.sparse = ConvBias::Param::Sparse::GROUP; + auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh, + size_t fw, size_t sh, size_t sw, size_t nr_times) { + param.pad_h = fh / 2; + param.pad_w = fw / 2; + param.stride_h = sh; + param.stride_w = sw; + + bencher.set_param(param) + .set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()) + .set_dtype(4, dtype::Float32()); + bencher.set_times(nr_times); + size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h); + size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w); + TensorShape inp{batch, g, hi, wi}, kern{g, 1, 1, fh, fw}, out{batch, g, ho, wo}; + + float bandwith = static_cast( + inp.total_nr_elems() + kern.total_nr_elems() + + out.total_nr_elems()) / + (1024 * 1024 * 1024) * 1e3; + + auto time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times; + auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12; + printf("chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, time: " + "%.2fms, " + "perf: %.2f Tops bandwidth: %.2fGB/s.\n", + inp.to_string().c_str(), kern.to_string().c_str(), + out.to_string().c_str(), time_in_ms, ops, bandwith * 4 / time_in_ms); + }; + + run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10); + run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10); + run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10); + run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10); + run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10); + run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10); + run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10); + run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10); + run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10); + run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10); + run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10); + run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10); + run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10); + run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10); + run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10); +} #endif #endif diff --git a/dnn/test/cuda/convolution.cpp b/dnn/test/cuda/convolution.cpp index e3419a062098f6afb2d8c6ea0abd4bd4954d8e59..12f3ec9f2a1471e5aecc60358ddfab88d82c4569 100644 --- a/dnn/test/cuda/convolution.cpp +++ b/dnn/test/cuda/convolution.cpp @@ -901,6 +901,43 @@ TEST_F(CUDA, CONVOLUTION_BWD_DATA_BENCHMARK) { run(32, 64, 64, 56, 56, 1, 1, 0); } +TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_CHANWISE_SMALL_FEAT_LARGE_FILTER) { + CUBenchmarker bench{handle_cuda()}; + std::unique_ptr> proxy{ + new OprProxy{true}}; + size_t RUNS = 10; + bench.set_proxy(proxy).set_times(RUNS); + + auto run = [&](size_t N, size_t OC, size_t g, size_t IH, size_t IW, size_t FH, + size_t SH, size_t PH) { + bench.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()); + param::Convolution param; + param.stride_h = param.stride_w = SH; + param.pad_h = param.pad_w = FH / 2; + param.sparse = param::Convolution::Sparse::GROUP; + bench.set_param(param); + bench.proxy()->target_execution_policy.algo.reset(); + TensorLayout src{{N, g, IH, IW}, dtype::Float32()}, + filter{{g, 1, 1, FH, FH}, dtype::Float32()}; + TensorLayout dst; + { + auto&& opr = handle_cuda()->create_operator(); + opr->param() = param; + opr->deduce_layout(src, filter, dst); + } + auto time_ms_fp32 = bench.execl({filter, dst, src}) / RUNS; + float flo = 2.0 * N * g * dst[2] * dst[3] * FH * FH; + printf("inp=%s, kern=%s, dst=%s ", src.to_string().c_str(), + filter.to_string().c_str(), dst.to_string().c_str()); + printf("time_fp32=%.2fms, flops=%.3fTFLOPS\n", time_ms_fp32, + (flo / (time_ms_fp32 * 1e9))); + }; + + run(64, 384, 384, 32, 32, 31, 1, 15); +} + TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_BF16) { CUBenchmarker bench{handle_cuda()}; std::unique_ptr> proxy{ @@ -1065,6 +1102,46 @@ TEST_F(CUDA, CONVOLUTION_BWD_FILTER_BENCHMARK) { run(32, 512, 1024, 14, 14, 1, 2, 0); run(32, 64, 64, 56, 56, 1, 1, 0); } + +TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_FILTER_CHANWISE_SMALL_FEAT_LARGE_FILTER) { + CUBenchmarker bench{handle_cuda()}; + std::unique_ptr> proxy{ + new OprProxy{true}}; + size_t RUNS = 10; + bench.set_proxy(proxy).set_times(RUNS); + + bench.set_before_exec_callback(AlgoChecker( + "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFTv7.6.3")); + + auto run = [&](size_t N, size_t OC, size_t g, size_t IH, size_t IW, size_t FH, + size_t SH, size_t PH) { + bench.set_dtype(0, dtype::Float32()) + .set_dtype(1, dtype::Float32()) + .set_dtype(2, dtype::Float32()); + param::Convolution param; + param.stride_h = param.stride_w = SH; + param.pad_h = param.pad_w = FH / 2; + param.sparse = param::Convolution::Sparse::GROUP; + bench.set_param(param); + bench.proxy()->target_execution_policy.algo.reset(); + TensorLayout src{{N, g, IH, IW}, dtype::Float32()}, + filter{{g, 1, 1, FH, FH}, dtype::Float32()}; + TensorLayout dst; + { + auto&& opr = handle_cuda()->create_operator(); + opr->param() = param; + opr->deduce_layout(src, filter, dst); + } + auto time_ms_fp32 = bench.execl({src, dst, filter}) / RUNS; + float flo = 2.0 * N * g * dst[2] * dst[3] * FH * FH; + printf("inp=%s, kern=%s, dst=%s ", src.to_string().c_str(), + filter.to_string().c_str(), dst.to_string().c_str()); + printf("time_fp32=%.2fms, flops=%.3fTFLOPS\n", time_ms_fp32, + (flo / (time_ms_fp32 * 1e9))); + }; + run(64, 384, 384, 32, 32, 31, 1, 15); +} + #endif #undef CUDNN_VERSION_STRING