Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d968942f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d968942f
编写于
4月 05, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(cuda): speedup direct large kernel conv
GitOrigin-RevId: 3ff6a9caebbd1dc4c5c1c23b51945f7574f186ca
上级
b2cffdde
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
611 addition
and
176 deletion
+611
-176
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh
+21
-5
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh
...c/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh
+590
-171
未找到文件。
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh
浏览文件 @
d968942f
...
...
@@ -59,14 +59,15 @@ struct ConvTraitInner {
static
int
const
smem_src_h
=
(
OutTileConfig
::
block_h
-
1
)
*
stride_h
+
FilterTileConfig
::
unroll_h
;
static
int
const
smem_buff_h
=
FilterTileConfig
::
unroll_h
;
static
int
const
smem_load_h
=
smem_src_h
+
smem_buff_h
;
static
int
const
smem_load_h
=
smem_src_h
+
smem_buff_h
*
FilterTileConfig
::
unroll_w
*
ThreadConfig
::
thread_x
;
static
int
const
smem_h
=
smem_load_h
+
smem_buff_h
;
static
int
const
smem_w
=
DIVUP
((
OutTileConfig
::
block_w
-
1
)
*
stride_w
+
FilterTileConfig
::
unroll_w
*
ThreadConfig
::
thread_x
,
2
)
*
2
;
static
int
const
smem_size
=
smem_h
*
smem_w
;
static
int
const
load_w
=
smem_w
>
ThreadConfig
::
nr_threads
?
ThreadConfig
::
nr_threads
:
smem_w
;
static
int
const
load_h
=
1
;
...
...
@@ -74,21 +75,36 @@ struct ConvTraitInner {
static
int
const
reg_w
=
DIVUP
(
smem_w
,
load_w
);
static
bool
constexpr
check_bounds_h
=
smem_h
%
load_h
!=
0
;
static
bool
constexpr
check_bounds_w
=
smem_w
%
load_w
!=
0
;
// to avoid bank confilct, every bank_offset_line in 8 lines, add one offset
static
int
const
bank_w
=
smem_w
/
(
4
/
sizeof
(
CompType
));
static
int
const
bank_offset_line
=
(
bank_w
%
32
==
0
||
bank_w
%
FilterTileConfig
::
unroll_w
==
0
)
?
1
:
(
bank_w
%
16
==
0
?
2
:
4
);
static
int
const
smem_size
=
smem_h
*
smem_w
+
DIVUP
(
smem_h
,
bank_offset_line
)
*
(
4
/
sizeof
(
CompType
));
};
struct
FilterTileCount
{
static
int
const
smem_flt_h
=
FilterTileConfig
::
unroll_h
;
static
int
const
smem_buff_h
=
FilterTileConfig
::
unroll_h
;
static
int
const
smem_load_h
=
smem_flt_h
+
smem_buff_h
;
static
int
const
smem_h
=
smem_load_h
+
smem_buff_h
;
static
int
const
smem_w
=
FilterTileConfig
::
unroll_w
*
ThreadConfig
::
thread_x
;
static
int
const
smem_size
=
smem_h
*
smem_w
;
static
int
const
smem_load_h
=
smem_flt_h
+
smem_buff_h
*
smem_w
;
static
int
const
smem_h
=
smem_load_h
+
smem_buff_h
;
static
int
const
load_w
=
smem_w
>
32
?
32
:
smem_w
;
static
int
const
load_h
=
ThreadConfig
::
nr_threads
/
load_w
;
static
int
const
reg_h
=
1
;
static
int
const
reg_w
=
DIVUP
(
smem_w
,
load_w
);
static
bool
constexpr
check_bounds_h
=
smem_h
%
load_h
!=
0
;
static
bool
constexpr
check_bounds_w
=
smem_w
%
load_w
!=
0
;
// to avoid bank confilct, every bank_offset_line in 8 lines, add one offset
static
int
const
bank_w
=
smem_w
/
(
4
/
sizeof
(
CompType
));
static
int
const
bank_offset_line
=
(
bank_w
%
32
==
0
||
bank_w
%
FilterTileConfig
::
unroll_w
==
0
)
?
1
:
(
bank_w
%
16
==
0
?
2
:
4
);
static
int
const
smem_size
=
smem_h
*
smem_w
+
DIVUP
(
smem_h
,
bank_offset_line
)
*
(
4
/
sizeof
(
CompType
));
};
};
...
...
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh
浏览文件 @
d968942f
...
...
@@ -119,11 +119,12 @@ __device__ __forceinline__ void Global2SharedMem<
#pragma unroll
for
(
int
i
=
0
;
i
<
h_per_thread
;
++
i
)
{
int
smem_h_idx
=
y_base_idx
+
i
*
load_h
;
int
bank_offset
=
smem_h_idx
/
TileCount
::
bank_offset_line
;
int
src_h_idx
;
if
(
is_fwd
)
{
src_h_idx
=
start_h
+
smem_h_idx
;
}
else
{
src_h_idx
=
start_h
+
TileCount
::
smem_load_h
-
smem_h_idx
-
1
;
src_h_idx
=
start_h
-
smem_h_idx
;
}
if
(
check_bounds_h
&&
smem_h_idx
>=
TileCount
::
smem_load_h
)
continue
;
...
...
@@ -146,7 +147,8 @@ __device__ __forceinline__ void Global2SharedMem<
TileCount
::
smem_w
-
w_offset
-
smem_w_idx
-
1
>=
0
)))
{
val
=
g_ptr
[
src_h_idx
/
stride_h
*
stride
+
src_w_idx
/
stride_w
];
}
*
(
sh_ptr_as_copy_t
(
smem_h_idx
,
smem_w_idx
))
=
val
;
*
(
sh_ptr_as_copy_t
(
smem_h_idx
,
smem_w_idx
+
bank_offset
*
(
4
/
sizeof
(
T
))))
=
val
;
}
}
}
...
...
@@ -261,24 +263,29 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
SrcGlobal2ShareVisitor
gl2sh_src
=
{
smem_src
,
param
.
src_w
,
static_cast
<
int
>
(
param
.
src_w
),
static_cast
<
int
>
(
is_fwd
?
src_start_h
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
),
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
)),
static_cast
<
int
>
(
is_fwd
?
src_start_w
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
),
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
,
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
,
is_fwd
?
1
:
param
.
stride_h
,
is_fwd
?
1
:
param
.
stride_w
};
FilterGlobal2ShareVisitor
gl2sh_flt
=
{
smem_flt
,
param
.
flt_w
,
is_fwd
?
0
:
param
.
flt_h
-
2
,
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
)),
static_cast
<
int
>
(
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
),
static_cast
<
int
>
(
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_h
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_w
)};
FilterGlobal2ShareVisitor
gl2sh_flt
=
{
smem_flt
,
static_cast
<
int
>
(
param
.
flt_w
),
is_fwd
?
0
:
static_cast
<
int
>
(
param
.
flt_h
-
1
),
0
,
param
.
flt_h
,
param
.
flt_w
,
static_cast
<
int
>
(
param
.
flt_h
)
,
static_cast
<
int
>
(
param
.
flt_w
)
,
1
,
1
};
...
...
@@ -290,14 +297,51 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
__syncthreads
();
T2
reg_src
[
SrcTileConfig
::
unroll_h
*
t2_src_unroll_w
],
reg_flt
[
2
][
FilterTileConfig
::
unroll_h
*
t2_flt_unroll_w
];
T2
reg_src
[
2
][
SrcTileConfig
::
unroll_h
*
t2_src_unroll_w
],
reg_flt
[
2
][
2
][
FilterTileConfig
::
unroll_h
*
t2_flt_unroll_w
];
T2
sum
[
OutTileConfig
::
unroll_size
]
=
{{
0.0
,
0.0
}};
for
(
int
fh
=
0
;
fh
<
param
.
flt_h
;
fh
+=
FilterTileConfig
::
unroll_h
)
{
gl2sh_src
.
copy
();
gl2sh_flt
.
copy
();
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
t2_src_unroll_w
;
++
s_w
)
{
int
src_offset
=
(
off_oh
*
stride_h
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
*
2
;
reg_src
[
0
][
s_h
*
t2_src_unroll_w
+
s_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_src_ptr
+
src_offset
+
((
off_oh
*
stride_h
+
s_h
)
/
SrcTileCount
::
bank_offset_line
)
*
2
);
}
}
if
(
off_ow
==
ThreadConfig
::
thread_x
-
1
)
{
reg_src
[
0
][
SrcTileConfig
::
unroll_h
*
t2_src_unroll_w
-
1
]
=
T2
{
0
,
0
};
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
t2_flt_unroll_w
-
1
;
++
f_w
)
{
int
flt_offset
=
(
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
*
2
;
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_flt_ptr
+
flt_offset
+
2
*
(
f_h
/
FilterTileCount
::
bank_offset_line
));
if
(
f_w
>
0
)
{
reg_flt
[
0
][
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
-
1
].
y
,
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
else
{
reg_flt
[
0
][
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
0.0
,
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
}
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
0.0
,
0.0
};
reg_flt
[
0
][
1
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
2
].
y
,
0.0
};
}
for
(
int
fh
=
1
;
fh
<
param
.
flt_h
-
1
;
fh
+=
FilterTileConfig
::
unroll_h
*
2
)
{
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
...
...
@@ -305,10 +349,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
int
src_offset
=
(
off_oh
*
stride_h
+
fh
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
*
2
;
reg_src
[
s_h
*
t2_src_unroll_w
+
s_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_src_ptr
+
src_offset
);
reg_src
[
1
][
s_h
*
t2_src_unroll_w
+
s_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_src_ptr
+
src_offset
+
2
*
((
off_oh
*
stride_h
+
fh
+
s_h
)
/
SrcTileCount
::
bank_offset_line
));
}
}
if
(
off_ow
==
ThreadConfig
::
thread_x
-
1
)
{
reg_src
[
1
][
SrcTileConfig
::
unroll_h
*
t2_src_unroll_w
-
1
]
=
T2
{
0
,
0
};
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
...
...
@@ -317,20 +366,21 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
int
flt_offset
=
(
fh
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
*
2
;
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_flt_ptr
+
flt_offset
);
reg_flt
[
1
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_flt_ptr
+
flt_offset
+
2
*
((
fh
+
f_h
)
/
FilterTileCount
::
bank_offset_line
));
if
(
f_w
>
0
)
{
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
-
1
].
y
,
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
reg_flt
[
1
][
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
reg_flt
[
1
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
-
1
].
y
,
reg_flt
[
1
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
else
{
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
0.0
,
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
reg_flt
[
1
][
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
0.0
,
reg_flt
[
1
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
}
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
0.0
,
0.0
};
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
2
].
y
,
0.0
};
reg_flt
[
1
][
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
0.0
,
0.0
};
reg_flt
[
1
][
1
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
reg_flt
[
1
]
[
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
2
].
y
,
0.0
};
}
#pragma unroll
...
...
@@ -342,9 +392,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
t2_out_unroll_w
+
ow
]
=
megdnn
::
cuda
::
fma2
(
reg_flt
[
ow
*
stride_w
%
2
]
reg_flt
[
0
][
ow
*
stride_w
%
2
]
[
inner_fh
*
t2_flt_unroll_w
+
fw
],
reg_src
[(
inner_fh
+
oh
)
*
t2_src_unroll_w
+
fw
+
reg_src
[
0
]
[(
inner_fh
+
oh
)
*
t2_src_unroll_w
+
fw
+
ow
*
stride_w
/
2
],
sum
[
oh
*
t2_out_unroll_w
+
ow
]);
}
...
...
@@ -352,13 +403,91 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
}
}
__syncthreads
();
gl2sh_src
.
commit
();
gl2sh_flt
.
commit
();
gl2sh_src
.
iter_forward
();
gl2sh_flt
.
iter_forward
();
__syncthreads
();
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
t2_src_unroll_w
;
++
s_w
)
{
int
src_offset
=
(
off_oh
*
stride_h
+
fh
+
1
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
*
2
;
reg_src
[
0
][
s_h
*
t2_src_unroll_w
+
s_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_src_ptr
+
src_offset
+
2
*
((
off_oh
*
stride_h
+
fh
+
1
+
s_h
)
/
SrcTileCount
::
bank_offset_line
));
}
}
if
(
off_ow
==
ThreadConfig
::
thread_x
-
1
)
{
reg_src
[
0
][
SrcTileConfig
::
unroll_h
*
t2_src_unroll_w
-
1
]
=
T2
{
0
,
0
};
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
t2_flt_unroll_w
-
1
;
++
f_w
)
{
int
flt_offset
=
(
fh
+
1
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
*
2
;
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_flt_ptr
+
flt_offset
+
2
*
((
fh
+
1
+
f_h
)
/
FilterTileCount
::
bank_offset_line
));
if
(
f_w
>
0
)
{
reg_flt
[
0
][
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
-
1
].
y
,
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
else
{
reg_flt
[
0
][
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
0.0
,
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
}
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
0.0
,
0.0
};
reg_flt
[
0
][
1
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
2
].
y
,
0.0
};
}
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
t2_flt_unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
t2_out_unroll_w
+
ow
]
=
megdnn
::
cuda
::
fma2
(
reg_flt
[
1
][
ow
*
stride_w
%
2
]
[
inner_fh
*
t2_flt_unroll_w
+
fw
],
reg_src
[
1
]
[(
inner_fh
+
oh
)
*
t2_src_unroll_w
+
fw
+
ow
*
stride_w
/
2
],
sum
[
oh
*
t2_out_unroll_w
+
ow
]);
}
}
}
}
}
if
(
param
.
flt_h
%
2
!=
0
)
{
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
t2_flt_unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
t2_out_unroll_w
+
ow
]
=
megdnn
::
cuda
::
fma2
(
reg_flt
[
0
][
ow
*
stride_w
%
2
]
[
inner_fh
*
t2_flt_unroll_w
+
fw
],
reg_src
[
0
]
[(
inner_fh
+
oh
)
*
t2_src_unroll_w
+
fw
+
ow
*
stride_w
/
2
],
sum
[
oh
*
t2_out_unroll_w
+
ow
]);
}
}
}
}
}
__syncthreads
();
for
(
int
o
=
0
;
o
<
OutTileConfig
::
unroll_size
;
++
o
)
{
for
(
int
i
=
1
;
i
<
ThreadConfig
::
thread_x
;
i
=
i
<<
1
)
{
...
...
@@ -429,24 +558,29 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
SrcGlobal2ShareVisitor
gl2sh_src
=
{
smem_src
,
param
.
src_w
,
static_cast
<
int
>
(
param
.
src_w
),
static_cast
<
int
>
(
is_fwd
?
src_start_h
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
),
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
)),
static_cast
<
int
>
(
is_fwd
?
src_start_w
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
),
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
,
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
,
is_fwd
?
1
:
param
.
stride_h
,
is_fwd
?
1
:
param
.
stride_w
};
FilterGlobal2ShareVisitor
gl2sh_flt
=
{
smem_flt
,
param
.
flt_w
,
is_fwd
?
0
:
param
.
flt_h
-
2
,
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
)),
static_cast
<
int
>
(
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
),
static_cast
<
int
>
(
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_h
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_w
)};
FilterGlobal2ShareVisitor
gl2sh_flt
=
{
smem_flt
,
static_cast
<
int
>
(
param
.
flt_w
),
is_fwd
?
0
:
static_cast
<
int
>
(
param
.
flt_h
-
1
),
0
,
param
.
flt_h
,
param
.
flt_w
,
static_cast
<
int
>
(
param
.
flt_h
)
,
static_cast
<
int
>
(
param
.
flt_w
)
,
1
,
1
};
...
...
@@ -458,14 +592,51 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
__syncthreads
();
T2
reg_src
[
SrcTileConfig
::
unroll_h
*
t2_src_unroll_w
],
reg_flt
[
2
][
FilterTileConfig
::
unroll_h
*
t2_flt_unroll_w
];
T2
reg_src
[
2
][
SrcTileConfig
::
unroll_h
*
t2_src_unroll_w
],
reg_flt
[
2
][
2
][
FilterTileConfig
::
unroll_h
*
t2_flt_unroll_w
];
float2
sum
[
OutTileConfig
::
unroll_size
]
=
{{
0.0
,
0.0
}};
for
(
int
fh
=
0
;
fh
<
param
.
flt_h
;
fh
+=
FilterTileConfig
::
unroll_h
)
{
gl2sh_src
.
copy
();
gl2sh_flt
.
copy
();
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
t2_src_unroll_w
;
++
s_w
)
{
int
src_offset
=
(
off_oh
*
stride_h
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
*
2
;
reg_src
[
0
][
s_h
*
t2_src_unroll_w
+
s_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_src_ptr
+
src_offset
+
((
off_oh
*
stride_h
+
s_h
)
/
SrcTileCount
::
bank_offset_line
)
*
2
);
}
}
if
(
off_ow
==
ThreadConfig
::
thread_x
-
1
)
{
reg_src
[
0
][
SrcTileConfig
::
unroll_h
*
t2_src_unroll_w
-
1
]
=
T2
{
0
,
0
};
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
t2_flt_unroll_w
-
1
;
++
f_w
)
{
int
flt_offset
=
(
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
*
2
;
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_flt_ptr
+
flt_offset
+
2
*
(
f_h
/
FilterTileCount
::
bank_offset_line
));
if
(
f_w
>
0
)
{
reg_flt
[
0
][
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
-
1
].
y
,
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
else
{
reg_flt
[
0
][
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
0.0
,
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
}
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
0.0
,
0.0
};
reg_flt
[
0
][
1
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
2
].
y
,
0.0
};
}
for
(
int
fh
=
1
;
fh
<
param
.
flt_h
-
1
;
fh
+=
FilterTileConfig
::
unroll_h
*
2
)
{
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
...
...
@@ -473,10 +644,15 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
int
src_offset
=
(
off_oh
*
stride_h
+
fh
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
*
2
;
reg_src
[
s_h
*
t2_src_unroll_w
+
s_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_src_ptr
+
src_offset
);
reg_src
[
1
][
s_h
*
t2_src_unroll_w
+
s_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_src_ptr
+
src_offset
+
2
*
((
off_oh
*
stride_h
+
fh
+
s_h
)
/
SrcTileCount
::
bank_offset_line
));
}
}
if
(
off_ow
==
ThreadConfig
::
thread_x
-
1
)
{
reg_src
[
1
][
SrcTileConfig
::
unroll_h
*
t2_src_unroll_w
-
1
]
=
T2
{
0
,
0
};
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
...
...
@@ -485,20 +661,21 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
int
flt_offset
=
(
fh
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
*
2
;
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_flt_ptr
+
flt_offset
);
reg_flt
[
1
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_flt_ptr
+
flt_offset
+
2
*
((
fh
+
f_h
)
/
FilterTileCount
::
bank_offset_line
));
if
(
f_w
>
0
)
{
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
-
1
].
y
,
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
reg_flt
[
1
][
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
reg_flt
[
1
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
-
1
].
y
,
reg_flt
[
1
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
else
{
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
0.0
,
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
reg_flt
[
1
][
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
0.0
,
reg_flt
[
1
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
}
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
0.0
,
0.0
};
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
2
].
y
,
0.0
};
reg_flt
[
1
][
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
0.0
,
0.0
};
reg_flt
[
1
][
1
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
reg_flt
[
1
]
[
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
2
].
y
,
0.0
};
}
#pragma unroll
...
...
@@ -510,9 +687,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
t2_out_unroll_w
+
ow
]
=
megdnn
::
cuda
::
fma2
(
reg_flt
[
ow
*
stride_w
%
2
]
reg_flt
[
0
][
ow
*
stride_w
%
2
]
[
inner_fh
*
t2_flt_unroll_w
+
fw
],
reg_src
[(
inner_fh
+
oh
)
*
t2_src_unroll_w
+
fw
+
reg_src
[
0
]
[(
inner_fh
+
oh
)
*
t2_src_unroll_w
+
fw
+
ow
*
stride_w
/
2
],
sum
[
oh
*
t2_out_unroll_w
+
ow
]);
}
...
...
@@ -520,13 +698,91 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
}
}
__syncthreads
();
gl2sh_src
.
commit
();
gl2sh_flt
.
commit
();
gl2sh_src
.
iter_forward
();
gl2sh_flt
.
iter_forward
();
__syncthreads
();
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
t2_src_unroll_w
;
++
s_w
)
{
int
src_offset
=
(
off_oh
*
stride_h
+
fh
+
1
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
*
2
;
reg_src
[
0
][
s_h
*
t2_src_unroll_w
+
s_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_src_ptr
+
src_offset
+
2
*
((
off_oh
*
stride_h
+
fh
+
1
+
s_h
)
/
SrcTileCount
::
bank_offset_line
));
}
}
if
(
off_ow
==
ThreadConfig
::
thread_x
-
1
)
{
reg_src
[
0
][
SrcTileConfig
::
unroll_h
*
t2_src_unroll_w
-
1
]
=
T2
{
0
,
0
};
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
t2_flt_unroll_w
-
1
;
++
f_w
)
{
int
flt_offset
=
(
fh
+
1
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
*
2
;
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_flt_ptr
+
flt_offset
+
2
*
((
fh
+
1
+
f_h
)
/
FilterTileCount
::
bank_offset_line
));
if
(
f_w
>
0
)
{
reg_flt
[
0
][
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
-
1
].
y
,
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
else
{
reg_flt
[
0
][
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
0.0
,
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
}
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
0.0
,
0.0
};
reg_flt
[
0
][
1
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
reg_flt
[
0
][
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
2
].
y
,
0.0
};
}
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
t2_flt_unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
t2_out_unroll_w
+
ow
]
=
megdnn
::
cuda
::
fma2
(
reg_flt
[
1
][
ow
*
stride_w
%
2
]
[
inner_fh
*
t2_flt_unroll_w
+
fw
],
reg_src
[
1
]
[(
inner_fh
+
oh
)
*
t2_src_unroll_w
+
fw
+
ow
*
stride_w
/
2
],
sum
[
oh
*
t2_out_unroll_w
+
ow
]);
}
}
}
}
}
if
(
param
.
flt_h
%
2
!=
0
)
{
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
t2_flt_unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
t2_out_unroll_w
+
ow
]
=
megdnn
::
cuda
::
fma2
(
reg_flt
[
0
][
ow
*
stride_w
%
2
]
[
inner_fh
*
t2_flt_unroll_w
+
fw
],
reg_src
[
0
]
[(
inner_fh
+
oh
)
*
t2_src_unroll_w
+
fw
+
ow
*
stride_w
/
2
],
sum
[
oh
*
t2_out_unroll_w
+
ow
]);
}
}
}
}
}
__syncthreads
();
for
(
int
o
=
0
;
o
<
OutTileConfig
::
unroll_size
;
++
o
)
{
for
(
int
i
=
1
;
i
<
ThreadConfig
::
thread_x
;
i
=
i
<<
1
)
{
...
...
@@ -595,24 +851,29 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
SrcGlobal2ShareVisitor
gl2sh_src
=
{
smem_src
,
param
.
src_w
,
static_cast
<
int
>
(
param
.
src_w
),
static_cast
<
int
>
(
is_fwd
?
src_start_h
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
),
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
)),
static_cast
<
int
>
(
is_fwd
?
src_start_w
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
),
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
,
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
,
is_fwd
?
1
:
param
.
stride_h
,
is_fwd
?
1
:
param
.
stride_w
};
FilterGlobal2ShareVisitor
gl2sh_flt
=
{
smem_flt
,
param
.
flt_w
,
is_fwd
?
0
:
param
.
flt_h
-
2
,
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
)),
static_cast
<
int
>
(
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
),
static_cast
<
int
>
(
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_h
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_w
)};
FilterGlobal2ShareVisitor
gl2sh_flt
=
{
smem_flt
,
static_cast
<
int
>
(
param
.
flt_w
),
is_fwd
?
0
:
static_cast
<
int
>
(
param
.
flt_h
-
1
),
0
,
param
.
flt_h
,
param
.
flt_w
,
static_cast
<
int
>
(
param
.
flt_h
)
,
static_cast
<
int
>
(
param
.
flt_w
)
,
1
,
1
};
...
...
@@ -624,22 +885,43 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
__syncthreads
();
T
reg_src
[
SrcTileConfig
::
unroll_h
*
SrcTileConfig
::
unroll_w
],
reg_flt
[
FilterTileConfig
::
unroll_h
*
FilterTileConfig
::
unroll_w
];
T
reg_src
[
2
][
SrcTileConfig
::
unroll_h
*
SrcTileConfig
::
unroll_w
],
reg_flt
[
2
][
FilterTileConfig
::
unroll_h
*
FilterTileConfig
::
unroll_w
];
T
sum
[
OutTileConfig
::
unroll_size
]
=
{
0.0
};
for
(
int
fh
=
0
;
fh
<
param
.
flt_h
;
fh
+=
FilterTileConfig
::
unroll_h
)
{
gl2sh_src
.
copy
();
gl2sh_flt
.
copy
();
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
SrcTileConfig
::
unroll_w
;
++
s_w
)
{
reg_src
[
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_src_ptr
reg_src
[
0
][
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_src_ptr
[(
off_oh
*
stride_h
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
+
(
off_oh
*
stride_h
+
s_h
)
/
SrcTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
0
][
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
+
f_h
/
FilterTileCount
::
bank_offset_line
];
}
}
for
(
int
fh
=
1
;
fh
<
param
.
flt_h
+
1
;
fh
+=
FilterTileConfig
::
unroll_h
*
2
)
{
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
SrcTileConfig
::
unroll_w
;
++
s_w
)
{
reg_src
[
1
][
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_src_ptr
[(
off_oh
*
stride_h
+
fh
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
];
s_w
+
(
off_oh
*
stride_h
+
fh
+
s_h
)
/
SrcTileCount
::
bank_offset_line
];
}
}
...
...
@@ -647,13 +929,53 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
reg_flt
[
1
][
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
fh
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
];
f_w
+
(
fh
+
f_h
)
/
FilterTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
FilterTileConfig
::
unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
0
][
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
]
*
reg_src
[
0
]
[(
inner_fh
+
oh
)
*
SrcTileConfig
::
unroll_w
+
fw
+
ow
*
stride_w
];
}
}
}
}
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
SrcTileConfig
::
unroll_w
;
++
s_w
)
{
reg_src
[
0
][
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_src_ptr
[(
off_oh
*
stride_h
+
fh
+
1
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
+
(
off_oh
*
stride_h
+
fh
+
1
+
s_h
)
/
SrcTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
0
][
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
fh
+
1
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
+
(
fh
+
1
+
f_h
)
/
FilterTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
...
...
@@ -663,21 +985,37 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
]
*
reg_src
[(
inner_fh
+
oh
)
*
SrcTileConfig
::
unroll_w
+
fw
+
reg_flt
[
1
][
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
]
*
reg_src
[
1
]
[(
inner_fh
+
oh
)
*
SrcTileConfig
::
unroll_w
+
fw
+
ow
*
stride_w
];
}
}
}
}
}
__syncthreads
();
gl2sh_src
.
commit
();
gl2sh_flt
.
commit
();
gl2sh_src
.
iter_forward
();
gl2sh_flt
.
iter_forward
();
__syncthreads
();
if
(
param
.
flt_h
%
2
!=
0
)
{
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
FilterTileConfig
::
unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
0
][
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
]
*
reg_src
[
0
]
[(
inner_fh
+
oh
)
*
SrcTileConfig
::
unroll_w
+
fw
+
ow
*
stride_w
];
}
}
}
}
}
__syncthreads
();
for
(
int
o
=
0
;
o
<
OutTileConfig
::
unroll_size
;
++
o
)
{
for
(
int
i
=
1
;
i
<
ThreadConfig
::
thread_x
;
i
=
i
<<
1
)
{
...
...
@@ -743,24 +1081,29 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
SrcGlobal2ShareVisitor
gl2sh_src
=
{
smem_src
,
param
.
src_w
,
static_cast
<
int
>
(
param
.
src_w
),
static_cast
<
int
>
(
is_fwd
?
src_start_h
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
),
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
)),
static_cast
<
int
>
(
is_fwd
?
src_start_w
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
),
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
,
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
,
is_fwd
?
1
:
param
.
stride_h
,
is_fwd
?
1
:
param
.
stride_w
};
FilterGlobal2ShareVisitor
gl2sh_flt
=
{
smem_flt
,
param
.
flt_w
,
is_fwd
?
0
:
param
.
flt_h
-
2
,
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
)),
static_cast
<
int
>
(
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
),
static_cast
<
int
>
(
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_h
),
is_fwd
?
1
:
static_cast
<
int
>
(
param
.
stride_w
)};
FilterGlobal2ShareVisitor
gl2sh_flt
=
{
smem_flt
,
static_cast
<
int
>
(
param
.
flt_w
),
is_fwd
?
0
:
static_cast
<
int
>
(
param
.
flt_h
-
1
),
0
,
param
.
flt_h
,
param
.
flt_w
,
static_cast
<
int
>
(
param
.
flt_h
)
,
static_cast
<
int
>
(
param
.
flt_w
)
,
1
,
1
};
...
...
@@ -772,22 +1115,43 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
__syncthreads
();
T
reg_src
[
SrcTileConfig
::
unroll_h
*
SrcTileConfig
::
unroll_w
],
reg_flt
[
FilterTileConfig
::
unroll_h
*
FilterTileConfig
::
unroll_w
];
T
reg_src
[
2
][
SrcTileConfig
::
unroll_h
*
SrcTileConfig
::
unroll_w
],
reg_flt
[
2
][
FilterTileConfig
::
unroll_h
*
FilterTileConfig
::
unroll_w
];
T
sum
[
OutTileConfig
::
unroll_size
]
=
{
0.0
};
for
(
int
fh
=
0
;
fh
<
param
.
flt_h
;
fh
+=
FilterTileConfig
::
unroll_h
)
{
gl2sh_src
.
copy
();
gl2sh_flt
.
copy
();
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
SrcTileConfig
::
unroll_w
;
++
s_w
)
{
reg_src
[
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_src_ptr
reg_src
[
0
][
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_src_ptr
[(
off_oh
*
stride_h
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
+
(
off_oh
*
stride_h
+
s_h
)
/
SrcTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
0
][
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
+
f_h
/
FilterTileCount
::
bank_offset_line
];
}
}
for
(
int
fh
=
1
;
fh
<
param
.
flt_h
+
1
;
fh
+=
FilterTileConfig
::
unroll_h
*
2
)
{
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
SrcTileConfig
::
unroll_w
;
++
s_w
)
{
reg_src
[
1
][
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_src_ptr
[(
off_oh
*
stride_h
+
fh
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
];
s_w
+
(
off_oh
*
stride_h
+
fh
+
s_h
)
/
SrcTileCount
::
bank_offset_line
];
}
}
...
...
@@ -795,13 +1159,73 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
reg_flt
[
1
][
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
fh
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
];
f_w
+
(
fh
+
f_h
)
/
FilterTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
FilterTileConfig
::
unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
0
][
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
]
*
reg_src
[
0
]
[(
inner_fh
+
oh
)
*
SrcTileConfig
::
unroll_w
+
fw
+
ow
*
stride_w
];
}
}
}
}
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
SrcTileConfig
::
unroll_w
;
++
s_w
)
{
reg_src
[
0
][
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_src_ptr
[(
off_oh
*
stride_h
+
fh
+
1
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
+
(
off_oh
*
stride_h
+
fh
+
1
+
s_h
)
/
SrcTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
0
][
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
fh
+
1
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
+
(
fh
+
1
+
f_h
)
/
FilterTileCount
::
bank_offset_line
];
}
}
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
FilterTileConfig
::
unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
1
][
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
]
*
reg_src
[
1
]
[(
inner_fh
+
oh
)
*
SrcTileConfig
::
unroll_w
+
fw
+
ow
*
stride_w
];
}
}
}
}
}
if
(
param
.
flt_h
%
2
!=
0
)
{
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
...
...
@@ -811,21 +1235,17 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
]
*
reg_src
[(
inner_fh
+
oh
)
*
SrcTileConfig
::
unroll_w
+
fw
+
reg_flt
[
0
][
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
]
*
reg_src
[
0
]
[(
inner_fh
+
oh
)
*
SrcTileConfig
::
unroll_w
+
fw
+
ow
*
stride_w
];
}
}
}
}
}
__syncthreads
();
gl2sh_src
.
commit
();
gl2sh_flt
.
commit
();
gl2sh_src
.
iter_forward
();
gl2sh_flt
.
iter_forward
();
__syncthreads
();
}
for
(
int
o
=
0
;
o
<
OutTileConfig
::
unroll_size
;
++
o
)
{
for
(
int
i
=
1
;
i
<
ThreadConfig
::
thread_x
;
i
=
i
<<
1
)
{
...
...
@@ -901,9 +1321,8 @@ void LaunchDepthwiseConv2dGPU(
#define INSTANCE_A(type1, type2, a, direction) \
if (param.flt_w > a * 4) { \
INSTANCE_AB(type1, type2, a, 15, direction) \
else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB( \
type1, type2, a, 3, direction) \
INSTANCE_AB(type1, type2, a, 7, direction) \
else INSTANCE_AB(type1, type2, a, 3, direction) \
}
#define INSTANCE(type1, type2, direction) \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录