From 472e2f96556f531a35c41cddf8c9b5847ea0def0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 21 Feb 2022 21:54:44 +0800 Subject: [PATCH] refactor(cuda): depthwish large kernel GitOrigin-RevId: dade8710b4122dce27dcebfd74491e810288a160 --- .../chanwise/depthwise_large_filter.cuh | 179 ++++++++++++++++++ ...go.inl => depthwise_large_filter_algo.cuh} | 145 ++++---------- .../conv_bias/chanwise/fwd_large_filter.cu | 2 +- .../cuda/conv_bias/depthwise_large_filter.cpp | 30 +-- dnn/src/cuda/convolution/backward_data/algo.h | 2 +- .../backward_data/depthwise_large_filter.cpp | 41 ++-- .../convolution/chanwise/bwd_large_filter.cu | 4 +- dnn/test/cuda/convolution.cpp | 58 +++--- 8 files changed, 272 insertions(+), 189 deletions(-) create mode 100644 dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh rename dnn/src/cuda/conv_bias/chanwise/{depthwise_large_filter_algo.inl => depthwise_large_filter_algo.cuh} (83%) diff --git a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh new file mode 100644 index 000000000..bc9d13ca1 --- /dev/null +++ b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh @@ -0,0 +1,179 @@ +/** + * \file dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +namespace { +#define DIVUP(x, y) (((x) + (y)-1) / (y)) +enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD }; + +template +struct OutTileConfig { + using ThreadConfig = ThreadConfig_; + static int const unroll_h = oh_; + static int const unroll_w = ThreadConfig::thread_x * ow_; + static int const unroll_size = unroll_h * unroll_w; + static int const block_h = unroll_h * ThreadConfig::thread_y; + static int const block_w = unroll_w; +}; + +template +struct FilterTileConfig { + static int const unroll_h = fh_; + static int const unroll_w = fw_; + static int const unroll_size = unroll_h * unroll_w; +}; + +template +struct ThreadConfig { + static int const thread_x = x_; + static_assert((thread_x & (thread_x - 1)) == 0, "thread_x must be pow of 2!"); + static int const thread_y = y_; + static int const nr_threads = x_ * y_; +}; + +template < + typename ldg_dtype, typename ThreadConfig_, typename OutTileConfig_, + typename FilterTileConfig_, int stride_w, int stride_h> +struct ConvTraitInner { + using ThreadConfig = ThreadConfig_; + using OutTileConfig = OutTileConfig_; + using FilterTileConfig = FilterTileConfig_; + using CompType = ldg_dtype; + + struct SrcTileConfig { + static int const unroll_h = + OutTileConfig::unroll_h + FilterTileConfig::unroll_h - 1; + static int const unroll_w = + (OutTileConfig::unroll_w - 1) * stride_w + FilterTileConfig::unroll_w; + static int const unroll_size = unroll_h * unroll_w; + }; + + struct SrcTileCount { + static int const smem_src_h = + (OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h; + static int const smem_buff_h = FilterTileConfig::unroll_h; + static int const smem_load_h = smem_src_h + smem_buff_h; + static int const smem_h = smem_load_h + smem_buff_h; + static int const smem_w = + DIVUP((OutTileConfig::block_w - 1) * stride_w + + FilterTileConfig::unroll_w * ThreadConfig::thread_x, + 2) * + 2; + static int const smem_size = smem_h * smem_w; + static int const load_w = + smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w; + static int const load_h = 1; + static int const reg_h = 1; + static int const reg_w = DIVUP(smem_w, load_w); + static bool constexpr check_bounds_h = smem_h % load_h != 0; + static bool constexpr check_bounds_w = smem_w % load_w != 0; + }; + + struct FilterTileCount { + static int const smem_flt_h = FilterTileConfig::unroll_h; + static int const smem_buff_h = FilterTileConfig::unroll_h; + static int const smem_load_h = smem_flt_h + smem_buff_h; + static int const smem_h = smem_load_h + smem_buff_h; + static int const smem_w = FilterTileConfig::unroll_w * ThreadConfig::thread_x; + static int const smem_size = smem_h * smem_w; + static int const load_w = smem_w > 32 ? 32 : smem_w; + static int const load_h = ThreadConfig::nr_threads / load_w; + static int const reg_h = 1; + static int const reg_w = DIVUP(smem_w, load_w); + static bool constexpr check_bounds_h = smem_h % load_h != 0; + static bool constexpr check_bounds_w = smem_w % load_w != 0; + }; +}; + +#define CHECK_AB_FWD(a, b) \ + if (param.out_w > b * 4) { \ + if (param.stride_h == 1 && param.stride_w == 1) { \ + using FilterTileConfig_ = FilterTileConfig; \ + using ThreadConfig_ = ThreadConfig<4, 32>; \ + using OutTileConfig_ = OutTileConfig; \ + using IConvTrait = ConvTraitInner< \ + float, ThreadConfig_, OutTileConfig_, FilterTileConfig_, 1, 1>; \ + using SrcTileConfig = typename IConvTrait::SrcTileConfig; \ + using SrcTileCount = typename IConvTrait::SrcTileCount; \ + using FilterTileCount = typename IConvTrait::FilterTileCount; \ + \ + if (device_prop.regsPerBlock < \ + 4 * 32 * \ + (FilterTileConfig_::unroll_h * \ + FilterTileConfig_::unroll_w * 2 + \ + SrcTileConfig::unroll_h * SrcTileConfig::unroll_w) || \ + device_prop.sharedMemPerBlock < \ + static_cast( \ + (SrcTileCount::smem_size + \ + FilterTileCount::smem_size))) { \ + return false; \ + } \ + return true; \ + } else if (param.stride_h == 2 && param.stride_w == 2) { \ + using FilterTileConfig_ = FilterTileConfig; \ + using ThreadConfig_ = ThreadConfig<4, 32>; \ + using OutTileConfig_ = OutTileConfig; \ + using IConvTrait = ConvTraitInner< \ + float, ThreadConfig_, OutTileConfig_, FilterTileConfig_, 2, 2>; \ + using SrcTileConfig = typename IConvTrait::SrcTileConfig; \ + using SrcTileCount = typename IConvTrait::SrcTileCount; \ + using FilterTileCount = typename IConvTrait::FilterTileCount; \ + \ + if (device_prop.regsPerBlock < \ + 4 * 32 * \ + (FilterTileConfig_::unroll_h * \ + FilterTileConfig_::unroll_w * 2 + \ + SrcTileConfig::unroll_h * SrcTileConfig::unroll_w) || \ + device_prop.sharedMemPerBlock < \ + static_cast( \ + (SrcTileCount::smem_size + \ + FilterTileCount::smem_size))) { \ + return false; \ + } \ + return true; \ + } \ + } + +#define CHECK_AB_BWD(a, b) \ + if (param.out_w > b * 4) { \ + using FilterTileConfig_ = FilterTileConfig; \ + using ThreadConfig_ = ThreadConfig<4, 32>; \ + using OutTileConfig_ = OutTileConfig; \ + using IConvTrait = ConvTraitInner< \ + float, ThreadConfig_, OutTileConfig_, FilterTileConfig_, 1, 1>; \ + using SrcTileConfig = typename IConvTrait::SrcTileConfig; \ + using SrcTileCount = typename IConvTrait::SrcTileCount; \ + using FilterTileCount = typename IConvTrait::FilterTileCount; \ + \ + if (device_prop.regsPerBlock < \ + 4 * 32 * \ + (FilterTileConfig_::unroll_h * \ + FilterTileConfig_::unroll_w * 2 + \ + SrcTileConfig::unroll_h * SrcTileConfig::unroll_w) || \ + device_prop.sharedMemPerBlock < \ + static_cast( \ + (SrcTileCount::smem_size + FilterTileCount::smem_size))) { \ + return false; \ + } \ + return true; \ + } + +#define CHECK_A(a, cb) \ + if (param.flt_w > a * 4) { \ + CHECK_AB_##cb( \ + a, \ + 15) else CHECK_AB_##cb(a, 14) else CHECK_AB_##cb(a, 13) else CHECK_AB_##cb(a, 12) else CHECK_AB_##cb(a, 11) else CHECK_AB_##cb(a, 10) else CHECK_AB_##cb(a, 9) else CHECK_AB_##cb(a, 8) else CHECK_AB_##cb(a, 7) else CHECK_AB_##cb(a, 6) else CHECK_AB_##cb(a, 5) else CHECK_AB_##cb(a, 4) else CHECK_AB_##cb(a, 3) else CHECK_AB_##cb(a, 2) else CHECK_AB_##cb(a, 1) else CHECK_AB_##cb(a, 0) \ + } + +#define CHECK(cb) \ + CHECK_A(6, cb) \ + else CHECK_A(4, cb) else CHECK_A(2, cb) else CHECK_A(0, cb) + +} // namespace diff --git a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh similarity index 83% rename from dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl rename to dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh index b9aa47ad3..a6df314b5 100644 --- a/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl +++ b/dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh @@ -1,5 +1,5 @@ /** - * \file dnn/src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.inl + * \file dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -9,35 +9,10 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once +#include "depthwise_large_filter.cuh" #include "src/cuda/cuda_shfl_compat.cuh" -namespace { - -enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD }; - -template -struct OutTileConfig { - using ThreadConfig = ThreadConfig_; - static int const unroll_h = oh_; - static int const unroll_w = ThreadConfig::thread_x * ow_; - static int const unroll_size = unroll_h * unroll_w; - static int const block_h = unroll_h * ThreadConfig::thread_y; - static int const block_w = unroll_w; -}; -template -struct FilterTileConfig { - static int const unroll_h = fh_; - static int const unroll_w = fw_; - static int const unroll_size = unroll_h * unroll_w; -}; - -template -struct ThreadConfig { - static int const thread_x = x_; - static_assert((thread_x & (thread_x - 1)) == 0, "thread_x must be pow of 2!"); - static int const thread_y = y_; - static int const nr_threads = x_ * y_; -}; +namespace { template < typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_, @@ -87,49 +62,12 @@ struct ConvTrait { using FilterTileConfig = FilterTileConfig_; using CompType = ldg_dtype; - struct SrcTileConfig { - static int const unroll_h = - OutTileConfig::unroll_h + FilterTileConfig::unroll_h - 1; - static int const unroll_w = - (OutTileConfig::unroll_w - 1) * stride_w + FilterTileConfig::unroll_w; - static int const unroll_size = unroll_h * unroll_w; - }; - - struct SrcTileCount { - static int const smem_src_h = - (OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h; - static int const smem_buff_h = FilterTileConfig::unroll_h; - static int const smem_load_h = smem_src_h + smem_buff_h; - static int const smem_h = smem_load_h + smem_buff_h; - static int const smem_w = - DIVUP((OutTileConfig::block_w - 1) * stride_w + - FilterTileConfig::unroll_w * ThreadConfig::thread_x, - 2) * - 2; - static int const smem_size = smem_h * smem_w; - static int const load_w = - smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w; - static int const load_h = 1; - static int const reg_h = 1; - static int const reg_w = DIVUP(smem_w, load_w); - static bool constexpr check_bounds_h = smem_h % load_h != 0; - static bool constexpr check_bounds_w = smem_w % load_w != 0; - }; - - struct FilterTileCount { - static int const smem_flt_h = FilterTileConfig::unroll_h; - static int const smem_buff_h = FilterTileConfig::unroll_h; - static int const smem_load_h = smem_flt_h + smem_buff_h; - static int const smem_h = smem_load_h + smem_buff_h; - static int const smem_w = FilterTileConfig::unroll_w * ThreadConfig::thread_x; - static int const smem_size = smem_h * smem_w; - static int const load_w = smem_w > 32 ? 32 : smem_w; - static int const load_h = ThreadConfig::nr_threads / load_w; - static int const reg_h = 1; - static int const reg_w = DIVUP(smem_w, load_w); - static bool constexpr check_bounds_h = smem_h % load_h != 0; - static bool constexpr check_bounds_w = smem_w % load_w != 0; - }; + using CI = ConvTraitInner< + ldg_dtype, ThreadConfig_, OutTileConfig_, FilterTileConfig_, stride_w, + stride_h>; + using SrcTileConfig = typename CI::SrcTileConfig; + using SrcTileCount = typename CI::SrcTileCount; + using FilterTileCount = typename CI::FilterTileCount; using SrcGlobal2ShareVisitor = Global2SharedMem< CompType, DepthwiseConv2dDirection::DIRECTION_FORWARD, ThreadConfig, @@ -272,14 +210,15 @@ __device__ __forceinline__ void Global2SharedMem< // CUDA kernel to compute the depthwise convolution forward pass in NCHW format, // tailored for small images up to 32x32. Stride and depth multiplier must be 1. // Padding must be 'SAME', which allows to reuse the index computation. Only -// use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) returns true. +// use this kernel if CanLaunchDepthwiseConv2dGPU(args) returns true. // Tiles of the input and filter tensors are loaded into shared memory before // performing the convolution. Each thread handles two elements per iteration, // one each in the lower and upper half of a tile. // Backprop input direction is the same as forward direction with the filter // rotated by 180°. +#if CUDA_VERSION >= 9000 template -__global__ void DepthwiseConv2dGPUKernelNCHWSmall( +__global__ void DepthwiseConv2dGPUKernelNCHW( const Param param, const __half* input, const __half* filter, __half* output) { using T = __half; using T2 = __half2; @@ -380,16 +319,18 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( f_w * 2; reg_flt[0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast(smem_flt_ptr + flt_offset); - reg_flt[1][f_h * t2_flt_unroll_w + f_w] = { - f_w > 0 ? reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y - : static_cast(0.0), - reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; + if (f_w > 0) { + reg_flt[1][f_h * t2_flt_unroll_w + f_w] = { + reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y, + reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; + } else { + reg_flt[1][f_h * t2_flt_unroll_w + f_w] = { + 0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; + } } - reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = { - static_cast(0.0), static_cast(0.0)}; + reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {0.0, 0.0}; reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = { - reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, - static_cast(0.0)}; + reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; } #pragma unroll @@ -444,9 +385,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( } } } +#endif template -__global__ void DepthwiseConv2dGPUKernelNCHWSmall( +__global__ void DepthwiseConv2dGPUKernelNCHW( const Param param, const float* input, const float* filter, float* output) { using T = float; using T2 = float2; @@ -530,11 +472,6 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( [(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h * SrcTileCount::smem_w + s_w]; - if (off_ochannel == 0 && off_obw == 0 && off_obh == 0 && off_oh == 30 && - off_ow == 0) { - printf("reg_src[%d] = %f\n", s_h * SrcTileConfig::unroll_w + s_w, - reg_src[s_h * SrcTileConfig::unroll_w + s_w]); - } } } @@ -561,15 +498,6 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] * reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw + ow * stride_w]; - if (off_ochannel == 0 && off_obw == 0 && off_obh == 0 && - off_oh == 30) { - printf("sum[%d] += %f * %f\nsum = %f\n", - oh * OutTileConfig::unroll_w + ow, - reg_flt[inner_fh * FilterTileConfig::unroll_w + fw], - reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + - fw + ow * stride_w], - sum[oh * OutTileConfig::unroll_w + ow]); - } } } } @@ -610,7 +538,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall( template < typename T, typename T2, DepthwiseConv2dDirection kDirection, int unroll_fw, int unroll_ow, int stride> -void LaunchDepthwiseConv2dGPUSmall( +void LaunchDepthwiseConv2dGPU( const Param& param, const T* input, const T* filter, T* output, cudaStream_t stream) { static int const unroll_oh = 1, unroll_fh = 1; @@ -633,22 +561,21 @@ void LaunchDepthwiseConv2dGPUSmall( (SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T); void (*kernel)(const Param, const T*, const T*, T*); - kernel = DepthwiseConv2dGPUKernelNCHWSmall; + kernel = DepthwiseConv2dGPUKernelNCHW; kernel<<>>(param, input, filter, output); after_kernel_launch(); } -#define INSTANCE_AB(type1, type2, a, b, direction) \ - if (param.out_w > b * 4) { \ - printf("param.out_w = %d, b = %d\n", param.out_w, b); \ - if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \ - (param.stride_h == 1 && param.stride_w == 1)) { \ - LaunchDepthwiseConv2dGPUSmall( \ - param, src, flt, dst, stream); \ - } else if (param.stride_h == 2 && param.stride_w == 2) { \ - LaunchDepthwiseConv2dGPUSmall( \ - param, src, flt, dst, stream); \ - } \ +#define INSTANCE_AB(type1, type2, a, b, direction) \ + if (param.out_w > b * 4) { \ + if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \ + (param.stride_h == 1 && param.stride_w == 1)) { \ + LaunchDepthwiseConv2dGPU( \ + param, src, flt, dst, stream); \ + } else if (param.stride_h == 2 && param.stride_w == 2) { \ + LaunchDepthwiseConv2dGPU( \ + param, src, flt, dst, stream); \ + } \ } #define INSTANCE_A(type1, type2, a, direction) \ diff --git a/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu b/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu index d4570b888..ef643efc3 100644 --- a/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu +++ b/dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu @@ -21,7 +21,7 @@ using namespace cuda; using namespace conv_bias; using namespace chanwise; -#include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl" +#include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh" namespace megdnn { namespace cuda { diff --git a/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp b/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp index 2f28e87ad..aba456801 100644 --- a/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp +++ b/dnn/src/cuda/conv_bias/depthwise_large_filter.cpp @@ -9,6 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh" #include "src/common/conv_bias.h" #include "src/cuda/conv_bias/algo.h" #include "src/cuda/conv_bias/chanwise/kern.cuh" @@ -20,26 +21,13 @@ using namespace conv_bias; namespace { inline bool is_available_depthwise_large_filter(const chanwise::Param& param) { - auto&& device_prop = cuda::current_device_prop(); - int flt_smem_w = (param.flt_w + 3) / 4 * 4; - int flt_smem_h = 3; - int flt_reg_per_thread = - flt_smem_w > 32 ? (flt_smem_w + 31) / 32 : 1 + flt_smem_w / 4; - int ow = param.out_w > 64 ? 64 : param.out_w; - int src_smem_w = ow + flt_smem_w - 1; - int src_smem_h = flt_smem_h + param.flt_h - 1; - int src_reg_per_thread = src_smem_w > 128 ? (flt_smem_w + 127) / 128 - : 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1; - int out_reg_per_thread = (ow + 3) / 4 * 4; - if (device_prop.regsPerBlock < 4 * 32 * - (flt_reg_per_thread * 2 + - src_reg_per_thread + out_reg_per_thread) || - device_prop.sharedMemPerBlock < - static_cast( - flt_smem_w * flt_smem_h * 2 + src_smem_w * src_smem_h)) { - return false; + if ((param.stride_h == 1 && param.stride_w == 1) || + (param.stride_h == 2 && param.stride_w == 2)) { + auto&& device_prop = cuda::current_device_prop(); + static int const unroll_oh = 1, unroll_fh = 1; + CHECK(FWD) } - return true; + return false; } } // anonymous namespace @@ -64,8 +52,8 @@ bool ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::is_available( return fm.group > 1 && args.filter_meta.format == Param::Format::NCHW && args.src_layout->dtype.category() == DTypeCategory::FLOAT && args.opr->param().compute_mode == Param::ComputeMode::DEFAULT && - fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && !fm.should_flip && + fm.spatial_ndim == 2 && fm.icpg == 1 && fm.ocpg == 1 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip && is_available_depthwise_large_filter(param); } diff --git a/dnn/src/cuda/convolution/backward_data/algo.h b/dnn/src/cuda/convolution/backward_data/algo.h index a9009c79b..a4f5536cc 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.h +++ b/dnn/src/cuda/convolution/backward_data/algo.h @@ -68,7 +68,7 @@ public: const TensorLayout& grad); convolution::ForwardSizeArgs as_fwd_args() const { - return {handle, diff_layout, filter_layout, filter_meta, grad_layout}; + return {handle, grad_layout, filter_layout, filter_meta, diff_layout}; } }; struct ExecArgs : public SizeArgs { diff --git a/dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp b/dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp index 5ebcd66d1..b5c3c5443 100644 --- a/dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp +++ b/dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp @@ -9,6 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +#include "src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh" #include "src/cuda/convolution/backward_data/algo.h" #include "src/cuda/convolution/chanwise/kern.cuh" #include "src/cuda/utils.h" @@ -19,29 +20,13 @@ using namespace convolution; namespace { inline bool is_available_depthwise_large_filter(const chanwise::Param& param) { - auto&& device_prop = cuda::current_device_prop(); - int flt_smem_w = (param.flt_w + 3) / 4 * 4; - int flt_smem_h = 3; - int flt_reg_per_thread = - flt_smem_w > 32 ? (flt_smem_w + 31) / 32 : 1 + flt_smem_w / 4; - int ow = param.out_w > 64 ? 64 : param.out_w; - int src_smem_w = ow + flt_smem_w - 1; - int src_smem_h = flt_smem_h + param.flt_h - 1; - int src_reg_per_thread = src_smem_w > 128 ? (flt_smem_w + 127) / 128 - : 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1; - int out_reg_per_thread = (ow + 3) / 4 * 4; - if (device_prop.regsPerBlock < 4 * 32 * - (flt_reg_per_thread * 2 + - src_reg_per_thread + out_reg_per_thread) || - device_prop.sharedMemPerBlock < - static_cast( - flt_smem_w * flt_smem_h * 2 + src_smem_w * src_smem_h)) { - return false; + if ((param.stride_h == 1 && param.stride_w == 1) || + (param.stride_h == 2 && param.stride_w == 2)) { + auto&& device_prop = cuda::current_device_prop(); + static int const unroll_oh = 1, unroll_fh = 1; + CHECK(BWD) } - printf("param.src_w = %d, param.src_h = %d, param.out_w = %d, param.out_h = %d\n", - param.src_w, param.src_h, param.out_w, param.out_h); - return (param.stride_h == 1 && param.stride_w == 1) || - (param.stride_h == 2 && param.stride_w == 2); + return false; } } // anonymous namespace @@ -59,13 +44,15 @@ bool ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::is_available( return false; } - auto param = chanwise::Param::from_fwd_args(args.as_fwd_args()); + auto param = chanwise::Param::from_fwd_args( + {args.handle, args.diff_layout, args.filter_layout, args.filter_meta, + args.grad_layout}); auto&& fm = args.filter_meta; return fm.group > 1 && args.filter_meta.format == Param::Format::NCHW && args.diff_layout->dtype.category() == DTypeCategory::FLOAT && args.opr->param().compute_mode == Param::ComputeMode::DEFAULT && - fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 && - fm.dilation[1] == 1 && !fm.should_flip && + fm.spatial_ndim == 2 && fm.icpg == 1 && fm.ocpg == 1 && + fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip && is_available_depthwise_large_filter(param); } @@ -76,7 +63,9 @@ size_t ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::get_workspace_in_b void ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::exec( const ExecArgs& args) const { - auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args()); + auto kparam = chanwise::Param::from_fwd_args( + {args.handle, args.diff_layout, args.filter_layout, args.filter_meta, + args.grad_layout}); auto stream = cuda_stream(args.handle); switch (args.diff_layout->dtype.enumv()) { case DTypeEnum::Float32: diff --git a/dnn/src/cuda/convolution/chanwise/bwd_large_filter.cu b/dnn/src/cuda/convolution/chanwise/bwd_large_filter.cu index acbb7b9a8..f1a8b770a 100644 --- a/dnn/src/cuda/convolution/chanwise/bwd_large_filter.cu +++ b/dnn/src/cuda/convolution/chanwise/bwd_large_filter.cu @@ -1,5 +1,5 @@ /** - * \file dnn/src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cu + * \file dnn/src/cuda/conv_bias/chanwise/bwd_large_filter.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -21,7 +21,7 @@ using namespace cuda; using namespace convolution; using namespace chanwise; -#include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl" +#include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh" namespace megdnn { namespace cuda { diff --git a/dnn/test/cuda/convolution.cpp b/dnn/test/cuda/convolution.cpp index 697b784fc..41b8ab3b3 100644 --- a/dnn/test/cuda/convolution.cpp +++ b/dnn/test/cuda/convolution.cpp @@ -739,7 +739,7 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) { param.sparse = param::Convolution::Sparse::GROUP; checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype); float scale = 64.f / sqrt(fh * fh); - UniformFloatRNG rng(1.0, 1.0); + UniformFloatRNG rng(scale, scale * 2); checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &rng); if (dtype.enumv() == DTypeEnum::Float16) checker.set_epsilon(1e-1); @@ -751,35 +751,35 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) { {n, g, h, h}}); }; run(4, 8, 32, 5, 5 / 2, 1); - run(4, 8, 32, 7, 7/2, 1); - run(4, 8, 32, 9, 9/2, 1); - run(4, 8, 32, 11, 11/2, 1); - run(4, 8, 32, 13, 13/2, 1); - run(4, 8, 32, 15, 15/2, 1); - run(4, 8, 32, 17, 17/2, 1); - run(4, 8, 32, 19, 19/2, 1); - run(4, 8, 32, 21, 21/2, 1); - run(4, 8, 32, 23, 23/2, 1); - run(4, 8, 32, 25, 25/2, 1); - run(4, 8, 32, 27, 27/2, 1); - run(4, 8, 32, 29, 29/2, 1); - run(4, 8, 32, 31, 31/2, 1); + run(4, 8, 32, 7, 7 / 2, 1); + run(4, 8, 32, 9, 9 / 2, 1); + run(4, 8, 32, 11, 11 / 2, 1); + run(4, 8, 32, 13, 13 / 2, 1); + run(4, 8, 32, 15, 15 / 2, 1); + run(4, 8, 32, 17, 17 / 2, 1); + run(4, 8, 32, 19, 19 / 2, 1); + run(4, 8, 32, 21, 21 / 2, 1); + run(4, 8, 32, 23, 23 / 2, 1); + run(4, 8, 32, 25, 25 / 2, 1); + run(4, 8, 32, 27, 27 / 2, 1); + run(4, 8, 32, 29, 29 / 2, 1); + run(4, 8, 32, 31, 31 / 2, 1); run(4, 8, 64, 5, 5 / 2, 2); - run(4, 8, 64, 7, 7/3, 2); - run(4, 8, 64, 9, 9/3, 2); - run(4, 8, 64, 11, 11/3, 2); - run(4, 8, 64, 13, 13/3, 2); - run(4, 8, 64, 15, 15/3, 2); - run(4, 8, 64, 17, 17/3, 2); - run(4, 8, 64, 19, 19/3, 2); - run(4, 8, 64, 21, 21/3, 2); - run(4, 8, 64, 23, 23/3, 2); - run(4, 8, 64, 25, 25/3, 2); - run(4, 8, 64, 27, 27/3, 2); - run(4, 8, 64, 29, 29/3, 2); - run(4, 8, 64, 31, 31/3, 2); - run(1, 2, 128, 31, 31/3, 2); - run(1, 2, 256, 31, 31/3, 2); + run(4, 8, 64, 7, 7 / 3, 2); + run(4, 8, 64, 9, 9 / 3, 2); + run(4, 8, 64, 11, 11 / 3, 2); + run(4, 8, 64, 13, 13 / 3, 2); + run(4, 8, 64, 15, 15 / 3, 2); + run(4, 8, 64, 17, 17 / 3, 2); + run(4, 8, 64, 19, 19 / 3, 2); + run(4, 8, 64, 21, 21 / 3, 2); + run(4, 8, 64, 23, 23 / 3, 2); + run(4, 8, 64, 25, 25 / 3, 2); + run(4, 8, 64, 27, 27 / 3, 2); + run(4, 8, 64, 29, 29 / 3, 2); + run(4, 8, 64, 31, 31 / 3, 2); + run(1, 2, 128, 31, 31 / 3, 2); + run(1, 2, 256, 31, 31 / 3, 2); } } -- GitLab