提交 d968942f 编写于 作者: M Megvii Engine Team

perf(cuda): speedup direct large kernel conv

GitOrigin-RevId: 3ff6a9caebbd1dc4c5c1c23b51945f7574f186ca
上级 b2cffdde
......@@ -59,14 +59,15 @@ struct ConvTraitInner {
static int const smem_src_h =
(OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h;
static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_src_h + smem_buff_h;
static int const smem_load_h = smem_src_h + smem_buff_h *
FilterTileConfig::unroll_w *
ThreadConfig::thread_x;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w =
DIVUP((OutTileConfig::block_w - 1) * stride_w +
FilterTileConfig::unroll_w * ThreadConfig::thread_x,
2) *
2;
static int const smem_size = smem_h * smem_w;
static int const load_w =
smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w;
static int const load_h = 1;
......@@ -74,21 +75,36 @@ struct ConvTraitInner {
static int const reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0;
// to avoid bank confilct, every bank_offset_line in 8 lines, add one offset
static int const bank_w = smem_w / (4 / sizeof(CompType));
static int const bank_offset_line =
(bank_w % 32 == 0 || bank_w % FilterTileConfig::unroll_w == 0)
? 1
: (bank_w % 16 == 0 ? 2 : 4);
static int const smem_size = smem_h * smem_w + DIVUP(smem_h, bank_offset_line) *
(4 / sizeof(CompType));
};
struct FilterTileCount {
static int const smem_flt_h = FilterTileConfig::unroll_h;
static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_flt_h + smem_buff_h;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w = FilterTileConfig::unroll_w * ThreadConfig::thread_x;
static int const smem_size = smem_h * smem_w;
static int const smem_load_h = smem_flt_h + smem_buff_h * smem_w;
static int const smem_h = smem_load_h + smem_buff_h;
static int const load_w = smem_w > 32 ? 32 : smem_w;
static int const load_h = ThreadConfig::nr_threads / load_w;
static int const reg_h = 1;
static int const reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0;
// to avoid bank confilct, every bank_offset_line in 8 lines, add one offset
static int const bank_w = smem_w / (4 / sizeof(CompType));
static int const bank_offset_line =
(bank_w % 32 == 0 || bank_w % FilterTileConfig::unroll_w == 0)
? 1
: (bank_w % 16 == 0 ? 2 : 4);
static int const smem_size = smem_h * smem_w + DIVUP(smem_h, bank_offset_line) *
(4 / sizeof(CompType));
};
};
......
......@@ -119,11 +119,12 @@ __device__ __forceinline__ void Global2SharedMem<
#pragma unroll
for (int i = 0; i < h_per_thread; ++i) {
int smem_h_idx = y_base_idx + i * load_h;
int bank_offset = smem_h_idx / TileCount::bank_offset_line;
int src_h_idx;
if (is_fwd) {
src_h_idx = start_h + smem_h_idx;
} else {
src_h_idx = start_h + TileCount::smem_load_h - smem_h_idx - 1;
src_h_idx = start_h - smem_h_idx;
}
if (check_bounds_h && smem_h_idx >= TileCount::smem_load_h)
continue;
......@@ -146,7 +147,8 @@ __device__ __forceinline__ void Global2SharedMem<
TileCount::smem_w - w_offset - smem_w_idx - 1 >= 0))) {
val = g_ptr[src_h_idx / stride_h * stride + src_w_idx / stride_w];
}
*(sh_ptr_as_copy_t(smem_h_idx, smem_w_idx)) = val;
*(sh_ptr_as_copy_t(
smem_h_idx, smem_w_idx + bank_offset * (4 / sizeof(T)))) = val;
}
}
}
......@@ -261,24 +263,29 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
SrcGlobal2ShareVisitor gl2sh_src = {
smem_src,
param.src_w,
static_cast<int>(param.src_w),
static_cast<int>(
is_fwd ? src_start_h
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2),
: src_start_h -
(param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2)),
static_cast<int>(
is_fwd ? src_start_w
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2),
is_fwd ? param.src_h : param.src_h * param.stride_h,
is_fwd ? param.src_w : param.src_w * param.stride_w,
is_fwd ? 1 : param.stride_h,
is_fwd ? 1 : param.stride_w};
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt,
param.flt_w,
is_fwd ? 0 : param.flt_h - 2,
: src_start_w -
(param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2)),
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h),
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w),
is_fwd ? 1 : static_cast<int>(param.stride_h),
is_fwd ? 1 : static_cast<int>(param.stride_w)};
FilterGlobal2ShareVisitor gl2sh_flt = {
smem_flt,
static_cast<int>(param.flt_w),
is_fwd ? 0 : static_cast<int>(param.flt_h - 1),
0,
param.flt_h,
param.flt_w,
static_cast<int>(param.flt_h),
static_cast<int>(param.flt_w),
1,
1};
......@@ -290,14 +297,51 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
__syncthreads();
T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w],
reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w];
T2 reg_src[2][SrcTileConfig::unroll_h * t2_src_unroll_w],
reg_flt[2][2][FilterTileConfig::unroll_h * t2_flt_unroll_w];
T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}};
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
gl2sh_flt.copy();
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) {
int src_offset = (off_oh * stride_h + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w * 2;
reg_src[0][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast<T2*>(
smem_src_ptr + src_offset +
((off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line) * 2);
}
}
if (off_ow == ThreadConfig::thread_x - 1) {
reg_src[0][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0};
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) {
int flt_offset =
(f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w * 2;
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast<T2*>(
smem_flt_ptr + flt_offset +
2 * (f_h / FilterTileCount::bank_offset_line));
if (f_w > 0) {
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] =
T2{reg_flt[0][0][f_h * t2_flt_unroll_w + f_w - 1].y,
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x};
} else {
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] =
T2{0.0, reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x};
}
}
reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0};
reg_flt[0][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] =
T2{reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0};
}
for (int fh = 1; fh < param.flt_h - 1; fh += FilterTileConfig::unroll_h * 2) {
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
......@@ -305,10 +349,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
int src_offset = (off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w * 2;
reg_src[s_h * t2_src_unroll_w + s_w] =
*reinterpret_cast<T2*>(smem_src_ptr + src_offset);
reg_src[1][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast<T2*>(
smem_src_ptr + src_offset +
2 * ((off_oh * stride_h + fh + s_h) /
SrcTileCount::bank_offset_line));
}
}
if (off_ow == ThreadConfig::thread_x - 1) {
reg_src[1][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0};
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
......@@ -317,20 +366,21 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
int flt_offset =
(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w +
f_w * 2;
reg_flt[0][f_h * t2_flt_unroll_w + f_w] =
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset);
reg_flt[1][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast<T2*>(
smem_flt_ptr + flt_offset +
2 * ((fh + f_h) / FilterTileCount::bank_offset_line));
if (f_w > 0) {
reg_flt[1][f_h * t2_flt_unroll_w + f_w] =
T2{reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y,
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
reg_flt[1][1][f_h * t2_flt_unroll_w + f_w] =
T2{reg_flt[1][0][f_h * t2_flt_unroll_w + f_w - 1].y,
reg_flt[1][0][f_h * t2_flt_unroll_w + f_w].x};
} else {
reg_flt[1][f_h * t2_flt_unroll_w + f_w] =
T2{0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
reg_flt[1][1][f_h * t2_flt_unroll_w + f_w] =
T2{0.0, reg_flt[1][0][f_h * t2_flt_unroll_w + f_w].x};
}
}
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0};
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] =
T2{reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0};
reg_flt[1][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0};
reg_flt[1][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{
reg_flt[1][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0};
}
#pragma unroll
......@@ -342,9 +392,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[ow * stride_w % 2]
reg_flt[0][ow * stride_w % 2]
[inner_fh * t2_flt_unroll_w + fw],
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw +
reg_src[0]
[(inner_fh + oh) * t2_src_unroll_w + fw +
ow * stride_w / 2],
sum[oh * t2_out_unroll_w + ow]);
}
......@@ -352,13 +403,91 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
}
}
__syncthreads();
gl2sh_src.commit();
gl2sh_flt.commit();
gl2sh_src.iter_forward();
gl2sh_flt.iter_forward();
__syncthreads();
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) {
int src_offset = (off_oh * stride_h + fh + 1 + s_h) %
SrcTileCount::smem_h * SrcTileCount::smem_w +
s_w * 2;
reg_src[0][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast<T2*>(
smem_src_ptr + src_offset +
2 * ((off_oh * stride_h + fh + 1 + s_h) /
SrcTileCount::bank_offset_line));
}
}
if (off_ow == ThreadConfig::thread_x - 1) {
reg_src[0][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0};
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) {
int flt_offset = (fh + 1 + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w * 2;
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast<T2*>(
smem_flt_ptr + flt_offset +
2 * ((fh + 1 + f_h) / FilterTileCount::bank_offset_line));
if (f_w > 0) {
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] =
T2{reg_flt[0][0][f_h * t2_flt_unroll_w + f_w - 1].y,
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x};
} else {
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] =
T2{0.0, reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x};
}
}
reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0};
reg_flt[0][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{
reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0};
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[1][ow * stride_w % 2]
[inner_fh * t2_flt_unroll_w + fw],
reg_src[1]
[(inner_fh + oh) * t2_src_unroll_w + fw +
ow * stride_w / 2],
sum[oh * t2_out_unroll_w + ow]);
}
}
}
}
}
if (param.flt_h % 2 != 0) {
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[0][ow * stride_w % 2]
[inner_fh * t2_flt_unroll_w + fw],
reg_src[0]
[(inner_fh + oh) * t2_src_unroll_w + fw +
ow * stride_w / 2],
sum[oh * t2_out_unroll_w + ow]);
}
}
}
}
}
__syncthreads();
for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
......@@ -429,24 +558,29 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
SrcGlobal2ShareVisitor gl2sh_src = {
smem_src,
param.src_w,
static_cast<int>(param.src_w),
static_cast<int>(
is_fwd ? src_start_h
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2),
: src_start_h -
(param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2)),
static_cast<int>(
is_fwd ? src_start_w
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2),
is_fwd ? param.src_h : param.src_h * param.stride_h,
is_fwd ? param.src_w : param.src_w * param.stride_w,
is_fwd ? 1 : param.stride_h,
is_fwd ? 1 : param.stride_w};
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt,
param.flt_w,
is_fwd ? 0 : param.flt_h - 2,
: src_start_w -
(param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2)),
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h),
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w),
is_fwd ? 1 : static_cast<int>(param.stride_h),
is_fwd ? 1 : static_cast<int>(param.stride_w)};
FilterGlobal2ShareVisitor gl2sh_flt = {
smem_flt,
static_cast<int>(param.flt_w),
is_fwd ? 0 : static_cast<int>(param.flt_h - 1),
0,
param.flt_h,
param.flt_w,
static_cast<int>(param.flt_h),
static_cast<int>(param.flt_w),
1,
1};
......@@ -458,14 +592,51 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
__syncthreads();
T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w],
reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w];
T2 reg_src[2][SrcTileConfig::unroll_h * t2_src_unroll_w],
reg_flt[2][2][FilterTileConfig::unroll_h * t2_flt_unroll_w];
float2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}};
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
gl2sh_flt.copy();
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) {
int src_offset = (off_oh * stride_h + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w * 2;
reg_src[0][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast<T2*>(
smem_src_ptr + src_offset +
((off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line) * 2);
}
}
if (off_ow == ThreadConfig::thread_x - 1) {
reg_src[0][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0};
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) {
int flt_offset =
(f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w * 2;
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast<T2*>(
smem_flt_ptr + flt_offset +
2 * (f_h / FilterTileCount::bank_offset_line));
if (f_w > 0) {
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] =
T2{reg_flt[0][0][f_h * t2_flt_unroll_w + f_w - 1].y,
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x};
} else {
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] =
T2{0.0, reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x};
}
}
reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0};
reg_flt[0][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] =
T2{reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0};
}
for (int fh = 1; fh < param.flt_h - 1; fh += FilterTileConfig::unroll_h * 2) {
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
......@@ -473,10 +644,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
int src_offset = (off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w * 2;
reg_src[s_h * t2_src_unroll_w + s_w] =
*reinterpret_cast<T2*>(smem_src_ptr + src_offset);
reg_src[1][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast<T2*>(
smem_src_ptr + src_offset +
2 * ((off_oh * stride_h + fh + s_h) /
SrcTileCount::bank_offset_line));
}
}
if (off_ow == ThreadConfig::thread_x - 1) {
reg_src[1][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0};
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
......@@ -485,20 +661,21 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
int flt_offset =
(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w +
f_w * 2;
reg_flt[0][f_h * t2_flt_unroll_w + f_w] =
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset);
reg_flt[1][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast<T2*>(
smem_flt_ptr + flt_offset +
2 * ((fh + f_h) / FilterTileCount::bank_offset_line));
if (f_w > 0) {
reg_flt[1][f_h * t2_flt_unroll_w + f_w] =
T2{reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y,
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
reg_flt[1][1][f_h * t2_flt_unroll_w + f_w] =
T2{reg_flt[1][0][f_h * t2_flt_unroll_w + f_w - 1].y,
reg_flt[1][0][f_h * t2_flt_unroll_w + f_w].x};
} else {
reg_flt[1][f_h * t2_flt_unroll_w + f_w] =
T2{0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
reg_flt[1][1][f_h * t2_flt_unroll_w + f_w] =
T2{0.0, reg_flt[1][0][f_h * t2_flt_unroll_w + f_w].x};
}
}
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0};
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] =
T2{reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0};
reg_flt[1][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0};
reg_flt[1][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{
reg_flt[1][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0};
}
#pragma unroll
......@@ -510,9 +687,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[ow * stride_w % 2]
reg_flt[0][ow * stride_w % 2]
[inner_fh * t2_flt_unroll_w + fw],
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw +
reg_src[0]
[(inner_fh + oh) * t2_src_unroll_w + fw +
ow * stride_w / 2],
sum[oh * t2_out_unroll_w + ow]);
}
......@@ -520,13 +698,91 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
}
}
__syncthreads();
gl2sh_src.commit();
gl2sh_flt.commit();
gl2sh_src.iter_forward();
gl2sh_flt.iter_forward();
__syncthreads();
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) {
int src_offset = (off_oh * stride_h + fh + 1 + s_h) %
SrcTileCount::smem_h * SrcTileCount::smem_w +
s_w * 2;
reg_src[0][s_h * t2_src_unroll_w + s_w] = *reinterpret_cast<T2*>(
smem_src_ptr + src_offset +
2 * ((off_oh * stride_h + fh + 1 + s_h) /
SrcTileCount::bank_offset_line));
}
}
if (off_ow == ThreadConfig::thread_x - 1) {
reg_src[0][SrcTileConfig::unroll_h * t2_src_unroll_w - 1] = T2{0, 0};
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) {
int flt_offset = (fh + 1 + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w * 2;
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w] = *reinterpret_cast<T2*>(
smem_flt_ptr + flt_offset +
2 * ((fh + 1 + f_h) / FilterTileCount::bank_offset_line));
if (f_w > 0) {
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] =
T2{reg_flt[0][0][f_h * t2_flt_unroll_w + f_w - 1].y,
reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x};
} else {
reg_flt[0][1][f_h * t2_flt_unroll_w + f_w] =
T2{0.0, reg_flt[0][0][f_h * t2_flt_unroll_w + f_w].x};
}
}
reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0};
reg_flt[0][1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{
reg_flt[0][0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0};
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[1][ow * stride_w % 2]
[inner_fh * t2_flt_unroll_w + fw],
reg_src[1]
[(inner_fh + oh) * t2_src_unroll_w + fw +
ow * stride_w / 2],
sum[oh * t2_out_unroll_w + ow]);
}
}
}
}
}
if (param.flt_h % 2 != 0) {
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[0][ow * stride_w % 2]
[inner_fh * t2_flt_unroll_w + fw],
reg_src[0]
[(inner_fh + oh) * t2_src_unroll_w + fw +
ow * stride_w / 2],
sum[oh * t2_out_unroll_w + ow]);
}
}
}
}
}
__syncthreads();
for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
......@@ -595,24 +851,29 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
SrcGlobal2ShareVisitor gl2sh_src = {
smem_src,
param.src_w,
static_cast<int>(param.src_w),
static_cast<int>(
is_fwd ? src_start_h
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2),
: src_start_h -
(param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2)),
static_cast<int>(
is_fwd ? src_start_w
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2),
is_fwd ? param.src_h : param.src_h * param.stride_h,
is_fwd ? param.src_w : param.src_w * param.stride_w,
is_fwd ? 1 : param.stride_h,
is_fwd ? 1 : param.stride_w};
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt,
param.flt_w,
is_fwd ? 0 : param.flt_h - 2,
: src_start_w -
(param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2)),
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h),
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w),
is_fwd ? 1 : static_cast<int>(param.stride_h),
is_fwd ? 1 : static_cast<int>(param.stride_w)};
FilterGlobal2ShareVisitor gl2sh_flt = {
smem_flt,
static_cast<int>(param.flt_w),
is_fwd ? 0 : static_cast<int>(param.flt_h - 1),
0,
param.flt_h,
param.flt_w,
static_cast<int>(param.flt_h),
static_cast<int>(param.flt_w),
1,
1};
......@@ -624,22 +885,43 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
__syncthreads();
T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w],
reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w];
T reg_src[2][SrcTileConfig::unroll_h * SrcTileConfig::unroll_w],
reg_flt[2][FilterTileConfig::unroll_h * FilterTileConfig::unroll_w];
T sum[OutTileConfig::unroll_size] = {0.0};
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
gl2sh_flt.copy();
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[(off_oh * stride_h + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w + (off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line];
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w +
f_h / FilterTileCount::bank_offset_line];
}
}
for (int fh = 1; fh < param.flt_h + 1; fh += FilterTileConfig::unroll_h * 2) {
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
reg_src[1][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w];
s_w +
(off_oh * stride_h + fh + s_h) /
SrcTileCount::bank_offset_line];
}
}
......@@ -647,13 +929,53 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
reg_flt[1][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(fh + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w];
f_w + (fh + f_h) / FilterTileCount::bank_offset_line];
}
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[0][inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[0]
[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow * stride_w];
}
}
}
}
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[(off_oh * stride_h + fh + 1 + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w +
(off_oh * stride_h + fh + 1 + s_h) /
SrcTileCount::bank_offset_line];
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(fh + 1 + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w + (fh + 1 + f_h) / FilterTileCount::bank_offset_line];
}
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
......@@ -663,21 +985,37 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
reg_flt[1][inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[1]
[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow * stride_w];
}
}
}
}
}
__syncthreads();
gl2sh_src.commit();
gl2sh_flt.commit();
gl2sh_src.iter_forward();
gl2sh_flt.iter_forward();
__syncthreads();
if (param.flt_h % 2 != 0) {
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[0][inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[0]
[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow * stride_w];
}
}
}
}
}
__syncthreads();
for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
......@@ -743,24 +1081,29 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
SrcGlobal2ShareVisitor gl2sh_src = {
smem_src,
param.src_w,
static_cast<int>(param.src_w),
static_cast<int>(
is_fwd ? src_start_h
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2),
: src_start_h -
(param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2)),
static_cast<int>(
is_fwd ? src_start_w
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2),
is_fwd ? param.src_h : param.src_h * param.stride_h,
is_fwd ? param.src_w : param.src_w * param.stride_w,
is_fwd ? 1 : param.stride_h,
is_fwd ? 1 : param.stride_w};
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt,
param.flt_w,
is_fwd ? 0 : param.flt_h - 2,
: src_start_w -
(param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2)),
static_cast<int>(is_fwd ? param.src_h : param.src_h * param.stride_h),
static_cast<int>(is_fwd ? param.src_w : param.src_w * param.stride_w),
is_fwd ? 1 : static_cast<int>(param.stride_h),
is_fwd ? 1 : static_cast<int>(param.stride_w)};
FilterGlobal2ShareVisitor gl2sh_flt = {
smem_flt,
static_cast<int>(param.flt_w),
is_fwd ? 0 : static_cast<int>(param.flt_h - 1),
0,
param.flt_h,
param.flt_w,
static_cast<int>(param.flt_h),
static_cast<int>(param.flt_w),
1,
1};
......@@ -772,22 +1115,43 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
__syncthreads();
T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w],
reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w];
T reg_src[2][SrcTileConfig::unroll_h * SrcTileConfig::unroll_w],
reg_flt[2][FilterTileConfig::unroll_h * FilterTileConfig::unroll_w];
T sum[OutTileConfig::unroll_size] = {0.0};
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
gl2sh_flt.copy();
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[(off_oh * stride_h + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w + (off_oh * stride_h + s_h) / SrcTileCount::bank_offset_line];
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w + f_w +
f_h / FilterTileCount::bank_offset_line];
}
}
for (int fh = 1; fh < param.flt_h + 1; fh += FilterTileConfig::unroll_h * 2) {
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
reg_src[1][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w];
s_w +
(off_oh * stride_h + fh + s_h) /
SrcTileCount::bank_offset_line];
}
}
......@@ -795,13 +1159,73 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
reg_flt[1][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(fh + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w];
f_w + (fh + f_h) / FilterTileCount::bank_offset_line];
}
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[0][inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[0]
[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow * stride_w];
}
}
}
}
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
reg_src[0][s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[(off_oh * stride_h + fh + 1 + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w +
(off_oh * stride_h + fh + 1 + s_h) /
SrcTileCount::bank_offset_line];
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[0][f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(fh + 1 + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w + (fh + 1 + f_h) / FilterTileCount::bank_offset_line];
}
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[1][inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[1]
[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow * stride_w];
}
}
}
}
}
if (param.flt_h % 2 != 0) {
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
......@@ -811,21 +1235,17 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
reg_flt[0][inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[0]
[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow * stride_w];
}
}
}
}
}
__syncthreads();
gl2sh_src.commit();
gl2sh_flt.commit();
gl2sh_src.iter_forward();
gl2sh_flt.iter_forward();
__syncthreads();
}
for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
......@@ -901,9 +1321,8 @@ void LaunchDepthwiseConv2dGPU(
#define INSTANCE_A(type1, type2, a, direction) \
if (param.flt_w > a * 4) { \
INSTANCE_AB(type1, type2, a, 15, direction) \
else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB( \
type1, type2, a, 3, direction) \
INSTANCE_AB(type1, type2, a, 7, direction) \
else INSTANCE_AB(type1, type2, a, 3, direction) \
}
#define INSTANCE(type1, type2, direction) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册