提交 d968942f 编写于 作者: M Megvii Engine Team

perf(cuda): speedup direct large kernel conv

GitOrigin-RevId: 3ff6a9caebbd1dc4c5c1c23b51945f7574f186ca
上级 b2cffdde
...@@ -59,14 +59,15 @@ struct ConvTraitInner { ...@@ -59,14 +59,15 @@ struct ConvTraitInner {
static int const smem_src_h = static int const smem_src_h =
(OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h; (OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h;
static int const smem_buff_h = FilterTileConfig::unroll_h; static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_src_h + smem_buff_h; static int const smem_load_h = smem_src_h + smem_buff_h *
FilterTileConfig::unroll_w *
ThreadConfig::thread_x;
static int const smem_h = smem_load_h + smem_buff_h; static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w = static int const smem_w =
DIVUP((OutTileConfig::block_w - 1) * stride_w + DIVUP((OutTileConfig::block_w - 1) * stride_w +
FilterTileConfig::unroll_w * ThreadConfig::thread_x, FilterTileConfig::unroll_w * ThreadConfig::thread_x,
2) * 2) *
2; 2;
static int const smem_size = smem_h * smem_w;
static int const load_w = static int const load_w =
smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w; smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w;
static int const load_h = 1; static int const load_h = 1;
...@@ -74,21 +75,36 @@ struct ConvTraitInner { ...@@ -74,21 +75,36 @@ struct ConvTraitInner {
static int const reg_w = DIVUP(smem_w, load_w); static int const reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_h % load_h != 0; static bool constexpr check_bounds_h = smem_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0; static bool constexpr check_bounds_w = smem_w % load_w != 0;
// to avoid bank confilct, every bank_offset_line in 8 lines, add one offset
static int const bank_w = smem_w / (4 / sizeof(CompType));
static int const bank_offset_line =
(bank_w % 32 == 0 || bank_w % FilterTileConfig::unroll_w == 0)
? 1
: (bank_w % 16 == 0 ? 2 : 4);
static int const smem_size = smem_h * smem_w + DIVUP(smem_h, bank_offset_line) *
(4 / sizeof(CompType));
}; };
struct FilterTileCount { struct FilterTileCount {
static int const smem_flt_h = FilterTileConfig::unroll_h; static int const smem_flt_h = FilterTileConfig::unroll_h;
static int const smem_buff_h = FilterTileConfig::unroll_h; static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_flt_h + smem_buff_h;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w = FilterTileConfig::unroll_w * ThreadConfig::thread_x; static int const smem_w = FilterTileConfig::unroll_w * ThreadConfig::thread_x;
static int const smem_size = smem_h * smem_w; static int const smem_load_h = smem_flt_h + smem_buff_h * smem_w;
static int const smem_h = smem_load_h + smem_buff_h;
static int const load_w = smem_w > 32 ? 32 : smem_w; static int const load_w = smem_w > 32 ? 32 : smem_w;
static int const load_h = ThreadConfig::nr_threads / load_w; static int const load_h = ThreadConfig::nr_threads / load_w;
static int const reg_h = 1; static int const reg_h = 1;
static int const reg_w = DIVUP(smem_w, load_w); static int const reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_h % load_h != 0; static bool constexpr check_bounds_h = smem_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0; static bool constexpr check_bounds_w = smem_w % load_w != 0;
// to avoid bank confilct, every bank_offset_line in 8 lines, add one offset
static int const bank_w = smem_w / (4 / sizeof(CompType));
static int const bank_offset_line =
(bank_w % 32 == 0 || bank_w % FilterTileConfig::unroll_w == 0)
? 1
: (bank_w % 16 == 0 ? 2 : 4);
static int const smem_size = smem_h * smem_w + DIVUP(smem_h, bank_offset_line) *
(4 / sizeof(CompType));
}; };
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册