Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ac26bdce
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ac26bdce
编写于
3月 07, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(cuda): fix direct conv speed and memory problem
GitOrigin-RevId: 6faeeff3b80b9cd2245268bfaa9c017b1d3bac58
上级
f7994683
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
28 addition
and
31 deletion
+28
-31
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh
+4
-6
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh
...c/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh
+24
-25
未找到文件。
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh
浏览文件 @
ac26bdce
...
...
@@ -142,7 +142,7 @@ struct ConvTraitInner {
}
#define CHECK_AB_BWD(a, b) \
if (param.out_w > b * 4
) {
\
if (param.out_w > b * 4
|| b == 3) {
\
using FilterTileConfig_ = FilterTileConfig<unroll_fh, a + 2>; \
using ThreadConfig_ = ThreadConfig<4, 32>; \
using OutTileConfig_ = OutTileConfig<ThreadConfig_, unroll_oh, b + 1>; \
...
...
@@ -165,11 +165,9 @@ struct ConvTraitInner {
return true; \
}
#define CHECK_A(a, cb) \
if (param.flt_w > a * 4) { \
CHECK_AB_##cb( \
a, \
15) else CHECK_AB_##cb(a, 14) else CHECK_AB_##cb(a, 13) else CHECK_AB_##cb(a, 12) else CHECK_AB_##cb(a, 11) else CHECK_AB_##cb(a, 10) else CHECK_AB_##cb(a, 9) else CHECK_AB_##cb(a, 8) else CHECK_AB_##cb(a, 7) else CHECK_AB_##cb(a, 6) else CHECK_AB_##cb(a, 5) else CHECK_AB_##cb(a, 4) else CHECK_AB_##cb(a, 3) else CHECK_AB_##cb(a, 2) else CHECK_AB_##cb(a, 1) else CHECK_AB_##cb(a, 0) \
#define CHECK_A(a, cb) \
if (param.flt_w > a * 4) { \
CHECK_AB_##cb(a, 15) else CHECK_AB_##cb(a, 7) else CHECK_AB_##cb(a, 3) \
}
#define CHECK(cb) \
...
...
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh
浏览文件 @
ac26bdce
...
...
@@ -217,7 +217,7 @@ __device__ __forceinline__ void Global2SharedMem<
// Backprop input direction is the same as forward direction with the filter
// rotated by 180°.
#if CUDA_VERSION >= 9000
template
<
typename
ConvTrait
,
DepthwiseConv2dDirection
kDirection
>
template
<
typename
ConvTrait
,
DepthwiseConv2dDirection
kDirection
,
int
stride
>
__global__
void
DepthwiseConv2dGPUKernelNCHW
(
const
Param
param
,
const
__half
*
input
,
const
__half
*
filter
,
__half
*
output
)
{
using
T
=
__half
;
...
...
@@ -230,7 +230,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
using
FilterTileCount
=
typename
ConvTrait
::
FilterTileCount
;
using
SrcGlobal2ShareVisitor
=
typename
ConvTrait
::
SrcGlobal2ShareVisitor
;
using
FilterGlobal2ShareVisitor
=
typename
ConvTrait
::
FilterGlobal2ShareVisitor
;
const
bool
is_fwd
=
(
kDirection
==
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
);
const
expr
bool
is_fwd
=
(
kDirection
==
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
);
int
off_ochannel
=
blockIdx
.
x
,
off_obw
=
blockIdx
.
y
,
off_obh
=
blockIdx
.
z
,
off_oh
=
threadIdx
.
y
,
off_ow
=
threadIdx
.
x
;
...
...
@@ -243,8 +243,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
static_assert
(
sizeof
(
T
)
<=
8
,
"Insufficient alignment detected"
);
T
*
smem_src
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
smem_flt
=
reinterpret_cast
<
T
*>
(
&
smem_src
[
SrcTileCount
::
smem_size
]);
int
stride_h
=
is_fwd
?
param
.
stride_h
:
1
;
int
stride_w
=
is_fwd
?
param
.
stride_w
:
1
;
constexpr
int
stride_h
=
is_fwd
?
stride
:
1
;
constexpr
int
stride_w
=
is_fwd
?
stride
:
1
;
int
off_ichannel
=
off_ochannel
/
param
.
chl_mul
,
off_fchannel
=
off_ichannel
%
param
.
src_chl
,
...
...
@@ -385,7 +385,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
}
}
template
<
typename
ConvTrait
,
DepthwiseConv2dDirection
kDirection
>
template
<
typename
ConvTrait
,
DepthwiseConv2dDirection
kDirection
,
int
stride
>
__global__
void
DepthwiseConv2dGPUKernelNCHWC32
(
const
Param
param
,
const
__half
*
input
,
const
__half
*
filter
,
__half
*
output
)
{
using
T
=
__half
;
...
...
@@ -398,7 +398,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
using
FilterTileCount
=
typename
ConvTrait
::
FilterTileCount
;
using
SrcGlobal2ShareVisitor
=
typename
ConvTrait
::
SrcGlobal2ShareVisitor
;
using
FilterGlobal2ShareVisitor
=
typename
ConvTrait
::
FilterGlobal2ShareVisitor
;
const
bool
is_fwd
=
(
kDirection
==
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
);
const
expr
bool
is_fwd
=
(
kDirection
==
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
);
int
off_ochannel
=
blockIdx
.
x
,
off_obw
=
blockIdx
.
y
,
off_obh
=
blockIdx
.
z
,
off_oh
=
threadIdx
.
y
,
off_ow
=
threadIdx
.
x
;
...
...
@@ -411,8 +411,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
static_assert
(
sizeof
(
T
)
<=
8
,
"Insufficient alignment detected"
);
T
*
smem_src
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
smem_flt
=
reinterpret_cast
<
T
*>
(
&
smem_src
[
SrcTileCount
::
smem_size
]);
int
stride_h
=
is_fwd
?
param
.
stride_h
:
1
;
int
stride_w
=
is_fwd
?
param
.
stride_w
:
1
;
constexpr
int
stride_h
=
is_fwd
?
stride
:
1
;
constexpr
int
stride_w
=
is_fwd
?
stride
:
1
;
int
off_ichannel
=
off_ochannel
/
param
.
chl_mul
,
off_fchannel
=
off_ichannel
%
param
.
src_chl
,
...
...
@@ -555,7 +555,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
}
#endif
template
<
typename
ConvTrait
,
DepthwiseConv2dDirection
kDirection
>
template
<
typename
ConvTrait
,
DepthwiseConv2dDirection
kDirection
,
int
stride
>
__global__
void
DepthwiseConv2dGPUKernelNCHW
(
const
Param
param
,
const
float
*
input
,
const
float
*
filter
,
float
*
output
)
{
using
T
=
float
;
...
...
@@ -568,7 +568,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
using
FilterTileCount
=
typename
ConvTrait
::
FilterTileCount
;
using
SrcGlobal2ShareVisitor
=
typename
ConvTrait
::
SrcGlobal2ShareVisitor
;
using
FilterGlobal2ShareVisitor
=
typename
ConvTrait
::
FilterGlobal2ShareVisitor
;
const
bool
is_fwd
=
(
kDirection
==
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
);
const
expr
bool
is_fwd
=
(
kDirection
==
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
);
int
off_ochannel
=
blockIdx
.
x
,
off_obw
=
blockIdx
.
y
,
off_obh
=
blockIdx
.
z
,
off_oh
=
threadIdx
.
y
,
off_ow
=
threadIdx
.
x
;
...
...
@@ -577,8 +577,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
static_assert
(
sizeof
(
T
)
<=
8
,
"Insufficient alignment detected"
);
T
*
smem_src
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
smem_flt
=
reinterpret_cast
<
T
*>
(
&
smem_src
[
SrcTileCount
::
smem_size
]);
int
stride_h
=
is_fwd
?
param
.
stride_h
:
1
;
int
stride_w
=
is_fwd
?
param
.
stride_w
:
1
;
constexpr
int
stride_h
=
is_fwd
?
stride
:
1
;
constexpr
int
stride_w
=
is_fwd
?
stride
:
1
;
int
off_ichannel
=
off_ochannel
/
param
.
chl_mul
,
off_fchannel
=
off_ichannel
%
param
.
src_chl
,
...
...
@@ -703,7 +703,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
}
}
template
<
typename
ConvTrait
,
DepthwiseConv2dDirection
kDirection
>
template
<
typename
ConvTrait
,
DepthwiseConv2dDirection
kDirection
,
int
stride
>
__global__
void
DepthwiseConv2dGPUKernelNCHWC32
(
const
Param
param
,
const
float
*
input
,
const
float
*
filter
,
float
*
output
)
{
using
T
=
float
;
...
...
@@ -716,7 +716,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
using
FilterTileCount
=
typename
ConvTrait
::
FilterTileCount
;
using
SrcGlobal2ShareVisitor
=
typename
ConvTrait
::
SrcGlobal2ShareVisitor
;
using
FilterGlobal2ShareVisitor
=
typename
ConvTrait
::
FilterGlobal2ShareVisitor
;
const
bool
is_fwd
=
(
kDirection
==
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
);
const
expr
bool
is_fwd
=
(
kDirection
==
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
);
int
off_ochannel
=
blockIdx
.
x
,
off_obw
=
blockIdx
.
y
,
off_obh
=
blockIdx
.
z
,
off_oh
=
threadIdx
.
y
,
off_ow
=
threadIdx
.
x
;
...
...
@@ -725,8 +725,8 @@ __global__ void DepthwiseConv2dGPUKernelNCHWC32(
static_assert
(
sizeof
(
T
)
<=
8
,
"Insufficient alignment detected"
);
T
*
smem_src
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
smem_flt
=
reinterpret_cast
<
T
*>
(
&
smem_src
[
SrcTileCount
::
smem_size
]);
int
stride_h
=
is_fwd
?
param
.
stride_h
:
1
;
int
stride_w
=
is_fwd
?
param
.
stride_w
:
1
;
constexpr
int
stride_h
=
is_fwd
?
stride
:
1
;
constexpr
int
stride_w
=
is_fwd
?
stride
:
1
;
int
off_ichannel
=
off_ochannel
/
param
.
chl_mul
,
off_fchannel
=
off_ichannel
%
param
.
src_chl
,
...
...
@@ -879,16 +879,16 @@ void LaunchDepthwiseConv2dGPU(
void
(
*
kernel
)(
const
Param
,
const
T
*
,
const
T
*
,
T
*
);
if
(
param
.
is_compute_deafult
)
{
kernel
=
DepthwiseConv2dGPUKernelNCHW
<
IConvTrait
,
kDirection
>
;
kernel
=
DepthwiseConv2dGPUKernelNCHW
<
IConvTrait
,
kDirection
,
stride
>
;
}
else
{
kernel
=
DepthwiseConv2dGPUKernelNCHWC32
<
IConvTrait
,
kDirection
>
;
kernel
=
DepthwiseConv2dGPUKernelNCHWC32
<
IConvTrait
,
kDirection
,
stride
>
;
}
kernel
<<<
grid
,
block
,
shared_storage
,
stream
>>>
(
param
,
input
,
filter
,
output
);
after_kernel_launch
();
}
#define INSTANCE_AB(type1, type2, a, b, direction) \
if (param.out_w > b * 4
) {
\
if (param.out_w > b * 4
|| b == 3) {
\
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \
(param.stride_h == 1 && param.stride_w == 1)) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a + 2, b + 1, 1>( \
...
...
@@ -899,12 +899,11 @@ void LaunchDepthwiseConv2dGPU(
} \
}
#define INSTANCE_A(type1, type2, a, direction) \
if (param.flt_w > a * 4) { \
INSTANCE_AB(type1, type2, a, 15, direction) \
else INSTANCE_AB(type1, type2, a, 14, direction) else INSTANCE_AB(type1, type2, a, 13, direction) else INSTANCE_AB(type1, type2, a, 12, direction) else INSTANCE_AB(type1, type2, a, 11, direction) else INSTANCE_AB(type1, type2, a, 10, direction) else INSTANCE_AB( \
type1, type2, \
a, 9, direction) else INSTANCE_AB(type1, type2, a, 8, direction) else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB(type1, type2, a, 6, direction) else INSTANCE_AB(type1, type2, a, 5, direction) else INSTANCE_AB(type1, type2, a, 4, direction) else INSTANCE_AB(type1, type2, a, 3, direction) else INSTANCE_AB(type1, type2, a, 2, direction) else INSTANCE_AB(type1, type2, a, 1, direction) else INSTANCE_AB(type1, type2, a, 0, direction) \
#define INSTANCE_A(type1, type2, a, direction) \
if (param.flt_w > a * 4) { \
INSTANCE_AB(type1, type2, a, 15, direction) \
else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB( \
type1, type2, a, 3, direction) \
}
#define INSTANCE(type1, type2, direction) \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录