Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8e5410e4
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
8e5410e4
编写于
2月 23, 2022
作者:
M
Megvii Engine Team
提交者:
王彪
2月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(cuda): add fp16 compute 16 kernel
GitOrigin-RevId: e03435be021ccf3d8eff357a80d5203e903aca96
上级
472e2f96
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
374 addition
and
26 deletion
+374
-26
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh
...c/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh
+331
-10
dnn/src/cuda/conv_bias/chanwise/kern.cuh
dnn/src/cuda/conv_bias/chanwise/kern.cuh
+9
-6
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
+4
-2
dnn/src/cuda/convolution/chanwise/kern.cuh
dnn/src/cuda/convolution/chanwise/kern.cuh
+9
-6
dnn/src/cuda/fp16_help.cuh
dnn/src/cuda/fp16_help.cuh
+9
-0
dnn/test/cuda/conv_bias.cpp
dnn/test/cuda/conv_bias.cpp
+6
-1
dnn/test/cuda/convolution.cpp
dnn/test/cuda/convolution.cpp
+6
-1
未找到文件。
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh
浏览文件 @
8e5410e4
...
...
@@ -235,7 +235,175 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
int
off_ochannel
=
blockIdx
.
x
,
off_obw
=
blockIdx
.
y
,
off_obh
=
blockIdx
.
z
,
off_oh
=
threadIdx
.
y
,
off_ow
=
threadIdx
.
x
;
constexpr
int
t2_src_unroll_w
=
(
SrcTileConfig
::
unroll_w
+
1
)
/
2
;
constexpr
int
t2_src_unroll_w
=
(
SrcTileConfig
::
unroll_w
+
3
)
/
2
;
constexpr
int
t2_flt_unroll_w
=
(
FilterTileConfig
::
unroll_w
+
2
)
/
2
;
constexpr
int
t2_out_unroll_w
=
(
OutTileConfig
::
unroll_w
+
1
)
/
2
;
extern
__shared__
__align__
(
8
)
unsigned
char
smem
[];
static_assert
(
sizeof
(
T
)
<=
8
,
"Insufficient alignment detected"
);
T
*
smem_src
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
smem_flt
=
reinterpret_cast
<
T
*>
(
&
smem_src
[
SrcTileCount
::
smem_size
]);
int
stride_h
=
is_fwd
?
param
.
stride_h
:
1
;
int
stride_w
=
is_fwd
?
param
.
stride_w
:
1
;
int
off_ichannel
=
off_ochannel
/
param
.
chl_mul
,
off_fchannel
=
off_ichannel
%
param
.
src_chl
,
out_start_h
=
off_obh
*
OutTileConfig
::
block_h
,
out_start_w
=
off_obw
*
OutTileConfig
::
block_w
,
src_start_h
=
out_start_h
*
stride_h
-
param
.
pad_h
,
src_start_w
=
out_start_w
*
stride_w
-
param
.
pad_w
,
out_base_h_idx
=
out_start_h
+
off_oh
*
OutTileConfig
::
unroll_h
;
T
*
smem_src_ptr
=
smem_src
+
off_ow
*
FilterTileConfig
::
unroll_w
;
T
*
smem_flt_ptr
=
smem_flt
+
off_ow
*
FilterTileConfig
::
unroll_w
;
T
*
out_base_ptr
=
output
+
off_ochannel
*
param
.
out_h
*
param
.
out_w
;
SrcGlobal2ShareVisitor
gl2sh_src
=
{
smem_src
,
param
.
src_w
,
is_fwd
?
src_start_h
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
),
is_fwd
?
src_start_w
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
),
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
,
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
,
is_fwd
?
1
:
param
.
stride_h
,
is_fwd
?
1
:
param
.
stride_w
};
FilterGlobal2ShareVisitor
gl2sh_flt
=
{
smem_flt
,
param
.
flt_w
,
is_fwd
?
0
:
param
.
flt_h
-
2
,
0
,
param
.
flt_h
,
param
.
flt_w
,
1
,
1
};
gl2sh_src
.
g_ptr
=
input
+
off_ichannel
*
param
.
src_h
*
param
.
src_w
;
gl2sh_flt
.
g_ptr
=
filter
+
off_fchannel
*
param
.
flt_h
*
param
.
flt_w
;
gl2sh_src
.
first_copy
();
gl2sh_flt
.
first_copy
();
__syncthreads
();
T2
reg_src
[
SrcTileConfig
::
unroll_h
*
t2_src_unroll_w
],
reg_flt
[
2
][
FilterTileConfig
::
unroll_h
*
t2_flt_unroll_w
];
T2
sum
[
OutTileConfig
::
unroll_size
]
=
{{
0.0
,
0.0
}};
for
(
int
fh
=
0
;
fh
<
param
.
flt_h
;
fh
+=
FilterTileConfig
::
unroll_h
)
{
gl2sh_src
.
copy
();
gl2sh_flt
.
copy
();
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
t2_src_unroll_w
;
++
s_w
)
{
int
src_offset
=
(
off_oh
*
stride_h
+
fh
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
*
2
;
reg_src
[
s_h
*
t2_src_unroll_w
+
s_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_src_ptr
+
src_offset
);
}
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
t2_flt_unroll_w
-
1
;
++
f_w
)
{
int
flt_offset
=
(
fh
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
*
2
;
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_flt_ptr
+
flt_offset
);
if
(
f_w
>
0
)
{
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
-
1
].
y
,
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
else
{
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
0.0
,
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
}
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
0.0
,
0.0
};
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
2
].
y
,
0.0
};
}
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
t2_flt_unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
t2_out_unroll_w
+
ow
]
=
megdnn
::
cuda
::
fma2
(
reg_flt
[
ow
*
stride_w
%
2
]
[
inner_fh
*
t2_flt_unroll_w
+
fw
],
reg_src
[(
inner_fh
+
oh
)
*
t2_src_unroll_w
+
fw
+
ow
*
stride_w
/
2
],
sum
[
oh
*
t2_out_unroll_w
+
ow
]);
}
}
}
}
__syncthreads
();
gl2sh_src
.
commit
();
gl2sh_flt
.
commit
();
gl2sh_src
.
iter_forward
();
gl2sh_flt
.
iter_forward
();
__syncthreads
();
}
for
(
int
o
=
0
;
o
<
OutTileConfig
::
unroll_size
;
++
o
)
{
for
(
int
i
=
1
;
i
<
ThreadConfig
::
thread_x
;
i
=
i
<<
1
)
{
sum
[
o
]
=
megdnn
::
cuda
::
hadd2
(
sum
[
o
],
__shfl_xor
(
sum
[
o
],
i
,
32
));
}
}
if
(
threadIdx
.
x
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
OutTileConfig
::
unroll_h
;
++
i
)
{
int
out_h_idx
=
out_base_h_idx
+
i
;
if
(
out_h_idx
<
param
.
out_h
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
OutTileConfig
::
unroll_w
;
++
j
)
{
int
out_w_idx
=
out_start_w
+
j
;
if
(
out_w_idx
>=
param
.
out_w
)
return
;
out_base_ptr
[
out_h_idx
*
param
.
out_w
+
out_w_idx
]
=
__float2half
(
__half2float
(
sum
[
i
*
OutTileConfig
::
unroll_w
+
j
].
x
)
+
__half2float
(
sum
[
i
*
OutTileConfig
::
unroll_w
+
j
].
y
));
}
}
}
}
}
template
<
typename
ConvTrait
,
DepthwiseConv2dDirection
kDirection
>
__global__
void
DepthwiseConv2dGPUKernelNCHWC32
(
const
Param
param
,
const
__half
*
input
,
const
__half
*
filter
,
__half
*
output
)
{
using
T
=
__half
;
using
T2
=
__half2
;
using
ThreadConfig
=
typename
ConvTrait
::
ThreadConfig
;
using
SrcTileConfig
=
typename
ConvTrait
::
SrcTileConfig
;
using
FilterTileConfig
=
typename
ConvTrait
::
FilterTileConfig
;
using
OutTileConfig
=
typename
ConvTrait
::
OutTileConfig
;
using
SrcTileCount
=
typename
ConvTrait
::
SrcTileCount
;
using
FilterTileCount
=
typename
ConvTrait
::
FilterTileCount
;
using
SrcGlobal2ShareVisitor
=
typename
ConvTrait
::
SrcGlobal2ShareVisitor
;
using
FilterGlobal2ShareVisitor
=
typename
ConvTrait
::
FilterGlobal2ShareVisitor
;
const
bool
is_fwd
=
(
kDirection
==
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
);
int
off_ochannel
=
blockIdx
.
x
,
off_obw
=
blockIdx
.
y
,
off_obh
=
blockIdx
.
z
,
off_oh
=
threadIdx
.
y
,
off_ow
=
threadIdx
.
x
;
constexpr
int
t2_src_unroll_w
=
(
SrcTileConfig
::
unroll_w
+
3
)
/
2
;
constexpr
int
t2_flt_unroll_w
=
(
FilterTileConfig
::
unroll_w
+
2
)
/
2
;
constexpr
int
t2_out_unroll_w
=
(
OutTileConfig
::
unroll_w
+
1
)
/
2
;
...
...
@@ -320,17 +488,17 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
*
reinterpret_cast
<
T2
*>
(
smem_flt_ptr
+
flt_offset
);
if
(
f_w
>
0
)
{
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
{
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
-
1
].
y
,
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
-
1
].
y
,
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
else
{
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
{
0.0
,
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
f_w
]
=
T2
{
0.0
,
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
f_w
].
x
};
}
}
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
{
0.0
,
0.0
};
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
{
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
2
].
y
,
0.0
};
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
0.0
,
0.0
};
reg_flt
[
1
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
1
]
=
T2
{
reg_flt
[
0
][
f_h
*
t2_flt_unroll_w
+
t2_flt_unroll_w
-
2
].
y
,
0.0
};
}
#pragma unroll
...
...
@@ -535,6 +703,154 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
}
}
template
<
typename
ConvTrait
,
DepthwiseConv2dDirection
kDirection
>
__global__
void
DepthwiseConv2dGPUKernelNCHWC32
(
const
Param
param
,
const
float
*
input
,
const
float
*
filter
,
float
*
output
)
{
using
T
=
float
;
using
T2
=
float2
;
using
ThreadConfig
=
typename
ConvTrait
::
ThreadConfig
;
using
SrcTileConfig
=
typename
ConvTrait
::
SrcTileConfig
;
using
FilterTileConfig
=
typename
ConvTrait
::
FilterTileConfig
;
using
OutTileConfig
=
typename
ConvTrait
::
OutTileConfig
;
using
SrcTileCount
=
typename
ConvTrait
::
SrcTileCount
;
using
FilterTileCount
=
typename
ConvTrait
::
FilterTileCount
;
using
SrcGlobal2ShareVisitor
=
typename
ConvTrait
::
SrcGlobal2ShareVisitor
;
using
FilterGlobal2ShareVisitor
=
typename
ConvTrait
::
FilterGlobal2ShareVisitor
;
const
bool
is_fwd
=
(
kDirection
==
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
);
int
off_ochannel
=
blockIdx
.
x
,
off_obw
=
blockIdx
.
y
,
off_obh
=
blockIdx
.
z
,
off_oh
=
threadIdx
.
y
,
off_ow
=
threadIdx
.
x
;
extern
__shared__
__align__
(
8
)
unsigned
char
smem
[];
static_assert
(
sizeof
(
T
)
<=
8
,
"Insufficient alignment detected"
);
T
*
smem_src
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
smem_flt
=
reinterpret_cast
<
T
*>
(
&
smem_src
[
SrcTileCount
::
smem_size
]);
int
stride_h
=
is_fwd
?
param
.
stride_h
:
1
;
int
stride_w
=
is_fwd
?
param
.
stride_w
:
1
;
int
off_ichannel
=
off_ochannel
/
param
.
chl_mul
,
off_fchannel
=
off_ichannel
%
param
.
src_chl
,
out_start_h
=
off_obh
*
OutTileConfig
::
block_h
,
out_start_w
=
off_obw
*
OutTileConfig
::
block_w
,
src_start_h
=
out_start_h
*
stride_h
-
param
.
pad_h
,
src_start_w
=
out_start_w
*
stride_w
-
param
.
pad_w
,
out_base_h_idx
=
out_start_h
+
off_oh
*
OutTileConfig
::
unroll_h
;
T
*
smem_src_ptr
=
smem_src
+
off_ow
*
FilterTileConfig
::
unroll_w
;
T
*
smem_flt_ptr
=
smem_flt
+
off_ow
*
FilterTileConfig
::
unroll_w
;
T
*
out_base_ptr
=
output
+
off_ochannel
*
param
.
out_h
*
param
.
out_w
;
SrcGlobal2ShareVisitor
gl2sh_src
=
{
smem_src
,
param
.
src_w
,
is_fwd
?
src_start_h
:
src_start_h
-
(
param
.
out_h
/
2
+
param
.
flt_h
/
2
-
param
.
pad_h
-
param
.
src_h
*
param
.
stride_h
/
2
),
is_fwd
?
src_start_w
:
src_start_w
-
(
param
.
out_w
/
2
+
param
.
flt_w
/
2
-
param
.
pad_w
-
param
.
src_w
*
param
.
stride_w
/
2
),
is_fwd
?
param
.
src_h
:
param
.
src_h
*
param
.
stride_h
,
is_fwd
?
param
.
src_w
:
param
.
src_w
*
param
.
stride_w
,
is_fwd
?
1
:
param
.
stride_h
,
is_fwd
?
1
:
param
.
stride_w
};
FilterGlobal2ShareVisitor
gl2sh_flt
=
{
smem_flt
,
param
.
flt_w
,
is_fwd
?
0
:
param
.
flt_h
-
2
,
0
,
param
.
flt_h
,
param
.
flt_w
,
1
,
1
};
gl2sh_src
.
g_ptr
=
input
+
off_ichannel
*
param
.
src_h
*
param
.
src_w
;
gl2sh_flt
.
g_ptr
=
filter
+
off_fchannel
*
param
.
flt_h
*
param
.
flt_w
;
gl2sh_src
.
first_copy
();
gl2sh_flt
.
first_copy
();
__syncthreads
();
T
reg_src
[
SrcTileConfig
::
unroll_h
*
SrcTileConfig
::
unroll_w
],
reg_flt
[
FilterTileConfig
::
unroll_h
*
FilterTileConfig
::
unroll_w
];
T
sum
[
OutTileConfig
::
unroll_size
]
=
{
0.0
};
for
(
int
fh
=
0
;
fh
<
param
.
flt_h
;
fh
+=
FilterTileConfig
::
unroll_h
)
{
gl2sh_src
.
copy
();
gl2sh_flt
.
copy
();
#pragma unroll
for
(
int
s_h
=
0
;
s_h
<
SrcTileConfig
::
unroll_h
;
++
s_h
)
{
#pragma unroll
for
(
int
s_w
=
0
;
s_w
<
SrcTileConfig
::
unroll_w
;
++
s_w
)
{
reg_src
[
s_h
*
SrcTileConfig
::
unroll_w
+
s_w
]
=
smem_src_ptr
[(
off_oh
*
stride_h
+
fh
+
s_h
)
%
SrcTileCount
::
smem_h
*
SrcTileCount
::
smem_w
+
s_w
];
}
}
#pragma unroll
for
(
int
f_h
=
0
;
f_h
<
FilterTileConfig
::
unroll_h
;
++
f_h
)
{
#pragma unroll
for
(
int
f_w
=
0
;
f_w
<
FilterTileConfig
::
unroll_w
;
++
f_w
)
{
reg_flt
[
f_h
*
FilterTileConfig
::
unroll_w
+
f_w
]
=
smem_flt_ptr
[(
fh
+
f_h
)
%
FilterTileCount
::
smem_h
*
FilterTileCount
::
smem_w
+
f_w
];
}
}
#pragma unroll
for
(
int
inner_fh
=
0
;
inner_fh
<
FilterTileConfig
::
unroll_h
;
++
inner_fh
)
{
#pragma unroll
for
(
int
oh
=
0
;
oh
<
OutTileConfig
::
unroll_h
;
++
oh
)
{
#pragma unroll
for
(
int
fw
=
0
;
fw
<
FilterTileConfig
::
unroll_w
;
++
fw
)
{
#pragma unroll
for
(
int
ow
=
0
;
ow
<
OutTileConfig
::
unroll_w
;
++
ow
)
{
sum
[
oh
*
OutTileConfig
::
unroll_w
+
ow
]
+=
reg_flt
[
inner_fh
*
FilterTileConfig
::
unroll_w
+
fw
]
*
reg_src
[(
inner_fh
+
oh
)
*
SrcTileConfig
::
unroll_w
+
fw
+
ow
*
stride_w
];
}
}
}
}
__syncthreads
();
gl2sh_src
.
commit
();
gl2sh_flt
.
commit
();
gl2sh_src
.
iter_forward
();
gl2sh_flt
.
iter_forward
();
__syncthreads
();
}
for
(
int
o
=
0
;
o
<
OutTileConfig
::
unroll_size
;
++
o
)
{
for
(
int
i
=
1
;
i
<
ThreadConfig
::
thread_x
;
i
=
i
<<
1
)
{
sum
[
o
]
+=
__shfl_xor
(
sum
[
o
],
i
,
32
);
}
}
if
(
threadIdx
.
x
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
OutTileConfig
::
unroll_h
;
++
i
)
{
int
out_h_idx
=
out_base_h_idx
+
i
;
if
(
out_h_idx
<
param
.
out_h
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
OutTileConfig
::
unroll_w
;
++
j
)
{
int
out_w_idx
=
out_start_w
+
j
;
if
(
out_w_idx
>=
param
.
out_w
)
return
;
out_base_ptr
[
out_h_idx
*
param
.
out_w
+
out_w_idx
]
=
sum
[
i
*
OutTileConfig
::
unroll_w
+
j
];
}
}
}
}
}
template
<
typename
T
,
typename
T2
,
DepthwiseConv2dDirection
kDirection
,
int
unroll_fw
,
int
unroll_ow
,
int
stride
>
...
...
@@ -561,7 +877,12 @@ void LaunchDepthwiseConv2dGPU(
(
SrcTileCount
::
smem_size
+
FilterTileCount
::
smem_size
)
*
sizeof
(
T
);
void
(
*
kernel
)(
const
Param
,
const
T
*
,
const
T
*
,
T
*
);
kernel
=
DepthwiseConv2dGPUKernelNCHW
<
IConvTrait
,
kDirection
>
;
if
(
param
.
is_compute_deafult
)
{
kernel
=
DepthwiseConv2dGPUKernelNCHW
<
IConvTrait
,
kDirection
>
;
}
else
{
kernel
=
DepthwiseConv2dGPUKernelNCHWC32
<
IConvTrait
,
kDirection
>
;
}
kernel
<<<
grid
,
block
,
shared_storage
,
stream
>>>
(
param
,
input
,
filter
,
output
);
after_kernel_launch
();
}
...
...
dnn/src/cuda/conv_bias/chanwise/kern.cuh
浏览文件 @
8e5410e4
...
...
@@ -27,8 +27,10 @@ namespace chanwise {
struct
Param
{
uint32_t
batch
,
src_chl
,
src_h
,
src_w
,
chl_mul
,
flt_h
,
flt_w
,
out_h
,
out_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
;
bool
is_compute_deafult
;
#if MEGDNN_CC_HOST
static
Param
from_fwd_args
(
const
BiasForwardSizeArgs
&
args
)
{
static
Param
from_fwd_args
(
const
BiasForwardSizeArgs
&
args
,
bool
is_compute_deafult_
=
true
)
{
#define U(v) static_cast<uint32_t>(v)
auto
&&
src
=
args
.
src_layout
->
shape
;
auto
&&
dst
=
args
.
dst_layout
->
shape
;
...
...
@@ -42,11 +44,12 @@ struct Param {
hw_pos
=
1
;
}
return
{
U
(
src
[
0
]),
U
(
src
[
c_pos
]),
U
(
src
[
hw_pos
]),
U
(
src
[
hw_pos
+
1
]),
U
(
fm
.
ocpg
),
U
(
fm
.
spatial
[
0
]),
U
(
fm
.
spatial
[
1
]),
U
(
dst
[
hw_pos
]),
U
(
dst
[
hw_pos
+
1
]),
U
(
fm
.
padding
[
0
]),
U
(
fm
.
padding
[
1
]),
U
(
fm
.
stride
[
0
]),
U
(
fm
.
stride
[
1
]),
U
(
fm
.
dilation
[
0
]),
U
(
fm
.
dilation
[
1
]),
U
(
src
[
0
]),
U
(
src
[
c_pos
]),
U
(
src
[
hw_pos
]),
U
(
src
[
hw_pos
+
1
]),
U
(
fm
.
ocpg
),
U
(
fm
.
spatial
[
0
]),
U
(
fm
.
spatial
[
1
]),
U
(
dst
[
hw_pos
]),
U
(
dst
[
hw_pos
+
1
]),
U
(
fm
.
padding
[
0
]),
U
(
fm
.
padding
[
1
]),
U
(
fm
.
stride
[
0
]),
U
(
fm
.
stride
[
1
]),
U
(
fm
.
dilation
[
0
]),
U
(
fm
.
dilation
[
1
]),
is_compute_deafult_
,
};
#undef U
}
...
...
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
浏览文件 @
8e5410e4
...
...
@@ -47,7 +47,8 @@ bool ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::is_available(
if
(
args
.
z_layout
->
ndim
>
0
)
return
false
;
auto
param
=
chanwise
::
Param
::
from_fwd_args
(
args
);
auto
param
=
chanwise
::
Param
::
from_fwd_args
(
args
,
args
.
opr
->
param
().
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
);
auto
&&
fm
=
args
.
filter_meta
;
return
fm
.
group
>
1
&&
args
.
filter_meta
.
format
==
Param
::
Format
::
NCHW
&&
args
.
src_layout
->
dtype
.
category
()
==
DTypeCategory
::
FLOAT
&&
...
...
@@ -80,7 +81,8 @@ void ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::exec(const ExecArgs& args) c
conv_dst_tensor
.
layout
.
dtype
);
}
{
auto
kparam
=
chanwise
::
Param
::
from_fwd_args
(
args
);
auto
kparam
=
chanwise
::
Param
::
from_fwd_args
(
args
,
args
.
opr
->
param
().
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
);
auto
stream
=
cuda_stream
(
args
.
handle
);
switch
(
args
.
src_layout
->
dtype
.
enumv
())
{
case
DTypeEnum
::
Float32
:
...
...
dnn/src/cuda/convolution/chanwise/kern.cuh
浏览文件 @
8e5410e4
...
...
@@ -27,8 +27,10 @@ namespace chanwise {
struct
Param
{
uint32_t
batch
,
src_chl
,
src_h
,
src_w
,
chl_mul
,
flt_h
,
flt_w
,
out_h
,
out_w
,
pad_h
,
pad_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
;
bool
is_compute_deafult
;
#if MEGDNN_CC_HOST
static
Param
from_fwd_args
(
const
ForwardSizeArgs
&
args
)
{
static
Param
from_fwd_args
(
const
ForwardSizeArgs
&
args
,
bool
is_compute_deafult_
=
true
)
{
#define U(v) static_cast<uint32_t>(v)
auto
&&
src
=
args
.
src_layout
->
shape
;
auto
&&
dst
=
args
.
dst_layout
->
shape
;
...
...
@@ -42,11 +44,12 @@ struct Param {
hw_pos
=
1
;
}
return
{
U
(
src
[
0
]),
U
(
src
[
c_pos
]),
U
(
src
[
hw_pos
]),
U
(
src
[
hw_pos
+
1
]),
U
(
fm
.
ocpg
),
U
(
fm
.
spatial
[
0
]),
U
(
fm
.
spatial
[
1
]),
U
(
dst
[
hw_pos
]),
U
(
dst
[
hw_pos
+
1
]),
U
(
fm
.
padding
[
0
]),
U
(
fm
.
padding
[
1
]),
U
(
fm
.
stride
[
0
]),
U
(
fm
.
stride
[
1
]),
U
(
fm
.
dilation
[
0
]),
U
(
fm
.
dilation
[
1
]),
U
(
src
[
0
]),
U
(
src
[
c_pos
]),
U
(
src
[
hw_pos
]),
U
(
src
[
hw_pos
+
1
]),
U
(
fm
.
ocpg
),
U
(
fm
.
spatial
[
0
]),
U
(
fm
.
spatial
[
1
]),
U
(
dst
[
hw_pos
]),
U
(
dst
[
hw_pos
+
1
]),
U
(
fm
.
padding
[
0
]),
U
(
fm
.
padding
[
1
]),
U
(
fm
.
stride
[
0
]),
U
(
fm
.
stride
[
1
]),
U
(
fm
.
dilation
[
0
]),
U
(
fm
.
dilation
[
1
]),
is_compute_deafult_
,
};
#undef U
}
...
...
dnn/src/cuda/fp16_help.cuh
浏览文件 @
8e5410e4
...
...
@@ -45,6 +45,15 @@ fma2(const __half2 a, const __half2 b, const __half2 c) {
#endif
}
__device__
__forceinline__
__half2
hadd2
(
const
__half2
a
,
const
__half2
b
)
{
#if __CUDA_ARCH__ >= 530
return
__hadd2
(
a
,
b
);
#else
return
{
__float2half
(
__half2float
(
a
.
x
)
+
__half2float
(
b
.
x
)),
__float2half
(
__half2float
(
a
.
y
)
+
__half2float
(
b
.
y
))};
#endif
}
__device__
__forceinline__
float2
fma2
(
const
__half2
a
,
const
__half2
b
,
const
float2
c
)
{
return
{
__half2float
(
a
.
x
)
*
__half2float
(
b
.
x
)
+
c
.
x
,
...
...
dnn/test/cuda/conv_bias.cpp
浏览文件 @
8e5410e4
...
...
@@ -701,7 +701,12 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
ConvBiasForward
::
algo_name
<
ConvBias
::
DirectParam
>
(
"DEPTHWISE_LARGE_FILTER"
,
{})
.
c_str
()));
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
#if CUDA_VERSION >= 9000
dtype
::
Float16
()
#endif
})
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
,
size_t
padding
,
size_t
stride
)
{
...
...
dnn/test/cuda/convolution.cpp
浏览文件 @
8e5410e4
...
...
@@ -728,7 +728,12 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) {
Checker
<
ConvolutionBackwardData
>
checker
(
handle_cuda
());
checker
.
set_before_exec_callback
(
AlgoChecker
<
ConvolutionBackwardData
>
(
"DEPTHWISE_LARGE_FILTER"
));
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
dtype
::
Float16
()})
{
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float32
(),
#if CUDA_VERSION >= 9000
dtype
::
Float16
()
#endif
})
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
,
size_t
padding
,
size_t
stride
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录