diff --git a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh index bc9d13ca1d1e19e617256af9e7999cdb67f52b70..2de44d5d76fe3db32104e0f7c13a61bb5ec49d3d 100644 --- a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh +++ b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh @@ -142,7 +142,7 @@ struct ConvTraitInner { } #define CHECK_AB_BWD(a, b) \ - if (param.out_w > b * 4) { \ + if (param.out_w > b * 4 || b == 3) { \ using FilterTileConfig_ = FilterTileConfig; \ using ThreadConfig_ = ThreadConfig<4, 32>; \ using OutTileConfig_ = OutTileConfig; \ @@ -165,11 +165,9 @@ struct ConvTraitInner { return true; \ } -#define CHECK_A(a, cb) \ - if (param.flt_w > a * 4) { \ - CHECK_AB_##cb( \ - a, \ - 15) else CHECK_AB_##cb(a, 14) else CHECK_AB_##cb(a, 13) else CHECK_AB_##cb(a, 12) else CHECK_AB_##cb(a, 11) else CHECK_AB_##cb(a, 10) else CHECK_AB_##cb(a, 9) else CHECK_AB_##cb(a, 8) else CHECK_AB_##cb(a, 7) else CHECK_AB_##cb(a, 6) else CHECK_AB_##cb(a, 5) else CHECK_AB_##cb(a, 4) else CHECK_AB_##cb(a, 3) else CHECK_AB_##cb(a, 2) else CHECK_AB_##cb(a, 1) else CHECK_AB_##cb(a, 0) \ +#define CHECK_A(a, cb) \ + if (param.flt_w > a * 4) { \ + CHECK_AB_##cb(a, 15) else CHECK_AB_##cb(a, 7) else CHECK_AB_##cb(a, 3) \ } #define CHECK(cb) \ diff --git a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh index c3d936bb08be43c834549738cdc1eee96f3c4d89..e434b3cfff9ec39eac957a558b9db36f2acfcdbc 100644 --- a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh +++ b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh @@ -217,7 +217,7 @@ __device__ __forceinline__ void Global2SharedMem< // Backprop input direction is the same as forward direction with the filter // rotated by 180°. #if CUDA_VERSION >= 9000 -template +template __global__ void DepthwiseConv2dGPUKernelNCHW( const Param param, const __half* input, const __half* filter, __half* output) { using T = __half; @@ -230,7 +230,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( using FilterTileCount = typename ConvTrait::FilterTileCount; using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; - const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); + constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, off_oh = threadIdx.y, off_ow = threadIdx.x; @@ -243,8 +243,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); T* smem_src = reinterpret_cast(smem); T* smem_flt = reinterpret_cast(&smem_src[SrcTileCount::smem_size]); - int stride_h = is_fwd ? param.stride_h : 1; - int stride_w = is_fwd ? param.stride_w : 1; + constexpr int stride_h = is_fwd ? stride : 1; + constexpr int stride_w = is_fwd ? stride : 1; int off_ichannel = off_ochannel / param.chl_mul, off_fchannel = off_ichannel % param.src_chl, @@ -385,7 +385,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( } } -template +template __global__ void DepthwiseConv2dGPUKernelNCHWC32( const Param param, const __half* input, const __half* filter, __half* output) { using T = __half; @@ -398,7 +398,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( using FilterTileCount = typename ConvTrait::FilterTileCount; using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; - const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); + constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, off_oh = threadIdx.y, off_ow = threadIdx.x; @@ -411,8 +411,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); T* smem_src = reinterpret_cast(smem); T* smem_flt = reinterpret_cast(&smem_src[SrcTileCount::smem_size]); - int stride_h = is_fwd ? param.stride_h : 1; - int stride_w = is_fwd ? param.stride_w : 1; + constexpr int stride_h = is_fwd ? stride : 1; + constexpr int stride_w = is_fwd ? stride : 1; int off_ichannel = off_ochannel / param.chl_mul, off_fchannel = off_ichannel % param.src_chl, @@ -555,7 +555,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( } #endif -template +template __global__ void DepthwiseConv2dGPUKernelNCHW( const Param param, const float* input, const float* filter, float* output) { using T = float; @@ -568,7 +568,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( using FilterTileCount = typename ConvTrait::FilterTileCount; using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; - const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); + constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, off_oh = threadIdx.y, off_ow = threadIdx.x; @@ -577,8 +577,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); T* smem_src = reinterpret_cast(smem); T* smem_flt = reinterpret_cast(&smem_src[SrcTileCount::smem_size]); - int stride_h = is_fwd ? param.stride_h : 1; - int stride_w = is_fwd ? param.stride_w : 1; + constexpr int stride_h = is_fwd ? stride : 1; + constexpr int stride_w = is_fwd ? stride : 1; int off_ichannel = off_ochannel / param.chl_mul, off_fchannel = off_ichannel % param.src_chl, @@ -703,7 +703,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( } } -template +template __global__ void DepthwiseConv2dGPUKernelNCHWC32( const Param param, const float* input, const float* filter, float* output) { using T = float; @@ -716,7 +716,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( using FilterTileCount = typename ConvTrait::FilterTileCount; using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor; using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor; - const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); + constexpr bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD); int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, off_oh = threadIdx.y, off_ow = threadIdx.x; @@ -725,8 +725,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32( static_assert(sizeof(T) <= 8, "Insufficient alignment detected"); T* smem_src = reinterpret_cast(smem); T* smem_flt = reinterpret_cast(&smem_src[SrcTileCount::smem_size]); - int stride_h = is_fwd ? param.stride_h : 1; - int stride_w = is_fwd ? param.stride_w : 1; + constexpr int stride_h = is_fwd ? stride : 1; + constexpr int stride_w = is_fwd ? stride : 1; int off_ichannel = off_ochannel / param.chl_mul, off_fchannel = off_ichannel % param.src_chl, @@ -879,16 +879,16 @@ void LaunchDepthwiseConv2dGPU( void (*kernel)(const Param, const T*, const T*, T*); if (param.is_compute_deafult) { - kernel = DepthwiseConv2dGPUKernelNCHW; + kernel = DepthwiseConv2dGPUKernelNCHW; } else { - kernel = DepthwiseConv2dGPUKernelNCHWC32; + kernel = DepthwiseConv2dGPUKernelNCHWC32; } kernel<<>>(param, input, filter, output); after_kernel_launch(); } #define INSTANCE_AB(type1, type2, a, b, direction) \ - if (param.out_w > b * 4) { \ + if (param.out_w > b * 4 || b == 3) { \ if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \ (param.stride_h == 1 && param.stride_w == 1)) { \ LaunchDepthwiseConv2dGPU( \ @@ -899,12 +899,11 @@ void LaunchDepthwiseConv2dGPU( } \ } -#define INSTANCE_A(type1, type2, a, direction) \ - if (param.flt_w > a * 4) { \ - INSTANCE_AB(type1, type2, a, 15, direction) \ - else INSTANCE_AB(type1, type2, a, 14, direction) else INSTANCE_AB(type1, type2, a, 13, direction) else INSTANCE_AB(type1, type2, a, 12, direction) else INSTANCE_AB(type1, type2, a, 11, direction) else INSTANCE_AB(type1, type2, a, 10, direction) else INSTANCE_AB( \ - type1, type2, \ - a, 9, direction) else INSTANCE_AB(type1, type2, a, 8, direction) else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB(type1, type2, a, 6, direction) else INSTANCE_AB(type1, type2, a, 5, direction) else INSTANCE_AB(type1, type2, a, 4, direction) else INSTANCE_AB(type1, type2, a, 3, direction) else INSTANCE_AB(type1, type2, a, 2, direction) else INSTANCE_AB(type1, type2, a, 1, direction) else INSTANCE_AB(type1, type2, a, 0, direction) \ +#define INSTANCE_A(type1, type2, a, direction) \ + if (param.flt_w > a * 4) { \ + INSTANCE_AB(type1, type2, a, 15, direction) \ + else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB( \ + type1, type2, a, 3, direction) \ } #define INSTANCE(type1, type2, direction) \