Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b18feaab
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b18feaab
编写于
7月 06, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): use cutlass remove shared load imma conv kernel
GitOrigin-RevId: 0b5574f52669ba88237967d3f00fceca0857b80f
上级
6b843ccd
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
1063 addition
and
715 deletion
+1063
-715
dnn/scripts/cutlass_generator/conv2d_operation.py
dnn/scripts/cutlass_generator/conv2d_operation.py
+18
-11
dnn/scripts/cutlass_generator/generator.py
dnn/scripts/cutlass_generator/generator.py
+37
-21
dnn/scripts/cutlass_generator/library.py
dnn/scripts/cutlass_generator/library.py
+10
-10
dnn/scripts/cutlass_generator/list.bzl
dnn/scripts/cutlass_generator/list.bzl
+0
-0
dnn/src/cuda/conv_bias/algo.cpp
dnn/src/cuda/conv_bias/algo.cpp
+35
-25
dnn/src/cuda/conv_bias/algo.h
dnn/src/cuda/conv_bias/algo.h
+3
-0
dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh
dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh
+6
-6
dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int4.cu
dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int4.cu
+595
-0
dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int8.cu
dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int8.cu
+37
-533
dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu
dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu
+194
-0
dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh
dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh
+33
-0
dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp
...rc/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp
+1
-1
dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp
dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp
+2
-2
dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp
...rc/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp
+15
-29
dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp
dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp
+28
-12
dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp
dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp
+29
-46
dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp
...c/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp
+2
-1
dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp
...src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp
+2
-2
dnn/test/cuda/conv_bias_int8.cpp
dnn/test/cuda/conv_bias_int8.cpp
+16
-16
未找到文件。
dnn/scripts/cutlass_generator/conv2d_operation.py
浏览文件 @
b18feaab
...
...
@@ -20,7 +20,7 @@ class Conv2dOperation:
#
def
__init__
(
self
,
conv_kind
,
conv_type
,
arch
,
tile_description
,
src
,
flt
,
bias
,
dst
,
element_epilogue
,
\
epilogue_functor
=
EpilogueFunctor
.
LinearCombination
,
swizzling_functor
=
SwizzlingFunctor
.
Identity4
,
\
need_load_from_const
=
True
,
implicit_gemm_mode
=
ImplicitGemmMode
.
GemmN
t
):
need_load_from_const
=
True
,
implicit_gemm_mode
=
ImplicitGemmMode
.
GemmN
T
,
without_shared_load
=
False
):
self
.
operation_kind
=
OperationKind
.
Conv2d
self
.
conv_kind
=
conv_kind
...
...
@@ -36,6 +36,7 @@ class Conv2dOperation:
self
.
swizzling_functor
=
swizzling_functor
self
.
need_load_from_const
=
need_load_from_const
self
.
implicit_gemm_mode
=
implicit_gemm_mode
self
.
without_shared_load
=
without_shared_load
#
def
accumulator_type
(
self
):
accum
=
self
.
tile_description
.
math_instruction
.
element_accumulator
...
...
@@ -58,11 +59,15 @@ class Conv2dOperation:
unity_kernel
=
''
if
not
self
.
need_load_from_const
:
unity_kernel
=
'_1x1'
unity_kernel
=
'_1x1'
return
"%s%s%s%s%s_%s"
%
(
ShortDataTypeNames
[
self
.
accumulator_type
()],
\
reorder_k
=
''
if
self
.
without_shared_load
:
reorder_k
=
'_roc'
return
"%s%s%s%s%s%s_%s"
%
(
ShortDataTypeNames
[
self
.
accumulator_type
()],
\
inst_shape
,
intermediate_type
,
ConvKindNames
[
self
.
conv_kind
],
unity_kernel
,
\
ShortEpilogueNames
[
self
.
epilogue_functor
])
reorder_k
,
ShortEpilogueNames
[
self
.
epilogue_functor
])
#
def
extended_name
(
self
):
...
...
@@ -177,7 +182,8 @@ using Convolution =
${alignment_filter},
${nonuninity_kernel},
${math_operator},
${implicit_gemm_mode}>;
${implicit_gemm_mode},
${without_shared_load}>;
"""
...
...
@@ -219,7 +225,8 @@ using Convolution =
'alignment_filter'
:
str
(
operation
.
flt
.
alignment
),
'nonuninity_kernel'
:
str
(
operation
.
need_load_from_const
).
lower
(),
'math_operator'
:
MathOperationTag
[
operation
.
tile_description
.
math_instruction
.
math_operation
],
'implicit_gemm_mode'
:
ImplicitGemmModeTag
[
operation
.
implicit_gemm_mode
]
'implicit_gemm_mode'
:
ImplicitGemmModeTag
[
operation
.
implicit_gemm_mode
],
'without_shared_load'
:
str
(
operation
.
without_shared_load
).
lower
()
}
return
SubstituteTemplate
(
self
.
template
,
values
)
...
...
@@ -312,13 +319,13 @@ using Deconvolution =
#
def
GenerateConv2d
(
conv_kind
,
tile_descriptions
,
src_layout
,
flt_layout
,
dst_layout
,
dst_type
,
min_cc
,
src_align
=
32
,
flt_align
=
32
,
dst_align
=
128
,
\
skip_unity_kernel
=
False
,
implicit_gemm_mode
=
ImplicitGemmMode
.
GemmN
t
):
skip_unity_kernel
=
False
,
implicit_gemm_mode
=
ImplicitGemmMode
.
GemmN
T
,
without_shared_load
=
False
):
operations
=
[]
element_epilogue
=
DataType
.
f32
if
conv_kind
==
ConvKind
.
Fprop
:
if
src_layout
==
LayoutType
.
TensorNHWC
:
swizzling_functor
=
SwizzlingFunctor
.
ConvFprop
NHWC
if
implicit_gemm_mode
==
ImplicitGemmMode
.
GemmTN
:
swizzling_functor
=
SwizzlingFunctor
.
ConvFprop
Trans
else
:
swizzling_functor
=
SwizzlingFunctor
.
ConvFpropNCxHWx
else
:
...
...
@@ -399,10 +406,10 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay
bias
=
TensorDescription
(
bias_type
,
dst_layout
,
max
(
1
,
int
(
32
/
DataTypeSize
[
bias_type
])))
dst
=
TensorDescription
(
dst_type
,
dst_layout
,
int
(
dst_align
/
DataTypeSize
[
dst_type
]))
new_operation
=
Conv2dOperation
(
conv_kind
,
ConvType
.
Convolution
,
min_cc
,
tile
,
src
,
flt
,
bias
,
dst
,
element_epilogue
,
epilogue
,
swizzling_functor
,
True
,
implicit_gemm_mode
)
new_operation
=
Conv2dOperation
(
conv_kind
,
ConvType
.
Convolution
,
min_cc
,
tile
,
src
,
flt
,
bias
,
dst
,
element_epilogue
,
epilogue
,
swizzling_functor
,
True
,
implicit_gemm_mode
,
without_shared_load
)
operations
.
append
(
new_operation
)
if
not
skip_unity_kernel
:
new_operation
=
Conv2dOperation
(
conv_kind
,
ConvType
.
Convolution
,
min_cc
,
tile
,
src
,
flt
,
bias
,
dst
,
element_epilogue
,
epilogue
,
swizzling_functor
,
False
,
implicit_gemm_mode
)
new_operation
=
Conv2dOperation
(
conv_kind
,
ConvType
.
Convolution
,
min_cc
,
tile
,
src
,
flt
,
bias
,
dst
,
element_epilogue
,
epilogue
,
swizzling_functor
,
False
,
implicit_gemm_mode
,
without_shared_load
)
operations
.
append
(
new_operation
)
return
operations
...
...
dnn/scripts/cutlass_generator/generator.py
浏览文件 @
b18feaab
...
...
@@ -175,12 +175,10 @@ def GenerateConv2d_Simt(args):
TileDescription
([
128
,
128
,
32
],
2
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
64
,
32
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
64
,
128
,
32
],
2
,
[
1
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
64
,
64
,
32
],
2
,
[
1
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
32
,
32
],
2
,
[
2
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
32
,
128
,
32
],
2
,
[
1
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
32
,
64
,
32
],
2
,
[
1
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
64
,
32
,
32
],
2
,
[
1
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
32
,
32
,
32
],
2
,
[
1
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
16
,
128
,
16
],
1
,
[
1
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
16
,
64
,
8
],
2
,
[
1
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
...
...
@@ -223,28 +221,36 @@ def GenerateConv2d_TensorOp_8816(args):
for
dst_type
,
dst_layout
in
zip
(
dst_types
,
dst_layouts
):
if
dst_layout
==
LayoutType
.
TensorNC32HW32
:
tile_descriptions
=
[
TileDescription
([
256
,
128
,
64
],
2
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
256
,
64
],
2
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
256
,
128
,
64
],
2
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
128
,
64
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
64
,
128
,
64
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
64
,
64
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
64
,
64
,
64
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
32
,
64
,
64
],
2
,
[
1
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
64
,
128
,
64
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
64
,
32
],
1
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
32
,
32
],
1
,
[
2
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
64
,
128
,
32
],
1
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
32
,
128
,
32
],
1
,
[
1
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
operations
+=
GenerateConv2d
(
ConvKind
.
Fprop
,
tile_descriptions
,
layout
[
0
],
layout
[
1
],
dst_layout
,
dst_type
,
min_cc
,
128
,
128
,
64
,
False
,
ImplicitGemmMode
.
GemmTN
,
True
)
else
:
assert
dst_layout
==
LayoutType
.
TensorNC4HW4
tile_descriptions
=
[
TileDescription
([
256
,
128
,
64
],
2
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
256
,
64
],
2
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
256
,
128
,
64
],
2
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
128
,
64
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
64
,
128
,
64
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
64
,
64
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
64
,
64
,
64
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
32
,
64
,
64
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
64
,
128
,
64
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
64
,
32
],
1
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
32
,
32
],
1
,
[
2
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
64
,
128
,
32
],
1
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
32
,
128
,
32
],
1
,
[
1
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
operations
+=
GenerateConv2d
(
ConvKind
.
Fprop
,
tile_descriptions
,
layout
[
0
],
layout
[
1
],
operations
+=
GenerateConv2d
(
ConvKind
.
Fprop
,
tile_descriptions
,
layout
[
0
],
layout
[
1
],
dst_layout
,
dst_type
,
min_cc
,
128
,
128
,
64
,
False
)
return
operations
def
GenerateConv2d_TensorOp_8832
(
args
):
...
...
@@ -279,12 +285,14 @@ def GenerateConv2d_TensorOp_8832(args):
for
dst_layout
in
dst_layouts
:
dst_type
=
math_inst
.
element_b
tile_descriptions
=
[
TileDescription
([
256
,
128
,
128
],
2
,
[
4
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
256
,
128
],
2
,
[
2
,
4
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
128
,
128
],
2
,
[
2
,
2
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
64
,
128
],
2
,
[
2
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
64
,
64
],
1
,
[
2
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
operations
+=
GenerateConv2d
(
ConvKind
.
Fprop
,
tile_descriptions
,
layout
[
0
],
layout
[
1
],
dst_layout
,
dst_type
,
min_cc
,
128
,
128
,
64
,
True
)
True
,
ImplicitGemmMode
.
GemmTN
,
True
)
layouts_nhwc
=
[
(
LayoutType
.
TensorNHWC
,
LayoutType
.
TensorNC8HW8
,
32
),
...
...
@@ -299,14 +307,21 @@ def GenerateConv2d_TensorOp_8832(args):
for
math_inst
in
math_instructions
:
for
layout
in
layouts_nhwc
:
for
dst_layout
in
dst_layouts_nhwc
:
dst_type
=
math_inst
.
element_b
tile_descriptions
=
[
TileDescription
([
128
,
32
,
64
],
2
,
[
2
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
64
,
64
],
2
,
[
2
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
operations
+=
GenerateConv2d
(
ConvKind
.
Fprop
,
tile_descriptions
,
layout
[
0
],
layout
[
1
],
dst_layout
,
dst_type
,
min_cc
,
layout
[
2
],
layout
[
2
],
32
,
False
,
ImplicitGemmMode
.
GemmTn
)
dst_type
=
math_inst
.
element_b
tile_descriptions
=
[
TileDescription
([
128
,
32
,
64
],
1
,
[
2
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
TileDescription
([
128
,
64
,
64
],
1
,
[
2
,
1
,
1
],
math_inst
,
min_cc
,
max_cc
),
]
for
tile
in
tile_descriptions
:
operations
+=
GenerateConv2d
(
ConvKind
.
Fprop
,
[
tile
],
layout
[
0
],
layout
[
1
],
dst_layout
,
dst_type
,
min_cc
,
layout
[
2
],
layout
[
2
],
32
,
False
,
ImplicitGemmMode
.
GemmTN
,
False
)
if
tile
.
threadblock_shape
[
1
]
==
32
or
tile
.
threadblock_shape
[
1
]
==
64
:
dst_align
=
32
if
tile
.
threadblock_shape
[
1
]
==
32
else
64
operations
+=
GenerateConv2d
(
ConvKind
.
Fprop
,
[
tile
],
layout
[
0
],
layout
[
1
],
dst_layout
,
dst_type
,
min_cc
,
layout
[
2
],
layout
[
2
],
dst_align
,
False
,
ImplicitGemmMode
.
GemmTN
,
True
)
return
operations
def
GenerateDeconv_Simt
(
args
):
...
...
@@ -649,3 +664,4 @@ if __name__ == "__main__":
#
###################################################################################################
\ No newline at end of file
dnn/scripts/cutlass_generator/library.py
浏览文件 @
b18feaab
...
...
@@ -464,10 +464,10 @@ EpilogueFunctorTag = {
ShortEpilogueNames
=
{
EpilogueFunctor
.
BiasAddLinearCombinationHSwishClamp
:
'hswish'
,
EpilogueFunctor
.
BiasAddLinearCombinationReluClamp
:
'relu'
,
EpilogueFunctor
.
BiasAddLinearCombinationClamp
:
'id
entity
'
,
EpilogueFunctor
.
BiasAddLinearCombinationClamp
:
'id'
,
EpilogueFunctor
.
BiasAddLinearCombinationHSwish
:
'hswish'
,
EpilogueFunctor
.
BiasAddLinearCombinationRelu
:
'relu'
,
EpilogueFunctor
.
BiasAddLinearCombination
:
'id
entity
'
,
EpilogueFunctor
.
BiasAddLinearCombination
:
'id'
,
}
...
...
@@ -482,7 +482,7 @@ class SwizzlingFunctor(enum.Enum):
Identity4
=
enum_auto
()
Identity8
=
enum_auto
()
ConvFpropNCxHWx
=
enum_auto
()
ConvFprop
NHWC
=
enum_auto
()
ConvFprop
Trans
=
enum_auto
()
ConvDgradNCxHWx
=
enum_auto
()
#
...
...
@@ -492,7 +492,7 @@ SwizzlingFunctorTag = {
SwizzlingFunctor
.
Identity4
:
'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>'
,
SwizzlingFunctor
.
Identity8
:
'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>'
,
SwizzlingFunctor
.
ConvFpropNCxHWx
:
'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle'
,
SwizzlingFunctor
.
ConvFprop
NHWC
:
'cutlass::conv::threadblock::ConvolutionFpropNHWC
ThreadblockSwizzle'
,
SwizzlingFunctor
.
ConvFprop
Trans
:
'cutlass::conv::threadblock::ConvolutionFpropTrans
ThreadblockSwizzle'
,
SwizzlingFunctor
.
ConvDgradNCxHWx
:
'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle'
,
}
...
...
@@ -563,17 +563,17 @@ StrideSupportNames = {
}
class
ImplicitGemmMode
(
enum
.
Enum
):
GemmN
t
=
enum_auto
()
GemmT
n
=
enum_auto
()
GemmN
T
=
enum_auto
()
GemmT
N
=
enum_auto
()
ImplicitGemmModeNames
=
{
ImplicitGemmMode
.
GemmN
t
:
'gemm_nt'
,
ImplicitGemmMode
.
GemmT
n
:
'gemm_tn'
,
ImplicitGemmMode
.
GemmN
T
:
'gemm_nt'
,
ImplicitGemmMode
.
GemmT
N
:
'gemm_tn'
,
}
ImplicitGemmModeTag
=
{
ImplicitGemmMode
.
GemmN
t
:
'cutlass::conv::ImplicitGemmMode::GEMM_NT'
,
ImplicitGemmMode
.
GemmT
n
:
'cutlass::conv::ImplicitGemmMode::GEMM_TN'
,
ImplicitGemmMode
.
GemmN
T
:
'cutlass::conv::ImplicitGemmMode::GEMM_NT'
,
ImplicitGemmMode
.
GemmT
N
:
'cutlass::conv::ImplicitGemmMode::GEMM_TN'
,
}
###################################################################################################
...
...
dnn/scripts/cutlass_generator/list.bzl
浏览文件 @
b18feaab
此差异由.gitattributes 抑制。
dnn/src/cuda/conv_bias/algo.cpp
浏览文件 @
b18feaab
...
...
@@ -217,56 +217,68 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() {
#if CUDA_VERSION >= 10020
{
using
AlgoParam
=
AlgoInt8NCHW32IMMAImplicitGemm
::
AlgoParam
;
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
128
,
256
,
64
,
64
,
64
,
64
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
256
,
128
,
64
,
64
,
64
,
64
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
128
,
128
,
64
,
64
,
64
,
64
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
64
,
128
,
64
,
32
,
64
,
64
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
128
,
64
,
64
,
64
,
32
,
64
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
64
,
64
,
64
,
32
,
32
,
64
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
32
,
64
,
64
,
32
,
16
,
64
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
128
,
256
,
64
,
64
,
64
,
64
,
2
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
256
,
128
,
64
,
64
,
64
,
64
,
2
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
128
,
128
,
64
,
64
,
64
,
64
,
2
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
128
,
64
,
64
,
64
,
32
,
64
,
2
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
64
,
128
,
64
,
32
,
64
,
64
,
2
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
128
,
64
,
32
,
64
,
32
,
32
,
1
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
128
,
32
,
32
,
64
,
32
,
32
,
1
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
64
,
128
,
32
,
32
,
64
,
32
,
1
});
int8_nchw32_imma
.
emplace_back
(
AlgoParam
{
32
,
128
,
32
,
32
,
64
,
32
,
1
});
}
{
using
AlgoParam
=
AlgoInt4Int4NCHW64IMMAImplicitGemm
::
AlgoParam
;
int4_int4_nchw64_imma
.
emplace_back
(
AlgoParam
{
128
,
128
,
128
,
64
,
64
,
128
});
AlgoParam
{
128
,
128
,
128
,
64
,
64
,
128
,
2
});
int4_int4_nchw64_imma
.
emplace_back
(
AlgoParam
{
256
,
128
,
128
,
64
,
64
,
128
});
AlgoParam
{
128
,
256
,
128
,
64
,
64
,
128
,
2
});
int4_int4_nchw64_imma
.
emplace_back
(
AlgoParam
{
128
,
64
,
128
,
64
,
64
,
128
,
2
});
int4_int4_nchw64_imma
.
emplace_back
(
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
1
});
}
{
using
AlgoParam
=
AlgoUInt4Int4NCHW64IMMAImplicitGemm
::
AlgoParam
;
uint4_int4_nchw64_imma
.
emplace_back
(
AlgoParam
{
128
,
128
,
128
,
64
,
64
,
128
});
AlgoParam
{
128
,
128
,
128
,
64
,
64
,
128
,
2
});
uint4_int4_nchw64_imma
.
emplace_back
(
AlgoParam
{
128
,
256
,
128
,
64
,
64
,
128
,
2
});
uint4_int4_nchw64_imma
.
emplace_back
(
AlgoParam
{
256
,
128
,
128
,
64
,
64
,
128
});
AlgoParam
{
128
,
64
,
128
,
64
,
64
,
128
,
2
});
uint4_int4_nchw64_imma
.
emplace_back
(
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
1
});
}
{
using
AlgoParam
=
AlgoInt4Int4NHWCIMMAImplicitGemm
::
AlgoParam
;
int4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
32
,
64
,
64
,
32
,
64
,
32
});
AlgoParam
{
128
,
32
,
64
,
64
,
32
,
64
,
1
,
32
});
int4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
32
,
64
,
64
,
32
,
64
,
1
,
16
});
int4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
32
,
64
,
64
,
32
,
64
,
1
,
8
});
int4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
32
,
64
,
64
,
32
,
64
,
16
});
int4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
32
,
64
,
64
,
32
,
64
,
8
});
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
1
,
32
});
int4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
32
});
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
1
,
16
});
int4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
16
});
int4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
8
});
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
1
,
8
});
}
{
using
AlgoParam
=
AlgoUInt4Int4NHWCIMMAImplicitGemm
::
AlgoParam
;
uint4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
32
,
64
,
64
,
32
,
64
,
32
});
AlgoParam
{
128
,
32
,
64
,
64
,
32
,
64
,
1
,
32
});
uint4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
32
,
64
,
64
,
32
,
64
,
16
});
AlgoParam
{
128
,
32
,
64
,
64
,
32
,
64
,
1
,
1
6
});
uint4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
32
,
64
,
64
,
32
,
64
,
8
});
AlgoParam
{
128
,
32
,
64
,
64
,
32
,
64
,
1
,
8
});
uint4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
32
});
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
1
,
32
});
uint4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
16
});
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
1
,
1
6
});
uint4_int4_nhwc_imma
.
emplace_back
(
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
8
});
AlgoParam
{
128
,
64
,
64
,
64
,
64
,
64
,
1
,
8
});
}
#endif
}
...
...
@@ -279,10 +291,8 @@ void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() {
int8_nchw4_dotprod
.
emplace_back
(
AlgoParam
{
64
,
128
,
32
,
64
,
32
,
32
,
2
});
int8_nchw4_dotprod
.
emplace_back
(
AlgoParam
{
32
,
128
,
32
,
32
,
64
,
32
,
2
});
int8_nchw4_dotprod
.
emplace_back
(
AlgoParam
{
128
,
32
,
32
,
64
,
32
,
32
,
2
});
int8_nchw4_dotprod
.
emplace_back
(
AlgoParam
{
64
,
64
,
32
,
64
,
32
,
32
,
2
});
int8_nchw4_dotprod
.
emplace_back
(
AlgoParam
{
32
,
64
,
32
,
32
,
64
,
32
,
2
});
int8_nchw4_dotprod
.
emplace_back
(
AlgoParam
{
64
,
32
,
32
,
64
,
32
,
32
,
2
});
int8_nchw4_dotprod
.
emplace_back
(
AlgoParam
{
32
,
32
,
32
,
32
,
32
,
32
,
2
});
int8_nchw4_dotprod
.
emplace_back
(
AlgoParam
{
16
,
128
,
16
,
16
,
128
,
16
,
1
});
int8_nchw4_dotprod
.
emplace_back
(
AlgoParam
{
16
,
64
,
8
,
16
,
64
,
8
,
2
});
}
...
...
dnn/src/cuda/conv_bias/algo.h
浏览文件 @
b18feaab
...
...
@@ -723,6 +723,7 @@ public:
int
warp_m
;
int
warp_n
;
int
warp_k
;
int
stage
;
};
AlgoInt8NCHW32IMMAImplicitGemm
(
AlgoParam
algo_param
)
:
m_algo_param
{
algo_param
}
{
...
...
@@ -770,6 +771,7 @@ public:
int
warp_m
;
int
warp_n
;
int
warp_k
;
int
stage
;
};
AlgoInt4NCHW64IMMAImplicitGemmBase
(
AlgoParam
algo_param
)
...
...
@@ -897,6 +899,7 @@ public:
int
warp_m
;
int
warp_n
;
int
warp_k
;
int
stage
;
int
access_size
;
};
...
...
dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh
浏览文件 @
b18feaab
...
...
@@ -38,7 +38,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32(
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
scale
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
cudaStream_t
stream
);
int
stages
,
cudaStream_t
stream
);
template
<
bool
NeedLoadFromConstMem
>
void
do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4
(
...
...
@@ -47,7 +47,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4(
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
scale
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
cudaStream_t
stream
);
int
stages
,
cudaStream_t
stream
);
template
<
bool
NeedLoadFromConstMem
>
void
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4
(
...
...
@@ -83,7 +83,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64(
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
scale
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
cudaStream_t
stream
);
int
stages
,
cudaStream_t
stream
);
template
<
bool
NeedLoadFromConstMem
>
void
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64
(
...
...
@@ -92,7 +92,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64(
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
delta
,
float
theta
,
float
scale
,
uint8_t
src_zero_point
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
cudaStream_t
stream
);
const
GemmCoord
&
warp_shape
,
int
stages
,
cudaStream_t
stream
);
template
<
bool
signedness
>
void
do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc
(
...
...
@@ -110,7 +110,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_nhwc(
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
scale
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
const
int32_t
access_size
,
cudaStream_t
stream
);
const
int32_t
access_size
,
int
stages
,
cudaStream_t
stream
);
template
<
bool
NeedLoadFromConstMem
>
void
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc
(
...
...
@@ -119,7 +119,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc(
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
delta
,
float
theta
,
float
scale
,
uint8_t
src_zero_point
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
const
int32_t
access_size
,
const
GemmCoord
&
warp_shape
,
const
int32_t
access_size
,
int
stages
,
cudaStream_t
stream
);
}
// namespace cutlass_wrapper
...
...
dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int4.cu
0 → 100644
浏览文件 @
b18feaab
/**
* \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#if !MEGDNN_TEGRA_X1
#include "cutlass/convolution/device/convolution.h"
#endif
#include "src/common/opr_param_defs_enumv.cuh"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#pragma GCC diagnostic pop
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
cutlass_wrapper
;
/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */
#if MEGDNN_TEGRA_X1
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64
(
const
int8_t
*
/* d_src */
,
const
int8_t
*
/* d_filter */
,
const
int32_t
*
/* d_bias */
,
const
int8_t
*
/* d_z */
,
int8_t
*
/* d_dst */
,
int
*
/* workspace */
,
const
convolution
::
ConvParam
&
/* param */
,
uint32_t
/* nonlinear_mode */
,
float
/* alpha */
,
float
/* beta */
,
float
/* gamma */
,
float
/* scale */
,
const
GemmCoord
&
/* threadblock_shape */
,
const
GemmCoord
&
/* warp_shape */
,
int
/* stages */
,
cudaStream_t
/* stream */
)
{}
#else
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64
(
const
int8_t
*
d_src
,
const
int8_t
*
d_filter
,
const
int32_t
*
d_bias
,
const
int8_t
*
d_z
,
int8_t
*
d_dst
,
int
*
workspace
,
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
scale
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
int
stages
,
cudaStream_t
stream
)
{
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_, stage_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_ && stages == stage_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \
using Convolution = cutlass::conv::device::Convolution< \
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \
cutlass::layout::TensorNCxHWx<64>, int32_t, \
cutlass::conv::ConvType::kConvolution, \
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionFpropTransThreadblockSwizzle, \
stage_, 32, 32, NeedLoadFromConstMem, \
cutlass::arch::OpMultiplyAddSaturate, \
cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \
typename Convolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
return cutlass_convolution_wrapper<Convolution>( \
reinterpret_cast<const cutlass::int4b_t*>(d_src), \
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \
reinterpret_cast<const cutlass::int4b_t*>(d_z), \
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \
conv_param, epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
using
ElementOutput
=
cutlass
::
int4b_t
;
using
ElementAccumulator
=
int32_t
;
using
ElementBias
=
int32_t
;
using
ElementCompute
=
float
;
using
NonlineMode
=
megdnn
::
param_enumv
::
ConvBias
::
NonlineMode
;
switch
(
nonlinear_mode
)
{
case
NonlineMode
::
IDENTITY
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationClamp
<
ElementOutput
,
16
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
};
DISPATCH_KERNEL
;
}
case
NonlineMode
::
RELU
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationReluClamp
<
ElementOutput
,
16
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
,
0
};
DISPATCH_KERNEL
;
}
case
NonlineMode
::
H_SWISH
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationHSwishClamp
<
ElementOutput
,
16
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
,
scale
};
DISPATCH_KERNEL
;
}
default:
megdnn_assert
(
false
,
"unsupported nonlinear mode for conv bias operator"
);
}
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \
need_load_from_const_mem>( \
const int8_t* d_src, const int8_t* d_filter, \
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float scale, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, int stages, \
cudaStream_t stream);
INST
(
true
);
#undef INST
/* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */
#if MEGDNN_TEGRA_X1
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64
(
const
uint8_t
*
/* d_src */
,
const
int8_t
*
/* d_filter */
,
const
int32_t
*
/* d_bias */
,
const
uint8_t
*
/* d_z */
,
uint8_t
*
/* d_dst */
,
int
*
/* workspace */
,
const
convolution
::
ConvParam
&
/* param */
,
uint32_t
/* nonlinear_mode */
,
float
/* alpha */
,
float
/* beta */
,
float
/* gamma */
,
float
/* delta */
,
float
/* theta */
,
float
/* scale */
,
uint8_t
/* src_zero_point */
,
const
GemmCoord
&
/* threadblock_shape */
,
const
GemmCoord
&
/* warp_shape */
,
int
/* stages */
,
cudaStream_t
/* stream */
)
{}
#else
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64
(
const
uint8_t
*
d_src
,
const
int8_t
*
d_filter
,
const
int32_t
*
d_bias
,
const
uint8_t
*
d_z
,
uint8_t
*
d_dst
,
int
*
workspace
,
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
delta
,
float
theta
,
float
/* scale */
,
uint8_t
src_zero_point
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
int
stages
,
cudaStream_t
stream
)
{
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_, stage_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_ && stages == stage_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \
using Convolution = cutlass::conv::device::Convolution< \
cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \
cutlass::layout::TensorNCxHWx<64>, int32_t, \
cutlass::conv::ConvType::kConvolution, \
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionFpropTransThreadblockSwizzle, \
stage_, 32, 32, NeedLoadFromConstMem, \
cutlass::arch::OpMultiplyAddSaturate, \
cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \
typename Convolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
return cutlass_convolution_wrapper<Convolution>( \
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \
conv_param, epilogue, stream, {src_zero_point}); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
using
ElementOutput
=
cutlass
::
uint4b_t
;
using
ElementAccumulator
=
int32_t
;
using
ElementBias
=
int32_t
;
using
ElementCompute
=
float
;
using
NonlineMode
=
megdnn
::
param_enumv
::
ConvBias
::
NonlineMode
;
switch
(
nonlinear_mode
)
{
case
NonlineMode
::
IDENTITY
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationClamp
<
ElementOutput
,
16
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
,
delta
+
theta
};
DISPATCH_KERNEL
;
}
case
NonlineMode
::
RELU
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationReluClamp
<
ElementOutput
,
16
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
,
0
,
delta
,
theta
};
DISPATCH_KERNEL
;
}
default:
megdnn_assert
(
false
,
"unsupported nonlinear mode for conv bias operator"
);
}
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \
need_load_from_const_mem>( \
const uint8_t* d_src, const int8_t* d_filter, \
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float delta, float theta, float scale, \
uint8_t src_zero_point, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, int stages, \
cudaStream_t stream);
INST
(
true
);
#undef INST
/* ====== cutlass kernel wrapper for int4 x int4 nhwc layout ====== */
#if MEGDNN_TEGRA_X1
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc
(
const
int8_t
*
/* d_src */
,
const
int8_t
*
/* d_filter */
,
const
int32_t
*
/* d_bias */
,
const
int8_t
*
/* d_z */
,
int8_t
*
/* d_dst */
,
int
*
/* workspace */
,
const
convolution
::
ConvParam
&
/* param */
,
uint32_t
/* nonlinear_mode */
,
float
/* alpha */
,
float
/* beta */
,
float
/* gamma */
,
float
/* scale */
,
const
GemmCoord
&
/* threadblock_shape */
,
const
GemmCoord
&
/* warp_shape */
,
const
int32_t
/* access_size */
,
int
/* stages */
,
cudaStream_t
/* stream */
)
{}
#else
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc
(
const
int8_t
*
d_src
,
const
int8_t
*
d_filter
,
const
int32_t
*
d_bias
,
const
int8_t
*
d_z
,
int8_t
*
d_dst
,
int
*
workspace
,
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
scale
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
const
int32_t
access_size
,
int
stages
,
cudaStream_t
stream
)
{
bool
without_shared_load
=
((
param
.
co
%
threadblock_shape
.
n
()
==
0
)
&&
(
threadblock_shape
.
n
()
==
32
||
threadblock_shape
.
n
()
==
64
));
int
out_elements_per_access
=
without_shared_load
?
threadblock_shape
.
n
()
/
4
:
8
;
#define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \
using Convolution = cutlass::conv::device::Convolution< \
cutlass::int4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \
cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \
cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \
int32_t, cutlass::conv::ConvType::kConvolution, \
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionFpropTransThreadblockSwizzle, \
stage_, access_size_, access_size_, NeedLoadFromConstMem, \
cutlass::arch::OpMultiplyAddSaturate, \
cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \
typename Convolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
return cutlass_convolution_wrapper<Convolution>( \
reinterpret_cast<const cutlass::int4b_t*>(d_src), \
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \
reinterpret_cast<const cutlass::int4b_t*>(d_z), \
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, conv_param, \
epilogue, stream);
#define DISPATCH_KERNEL_WITH_TILE_SHAPE( \
threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \
warp_k_, stage_, access_size_, out_elements_per_access_, \
without_shared_load_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_ && stages == stage_ && \
access_size == access_size_ && \
out_elements_per_access == out_elements_per_access_ && \
without_shared_load == without_shared_load_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \
using ElementOutput = cutlass::int4b_t; \
using ElementAccumulator = int32_t; \
using ElementBias = int32_t; \
using ElementCompute = float; \
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \
switch (nonlinear_mode) { \
case NonlineMode::IDENTITY: { \
using EpilogueOp = cutlass::epilogue::thread:: \
BiasAddLinearCombinationClamp< \
ElementOutput, out_elements_per_access_, \
ElementAccumulator, ElementBias, \
ElementCompute>; \
typename EpilogueOp::Params epilogue{alpha, beta, gamma}; \
RUN_CUTLASS_WRAPPER(stage_, access_size_, \
without_shared_load_); \
} \
case NonlineMode::RELU: { \
using EpilogueOp = cutlass::epilogue::thread:: \
BiasAddLinearCombinationReluClamp< \
ElementOutput, out_elements_per_access_, \
ElementAccumulator, ElementBias, \
ElementCompute>; \
typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; \
RUN_CUTLASS_WRAPPER(stage_, access_size_, \
without_shared_load_); \
} \
case NonlineMode::H_SWISH: { \
using EpilogueOp = cutlass::epilogue::thread:: \
BiasAddLinearCombinationHSwishClamp< \
ElementOutput, out_elements_per_access_, \
ElementAccumulator, ElementBias, \
ElementCompute>; \
typename EpilogueOp::Params epilogue{alpha, beta, gamma, \
scale}; \
RUN_CUTLASS_WRAPPER(stage_, access_size_, \
without_shared_load_); \
} \
default: \
megdnn_assert( \
false, \
"unsupported nonlinear mode for conv bias operator"); \
} \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d) and access_size (%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k(), access_size);
DISPATCH_KERNEL
;
#undef RUN_CUTLASS_WRAPPER
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \
need_load_from_const_mem>( \
const int8_t* d_src, const int8_t* d_filter, \
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float scale, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, const int32_t access_size, \
int stages, cudaStream_t stream);
INST
(
true
);
INST
(
false
);
#undef INST
/* ====== cutlass kernel wrapper for uint4 x int4 nhwc layout ====== */
#if MEGDNN_TEGRA_X1
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc
(
const
uint8_t
*
/* d_src */
,
const
int8_t
*
/* d_filter */
,
const
int32_t
*
/* d_bias */
,
const
uint8_t
*
/* d_z */
,
uint8_t
*
/* d_dst */
,
int
*
/* workspace */
,
const
convolution
::
ConvParam
&
/* param */
,
uint32_t
/* nonlinear_mode */
,
float
/* alpha */
,
float
/* beta */
,
float
/* gamma */
,
float
/* delta */
,
float
/* theta */
,
float
/* scale */
,
uint8_t
/* src_zero_point */
,
const
GemmCoord
&
/* threadblock_shape */
,
const
GemmCoord
&
/* warp_shape */
,
const
int32_t
/* access_size */
,
int
/* stages */
,
cudaStream_t
/* stream */
)
{}
#else
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc
(
const
uint8_t
*
d_src
,
const
int8_t
*
d_filter
,
const
int32_t
*
d_bias
,
const
uint8_t
*
d_z
,
uint8_t
*
d_dst
,
int
*
workspace
,
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
delta
,
float
theta
,
float
/* scale */
,
uint8_t
src_zero_point
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
const
int32_t
access_size
,
int
stages
,
cudaStream_t
stream
)
{
bool
without_shared_load
=
((
param
.
co
%
threadblock_shape
.
n
()
==
0
)
&&
(
threadblock_shape
.
n
()
==
32
||
threadblock_shape
.
n
()
==
64
));
int
out_elements_per_access
=
without_shared_load
?
threadblock_shape
.
n
()
/
4
:
8
;
#define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \
using Convolution = cutlass::conv::device::Convolution< \
cutlass::uint4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \
cutlass::layout::TensorNCxHWx<access_size_>, ElementOutput, \
cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \
int32_t, cutlass::conv::ConvType::kConvolution, \
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionFpropTransThreadblockSwizzle, \
stage_, access_size_, access_size_, NeedLoadFromConstMem, \
cutlass::arch::OpMultiplyAddSaturate, \
cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \
typename Convolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
return cutlass_convolution_wrapper<Convolution>( \
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \
conv_param, epilogue, stream, {src_zero_point});
#define DISPATCH_KERNEL_WITH_TILE_SHAPE( \
threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \
warp_k_, stage_, access_size_, out_elements_per_access_, \
without_shared_load_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_ && stages == stage_ && \
access_size == access_size_ && \
out_elements_per_access == out_elements_per_access_ && \
without_shared_load == without_shared_load_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \
using ElementOutput = cutlass::uint4b_t; \
using ElementAccumulator = int32_t; \
using ElementBias = int32_t; \
using ElementCompute = float; \
using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \
switch (nonlinear_mode) { \
case NonlineMode::IDENTITY: { \
using EpilogueOp = cutlass::epilogue::thread:: \
BiasAddLinearCombinationClamp< \
ElementOutput, out_elements_per_access_, \
ElementAccumulator, ElementBias, \
ElementCompute>; \
typename EpilogueOp::Params epilogue{alpha, beta, gamma, \
delta + theta}; \
RUN_CUTLASS_WRAPPER(stage_, access_size_, \
without_shared_load_); \
} \
case NonlineMode::RELU: { \
using EpilogueOp = cutlass::epilogue::thread:: \
BiasAddLinearCombinationReluClamp< \
ElementOutput, out_elements_per_access_, \
ElementAccumulator, ElementBias, \
ElementCompute>; \
typename EpilogueOp::Params epilogue{alpha, beta, gamma, \
0, delta, theta}; \
RUN_CUTLASS_WRAPPER(stage_, access_size_, \
without_shared_load_); \
} \
default: \
megdnn_assert( \
false, \
"unsupported nonlinear mode for conv bias operator"); \
} \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d) and access_size (%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k(), access_size);
DISPATCH_KERNEL
;
#undef RUN_CUTLASS_WRAPPER
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \
need_load_from_const_mem>( \
const uint8_t* d_src, const int8_t* d_filter, \
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float delta, float theta, float scale, \
uint8_t src_zero_point, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, const int32_t access_size, \
int stages, cudaStream_t stream);
INST
(
true
);
INST
(
false
);
#undef INST
// vim: syntax=cuda.doxygen
dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu
→
dnn/src/cuda/conv_bias/cutlass_convolution_wrapper
_int8
.cu
浏览文件 @
b18feaab
...
...
@@ -38,7 +38,8 @@ void megdnn::cuda::cutlass_wrapper::
uint32_t
/* nonlinear_mode */
,
float
/* alpha */
,
float
/* beta */
,
float
/* gamma */
,
float
/* scale */
,
const
GemmCoord
&
/* threadblock_shape */
,
const
GemmCoord
&
/* warp_shape */
,
cudaStream_t
/* stream */
)
{}
const
GemmCoord
&
/* warp_shape */
,
int
/* stages */
,
cudaStream_t
/* stream */
)
{}
#else
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
...
...
@@ -48,15 +49,15 @@ void megdnn::cuda::cutlass_wrapper::
int
*
workspace
,
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
scale
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
cudaStream_t
stream
)
{
const
GemmCoord
&
warp_shape
,
int
stages
,
cudaStream_t
stream
)
{
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_
)
\
warp_k_
, stage_)
\
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_
) {
\
warp_shape.k() == warp_k_
&& stages == stage_) {
\
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
...
...
@@ -71,8 +72,10 @@ void megdnn::cuda::cutlass_wrapper::
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionFpropNCxHWxThreadblockSwizzle, \
2, 16, 16, NeedLoadFromConstMem>; \
ConvolutionFpropTransThreadblockSwizzle, \
stage_, 16, 16, NeedLoadFromConstMem, \
cutlass::arch::OpMultiplyAddSaturate, \
cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \
typename Convolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
...
...
@@ -82,13 +85,15 @@ void megdnn::cuda::cutlass_wrapper::
epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 32, 16, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
...
...
@@ -144,7 +149,8 @@ void megdnn::cuda::cutlass_wrapper::
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float scale, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, cudaStream_t stream);
const GemmCoord& warp_shape, int stages, \
cudaStream_t stream);
INST
(
true
);
INST
(
false
);
#undef INST
...
...
@@ -162,7 +168,8 @@ void megdnn::cuda::cutlass_wrapper::
uint32_t
/* nonlinear_mode */
,
float
/* alpha */
,
float
/* beta */
,
float
/* gamma */
,
float
/* scale */
,
const
GemmCoord
&
/* threadblock_shape */
,
const
GemmCoord
&
/* warp_shape */
,
cudaStream_t
/* stream */
)
{}
const
GemmCoord
&
/* warp_shape */
,
int
/* stages */
,
cudaStream_t
/* stream */
)
{}
#else
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
...
...
@@ -172,15 +179,15 @@ void megdnn::cuda::cutlass_wrapper::
int
*
workspace
,
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
scale
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
cudaStream_t
stream
)
{
const
GemmCoord
&
warp_shape
,
int
stages
,
cudaStream_t
stream
)
{
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_
)
\
warp_k_
, stage_)
\
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_
) {
\
warp_shape.k() == warp_k_
&& stages == stage_) {
\
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
...
...
@@ -196,7 +203,7 @@ void megdnn::cuda::cutlass_wrapper::
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionFpropNCxHWxThreadblockSwizzle, \
2, 16, 16, NeedLoadFromConstMem>;
\
stage_, 16, 16, NeedLoadFromConstMem>;
\
typename Convolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
...
...
@@ -206,13 +213,15 @@ void megdnn::cuda::cutlass_wrapper::
epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 16, 32, 64); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
...
...
@@ -268,7 +277,8 @@ void megdnn::cuda::cutlass_wrapper::
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float scale, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, cudaStream_t stream);
const GemmCoord& warp_shape, int stages, \
cudaStream_t stream);
INST
(
true
);
INST
(
false
);
#undef INST
...
...
@@ -337,10 +347,8 @@ void megdnn::cuda::cutlass_wrapper::
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \
megdnn_assert(false, \
...
...
@@ -468,10 +476,8 @@ void megdnn::cuda::cutlass_wrapper::
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \
megdnn_assert(false, \
...
...
@@ -599,10 +605,8 @@ void megdnn::cuda::cutlass_wrapper::
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
...
...
@@ -664,246 +668,6 @@ INST(true);
INST
(
false
);
#undef INST
/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */
#if MEGDNN_TEGRA_X1
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64
(
const
int8_t
*
/* d_src */
,
const
int8_t
*
/* d_filter */
,
const
int32_t
*
/* d_bias */
,
const
int8_t
*
/* d_z */
,
int8_t
*
/* d_dst */
,
int
*
/* workspace */
,
const
convolution
::
ConvParam
&
/* param */
,
uint32_t
/* nonlinear_mode */
,
float
/* alpha */
,
float
/* beta */
,
float
/* gamma */
,
float
/* scale */
,
const
GemmCoord
&
/* threadblock_shape */
,
const
GemmCoord
&
/* warp_shape */
,
cudaStream_t
/* stream */
)
{}
#else
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64
(
const
int8_t
*
d_src
,
const
int8_t
*
d_filter
,
const
int32_t
*
d_bias
,
const
int8_t
*
d_z
,
int8_t
*
d_dst
,
int
*
workspace
,
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
scale
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
cudaStream_t
stream
)
{
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \
using Convolution = cutlass::conv::device::Convolution< \
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \
cutlass::layout::TensorNCxHWx<64>, int32_t, \
cutlass::conv::ConvType::kConvolution, \
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionFpropNCxHWxThreadblockSwizzle, \
2, 32, 32, NeedLoadFromConstMem>; \
typename Convolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
return cutlass_convolution_wrapper<Convolution>( \
reinterpret_cast<const cutlass::int4b_t*>(d_src), \
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \
reinterpret_cast<const cutlass::int4b_t*>(d_z), \
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \
conv_param, epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
using
ElementOutput
=
cutlass
::
int4b_t
;
using
ElementAccumulator
=
int32_t
;
using
ElementBias
=
int32_t
;
using
ElementCompute
=
float
;
using
NonlineMode
=
megdnn
::
param_enumv
::
ConvBias
::
NonlineMode
;
switch
(
nonlinear_mode
)
{
case
NonlineMode
::
IDENTITY
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationClamp
<
ElementOutput
,
16
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
};
DISPATCH_KERNEL
;
}
case
NonlineMode
::
RELU
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationReluClamp
<
ElementOutput
,
16
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
,
0
};
DISPATCH_KERNEL
;
}
case
NonlineMode
::
H_SWISH
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationHSwishClamp
<
ElementOutput
,
16
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
,
scale
};
DISPATCH_KERNEL
;
}
default:
megdnn_assert
(
false
,
"unsupported nonlinear mode for conv bias operator"
);
}
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \
need_load_from_const_mem>( \
const int8_t* d_src, const int8_t* d_filter, \
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float scale, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, cudaStream_t stream);
INST
(
true
);
#undef INST
/* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */
#if MEGDNN_TEGRA_X1
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64
(
const
uint8_t
*
/* d_src */
,
const
int8_t
*
/* d_filter */
,
const
int32_t
*
/* d_bias */
,
const
uint8_t
*
/* d_z */
,
uint8_t
*
/* d_dst */
,
int
*
/* workspace */
,
const
convolution
::
ConvParam
&
/* param */
,
uint32_t
/* nonlinear_mode */
,
float
/* alpha */
,
float
/* beta */
,
float
/* gamma */
,
float
/* delta */
,
float
/* theta */
,
float
/* scale */
,
uint8_t
/* src_zero_point */
,
const
GemmCoord
&
/* threadblock_shape */
,
const
GemmCoord
&
/* warp_shape */
,
cudaStream_t
/* stream */
)
{}
#else
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64
(
const
uint8_t
*
d_src
,
const
int8_t
*
d_filter
,
const
int32_t
*
d_bias
,
const
uint8_t
*
d_z
,
uint8_t
*
d_dst
,
int
*
workspace
,
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
delta
,
float
theta
,
float
/* scale */
,
uint8_t
src_zero_point
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
cudaStream_t
stream
)
{
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \
using Convolution = cutlass::conv::device::Convolution< \
cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \
cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \
ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \
cutlass::layout::TensorNCxHWx<64>, int32_t, \
cutlass::conv::ConvType::kConvolution, \
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionFpropNCxHWxThreadblockSwizzle, \
2, 32, 32, NeedLoadFromConstMem>; \
typename Convolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
return cutlass_convolution_wrapper<Convolution>( \
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \
conv_param, epilogue, stream, {src_zero_point}); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
using
ElementOutput
=
cutlass
::
uint4b_t
;
using
ElementAccumulator
=
int32_t
;
using
ElementBias
=
int32_t
;
using
ElementCompute
=
float
;
using
NonlineMode
=
megdnn
::
param_enumv
::
ConvBias
::
NonlineMode
;
switch
(
nonlinear_mode
)
{
case
NonlineMode
::
IDENTITY
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationClamp
<
ElementOutput
,
16
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
,
delta
+
theta
};
DISPATCH_KERNEL
;
}
case
NonlineMode
::
RELU
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationReluClamp
<
ElementOutput
,
16
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
,
0
,
delta
,
theta
};
DISPATCH_KERNEL
;
}
default:
megdnn_assert
(
false
,
"unsupported nonlinear mode for conv bias operator"
);
}
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \
need_load_from_const_mem>( \
const uint8_t* d_src, const int8_t* d_filter, \
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float delta, float theta, float scale, \
uint8_t src_zero_point, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, cudaStream_t stream);
INST
(
true
);
#undef INST
/* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */
#if MEGDNN_TEGRA_X1
template
<
bool
signedness
>
...
...
@@ -970,10 +734,8 @@ void megdnn::cuda::cutlass_wrapper::
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \
megdnn_assert(false, \
...
...
@@ -1039,262 +801,4 @@ INST(true);
INST
(
false
);
#undef INST
/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */
#if MEGDNN_TEGRA_X1
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc
(
const
int8_t
*
/* d_src */
,
const
int8_t
*
/* d_filter */
,
const
int32_t
*
/* d_bias */
,
const
int8_t
*
/* d_z */
,
int8_t
*
/* d_dst */
,
int
*
/* workspace */
,
const
convolution
::
ConvParam
&
/* param */
,
uint32_t
/* nonlinear_mode */
,
float
/* alpha */
,
float
/* beta */
,
float
/* gamma */
,
float
/* scale */
,
const
GemmCoord
&
/* threadblock_shape */
,
const
GemmCoord
&
/* warp_shape */
,
const
int32_t
/* access_size */
,
cudaStream_t
/* stream */
)
{}
#else
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc
(
const
int8_t
*
d_src
,
const
int8_t
*
d_filter
,
const
int32_t
*
d_bias
,
const
int8_t
*
d_z
,
int8_t
*
d_dst
,
int
*
workspace
,
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
scale
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
const
int32_t
access_size
,
cudaStream_t
stream
)
{
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_, access_size_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_ && access_size == access_size_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \
using Convolution = cutlass::conv::device::Convolution< \
cutlass::int4b_t, cutlass::layout::TensorNHWC, \
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<access_size_>, \
ElementOutput, cutlass::layout::TensorNHWC, int32_t, \
cutlass::layout::TensorNHWC, int32_t, \
cutlass::conv::ConvType::kConvolution, \
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionFpropNHWCThreadblockSwizzle, \
2, access_size_, access_size_, NeedLoadFromConstMem, \
cutlass::arch::OpMultiplyAddSaturate, \
cutlass::conv::ImplicitGemmMode::GEMM_TN>; \
typename Convolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
return cutlass_convolution_wrapper<Convolution>( \
reinterpret_cast<const cutlass::int4b_t*>(d_src), \
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \
reinterpret_cast<const cutlass::int4b_t*>(d_z), \
reinterpret_cast<cutlass::int4b_t*>(d_dst), workspace, \
conv_param, epilogue, stream); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 32); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 8); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 32); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 8); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d) and access_size (%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k(), access_size);
using
ElementOutput
=
cutlass
::
int4b_t
;
using
ElementAccumulator
=
int32_t
;
using
ElementBias
=
int32_t
;
using
ElementCompute
=
float
;
using
NonlineMode
=
megdnn
::
param_enumv
::
ConvBias
::
NonlineMode
;
switch
(
nonlinear_mode
)
{
case
NonlineMode
::
IDENTITY
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationClamp
<
ElementOutput
,
8
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
};
DISPATCH_KERNEL
;
}
case
NonlineMode
::
RELU
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationReluClamp
<
ElementOutput
,
8
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
,
0
};
DISPATCH_KERNEL
;
}
case
NonlineMode
::
H_SWISH
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationHSwishClamp
<
ElementOutput
,
8
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
,
scale
};
DISPATCH_KERNEL
;
}
default:
megdnn_assert
(
false
,
"unsupported nonlinear mode for conv bias operator"
);
}
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \
need_load_from_const_mem>( \
const int8_t* d_src, const int8_t* d_filter, \
const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float scale, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, const int32_t access_size, \
cudaStream_t stream);
INST
(
true
);
INST
(
false
);
#undef INST
#if MEGDNN_TEGRA_X1
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc
(
const
uint8_t
*
/* d_src */
,
const
int8_t
*
/* d_filter */
,
const
int32_t
*
/* d_bias */
,
const
uint8_t
*
/* d_z */
,
uint8_t
*
/* d_dst */
,
int
*
/* workspace */
,
const
convolution
::
ConvParam
&
/* param */
,
uint32_t
/* nonlinear_mode */
,
float
/* alpha */
,
float
/* beta */
,
float
/* gamma */
,
float
/* delta */
,
float
/* theta */
,
float
/* scale */
,
uint8_t
/* src_zero_point */
,
const
GemmCoord
&
/* threadblock_shape */
,
const
GemmCoord
&
/* warp_shape */
,
const
int32_t
/* access_size */
,
cudaStream_t
/* stream */
)
{}
#else
template
<
bool
NeedLoadFromConstMem
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc
(
const
uint8_t
*
d_src
,
const
int8_t
*
d_filter
,
const
int32_t
*
d_bias
,
const
uint8_t
*
d_z
,
uint8_t
*
d_dst
,
int
*
workspace
,
const
convolution
::
ConvParam
&
param
,
uint32_t
nonlinear_mode
,
float
alpha
,
float
beta
,
float
gamma
,
float
delta
,
float
theta
,
float
/* scale */
,
uint8_t
src_zero_point
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
const
int32_t
access_size
,
cudaStream_t
stream
)
{
#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \
threadblock_k_, warp_m_, warp_n_, \
warp_k_, access_size_) \
if (threadblock_shape.m() == threadblock_m_ && \
threadblock_shape.n() == threadblock_n_ && \
threadblock_shape.k() == threadblock_k_ && \
warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \
warp_shape.k() == warp_k_ && access_size == access_size_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<threadblock_m_, threadblock_n_, \
threadblock_k_>; \
using WarpShape = cutlass::gemm::GemmShape<warp_m_, warp_n_, warp_k_>; \
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \
using Convolution = cutlass::conv::device::Convolution< \
cutlass::uint4b_t, cutlass::layout::TensorNHWC, \
cutlass::int4b_t, cutlass::layout::TensorNCxHWx<access_size_>, \
ElementOutput, cutlass::layout::TensorNHWC, int32_t, \
cutlass::layout::TensorNHWC, int32_t, \
cutlass::conv::ConvType::kConvolution, \
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \
ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \
cutlass::conv::threadblock:: \
ConvolutionFpropNHWCThreadblockSwizzle, \
2, access_size_, access_size_, NeedLoadFromConstMem, \
cutlass::arch::OpMultiplyAddSaturate, \
cutlass::conv::ImplicitGemmMode::GEMM_TN>; \
typename Convolution::ConvolutionParameter conv_param( \
param.n, param.hi, param.wi, param.ci, param.co, param.fh, \
param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \
param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \
return cutlass_convolution_wrapper<Convolution>( \
reinterpret_cast<const cutlass::uint4b_t*>(d_src), \
reinterpret_cast<const cutlass::int4b_t*>(d_filter), d_bias, \
reinterpret_cast<const cutlass::uint4b_t*>(d_z), \
reinterpret_cast<cutlass::uint4b_t*>(d_dst), workspace, \
conv_param, epilogue, stream, {src_zero_point}); \
}
#define DISPATCH_KERNEL \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 32); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 8); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 32); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 16); \
DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 8); \
megdnn_assert(false, \
"unsupported threadblock shape (%dx%dx%d) and warp shape " \
"(%dx%dx%d) and access_size (%d)", \
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k(), access_size);
using
ElementOutput
=
cutlass
::
uint4b_t
;
using
ElementAccumulator
=
int32_t
;
using
ElementBias
=
int32_t
;
using
ElementCompute
=
float
;
using
NonlineMode
=
megdnn
::
param_enumv
::
ConvBias
::
NonlineMode
;
switch
(
nonlinear_mode
)
{
case
NonlineMode
::
IDENTITY
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationClamp
<
ElementOutput
,
8
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
,
delta
+
theta
};
DISPATCH_KERNEL
;
}
case
NonlineMode
::
RELU
:
{
using
EpilogueOp
=
cutlass
::
epilogue
::
thread
::
BiasAddLinearCombinationReluClamp
<
ElementOutput
,
8
,
ElementAccumulator
,
ElementBias
,
ElementCompute
>
;
typename
EpilogueOp
::
Params
epilogue
{
alpha
,
beta
,
gamma
,
0
,
delta
,
theta
};
DISPATCH_KERNEL
;
}
default:
megdnn_assert
(
false
,
"unsupported nonlinear mode for conv bias operator"
);
}
#undef DISPATCH_KERNEL_WITH_TILE_SHAPE
#undef DISPATCH_KERNEL
}
#endif
#define INST(need_load_from_const_mem) \
template void megdnn::cuda::cutlass_wrapper:: \
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \
need_load_from_const_mem>( \
const uint8_t* d_src, const int8_t* d_filter, \
const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \
int* workspace, const convolution::ConvParam& param, \
uint32_t nonlinear_mode, float alpha, float beta, \
float gamma, float delta, float theta, float scale, \
uint8_t src_zero_point, \
const GemmCoord& threadblock_shape, \
const GemmCoord& warp_shape, const int32_t access_size, \
cudaStream_t stream);
INST
(
true
);
INST
(
false
);
#undef INST
// vim: syntax=cuda.doxygen
dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu
0 → 100644
浏览文件 @
b18feaab
/**
* \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh"
#include "src/cuda/query_blocksize.cuh"
#include "src/cuda/integer_subbyte_utils.cuh"
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
cutlass_wrapper
;
namespace
{
template
<
uint32_t
size_bits
,
uint32_t
interleaved
>
__device__
__forceinline__
void
reorder_ncxhwx_imma_filter_func
(
int8_t
*
dst
,
const
int8_t
*
src
,
uint32_t
OC
,
uint32_t
IC
,
uint32_t
FH
,
uint32_t
FW
,
uint32_t
lane
,
bool
trans_oc
)
{
static
constexpr
uint32_t
elements_per_lane
=
128
/
size_bits
;
static
constexpr
uint32_t
threads_per_interleaved
=
interleaved
/
elements_per_lane
;
static
constexpr
uint32_t
instruction_shape_col
=
8
;
// 4 threads per Quad
static
constexpr
uint32_t
elements_per_thread
=
instruction_shape_col
/
4
;
// 4 threads per Quad
static
constexpr
uint32_t
reordered_elements_per_thread
=
interleaved
/
4
;
uint32_t
id
=
lane
/
threads_per_interleaved
;
uint32_t
residue
=
lane
%
threads_per_interleaved
;
uint32_t
ICx
=
IC
/
interleaved
;
uint32_t
row
=
id
/
(
ICx
*
FH
*
FW
);
uint32_t
col
=
id
-
row
*
ICx
*
FH
*
FW
;
// transpose ncxhwx to cxhwnx
uint32_t
src_offset
=
id
*
interleaved
+
residue
*
elements_per_lane
;
row
=
(
trans_oc
)
?
(
row
/
interleaved
)
*
interleaved
+
((
row
%
reordered_elements_per_thread
)
/
elements_per_thread
)
*
instruction_shape_col
+
((
row
%
interleaved
)
/
reordered_elements_per_thread
)
*
elements_per_thread
+
(
row
%
elements_per_thread
)
:
row
;
uint32_t
dst_offset
=
(
col
*
OC
+
row
)
*
interleaved
+
residue
*
elements_per_lane
;
*
(
reinterpret_cast
<
int4
*>
(
dst
+
dst_offset
*
size_bits
/
8
))
=
*
(
reinterpret_cast
<
const
int4
*>
(
src
+
src_offset
*
size_bits
/
8
));
}
template
<
uint32_t
size_bits
,
uint32_t
interleaved
>
__global__
void
reorder_ncxhwx_imma_filter_kernel
(
int8_t
*
__restrict__
dst_filter
,
const
int8_t
*
__restrict__
src_filter
,
uint32_t
OC
,
uint32_t
IC
,
uint32_t
FH
,
uint32_t
FW
,
bool
trans_oc
)
{
static
constexpr
uint32_t
elements_per_lane
=
128
/
size_bits
;
const
uint32_t
size
=
OC
*
IC
*
FH
*
FW
/
elements_per_lane
;
uint32_t
lane
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
lane
<
size
)
{
reorder_ncxhwx_imma_filter_func
<
size_bits
,
interleaved
>
(
dst_filter
,
src_filter
,
OC
,
IC
,
FH
,
FW
,
lane
,
trans_oc
);
}
}
template
<
uint32_t
size_bits
,
uint32_t
alignbits
,
uint32_t
interleaved
>
__device__
__forceinline__
void
reorder_nhwc_imma_filter_func
(
int8_t
*
dst
,
const
int8_t
*
src
,
uint32_t
OC
,
uint32_t
IC
,
uint32_t
FH
,
uint32_t
FW
,
uint32_t
lane
,
bool
trans_oc
)
{
static
constexpr
uint32_t
elements_per_access
=
alignbits
/
size_bits
;
static
constexpr
uint32_t
instruction_shape_col
=
8
;
// 4 threads per Quad
static
constexpr
uint32_t
elements_per_thread
=
instruction_shape_col
/
4
;
// 4 threads per Quad
static
constexpr
uint32_t
reordered_elements_per_thread
=
interleaved
/
4
;
uint32_t
ICx
=
IC
/
elements_per_access
;
uint32_t
k
=
lane
/
(
ICx
*
FH
*
FW
);
uint32_t
cxrs
=
lane
-
k
*
ICx
*
FH
*
FW
;
uint32_t
rs
=
cxrs
/
ICx
;
uint32_t
cx
=
cxrs
-
rs
*
ICx
;
// transpose nhwc to ncxhwx
uint32_t
src_offset
=
lane
*
elements_per_access
;
// reorder k
k
=
(
trans_oc
)
?
(
k
/
interleaved
)
*
interleaved
+
((
k
%
reordered_elements_per_thread
)
/
elements_per_thread
)
*
instruction_shape_col
+
((
k
%
interleaved
)
/
reordered_elements_per_thread
)
*
elements_per_thread
+
(
k
%
elements_per_thread
)
:
k
;
uint32_t
dst_offset
=
(
k
*
ICx
*
FH
*
FW
+
cx
*
FH
*
FW
+
rs
)
*
elements_per_access
;
if
(
alignbits
==
32
)
{
*
(
reinterpret_cast
<
int
*>
(
dst
+
dst_offset
*
size_bits
/
8
))
=
*
(
reinterpret_cast
<
const
int
*>
(
src
+
src_offset
*
size_bits
/
8
));
}
else
if
(
alignbits
==
64
)
{
*
(
reinterpret_cast
<
int2
*>
(
dst
+
dst_offset
*
size_bits
/
8
))
=
*
(
reinterpret_cast
<
const
int2
*>
(
src
+
src_offset
*
size_bits
/
8
));
}
else
{
*
(
reinterpret_cast
<
int4
*>
(
dst
+
dst_offset
*
size_bits
/
8
))
=
*
(
reinterpret_cast
<
const
int4
*>
(
src
+
src_offset
*
size_bits
/
8
));
}
}
template
<
uint32_t
size_bits
,
uint32_t
alignbits
,
uint32_t
interleaved
>
__global__
void
reorder_nhwc_imma_filter_kernel
(
int8_t
*
__restrict__
dst_filter
,
const
int8_t
*
__restrict__
src_filter
,
uint32_t
OC
,
uint32_t
IC
,
uint32_t
FH
,
uint32_t
FW
,
bool
trans_oc
)
{
static
constexpr
uint32_t
elements_per_access
=
alignbits
/
size_bits
;
const
uint32_t
size
=
OC
*
IC
*
FH
*
FW
/
elements_per_access
;
uint32_t
lane
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
lane
<
size
)
{
reorder_nhwc_imma_filter_func
<
size_bits
,
alignbits
,
interleaved
>
(
dst_filter
,
src_filter
,
OC
,
IC
,
FH
,
FW
,
lane
,
trans_oc
);
}
}
}
// namespace
template
<
uint32_t
size_bits
,
uint32_t
interleaved
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
reorder_ncxhwx_imma_filter
(
int8_t
*
dst_filter
,
const
int8_t
*
src_filter
,
uint32_t
OC
,
uint32_t
IC
,
uint32_t
FH
,
uint32_t
FW
,
bool
trans_oc
,
cudaStream_t
stream
)
{
static
constexpr
uint32_t
elements_per_lane
=
128
/
size_bits
;
uint32_t
nr_threads
=
query_blocksize_for_kernel
(
reinterpret_cast
<
const
void
*>
(
reorder_ncxhwx_imma_filter_kernel
<
size_bits
,
interleaved
>
));
uint32_t
vthreads
=
DIVUP
(
OC
*
IC
*
FH
*
FW
,
elements_per_lane
);
nr_threads
=
std
::
min
(
nr_threads
,
vthreads
);
uint32_t
nr_blocks
=
DIVUP
(
vthreads
,
nr_threads
);
reorder_ncxhwx_imma_filter_kernel
<
size_bits
,
interleaved
>
<<<
nr_blocks
,
nr_threads
,
0
,
stream
>>>
(
dst_filter
,
src_filter
,
OC
,
IC
,
FH
,
FW
,
trans_oc
);
after_kernel_launch
();
}
template
<
uint32_t
size_bits
,
uint32_t
alignbits
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
reorder_nhwc_imma_filter
(
int8_t
*
dst_filter
,
const
int8_t
*
src_filter
,
uint32_t
OC
,
uint32_t
IC
,
uint32_t
FH
,
uint32_t
FW
,
bool
trans_oc
,
uint32_t
oc_interleaved
,
cudaStream_t
stream
)
{
static
constexpr
uint32_t
elements_per_access
=
alignbits
/
size_bits
;
uint32_t
nr_threads
=
query_blocksize_for_kernel
(
reinterpret_cast
<
const
void
*>
(
reorder_nhwc_imma_filter_kernel
<
size_bits
,
alignbits
,
32
>
));
uint32_t
vthreads
=
DIVUP
(
OC
*
IC
*
FH
*
FW
,
elements_per_access
);
nr_threads
=
std
::
min
(
nr_threads
,
vthreads
);
uint32_t
nr_blocks
=
DIVUP
(
vthreads
,
nr_threads
);
if
(
oc_interleaved
==
32
)
{
reorder_nhwc_imma_filter_kernel
<
size_bits
,
alignbits
,
32
>
<<<
nr_blocks
,
nr_threads
,
0
,
stream
>>>
(
dst_filter
,
src_filter
,
OC
,
IC
,
FH
,
FW
,
trans_oc
);
}
else
{
reorder_nhwc_imma_filter_kernel
<
size_bits
,
alignbits
,
64
>
<<<
nr_blocks
,
nr_threads
,
0
,
stream
>>>
(
dst_filter
,
src_filter
,
OC
,
IC
,
FH
,
FW
,
trans_oc
);
}
after_kernel_launch
();
}
#define INST(_size_bits, _interleaved) \
template void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter< \
_size_bits, _interleaved>(int8_t * dst_filter, \
const int8_t* src_filter, uint32_t OC, \
uint32_t IC, uint32_t FH, uint32_t FW, \
bool trans_oc, cudaStream_t stream);
INST
(
8
,
32
)
INST
(
4
,
64
)
#undef INST
#define INST(_size_bits, _alignbits) \
template void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter< \
_size_bits, _alignbits>( \
int8_t * dst_filter, const int8_t* src_filter, uint32_t OC, \
uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc, \
uint32_t oc_interleaved, cudaStream_t stream);
INST
(
4
,
32
)
INST
(
4
,
64
)
INST
(
4
,
128
)
#undef INST
// vim: syntax=cuda.doxygen
dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh
0 → 100644
浏览文件 @
b18feaab
/**
* \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/cuda/utils.cuh"
namespace
megdnn
{
namespace
cuda
{
namespace
cutlass_wrapper
{
template
<
uint32_t
size_bits
,
uint32_t
interleaved
>
void
reorder_ncxhwx_imma_filter
(
int8_t
*
dst_filter
,
const
int8_t
*
src_filter
,
uint32_t
OC
,
uint32_t
IC
,
uint32_t
FH
,
uint32_t
FW
,
bool
trans_oc
,
cudaStream_t
stream
);
template
<
uint32_t
size_bits
,
uint32_t
alignbits
>
void
reorder_nhwc_imma_filter
(
int8_t
*
dst_filter
,
const
int8_t
*
src_filter
,
uint32_t
OC
,
uint32_t
IC
,
uint32_t
FH
,
uint32_t
FW
,
bool
trans_oc
,
uint32_t
oc_interleaved
,
cudaStream_t
stream
);
}
// namespace cutlass_wrapper
}
// namespace cuda
}
// namespace megdnn
dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp
浏览文件 @
b18feaab
...
...
@@ -102,7 +102,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::do_exec(
reinterpret_cast
<
int8_t
*>
(
z_ptr
),
reinterpret_cast
<
int8_t
*>
(
args
.
dst_tensor
->
raw_ptr
),
nullptr
,
kern_param
,
nonlinear_mode
,
alpha
,
beta
,
gamma
,
dst_scale
,
threadblock_shape
,
warp_shape
,
stream
);
threadblock_shape
,
warp_shape
,
m_algo_param
.
stage
,
stream
);
}
#endif
...
...
dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp
浏览文件 @
b18feaab
...
...
@@ -104,7 +104,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec(
reinterpret_cast
<
int8_t
*>
(
args
.
dst_tensor
->
raw_ptr
),
nullptr
,
kern_param
,
nonlinear_mode
,
alpha
,
beta
,
gamma
,
dst_scale
,
threadblock_shape
,
warp_shape
,
m_algo_param
.
access_size
,
stream
);
m_algo_param
.
stage
,
stream
);
}
else
{
cutlass_wrapper
::
do_conv_bias_int4_int4_implicit_gemm_imma_nhwc
<
true
>
(
reinterpret_cast
<
int8_t
*>
(
args
.
src_tensor
->
raw_ptr
),
...
...
@@ -114,7 +114,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec(
reinterpret_cast
<
int8_t
*>
(
args
.
dst_tensor
->
raw_ptr
),
nullptr
,
kern_param
,
nonlinear_mode
,
alpha
,
beta
,
gamma
,
dst_scale
,
threadblock_shape
,
warp_shape
,
m_algo_param
.
access_size
,
stream
);
m_algo_param
.
stage
,
stream
);
}
}
#endif
...
...
dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp
浏览文件 @
b18feaab
...
...
@@ -12,6 +12,7 @@
#include "./algo.h"
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/conv_bias/reduce_filter.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
...
...
@@ -121,41 +122,26 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec(
std
::
string
ConvBiasForwardImpl
::
AlgoInt4NCHW64IMMAImplicitGemmBase
::
to_string
(
AlgoParam
algo_param
)
{
return
ssprintf
(
"%dX%dX%d_%dX%dX%d"
,
algo_param
.
threadblock_m
,
return
ssprintf
(
"%dX%dX%d_%dX%dX%d
_%d
"
,
algo_param
.
threadblock_m
,
algo_param
.
threadblock_n
,
algo_param
.
threadblock_k
,
algo_param
.
warp_m
,
algo_param
.
warp_n
,
algo_param
.
warp_k
);
algo_param
.
warp_m
,
algo_param
.
warp_n
,
algo_param
.
warp_k
,
algo_param
.
stage
);
}
void
ConvBiasForwardImpl
::
AlgoInt4NCHW64IMMAImplicitGemmBase
::
reorder_filter
(
const
ExecArgs
&
args
,
void
*
reordered_filter
)
const
{
auto
&&
param
=
args
.
opr
->
param
();
size_t
ci
=
args
.
src_layout
->
operator
[](
1
)
*
64
;
size_t
co
=
args
.
dst_layout
->
operator
[](
1
)
*
64
;
auto
&&
fm
=
args
.
filter_meta
;
size_t
n
=
args
.
src_layout
->
operator
[](
0
),
ci
=
args
.
src_layout
->
operator
[](
1
)
*
64
,
hi
=
args
.
src_layout
->
operator
[](
2
),
wi
=
args
.
src_layout
->
operator
[](
3
);
size_t
co
=
args
.
dst_layout
->
operator
[](
1
)
*
64
,
ho
=
args
.
dst_layout
->
operator
[](
2
),
wo
=
args
.
dst_layout
->
operator
[](
3
);
UNPACK_CONV_PARAMETER
(
fm
,
param
);
MARK_USED_VAR
;
// filter: KCRS64 => CRSK64
TensorLayout
src
{{
co
,
ci
/
64
,
fh
,
fw
,
64
},
dtype
::
QuantizedS4
()};
src
.
init_contiguous_stride
();
TensorLayout
dst
=
src
;
dst
.
stride
[
0
]
=
64
;
dst
.
stride
[
1
]
=
co
*
fh
*
fw
*
64
;
dst
.
stride
[
2
]
=
co
*
fw
*
64
;
dst
.
stride
[
3
]
=
co
*
64
;
dst
.
stride
[
4
]
=
1
;
TensorND
ts_src
,
ts_dst
;
ts_src
.
raw_ptr
=
args
.
filter_tensor
->
raw_ptr
;
ts_src
.
layout
=
src
;
ts_dst
.
raw_ptr
=
reordered_filter
;
ts_dst
.
layout
=
dst
;
auto
&&
transpose
=
args
.
opr
->
handle
()
->
create_operator
<
RelayoutForward
>
();
transpose
->
exec
(
ts_src
,
ts_dst
);
size_t
fh
=
fm
.
spatial
[
0
],
fw
=
fm
.
spatial
[
1
];
cudaStream_t
stream
=
cuda_stream
(
args
.
opr
->
handle
());
// filter: KCRS64 => CRSK64 and reorder oc
cutlass_wrapper
::
reorder_ncxhwx_imma_filter
<
4
,
64
>
(
reinterpret_cast
<
int8_t
*>
(
reordered_filter
),
reinterpret_cast
<
int8_t
*>
(
args
.
filter_tensor
->
raw_ptr
),
co
,
ci
,
fh
,
fw
,
true
,
stream
);
}
#endif
...
...
dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp
浏览文件 @
b18feaab
...
...
@@ -12,6 +12,7 @@
#include "./algo.h"
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/conv_bias/reduce_filter.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
...
...
@@ -128,10 +129,10 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec(
std
::
string
ConvBiasForwardImpl
::
AlgoInt4NHWCIMMAImplicitGemmBase
::
to_string
(
AlgoParam
algo_param
)
{
return
ssprintf
(
"%dX%dX%d_%dX%dX%d_%d"
,
algo_param
.
threadblock_m
,
return
ssprintf
(
"%dX%dX%d_%dX%dX%d_%d
_%d
"
,
algo_param
.
threadblock_m
,
algo_param
.
threadblock_n
,
algo_param
.
threadblock_k
,
algo_param
.
warp_m
,
algo_param
.
warp_n
,
algo_param
.
warp_k
,
algo_param
.
access_size
);
algo_param
.
stage
,
algo_param
.
access_size
);
}
void
ConvBiasForwardImpl
::
AlgoInt4NHWCIMMAImplicitGemmBase
::
reorder_filter
(
...
...
@@ -142,17 +143,32 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter(
fh
=
args
.
filter_layout
->
operator
[](
1
),
fw
=
args
.
filter_layout
->
operator
[](
2
);
// reformat grad from nhwc to ncxhwx
TensorLayout
exec_src
{{
co
,
fh
,
fw
,
ci
/
iterleaved
,
(
size_t
)
iterleaved
/
2
},
dtype
::
Int8
()};
TensorLayout
exec_dst
{{
co
,
ci
/
iterleaved
,
fh
,
fw
,
(
size_t
)
iterleaved
/
2
},
dtype
::
Int8
()};
exec_src
=
exec_src
.
dimshuffle
({
0
,
3
,
1
,
2
,
4
});
cudaStream_t
stream
=
cuda_stream
(
args
.
opr
->
handle
());
auto
&&
relayout
=
args
.
opr
->
handle
()
->
create_operator
<
RelayoutForward
>
();
relayout
->
exec
({
args
.
filter_tensor
->
raw_ptr
,
exec_src
},
{
reordered_filter
,
exec_dst
});
// reformat filter from nhwc to ncxhwx and reorder oc
// use trans_oc threadblock_n must be 32 or 64
bool
trans_oc
=
((
co
%
m_algo_param
.
threadblock_n
==
0
)
&&
(
m_algo_param
.
threadblock_n
==
32
||
m_algo_param
.
threadblock_n
==
64
));
uint32_t
oc_iterleave
=
(
m_algo_param
.
threadblock_n
==
64
)
?
64
:
32
;
if
(
iterleaved
==
8
)
{
cutlass_wrapper
::
reorder_nhwc_imma_filter
<
4
,
32
>
(
reinterpret_cast
<
int8_t
*>
(
reordered_filter
),
reinterpret_cast
<
int8_t
*>
(
args
.
filter_tensor
->
raw_ptr
),
co
,
ci
,
fh
,
fw
,
trans_oc
,
oc_iterleave
,
stream
);
}
else
if
(
iterleaved
==
16
)
{
cutlass_wrapper
::
reorder_nhwc_imma_filter
<
4
,
64
>
(
reinterpret_cast
<
int8_t
*>
(
reordered_filter
),
reinterpret_cast
<
int8_t
*>
(
args
.
filter_tensor
->
raw_ptr
),
co
,
ci
,
fh
,
fw
,
trans_oc
,
oc_iterleave
,
stream
);
}
else
{
megdnn_assert
(
iterleaved
==
32
);
cutlass_wrapper
::
reorder_nhwc_imma_filter
<
4
,
128
>
(
reinterpret_cast
<
int8_t
*>
(
reordered_filter
),
reinterpret_cast
<
int8_t
*>
(
args
.
filter_tensor
->
raw_ptr
),
co
,
ci
,
fh
,
fw
,
trans_oc
,
oc_iterleave
,
stream
);
}
}
#endif
...
...
dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp
浏览文件 @
b18feaab
...
...
@@ -11,6 +11,7 @@
*/
#include "./algo.h"
#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh"
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/utils.h"
...
...
@@ -110,11 +111,14 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
size_t
ho
=
args
.
dst_layout
->
operator
[](
2
),
wo
=
args
.
dst_layout
->
operator
[](
3
);
size_t
co
;
bool
trans_oc
;
if
(
param
.
format
==
Format
::
NCHW32
)
{
co
=
args
.
dst_layout
->
operator
[](
1
)
*
32
;
trans_oc
=
true
;
}
else
{
megdnn_assert
(
param
.
format
==
Format
::
NCHW32_NCHW4
);
co
=
args
.
dst_layout
->
operator
[](
1
)
*
4
;
trans_oc
=
false
;
}
UNPACK_CONV_PARAMETER
(
fm
,
param
);
MARK_USED_VAR
...
...
@@ -123,23 +127,11 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
int8_t
*
filter_ptr
=
nullptr
;
if
(
args
.
preprocessed_filter
==
nullptr
)
{
filter_ptr
=
reinterpret_cast
<
int8_t
*>
(
args
.
workspace
.
raw_ptr
);
// reformat filter from nchw32 to chwn32
TensorLayout
src
{{
co
,
ci
/
32
,
fh
,
fw
,
32
},
dtype
::
Int8
()};
src
.
init_contiguous_stride
();
TensorLayout
dst
=
src
;
dst
.
stride
[
0
]
=
32
;
dst
.
stride
[
1
]
=
co
*
fh
*
fw
*
32
;
dst
.
stride
[
2
]
=
co
*
fw
*
32
;
dst
.
stride
[
3
]
=
co
*
32
;
dst
.
stride
[
4
]
=
1
;
TensorND
ts_src
,
ts_dst
;
ts_src
.
raw_ptr
=
args
.
filter_tensor
->
raw_ptr
;
ts_src
.
layout
=
src
;
ts_dst
.
raw_ptr
=
args
.
workspace
.
raw_ptr
;
ts_dst
.
layout
=
dst
;
auto
&&
transpose
=
args
.
opr
->
handle
()
->
create_operator
<
RelayoutForward
>
();
transpose
->
exec
(
ts_src
,
ts_dst
);
// filter: KCRS32 => CRSK32 and reorder oc
cutlass_wrapper
::
reorder_ncxhwx_imma_filter
<
8
,
32
>
(
filter_ptr
,
reinterpret_cast
<
int8_t
*>
(
args
.
filter_tensor
->
raw_ptr
),
co
,
ci
,
fh
,
fw
,
trans_oc
,
stream
);
}
else
{
filter_ptr
=
reinterpret_cast
<
int8_t
*>
(
args
.
preprocessed_filter
->
tensors
[
0
].
raw_ptr
);
...
...
@@ -182,7 +174,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
cutlass_wrapper
::
GemmCoord
{
m_algo_param
.
warp_m
,
m_algo_param
.
warp_n
,
m_algo_param
.
warp_k
},
stream
);
m_algo_param
.
stage
,
stream
);
}
else
{
megdnn_assert
(
param
.
format
==
Format
::
NCHW32_NCHW4
);
cutlass_wrapper
::
...
...
@@ -202,7 +194,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
cutlass_wrapper
::
GemmCoord
{
m_algo_param
.
warp_m
,
m_algo_param
.
warp_n
,
m_algo_param
.
warp_k
},
stream
);
m_algo_param
.
stage
,
stream
);
}
}
else
{
if
(
param
.
format
==
Format
::
NCHW32
)
{
...
...
@@ -218,7 +210,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
cutlass_wrapper
::
GemmCoord
{
m_algo_param
.
warp_m
,
m_algo_param
.
warp_n
,
m_algo_param
.
warp_k
},
stream
);
m_algo_param
.
stage
,
stream
);
}
else
{
megdnn_assert
(
param
.
format
==
Format
::
NCHW32_NCHW4
);
cutlass_wrapper
::
...
...
@@ -238,7 +230,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
cutlass_wrapper
::
GemmCoord
{
m_algo_param
.
warp_m
,
m_algo_param
.
warp_n
,
m_algo_param
.
warp_k
},
stream
);
m_algo_param
.
stage
,
stream
);
}
}
after_kernel_launch
();
...
...
@@ -246,9 +238,10 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec(
std
::
string
ConvBiasForwardImpl
::
AlgoInt8NCHW32IMMAImplicitGemm
::
to_string
(
AlgoParam
algo_param
)
{
return
ssprintf
(
"%uX%uX%u_%uX%uX%u"
,
algo_param
.
threadblock_m
,
return
ssprintf
(
"%uX%uX%u_%uX%uX%u
_%u
"
,
algo_param
.
threadblock_m
,
algo_param
.
threadblock_n
,
algo_param
.
threadblock_k
,
algo_param
.
warp_m
,
algo_param
.
warp_n
,
algo_param
.
warp_k
);
algo_param
.
warp_m
,
algo_param
.
warp_n
,
algo_param
.
warp_k
,
algo_param
.
stage
);
}
size_t
ConvBiasForwardImpl
::
AlgoInt8NCHW32IMMAImplicitGemm
::
...
...
@@ -267,36 +260,26 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec_preprocess(
using
Format
=
Param
::
Format
;
auto
&&
param
=
args
.
opr
->
param
();
auto
&&
fm
=
args
.
filter_meta
;
size_t
n
=
args
.
src_layout
->
operator
[](
0
),
ci
=
args
.
src_layout
->
operator
[](
1
)
*
32
,
hi
=
args
.
src_layout
->
operator
[](
2
),
wi
=
args
.
src_layout
->
operator
[](
3
);
size_t
ho
=
args
.
dst_layout
->
operator
[](
2
),
wo
=
args
.
dst_layout
->
operator
[](
3
);
size_t
ci
=
args
.
src_layout
->
operator
[](
1
)
*
32
;
size_t
co
;
bool
trans_oc
;
if
(
param
.
format
==
Format
::
NCHW32
)
{
co
=
args
.
dst_layout
->
operator
[](
1
)
*
32
;
trans_oc
=
true
;
}
else
{
megdnn_assert
(
param
.
format
==
Format
::
NCHW32_NCHW4
);
co
=
args
.
dst_layout
->
operator
[](
1
)
*
4
;
trans_oc
=
false
;
}
UNPACK_CONV_PARAMETER
(
fm
,
param
);
MARK_USED_VAR
TensorLayout
src
{{
co
,
ci
/
32
,
fh
,
fw
,
32
},
dtype
::
Int8
()};
src
.
init_contiguous_stride
();
TensorLayout
dst
=
src
;
dst
.
stride
[
0
]
=
32
;
dst
.
stride
[
1
]
=
co
*
fh
*
fw
*
32
;
dst
.
stride
[
2
]
=
co
*
fw
*
32
;
dst
.
stride
[
3
]
=
co
*
32
;
dst
.
stride
[
4
]
=
1
;
TensorND
ts_src
,
ts_dst
;
ts_src
.
raw_ptr
=
args
.
filter_tensor
->
raw_ptr
;
ts_src
.
layout
=
src
;
ts_dst
.
raw_ptr
=
args
.
preprocessed_filter
->
tensors
[
0
].
raw_ptr
;
ts_dst
.
layout
=
dst
;
auto
&&
transpose
=
args
.
opr
->
handle
()
->
create_operator
<
RelayoutForward
>
();
transpose
->
exec
(
ts_src
,
ts_dst
);
size_t
fh
=
fm
.
spatial
[
0
],
fw
=
fm
.
spatial
[
1
];
cudaStream_t
stream
=
cuda_stream
(
args
.
opr
->
handle
());
// filter: KCRS32 => CRSK32 and reorder oc
cutlass_wrapper
::
reorder_ncxhwx_imma_filter
<
8
,
32
>
(
reinterpret_cast
<
int8_t
*>
(
args
.
preprocessed_filter
->
tensors
[
0
].
raw_ptr
),
reinterpret_cast
<
int8_t
*>
(
args
.
filter_tensor
->
raw_ptr
),
co
,
ci
,
fh
,
fw
,
trans_oc
,
stream
);
}
#endif
...
...
dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp
浏览文件 @
b18feaab
...
...
@@ -144,7 +144,8 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::do_exec(
reinterpret_cast
<
uint8_t
*>
(
z_ptr
),
reinterpret_cast
<
uint8_t
*>
(
args
.
dst_tensor
->
raw_ptr
),
nullptr
,
kern_param
,
nonlinear_mode
,
alpha
,
beta
,
gamma
,
delta
,
theta
,
dst_scale
,
src_zero
,
threadblock_shape
,
warp_shape
,
stream
);
dst_scale
,
src_zero
,
threadblock_shape
,
warp_shape
,
m_algo_param
.
stage
,
stream
);
}
void
ConvBiasForwardImpl
::
AlgoUInt4Int4NCHW64IMMAImplicitGemm
::
update_bias
(
...
...
dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp
浏览文件 @
b18feaab
...
...
@@ -147,7 +147,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec(
reinterpret_cast
<
uint8_t
*>
(
args
.
dst_tensor
->
raw_ptr
),
nullptr
,
kern_param
,
nonlinear_mode
,
alpha
,
beta
,
gamma
,
delta
,
theta
,
dst_scale
,
src_zero
,
threadblock_shape
,
warp_shape
,
m_algo_param
.
access_size
,
stream
);
m_algo_param
.
access_size
,
m_algo_param
.
stage
,
stream
);
}
else
{
cutlass_wrapper
::
do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc
<
true
>
(
reinterpret_cast
<
uint8_t
*>
(
args
.
src_tensor
->
raw_ptr
),
...
...
@@ -157,7 +157,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec(
reinterpret_cast
<
uint8_t
*>
(
args
.
dst_tensor
->
raw_ptr
),
nullptr
,
kern_param
,
nonlinear_mode
,
alpha
,
beta
,
gamma
,
delta
,
theta
,
dst_scale
,
src_zero
,
threadblock_shape
,
warp_shape
,
m_algo_param
.
access_size
,
stream
);
m_algo_param
.
access_size
,
m_algo_param
.
stage
,
stream
);
}
}
...
...
dnn/test/cuda/conv_bias_int8.cpp
浏览文件 @
b18feaab
...
...
@@ -840,21 +840,21 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) {
param
.
pad_h
=
param
.
pad_w
=
1
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW32
;
checker
.
set_param
(
param
).
execs
({{
16
,
16
,
7
,
7
,
32
},
{
512
,
16
,
3
,
3
,
32
},
{
1
,
16
,
1
,
1
,
32
},
checker
.
set_param
(
param
).
execs
({{
16
,
8
,
7
,
7
,
32
},
{
256
,
8
,
3
,
3
,
32
},
{
1
,
8
,
1
,
1
,
32
},
{},
{}});
param
.
nonlineMode
=
param
::
ConvBias
::
NonlineMode
::
RELU
;
checker
.
set_param
(
param
).
execs
({{
16
,
16
,
7
,
7
,
32
},
{
512
,
16
,
1
,
1
,
32
},
{
1
,
16
,
1
,
1
,
32
},
checker
.
set_param
(
param
).
execs
({{
16
,
8
,
7
,
7
,
32
},
{
256
,
8
,
1
,
1
,
32
},
{
1
,
8
,
1
,
1
,
32
},
{},
{}});
param
.
nonlineMode
=
param
::
ConvBias
::
NonlineMode
::
H_SWISH
;
checker
.
set_param
(
param
).
execs
({{
16
,
16
,
7
,
7
,
32
},
{
512
,
16
,
3
,
3
,
32
},
{
1
,
16
,
1
,
1
,
32
},
checker
.
set_param
(
param
).
execs
({{
16
,
8
,
7
,
7
,
32
},
{
256
,
8
,
3
,
3
,
32
},
{
1
,
8
,
1
,
1
,
32
},
{},
{}});
// use non integer scale
...
...
@@ -867,18 +867,18 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) {
.
set_epsilon
(
1
+
1e-3
)
.
set_max_avg_error
(
1e-1
)
.
set_max_avg_biased_error
(
1e-1
)
.
execs
({{
16
,
16
,
7
,
7
,
32
},
{
512
,
16
,
3
,
3
,
32
},
{
1
,
16
,
1
,
1
,
32
},
{
16
,
16
,
7
,
7
,
32
},
.
execs
({{
16
,
8
,
7
,
7
,
32
},
{
256
,
8
,
3
,
3
,
32
},
{
1
,
8
,
1
,
1
,
32
},
{
16
,
8
,
7
,
7
,
32
},
{}});
};
std
::
string
algo
=
ConvBias
::
algo_name
<
ConvBias
::
DirectParam
>
(
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_
256X128X64_64X64X64
"
,
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_
128X128X64_64X64X64_2
"
,
ConvBias
::
DirectParam
{});
check
(
algo
);
algo
=
ConvBias
::
algo_name
<
ConvBias
::
DirectParam
>
(
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_
32X64X64_32X16X64
"
,
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_
128X32X32_64X32X32_1
"
,
ConvBias
::
DirectParam
{});
check
(
algo
);
}
...
...
@@ -969,7 +969,7 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_NCHW4) {
checker
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBiasForward
>
(
ConvBias
::
algo_name
<
ConvBias
::
DirectParam
>
(
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_
256X128X64_64X64X64
"
,
"INT8_NCHW32_IMMA_IMPLICIT_GEMM_
128X128X64_64X64X64_2
"
,
ConvBias
::
DirectParam
{})
.
c_str
()));
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
1.9980618
f
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录