Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ed922075
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ed922075
编写于
3月 11, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add conv bias impl for int4 data type using sass language
GitOrigin-RevId: ae3d3e1c987247add166fe608cd54b8a70513c4e
上级
52b55564
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
376 addition
and
26 deletion
+376
-26
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+5
-3
dnn/src/common/conv_bias.cpp
dnn/src/common/conv_bias.cpp
+14
-20
dnn/src/common/convolution.cpp
dnn/src/common/convolution.cpp
+34
-2
dnn/src/common/utils.cpp
dnn/src/common/utils.cpp
+11
-0
dnn/src/common/utils.h
dnn/src/common/utils.h
+2
-0
dnn/src/cuda/conv_bias/algo.h
dnn/src/cuda/conv_bias/algo.h
+0
-1
dnn/src/cuda/conv_bias/conv_bias_int8.cuh
dnn/src/cuda/conv_bias/conv_bias_int8.cuh
+8
-0
dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp
...rc/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp
+302
-0
未找到文件。
dnn/scripts/opr_param_defs.py
浏览文件 @
ed922075
...
...
@@ -36,7 +36,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum
(
Doc
(
'Format'
,
'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'
),
'NCHW'
,
'NHWC'
,
'NHWCD4'
,
'NCHW4'
,
'NCHW8'
,
'NCHW32'
,
'NCHW88'
,
'NCHW44'
,
'NCHW44_DOT'
,
'NCHW44'
,
'NCHW44_DOT'
,
Doc
(
'NCHW_WINOGRAD'
,
'NCHW layout with weights tranformed by winograd'
),
Doc
(
'NCHW88_WINOGRAD'
,
'NCHW88 layout with weights tranformed by winograd'
),
Doc
(
'NCHW44_WINOGRAD'
,
'NCHW44 layout with weights tranformed by winograd'
),
...
...
@@ -95,7 +95,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum
(
Doc
(
'Format'
,
'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'
),
'NCHW'
,
'NHWC'
,
'NHWCD4'
,
'NCHW4'
,
'NCHW8'
,
'NCHW32'
,
'NCHW88'
,
'NCHW44'
,
'NCHW44_DOT'
,
'NCHW44'
,
'NCHW44_DOT'
,
Doc
(
'NCHW4_NCHW32'
,
'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'
),
Doc
(
'NCHW32_NCHW4'
,
'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'
),
Doc
(
'NCHW4_NCHW'
,
'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'
),
...
...
@@ -106,7 +106,9 @@ pdef('Axis').add_fields('int32', 'axis', 0)
Doc
(
'NCHW_NCHW4_IC_SMALL'
,
'NCHW_NCHW4_IC_SMALL means input tensors are nchw(c < 4) layout, '
'output tensor is nchw4 layout, padding c=4'
),
Doc
(
'CHWN4'
,
'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'
)).
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'
),
Doc
(
'NCHW64'
,
'NCHW64 is designed for convolution implementation to utilizing TensorCore '
'instructions for 4-bit integers on Nvidia platforms'
)).
add_enum_alias
(
'ComputeMode'
,
'ConvolutionV1'
,
name_field
=
'compute_mode'
)
)
...
...
dnn/src/common/conv_bias.cpp
浏览文件 @
ed922075
...
...
@@ -36,28 +36,15 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
,
const
PreprocessedFilter
*
preprocessed_filter
)
{
megdnn_assert
(
src
.
dtype
.
enumv
()
==
filter
.
dtype
.
enumv
());
if
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
// check compatibility of bias's scale
if
(
src
.
dtype
.
category
()
==
DTypeCategory
::
QUANTIZED
)
{
if
(
bias
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
float
scale_src
=
src
.
dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
float
scale_filter
=
filter
.
dtype
.
param
<
dtype
::
QuantizedS8
>
().
scale
;
float
scale_expected
=
mul_scale
(
src
.
dtype
,
filter
.
dtype
);
float
scale_bias
=
bias
.
dtype
.
param
<
dtype
::
QuantizedS32
>
().
scale
;
megdnn_assert
(
std
::
abs
(
scale_src
*
scale_filter
-
scale_bias
)
<
1e-6
,
"scale_src: %f scale_filter: %f scale_bias: %f"
,
scale_src
,
scale_filter
,
scale_bias
);
}
else
{
megdnn_assert
(
bias
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
);
}
}
else
if
(
src
.
dtype
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
{
if
(
bias
.
dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
float
scale_src
=
src
.
dtype
.
param
<
dtype
::
Quantized8Asymm
>
().
scale
;
float
scale_filter
=
filter
.
dtype
.
param
<
dtype
::
Quantized8Asymm
>
().
scale
;
float
scale_bias
=
bias
.
dtype
.
param
<
dtype
::
QuantizedS32
>
().
scale
;
megdnn_assert
(
std
::
abs
(
scale_src
*
scale_filter
-
scale_bias
)
<
1e-6
,
"scale_src: %f scale_filter: %f scale_bias: %f"
,
scale_src
,
scale_filter
,
scale_bias
);
megdnn_assert
(
std
::
abs
(
scale_expected
-
scale_bias
)
<
1e-6
,
"scale_src: %f scale_filter: %f scale_bias: %f"
,
get_scale
(
src
.
dtype
),
get_scale
(
filter
.
dtype
),
scale_bias
);
}
else
{
megdnn_assert
(
bias
.
dtype
.
enumv
()
==
DTypeEnum
::
Float32
);
}
...
...
@@ -127,6 +114,13 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
megdnn_assert
(
bias
.
shape
[
2
]
==
1
);
megdnn_assert
(
bias
.
shape
[
3
]
==
1
);
megdnn_assert
(
bias
.
shape
[
4
]
==
4
);
}
else
if
(
param
().
format
==
param
::
ConvBias
::
Format
::
NCHW64
)
{
megdnn_assert
(
bias
.
shape
[
0
]
==
1
);
megdnn_assert
(
bias
.
shape
[
1
]
==
dst
.
shape
[
1
],
"bias:%s, dst:%s"
,
bias
.
to_string
().
c_str
(),
dst
.
to_string
().
c_str
());
megdnn_assert
(
bias
.
shape
[
2
]
==
1
);
megdnn_assert
(
bias
.
shape
[
3
]
==
1
);
megdnn_assert
(
bias
.
shape
[
4
]
==
64
);
}
else
{
megdnn_assert
(
param
().
format
==
param
::
ConvBias
::
Format
::
NHWCD4
);
megdnn_assert
(
bias
.
shape
[
0
]
==
1
);
...
...
dnn/src/common/convolution.cpp
浏览文件 @
ed922075
...
...
@@ -370,7 +370,8 @@ void make_canonized_filter_meta_nchwx(
param
.
format
==
Param
::
Format
::
NCHW32
||
param
.
format
==
Param
::
Format
::
NCHW4_NCHW
||
param
.
format
==
Param
::
Format
::
NCHW4_NCHW32
||
param
.
format
==
Param
::
Format
::
NCHW32_NCHW4
);
param
.
format
==
Param
::
Format
::
NCHW32_NCHW4
||
param
.
format
==
Param
::
Format
::
NCHW64
);
auto
img_ndim
=
src_ndim
-
3
;
size_t
flt_start
=
0
,
flt_spatial_start
=
2
;
if
(
param
.
sparse
==
Param
::
Sparse
::
DENSE
)
{
...
...
@@ -517,6 +518,9 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta(
}
else
if
(
param
().
format
==
Param
::
Format
::
CHWN4
)
{
make_canonized_filter_meta_chwnx
<
4
,
Parameter
>
(
src_ndim
,
filter
,
param
(),
ret
);
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW64
)
{
make_canonized_filter_meta_nchwx
<
64
,
Parameter
>
(
src_ndim
,
filter
,
param
(),
ret
);
}
else
{
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NHWC
||
param
().
format
==
Param
::
Format
::
NCHW
);
...
...
@@ -539,6 +543,7 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src,
supported_dst_dtype
=
{
dtype
::
Int32
(),
dtype
::
Int16
()};
}
else
if
(
src
.
enumv
()
==
DTypeEnum
::
QuantizedS8
||
src
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
||
src
.
enumv
()
==
DTypeEnum
::
QuantizedS4
||
src
.
enumv
()
==
DTypeEnum
::
Quantized4Asymm
)
{
supported_dst_dtype
.
push_back
(
dtype
::
QuantizedS32
(
mul_scale
(
src
,
filter
)));
...
...
@@ -614,7 +619,8 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
param
().
format
==
Param
::
Format
::
NCHW32
||
param
().
format
==
Param
::
Format
::
NCHW32_NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
CHWN4
);
param
().
format
==
Param
::
Format
::
CHWN4
||
param
().
format
==
Param
::
Format
::
NCHW64
);
img_dim
=
src
.
ndim
-
3
;
if
((
param
().
format
==
Param
::
Format
::
NCHW88
||
param
().
format
==
Param
::
Format
::
NCHW44_DOT
||
...
...
@@ -712,6 +718,15 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
"but got src %s, filter %s"
,
src
.
to_string
().
c_str
(),
filter
.
to_string
().
c_str
());
}
if
(
param
().
format
==
Param
::
Format
::
NCHW64
)
{
megdnn_assert
(
src
.
ndim
==
5
&&
(
filter
.
ndim
==
5
||
filter
.
ndim
==
6
)
&&
src
[
src
.
ndim
-
1
]
==
64
&&
filter
[
filter
.
ndim
-
1
]
==
4
,
"NCHW64 require src and filter's ndim is 5 or 6, and "
"last shape is 64 but got src %s, filter %s"
,
src
.
to_string
().
c_str
(),
filter
.
to_string
().
c_str
());
}
}
megdnn_assert
(
img_dim
==
2
,
"currently only convolution on 2D image is supported"
);
...
...
@@ -899,6 +914,23 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
dst
[
3
]
=
infer_conv_shape
(
src
[
3
],
cflt
.
dilated_spatial
[
1
],
cflt
.
stride
[
1
],
cflt
.
padding
[
1
]);
dst
[
4
]
=
4
;
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW64
)
{
megdnn_assert
(
src
.
ndim
==
5
,
"invalid src ndim for NCHW64, expected=5, got=%zu"
,
src
.
ndim
);
megdnn_assert
(
cflt
.
icpg
*
cflt
.
group
==
src
[
1
]
*
64
,
"%s icpg=%u group=%u"
,
errmsg
().
c_str
(),
cflt
.
icpg
,
cflt
.
group
);
dst
.
ndim
=
src
.
ndim
;
dst
[
0
]
=
src
[
0
];
auto
oc
=
cflt
.
ocpg
*
cflt
.
group
;
megdnn_assert
(
oc
%
64
==
0
);
dst
[
1
]
=
oc
/
64
;
dst
[
2
]
=
infer_conv_shape
(
src
[
2
],
cflt
.
dilated_spatial
[
0
],
cflt
.
stride
[
0
],
cflt
.
padding
[
0
]);
dst
[
3
]
=
infer_conv_shape
(
src
[
3
],
cflt
.
dilated_spatial
[
1
],
cflt
.
stride
[
1
],
cflt
.
padding
[
1
]);
dst
[
4
]
=
64
;
}
else
{
megdnn_assert
(
param
().
format
==
Param
::
Format
::
NHWCD4
);
megdnn_assert
(
src
.
ndim
==
5
,
...
...
dnn/src/common/utils.cpp
浏览文件 @
ed922075
...
...
@@ -245,6 +245,17 @@ float megdnn::mul_scale(DType lhs, DType rhs) {
}
// clang-format on
float
megdnn
::
get_scale
(
DType
dt
)
{
megdnn_assert
(
dt
.
category
()
==
DTypeCategory
::
QUANTIZED
);
#define cb(_dt) \
if (dt.enumv() == DTypeTrait<_dt>::enumv) \
return dt.param<_dt>().scale;
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
#undef cb
megdnn_assert_internal
(
0
);
}
bool
megdnn
::
dtype_almost_equal
(
DType
lhs
,
DType
rhs
)
{
if
(
lhs
.
enumv
()
!=
rhs
.
enumv
())
return
false
;
...
...
dnn/src/common/utils.h
浏览文件 @
ed922075
...
...
@@ -504,6 +504,8 @@ bool vec_contains(const SmallVector<T>& vec, const T& elem) {
float
mul_scale
(
DType
lhs
,
DType
rhs
);
float
get_scale
(
DType
dt
);
template
<
typename
stype
,
typename
dtype
>
dtype
convert
(
stype
src
,
dtype
dst
,
size_t
offset
);
...
...
dnn/src/cuda/conv_bias/algo.h
浏览文件 @
ed922075
...
...
@@ -807,7 +807,6 @@ public:
AlgoBatchedMatmul
batched_matmul
;
std
::
vector
<
AlgoInt8NCHW4DotProdImplicitGemm
>
int8_nchw4_dotprod
;
AlgoInt8CHWN4DotProdImplicitGemm
int8_chwn4_dotprod
;
<<<<<<<
HEAD
#if CUDA_VERSION >= 10000
AlgoQUInt4x4x32WMMA
wmma_quint4x4x32
;
std
::
vector
<
AlgoInt8CHWN4IMMAImplicitGemm
>
int8_chwn4_imma
;
...
...
dnn/src/cuda/conv_bias/conv_bias_int8.cuh
浏览文件 @
ed922075
...
...
@@ -150,4 +150,12 @@ void do_conv_bias_int8_implicit_gemm_imma8x32x16_cdiv4hwn4_unroll_width(
UNPACK_CONV_PARAMETER(_filter_meta, _param); \
MARK_USED_VAR
#define UNPACK_CONV_BIAS_NCHW64_PARAM(_src, _filter_meta, _dst, _param) \
using Format = param::ConvBias::Format; \
megdnn_assert(_param.format == Format::NCHW64); \
size_t n = (_src)[0], ci = (_src)[1] * 64, hi = (_src)[2], wi = (_src)[3]; \
size_t co = (_dst)[1] * 64, ho = (_dst)[2], wo = (_dst)[3]; \
UNPACK_CONV_PARAMETER(_filter_meta, _param); \
MARK_USED_VAR
// vim: syntax=cuda.doxygen
dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp
0 → 100644
浏览文件 @
ed922075
/**
* \file dnn/src/cuda/conv_bias/sass_implicit_gemm_int4_nchw64_imma.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./algo.h"
#include "src/cuda/conv_bias/sass_helper.cuh"
#include "src/cuda/sass_loader.h"
#include "src/cuda/utils.h"
#include "src/common/conv_bias.h"
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
sass
;
namespace
{
#if !MEGDNN_TEGRA_X1
// all stride are in bytes
void
compute_conv2d_offset
(
size_t
fh
,
size_t
fw
,
size_t
ics
,
size_t
ihs
,
Conv2dConstantOffset
&
constant_offset
)
{
constexpr
int
interleaved
=
64
;
constexpr
int
size_bits
=
4
;
constexpr
int
threablock_k
=
128
;
constexpr
int
inc_step
=
threablock_k
/
interleaved
;
size_t
i
=
0
;
int
*
s32
=
reinterpret_cast
<
int
*>
(
&
(
constant_offset
.
c_offset
[
0
]));
for
(;
i
<
inc_step
;
i
++
)
{
int
c
=
i
/
(
fh
*
fw
);
int
khkw
=
i
%
(
fh
*
fw
);
int
kh
=
khkw
/
fw
;
int
kw
=
khkw
%
fw
;
s32
[
2
*
i
]
=
c
*
ics
+
kh
*
ihs
+
kw
*
interleaved
*
size_bits
/
8
;
int8_t
*
s8
=
reinterpret_cast
<
int8_t
*>
(
&
(
s32
[
2
*
i
+
1
]));
s8
[
0
]
=
kh
;
s8
[
1
]
=
kw
;
s8
[
2
]
=
-
kh
;
s8
[
3
]
=
-
kw
;
}
for
(;
i
<
(
inc_step
+
fh
*
fw
*
inc_step
);
i
++
)
{
int
c
=
i
/
(
fh
*
fw
);
int
khkw
=
i
%
(
fh
*
fw
);
int
kh
=
khkw
/
fw
;
int
kw
=
khkw
%
fw
;
s32
[
2
*
i
]
=
c
*
ics
+
kh
*
ihs
+
kw
*
interleaved
*
size_bits
/
8
;
int8_t
*
s8
=
reinterpret_cast
<
int8_t
*>
(
&
(
s32
[
2
*
i
+
1
]));
s8
[
0
]
=
kh
;
s8
[
1
]
=
kw
;
s8
[
2
]
=
-
kh
;
s8
[
3
]
=
-
kw
;
int
i_
=
i
-
inc_step
;
c
=
i_
/
(
fh
*
fw
);
khkw
=
i_
%
(
fh
*
fw
);
kh
=
khkw
/
fw
;
kw
=
khkw
%
fw
;
s32
[
2
*
i
]
-=
c
*
ics
+
kh
*
ihs
+
kw
*
interleaved
*
size_bits
/
8
;
}
}
#endif
};
// namespace
std
::
string
ConvBiasForwardImpl
::
AlgoSASSInt4NCHW64IMMAImplicitGemm
::
kernel_key
(
const
SizeArgs
&
args
)
const
{
std
::
string
kernel_key
;
using
NonlineMode
=
Param
::
NonlineMode
;
auto
&&
param
=
args
.
opr
->
param
();
if
(
args
.
z_layout
->
ndim
>
0
)
{
kernel_key
=
ssprintf
(
"%s_conv_bias_int4_fuse_z_imma_ldg16_%ux%u"
,
current_device_arch_name
(),
m_tile_nhw
,
m_tile_oc
);
}
else
{
kernel_key
=
ssprintf
(
"%s_conv_bias_int4_imma_ldg16_%ux%u"
,
current_device_arch_name
(),
m_tile_nhw
,
m_tile_oc
);
}
if
(
param
.
nonlineMode
==
NonlineMode
::
H_SWISH
)
{
kernel_key
+=
"_hswish"
;
}
else
{
megdnn_assert
(
param
.
nonlineMode
==
NonlineMode
::
RELU
||
param
.
nonlineMode
==
NonlineMode
::
IDENTITY
);
kernel_key
+=
"_relu"
;
}
return
kernel_key
;
}
bool
ConvBiasForwardImpl
::
AlgoSASSInt4NCHW64IMMAImplicitGemm
::
is_available
(
const
SizeArgs
&
args
)
const
{
if
(
args
.
bias_layout
->
ndim
<=
0
)
return
false
;
using
Param
=
param
::
ConvBias
;
using
Format
=
Param
::
Format
;
using
Sparse
=
Param
::
Sparse
;
using
Mode
=
Param
::
Mode
;
bool
available
=
true
;
auto
&&
param
=
args
.
opr
->
param
();
auto
&&
fm
=
args
.
filter_meta
;
if
(
!
check_bias_share_in_channel
(
*
(
args
.
bias_layout
),
param
.
format
))
return
false
;
if
(
param
.
format
!=
Format
::
NCHW64
)
return
false
;
UNPACK_CONV_BIAS_NCHW64_PARAM
(
*
(
args
.
src_layout
),
fm
,
*
(
args
.
dst_layout
),
param
);
// TODO support group conv
available
&=
param
.
sparse
==
Sparse
::
DENSE
;
// mode must be cross correlation
available
&=
param
.
mode
==
Mode
::
CROSS_CORRELATION
;
// check data type
auto
src_dtype
=
args
.
src_layout
->
dtype
,
filter_dtype
=
args
.
filter_layout
->
dtype
,
bias_dtype
=
args
.
bias_layout
->
dtype
,
dst_dtype
=
args
.
dst_layout
->
dtype
;
available
&=
(
src_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
filter_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
&&
bias_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS32
&&
dst_dtype
.
enumv
()
==
DTypeEnum
::
QuantizedS4
);
// TODO: support dialtion
available
&=
dh
==
1
&&
dw
==
1
;
// ensure precomputed offsets are positive integers
available
&=
hi
>=
fh
&&
wi
>=
fw
;
// only support sm_75 or later, platform should have tensorcore int8
// support
available
&=
is_compute_capability_required
(
7
,
5
);
// param buffer size is 4K, use 3K to store precomputed offset, fh * fw <=
// (3*1024/4/2/2) - 1
available
&=
fh
*
fw
<=
191
;
return
available
;
}
size_t
ConvBiasForwardImpl
::
AlgoSASSInt4NCHW64IMMAImplicitGemm
::
get_workspace_in_bytes
(
const
SizeArgs
&
args
)
const
{
if
(
args
.
preprocessed_filter
==
nullptr
)
{
return
args
.
filter_layout
->
span
().
dist_byte
()
+
args
.
bias_layout
->
span
().
dist_byte
();
}
return
0
_z
;
}
void
ConvBiasForwardImpl
::
AlgoSASSInt4NCHW64IMMAImplicitGemm
::
exec
(
const
ExecArgs
&
args
)
const
{
#if MEGDNN_TEGRA_X1
megdnn_throw
(
"sass kernel is disabled at compile time for TX1"
);
#else
using
Format
=
Param
::
Format
;
auto
&&
param
=
args
.
opr
->
param
();
auto
&&
fm
=
args
.
filter_meta
;
UNPACK_CONV_BIAS_NCHW64_PARAM
(
*
(
args
.
src_layout
),
fm
,
*
(
args
.
dst_layout
),
param
);
auto
&&
stream
=
cuda_stream
(
args
.
opr
->
handle
());
constexpr
int
interleaved
=
64
;
void
*
bias_ptr
=
nullptr
;
void
*
filter_ptr
=
nullptr
;
if
(
args
.
preprocessed_filter
)
{
megdnn_assert
(
args
.
preprocessed_filter
->
tensors
.
size
()
==
2
);
filter_ptr
=
args
.
preprocessed_filter
->
tensors
[
0
].
raw_ptr
;
bias_ptr
=
args
.
preprocessed_filter
->
tensors
[
1
].
raw_ptr
;
}
else
{
// reorder filter and bias
filter_ptr
=
reinterpret_cast
<
void
*>
(
args
.
workspace
.
raw_ptr
);
bias_ptr
=
reinterpret_cast
<
void
*>
(
args
.
workspace
.
raw_ptr
+
args
.
filter_layout
->
span
().
dist_byte
());
reorder_imma_filter_bias
<
4
,
64
>
(
reinterpret_cast
<
int8_t
*>
(
filter_ptr
),
reinterpret_cast
<
int32_t
*>
(
bias_ptr
),
args
.
filter_tensor
->
compatible_ptr
<
int8_t
>
(),
args
.
bias_tensor
->
compatible_ptr
<
int32_t
>
(),
co
,
ci
,
fh
,
fw
,
stream
);
}
uint32_t
u32_n
=
n
,
u32_ci
=
ci
,
u32_hi
=
hi
,
u32_wi
=
wi
,
u32_fh
=
fh
,
u32_fw
=
fw
,
u32_sh
=
sh
,
u32_sw
=
sw
,
u32_ph
=
ph
,
u32_pw
=
pw
,
u32_co
=
co
,
u32_ho
=
ho
,
u32_wo
=
wo
;
Conv2dInt4Param
kern_param
(
u32_n
,
u32_ci
,
u32_hi
,
u32_wi
,
u32_fh
,
u32_fw
,
u32_sh
,
u32_sw
,
u32_ph
,
u32_pw
,
u32_co
,
u32_ho
,
u32_wo
,
interleaved
);
Conv2dConstantOffset
kern_coffset
;
compute_conv2d_offset
(
fh
,
fw
,
kern_param
.
ics
,
kern_param
.
ihs
,
kern_coffset
);
// The starting address of Turing param buffer is c[0x0][0x160]
kern_coffset
.
c_offset_param
.
begin
=
param_buffer_start_address
();
kern_coffset
.
c_offset_param
.
size
=
16
*
(
1
+
fh
*
fw
);
kern_coffset
.
c_offset_param
.
max
=
16
*
fh
*
fw
;
kern_coffset
.
c_offset_param
.
rewind
=
16
*
(
1
-
fh
*
fw
);
auto
kern_key
=
kernel_key
(
args
);
float
src_scale
=
args
.
src_layout
->
dtype
.
param
<
dtype
::
QuantizedS4
>
().
scale
,
filter_scale
=
args
.
filter_layout
->
dtype
.
param
<
dtype
::
QuantizedS4
>
().
scale
,
bias_scale
=
args
.
bias_layout
->
dtype
.
param
<
dtype
::
QuantizedS32
>
().
scale
,
dst_scale
=
args
.
dst_layout
->
dtype
.
param
<
dtype
::
QuantizedS4
>
().
scale
;
float
alpha
=
src_scale
*
filter_scale
/
dst_scale
,
beta
=
bias_scale
/
dst_scale
;
float
inv_dst_scale
=
1.
f
/
dst_scale
;
unsigned
int
tx
=
m_threads
,
ty
=
1
;
unsigned
int
gridx
=
div_ceil
<
unsigned
int
>
(
static_cast
<
unsigned
int
>
(
n
*
ho
*
wo
),
m_tile_nhw
);
unsigned
int
gridy
=
div_ceil
<
unsigned
int
>
(
static_cast
<
unsigned
int
>
(
co
),
m_tile_oc
);
void
*
src_ptr
=
const_cast
<
void
*>
(
args
.
src_tensor
->
raw_ptr
);
void
*
dst_ptr
=
const_cast
<
void
*>
(
args
.
dst_tensor
->
raw_ptr
);
using
NonlineMode
=
Param
::
NonlineMode
;
auto
&&
kernel
=
SASSKernelLoader
::
instance
().
get_kernel
(
kern_key
,
kern_key
);
if
(
args
.
z_layout
->
ndim
>
0
)
{
void
*
z_ptr
=
const_cast
<
void
*>
(
args
.
z_tensor
->
raw_ptr
);
float
z_scale
=
args
.
z_layout
->
dtype
.
param
<
dtype
::
QuantizedS4
>
().
scale
;
float
gamma
=
z_scale
/
dst_scale
;
std
::
vector
<
void
*>
params
=
{
&
src_ptr
,
&
filter_ptr
,
&
bias_ptr
,
&
z_ptr
,
&
dst_ptr
,
&
alpha
,
&
beta
,
&
gamma
};
kern_coffset
.
c_offset_param
.
begin
+=
sizeof
(
src_ptr
)
+
sizeof
(
filter_ptr
)
+
sizeof
(
bias_ptr
)
+
sizeof
(
z_ptr
)
+
sizeof
(
dst_ptr
)
+
sizeof
(
alpha
)
+
sizeof
(
beta
)
+
sizeof
(
gamma
);
uint32_t
relu
=
param
.
nonlineMode
==
NonlineMode
::
RELU
?
1
:
0
;
if
(
param
.
nonlineMode
==
NonlineMode
::
H_SWISH
)
{
params
.
push_back
(
&
dst_scale
);
params
.
push_back
(
&
inv_dst_scale
);
kern_coffset
.
c_offset_param
.
begin
+=
sizeof
(
dst_scale
)
+
sizeof
(
inv_dst_scale
);
}
else
{
params
.
push_back
(
&
relu
);
kern_coffset
.
c_offset_param
.
begin
+=
sizeof
(
relu
);
}
params
.
push_back
(
&
kern_param
);
kern_coffset
.
c_offset_param
.
begin
+=
sizeof
(
kern_param
);
kern_coffset
.
c_offset_param
.
begin
+=
sizeof
(
kern_coffset
.
c_offset_param
);
kern_coffset
.
c_offset_param
.
max
+=
kern_coffset
.
c_offset_param
.
begin
;
params
.
push_back
(
&
kern_coffset
);
cucheck
(
cuLaunchKernel
(
kernel
,
gridx
,
gridy
,
1
,
tx
,
ty
,
1
,
0
,
stream
,
params
.
data
(),
0
));
}
else
{
std
::
vector
<
void
*>
params
=
{
&
src_ptr
,
&
filter_ptr
,
&
bias_ptr
,
&
dst_ptr
,
&
alpha
,
&
beta
};
kern_coffset
.
c_offset_param
.
begin
+=
sizeof
(
src_ptr
)
+
sizeof
(
filter_ptr
)
+
sizeof
(
bias_ptr
)
+
sizeof
(
dst_ptr
)
+
sizeof
(
alpha
)
+
sizeof
(
beta
);
uint32_t
relu
=
param
.
nonlineMode
==
NonlineMode
::
RELU
?
1
:
0
;
if
(
param
.
nonlineMode
==
NonlineMode
::
H_SWISH
)
{
params
.
push_back
(
&
dst_scale
);
params
.
push_back
(
&
inv_dst_scale
);
kern_coffset
.
c_offset_param
.
begin
+=
sizeof
(
dst_scale
)
+
sizeof
(
inv_dst_scale
);
}
else
{
params
.
push_back
(
&
relu
);
kern_coffset
.
c_offset_param
.
begin
+=
sizeof
(
relu
);
}
params
.
push_back
(
&
kern_param
);
kern_coffset
.
c_offset_param
.
begin
+=
sizeof
(
kern_param
);
kern_coffset
.
c_offset_param
.
begin
+=
sizeof
(
kern_coffset
.
c_offset_param
);
kern_coffset
.
c_offset_param
.
max
+=
kern_coffset
.
c_offset_param
.
begin
;
params
.
push_back
(
&
kern_coffset
);
cucheck
(
cuLaunchKernel
(
kernel
,
gridx
,
gridy
,
1
,
tx
,
ty
,
1
,
0
,
stream
,
params
.
data
(),
0
));
}
after_kernel_launch
();
#endif
}
size_t
ConvBiasForwardImpl
::
AlgoSASSInt4NCHW64IMMAImplicitGemm
::
get_preprocess_workspace_in_bytes
(
const
SizeArgs
&
args
)
const
{
return
0
_z
;
}
SmallVector
<
TensorLayout
>
ConvBiasForwardImpl
::
AlgoSASSInt4NCHW64IMMAImplicitGemm
::
deduce_preprocessed_filter_layout
(
const
SizeArgs
&
args
)
const
{
return
{
args
.
filter_layout
->
collapse_contiguous
(),
args
.
bias_layout
->
collapse_contiguous
()};
}
void
ConvBiasForwardImpl
::
AlgoSASSInt4NCHW64IMMAImplicitGemm
::
exec_preprocess
(
const
ExecArgs
&
args
)
const
{
using
Format
=
Param
::
Format
;
auto
&&
param
=
args
.
opr
->
param
();
auto
&&
fm
=
args
.
filter_meta
;
UNPACK_CONV_BIAS_NCHW64_PARAM
(
*
(
args
.
src_layout
),
fm
,
*
(
args
.
dst_layout
),
param
);
auto
&&
stream
=
cuda_stream
(
args
.
opr
->
handle
());
reorder_imma_filter_bias
<
4
,
64
>
(
args
.
preprocessed_filter
->
tensors
[
0
].
compatible_ptr
<
int8_t
>
(),
args
.
preprocessed_filter
->
tensors
[
1
].
compatible_ptr
<
int32_t
>
(),
args
.
filter_tensor
->
compatible_ptr
<
int8_t
>
(),
args
.
bias_tensor
->
compatible_ptr
<
int32_t
>
(),
co
,
ci
,
fh
,
fw
,
stream
);
}
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录