Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2a3f4d09
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
411
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2a3f4d09
编写于
8月 03, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dnn/arm): refactor CPU heuristic algo selection
GitOrigin-RevId: 60d2646bb33316411caa18686eec724dc1f6c430
上级
981f487b
变更
47
隐藏空白更改
内联
并排
Showing
47 changed file
with
856 addition
and
172 deletion
+856
-172
dnn/include/megdnn/oprs/base.h
dnn/include/megdnn/oprs/base.h
+12
-0
dnn/src/aarch64/conv_bias/fp16/algos.h
dnn/src/aarch64/conv_bias/fp16/algos.h
+4
-0
dnn/src/aarch64/conv_bias/fp32/algos.h
dnn/src/aarch64/conv_bias/fp32/algos.h
+4
-0
dnn/src/aarch64/conv_bias/int8/algos.h
dnn/src/aarch64/conv_bias/int8/algos.h
+3
-0
dnn/src/aarch64/conv_bias/opr_impl.cpp
dnn/src/aarch64/conv_bias/opr_impl.cpp
+2
-3
dnn/src/aarch64/conv_bias/quint8/algos.h
dnn/src/aarch64/conv_bias/quint8/algos.h
+3
-0
dnn/src/aarch64/matrix_mul/algos.cpp
dnn/src/aarch64/matrix_mul/algos.cpp
+32
-18
dnn/src/aarch64/matrix_mul/algos.h
dnn/src/aarch64/matrix_mul/algos.h
+4
-4
dnn/src/arm_common/conv_bias/f16/algos.h
dnn/src/arm_common/conv_bias/f16/algos.h
+11
-4
dnn/src/arm_common/conv_bias/fp32/algos.h
dnn/src/arm_common/conv_bias/fp32/algos.h
+26
-8
dnn/src/arm_common/conv_bias/int8/algos.h
dnn/src/arm_common/conv_bias/int8/algos.h
+35
-3
dnn/src/arm_common/conv_bias/int8x8x16/algos.h
dnn/src/arm_common/conv_bias/int8x8x16/algos.h
+18
-0
dnn/src/arm_common/conv_bias/opr_impl.cpp
dnn/src/arm_common/conv_bias/opr_impl.cpp
+92
-15
dnn/src/arm_common/conv_bias/opr_impl.h
dnn/src/arm_common/conv_bias/opr_impl.h
+4
-1
dnn/src/arm_common/conv_bias/quint8/algos.h
dnn/src/arm_common/conv_bias/quint8/algos.h
+12
-0
dnn/src/arm_common/matrix_mul/algos.h
dnn/src/arm_common/matrix_mul/algos.h
+14
-8
dnn/src/arm_common/matrix_mul/opr_impl.cpp
dnn/src/arm_common/matrix_mul/opr_impl.cpp
+4
-3
dnn/src/armv7/conv_bias/int8/algos.h
dnn/src/armv7/conv_bias/int8/algos.h
+3
-0
dnn/src/armv7/conv_bias/quint8/algos.h
dnn/src/armv7/conv_bias/quint8/algos.h
+4
-0
dnn/src/armv7/matrix_mul/algos.cpp
dnn/src/armv7/matrix_mul/algos.cpp
+23
-14
dnn/src/armv7/matrix_mul/algos.h
dnn/src/armv7/matrix_mul/algos.h
+3
-3
dnn/src/armv7/matrix_mul/opr_impl.h
dnn/src/armv7/matrix_mul/opr_impl.h
+0
-1
dnn/src/common/utils.h
dnn/src/common/utils.h
+28
-0
dnn/src/fallback/conv_bias/algos.h
dnn/src/fallback/conv_bias/algos.h
+26
-0
dnn/src/fallback/conv_bias/common.h
dnn/src/fallback/conv_bias/common.h
+4
-1
dnn/src/fallback/conv_bias/conv1x1/algos.cpp
dnn/src/fallback/conv_bias/conv1x1/algos.cpp
+2
-1
dnn/src/fallback/conv_bias/conv1x1/algos.h
dnn/src/fallback/conv_bias/conv1x1/algos.h
+5
-0
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h
+10
-0
dnn/src/fallback/conv_bias/im2col/algos.h
dnn/src/fallback/conv_bias/im2col/algos.h
+18
-8
dnn/src/fallback/conv_bias/opr_impl.cpp
dnn/src/fallback/conv_bias/opr_impl.cpp
+90
-14
dnn/src/fallback/conv_bias/opr_impl.h
dnn/src/fallback/conv_bias/opr_impl.h
+19
-0
dnn/src/fallback/convolution/algos.h
dnn/src/fallback/convolution/algos.h
+19
-2
dnn/src/fallback/convolution/opr_impl.cpp
dnn/src/fallback/convolution/opr_impl.cpp
+83
-15
dnn/src/fallback/convolution/opr_impl.h
dnn/src/fallback/convolution/opr_impl.h
+31
-0
dnn/src/fallback/matrix_mul/algos.cpp
dnn/src/fallback/matrix_mul/algos.cpp
+1
-1
dnn/src/fallback/matrix_mul/algos.h
dnn/src/fallback/matrix_mul/algos.h
+9
-1
dnn/src/fallback/matrix_mul/gemm_common.h
dnn/src/fallback/matrix_mul/gemm_common.h
+17
-13
dnn/src/fallback/matrix_mul/opr_impl.cpp
dnn/src/fallback/matrix_mul/opr_impl.cpp
+59
-5
dnn/src/fallback/matrix_mul/opr_impl.h
dnn/src/fallback/matrix_mul/opr_impl.h
+18
-1
dnn/src/x86/conv_bias/f32/algos.h
dnn/src/x86/conv_bias/f32/algos.h
+17
-2
dnn/src/x86/conv_bias/int8/algo_usable_preferred.cpp
dnn/src/x86/conv_bias/int8/algo_usable_preferred.cpp
+0
-2
dnn/src/x86/conv_bias/int8/algos.h
dnn/src/x86/conv_bias/int8/algos.h
+24
-0
dnn/src/x86/conv_bias/opr_impl.cpp
dnn/src/x86/conv_bias/opr_impl.cpp
+39
-2
dnn/src/x86/conv_bias/opr_impl.h
dnn/src/x86/conv_bias/opr_impl.h
+2
-0
dnn/src/x86/matrix_mul/algos.cpp
dnn/src/x86/matrix_mul/algos.cpp
+13
-10
dnn/src/x86/matrix_mul/algos.h
dnn/src/x86/matrix_mul/algos.h
+4
-4
src/opr/impl/dnn/convolution.cpp
src/opr/impl/dnn/convolution.cpp
+5
-5
未找到文件。
dnn/include/megdnn/oprs/base.h
浏览文件 @
2a3f4d09
...
...
@@ -76,6 +76,18 @@ enum class AlgoSelectionStrategy {
FULL_RUN
=
2
,
};
/**
* \brief separate algo by datatype for Matmul and conv
*/
enum
class
AlgoDataType
:
uint32_t
{
FLOAT32
=
1
<<
0
,
FLOAT16
=
1
<<
1
,
QINT8X8X32
=
1
<<
2
,
QUINT8X8X32
=
1
<<
3
,
INT8X8X16
=
1
<<
4
,
INT16X16X32
=
1
<<
5
,
};
/*!
* \brief Abstract representation of an algorithm for implementing
* the operator
...
...
dnn/src/aarch64/conv_bias/fp16/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -27,6 +27,10 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT16
,
AlgoCategory
::
DIRECT
};
}
};
}
// namespace aarch64
}
// namespace megdnn
...
...
dnn/src/aarch64/conv_bias/fp32/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -32,6 +32,10 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
}
};
}
// namespace aarch64
...
...
dnn/src/aarch64/conv_bias/int8/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -45,6 +45,9 @@ public:
return
static_cast
<
ConvBiasImpl
*>
(
conv_bias_opr
)
->
is_matmul_quantized_prefer
(
param
);
}
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
IM2COL
};
}
};
}
// namespace aarch64
...
...
dnn/src/aarch64/conv_bias/opr_impl.cpp
浏览文件 @
2a3f4d09
...
...
@@ -50,10 +50,9 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
auto
&&
algos
=
arm_common
::
ConvBiasImpl
::
algo_pack
();
algos
.
insert
(
algos
.
begin
(),
sl_algo_pack
.
direct_algos
.
begin
(),
sl_algo_pack
.
direct_algos
.
end
());
//! We put matmul algos at the
end
. Because matmul will get privilege when
//! We put matmul algos at the
begin
. Because matmul will get privilege when
//! prefer return true. See
//! fallback::ConvolutionImpl::ncb_1g_get_all_algorithms for more details.
algos
.
insert
(
algos
.
end
(),
sl_algo_pack
.
matmul_algos
.
begin
(),
algos
.
insert
(
algos
.
begin
(),
sl_algo_pack
.
matmul_algos
.
begin
(),
sl_algo_pack
.
matmul_algos
.
end
());
return
std
::
move
(
algos
);
}
...
...
dnn/src/aarch64/conv_bias/quint8/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -45,6 +45,9 @@ public:
return
static_cast
<
ConvBiasImpl
*>
(
conv_bias_opr
)
->
is_matmul_quantized_prefer
(
param
);
}
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QUINT8X8X32
,
AlgoCategory
::
IM2COL
};
}
};
}
// namespace aarch64
}
// namespace megdnn
...
...
dnn/src/aarch64/matrix_mul/algos.cpp
浏览文件 @
2a3f4d09
...
...
@@ -89,7 +89,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoF32K8x12x1
,
megdnn_aarch64_matmul_kern
,
"AlgoF32K8x12x1Impl"
_hash
,
aarch64
::
matmul
::
sgemm_8x12
,
float
,
float
);
aarch64
::
matmul
::
sgemm_8x12
,
float
,
float
,
AlgoDataType
::
FLOAT32
,
DEFAULT
);
/* ===================== F32_MK4_8X12X1 algo ===================== */
bool
MatrixMulImpl
::
AlgoF32MK4_8x12x1
::
usable
(
...
...
@@ -151,7 +152,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4_8x12x1,
megdnn_aarch64_matmul_kern
,
"AlgoF32MK4_8x12x1Impl"
_hash
,
aarch64
::
matmul
::
sgemm_mk4_8x12
,
float
,
float
);
float
,
AlgoDataType
::
FLOAT32
,
MK4
);
/* ===================== F32K4X16X1 algo ===================== */
...
...
@@ -210,7 +211,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern(
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoF32K4x16x1
,
megdnn_aarch64_matmul_kern
,
"AlgoF32K4x16x1Impl"
_hash
,
aarch64
::
matmul
::
sgemm_4x16
,
float
,
float
);
aarch64
::
matmul
::
sgemm_4x16
,
float
,
float
,
AlgoDataType
::
FLOAT32
,
MK4
);
/* ===================== F32MK4_4x16 algo ===================== */
...
...
@@ -328,7 +330,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoF16K8x24x1
,
megdnn_aarch64_matmul_kern
,
"AlogF16K8x24x1Impl"
_hash
,
aarch64
::
matmul
::
hgemm_8x24
,
dt_float16
,
dt_float16
);
dt_float16
,
AlgoDataType
::
FLOAT16
,
DEFAULT
);
/* ===================== F16_MK8_8x8 algo ===================== */
bool
MatrixMulImpl
::
AlgoF16MK8_8x8
::
usable
(
...
...
@@ -449,7 +452,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd,
megdnn_aarch64_matmul_kern
,
"AlgoInt8x8x32K8x12x4DotProdImpl"
_hash
,
aarch64
::
matmul
::
gemm_s8_8x12
,
int8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
);
/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */
namespace
{
...
...
@@ -520,7 +524,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd,
megdnn_aarch64_matmul_kern
,
"AlgoInt8x8x32MK4_8x12x4DotProdImpl"
_hash
,
aarch64
::
matmul
::
gemm_mk4_s8_8x12
,
int8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QINT8X8X32
,
MK4_DOT
);
#else
/* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */
...
...
@@ -593,7 +598,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x4x16,
megdnn_aarch64_matmul_kern
,
"AlgoInt8x8x32MK4_4x4x16Impl"
_hash
,
aarch64
::
matmul
::
gemm_mk4_s8_4x4
,
int8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QINT8X8X32
,
MK4
);
/* ===================== Int8x8x32 K4x4x16 algo ===================== */
namespace
{
...
...
@@ -656,7 +662,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x4x16,
megdnn_aarch64_matmul_kern
,
"AlgoInt8x8x32K4x4x16Impl"
_hash
,
aarch64
::
matmul
::
gemm_s8_4x4
,
int8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
);
/* ===================== Int8x8x32 K8x8x8 algo ===================== */
namespace
{
void
int8x8x32_k8x8x8_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
...
...
@@ -717,7 +724,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8,
megdnn_aarch64_matmul_kern
,
"AlgoInt8x8x32K8x8x8Impl"
_hash
,
aarch64
::
matmul
::
gemm_s8_8x8
,
int8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
);
#endif
/* ===================== Int8x8x16 K8x8x8 algo ===================== */
...
...
@@ -785,7 +793,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x8,
megdnn_aarch64_matmul_kern
,
"AlgoInt8x8x16K8x8x8Impl"
_hash
,
aarch64
::
matmul
::
gemm_s8x8x16_8x8
,
int8_t
,
int16_t
);
int16_t
,
AlgoDataType
::
INT8X8X16
,
DEFAULT
);
/* ===================== Int8x8x16 K4x4x16 algo ===================== */
namespace
{
void
int8x8x16_k4x4x16_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
...
...
@@ -852,7 +860,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16,
megdnn_aarch64_matmul_kern
,
"AlgoInt8x8x16K4x4x16Impl"
_hash
,
aarch64
::
matmul
::
gemm_s8x8x16_4x4
,
int8_t
,
int16_t
);
int16_t
,
AlgoDataType
::
INT8X8X16
,
DEFAULT
);
/* ===================== Int8x8x16 K16x12x4 algo ===================== */
namespace
{
...
...
@@ -929,7 +937,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoInt8x8x16MK4_16x12x4
,
megdnn_aarch64_matmul_kern
,
"AlgoInt8x8x16MK4_16x12x4Impl"
_hash
,
aarch64
::
matmul
::
gemm_s8x8x16_mk4_16x12_a53
,
int8_t
,
int16_t
,
int16_t
);
aarch64
::
matmul
::
gemm_s8x8x16_mk4_16x12_a53
,
int8_t
,
int16_t
,
int16_t
,
AlgoDataType
::
INT8X8X16
,
MK4
);
/* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */
namespace
{
...
...
@@ -1007,7 +1016,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8,
megdnn_aarch64_matmul_kern
,
"AlgoInt8x8x16MK4_4x4x8_Impl"
_hash
,
aarch64
::
matmul
::
gemm_s8x8x16_mk4_4x4_a72
,
int8_t
,
int16_t
);
int8_t
,
int16_t
,
AlgoDataType
::
INT8X8X16
,
MK4
);
/* ===================== Int16x16x32 K12x8x1 algo ===================== */
namespace
{
...
...
@@ -1078,7 +1088,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x8x1,
megdnn_aarch64_matmul_kern
,
"AlgoInt16x16x32K12x8x1Impl"
_hash
,
aarch64
::
matmul
::
gemm_s16_12x8x1
,
int16_t
,
int32_t
);
int32_t
,
AlgoDataType
::
INT16X16X32
,
DEFAULT
);
/* ===================== Int16x16x32MK8_8x8 algo ===================== */
...
...
@@ -1201,7 +1212,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd,
megdnn_aarch64_matmul_kern
,
"AlgoQuint8K8x8x4DotProdImpl"
_hash
,
aarch64
::
matmul
::
gemm_u8_8x8
,
uint8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QUINT8X8X32
,
DEFAULT
);
/* ===================== Quint8 Gemv DotProd algo ===================== */
namespace
{
void
quint8_gemv_dotprod_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
...
...
@@ -1307,7 +1319,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8,
megdnn_aarch64_matmul_kern
,
"AlgoQuint8K8x8x8Impl"
_hash
,
aarch64
::
matmul
::
gemm_u8_8x8
,
uint8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QUINT8X8X32
,
DEFAULT
);
#endif
/* ===================== Int8x8x16 K8x8x8 algo ===================== */
...
...
@@ -1378,6 +1391,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoInt8x8x16MK4_K8x8x8
,
megdnn_aarch64_matmul_kern
,
"AlgoInt8x8x16MK4_K8x8x8Impl"
_hash
,
aarch64
::
matmul
::
gemm_s8x8x16_mk4_8x8x8
,
int8_t
,
int16_t
);
aarch64
::
matmul
::
gemm_s8x8x16_mk4_8x8x8
,
int8_t
,
int16_t
,
AlgoDataType
::
INT8X8X16
,
MK4
);
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -61,7 +61,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
16
,
4
,
4
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
16
,
4
,
4
,
AlgoDataType
::
FLOAT32
,
MK4
)
};
class
MatrixMulImpl
::
AlgoF32Gemv
final
...
...
@@ -88,7 +88,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
2
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
2
,
AlgoDataType
::
FLOAT16
,
MK8
)
};
#endif
...
...
@@ -253,7 +253,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
2
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
2
,
AlgoDataType
::
INT16X16X32
,
MK8
)
};
#if __ARM_FEATURE_DOTPROD
...
...
@@ -281,7 +281,7 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
QUINT8X8X32
,
DEFAULT
)
};
#else
...
...
dnn/src/arm_common/conv_bias/f16/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -29,7 +29,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT16
);
};
class
ConvBiasImpl
::
AlgoFP16WinogradF45
final
:
public
AlgoBase
{
...
...
@@ -44,7 +44,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT16
);
};
class
ConvBiasImpl
::
AlgoFP16WinogradF63
final
:
public
AlgoBase
{
...
...
@@ -60,7 +60,7 @@ public:
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT16
);
};
class
ConvBiasImpl
::
AlgoFP16WinogradF23_8x8
final
:
public
AlgoBase
{
public:
...
...
@@ -74,7 +74,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT16
);
};
class
ConvBiasImpl
::
AlgoF16Direct
final
:
public
AlgoBase
{
...
...
@@ -90,6 +90,10 @@ public:
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT16
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoF16DirectStride1
final
:
public
AlgoBase
{
...
...
@@ -103,6 +107,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT16
,
AlgoCategory
::
DIRECT
};
}
};
}
// namespace arm_common
...
...
dnn/src/arm_common/conv_bias/fp32/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -29,7 +29,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT32
);
};
class
ConvBiasImpl
::
AlgoFP32WinogradF63
final
:
public
AlgoBase
{
...
...
@@ -44,7 +44,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT32
);
};
class
ConvBiasImpl
::
AlgoFP32WinogradF63_4x4
final
:
public
AlgoBase
{
...
...
@@ -59,7 +59,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT32
);
};
class
ConvBiasImpl
::
AlgoFP32WinogradF54
final
:
public
AlgoBase
{
...
...
@@ -74,7 +74,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT32
);
};
class
ConvBiasImpl
::
AlgoFP32WinogradF45
final
:
public
AlgoBase
{
...
...
@@ -89,7 +89,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT32
);
};
//===================== NCHW44 Winograd Support =====================//
...
...
@@ -106,7 +106,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT32
);
};
class
ConvBiasImpl
::
AlgoFP32WinogradF63_4x4_NCHW44
final
:
public
AlgoBase
{
...
...
@@ -122,7 +122,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT32
);
};
class
ConvBiasImpl
::
AlgoFP32WinogradF73_4x4_NCHW44
final
:
public
AlgoBase
{
...
...
@@ -138,7 +138,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT32
);
};
// ================================================================= //
...
...
@@ -154,6 +154,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoF32DirectStride1
final
:
public
AlgoBase
{
...
...
@@ -168,6 +171,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoF32DirectStride2
final
:
public
AlgoBase
{
...
...
@@ -182,6 +188,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoF32DirectNCHW44
final
:
public
AlgoBase
{
...
...
@@ -197,6 +206,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoF32DirectNCHWNCHW44
final
:
public
AlgoBase
{
...
...
@@ -212,6 +224,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoF32ChannelWiseNCHW44
final
:
public
AlgoBase
{
...
...
@@ -226,6 +241,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
}
};
}
// namespace arm_common
...
...
dnn/src/arm_common/conv_bias/int8/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -29,6 +29,10 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoS8DirectStride2
final
:
public
AlgoBase
{
...
...
@@ -42,6 +46,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoS8DirectNCHW44
final
:
public
AlgoBase
{
...
...
@@ -55,6 +62,9 @@ public:
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoS8DirectNCHWNCHW44
final
:
public
AlgoBase
{
...
...
@@ -68,6 +78,9 @@ public:
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoS8ChanWiseStride1NCHW44
final
:
public
AlgoBase
{
...
...
@@ -79,6 +92,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoS8ChanWiseStride2NCHW44
final
:
public
AlgoBase
{
...
...
@@ -90,6 +106,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
#if __ARM_FEATURE_DOTPROD
...
...
@@ -104,6 +123,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoDotS8DirectStride1
final
:
public
AlgoBase
{
...
...
@@ -117,6 +139,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoDotS8DirectStride2
final
:
public
AlgoBase
{
...
...
@@ -131,6 +156,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoDotS8Direct_NCHW44
final
:
public
AlgoBase
{
...
...
@@ -148,6 +176,10 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
#endif
...
...
@@ -163,7 +195,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
QINT8X8X32
);
};
//=======================input int8 compute fp32 output int8============
...
...
@@ -180,7 +212,7 @@ public:
}
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
QINT8X8X32
);
};
//=======================input int8 compute int16 output int8============
...
...
@@ -198,7 +230,7 @@ public:
return
m_name
.
c_str
();
}
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
QINT8X8X32
);
};
}
// namespace arm_common
...
...
dnn/src/arm_common/conv_bias/int8x8x16/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -36,6 +36,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
INT8X8X16
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoS8x8x16DirectNCHW44
final
:
public
AlgoBase
{
...
...
@@ -48,6 +51,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
INT8X8X16
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoI8x8x16Stride2
final
:
public
AlgoBase
{
...
...
@@ -71,6 +77,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
INT8X8X16
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoI8x8x16Stride2Filter2
final
:
public
AlgoBase
{
...
...
@@ -84,6 +93,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
INT8X8X16
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoS8x8x16ChanWiseStride1Stride2NCHW44
final
:
public
AlgoBase
{
...
...
@@ -96,6 +108,9 @@ public:
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
INT8X8X16
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoI8x8x16DirectNCHWNCHW44
final
:
public
AlgoBase
{
...
...
@@ -111,6 +126,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
INT8X8X16
,
AlgoCategory
::
DIRECT
};
}
};
}
// namespace arm_common
...
...
dnn/src/arm_common/conv_bias/opr_impl.cpp
浏览文件 @
2a3f4d09
...
...
@@ -10,6 +10,7 @@
* implied.
*/
#include "megdnn/opr_param_defs.h"
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/quint8/algos.h"
...
...
@@ -122,9 +123,11 @@ public:
static
CpuOprDelegationStorage
<
2
>
storage
;
auto
matmul_opr
=
storage
.
get
<
MatrixMul
,
0
>
();
using
MatmulFormat
=
param
::
MatrixMul
::
Format
;
auto
&&
matmul_algos
=
static_cast
<
arm_common
::
MatrixMulImpl
*>
(
matmul_opr
)
->
algo_pack
();
->
select_algo_type
(
{
AlgoDataType
::
FLOAT32
,
MatmulFormat
::
MK4
});
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
type
()
==
nullptr
)
continue
;
...
...
@@ -133,38 +136,62 @@ public:
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
refhold
.
emplace_back
(
new
AlgoFP32WinogradF63
(
refhold
.
emplace_back
(
new
AlgoFP32WinogradF63
_4x4
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
refhold
.
emplace_back
(
new
AlgoFP32WinogradF63_4x4
(
refhold
.
emplace_back
(
new
AlgoFP32WinogradF63_4x4
_NCHW44
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
refhold
.
emplace_back
(
new
AlgoFP32WinogradF
5
4
(
refhold
.
emplace_back
(
new
AlgoFP32WinogradF
23_4x4_NCHW4
4
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
refhold
.
emplace_back
(
new
AlgoFP32WinogradF45
(
//! uncomment this when low precision mode is done
#if 0
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
refhold
.
emplace_back
(
new
AlgoFP32WinogradF23_4x4_NCHW44
(
#endif
//! Qint8x8x32 winograd compute with fp32
refhold
.
emplace_back
(
new
AlgoS8CF32WinogradF23_4x4_NCHW44
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
refhold
.
emplace_back
(
new
AlgoFP32WinogradF63_4x4_NCHW44
(
}
}
matmul_algos
=
static_cast
<
arm_common
::
MatrixMulImpl
*>
(
matmul_opr
)
->
select_algo_type
({
AlgoDataType
::
FLOAT32
,
MatmulFormat
::
DEFAULT
});
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
type
()
==
nullptr
)
continue
;
for
(
uint32_t
tile_size
:
{
16
,
8
,
24
,
32
})
{
refhold
.
emplace_back
(
new
AlgoFP32WinogradF63
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
//! uncomment this when low precision mode is done
#if 0
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
refhold
.
emplace_back
(
new
AlgoFP32WinogradF54
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
#endif
refhold
.
emplace_back
(
new
AlgoFP32WinogradF45
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
}
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
matmul_algos
=
static_cast
<
arm_common
::
MatrixMulImpl
*>
(
matmul_opr
)
->
select_algo_type
({
AlgoDataType
::
FLOAT16
,
MatmulFormat
::
DEFAULT
});
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
type
()
==
nullptr
)
continue
;
for
(
uint32_t
tile_size
:
{
16
,
8
,
24
,
32
})
{
refhold
.
emplace_back
(
new
AlgoFP16WinogradF23
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
...
...
@@ -177,19 +204,33 @@ public:
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
}
}
matmul_algos
=
static_cast
<
arm_common
::
MatrixMulImpl
*>
(
matmul_opr
)
->
select_algo_type
({
AlgoDataType
::
FLOAT16
,
MatmulFormat
::
MK8
});
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
type
()
==
nullptr
)
continue
;
for
(
uint32_t
tile_size
:
{
16
,
8
,
24
,
32
})
{
refhold
.
emplace_back
(
new
AlgoFP16WinogradF23_8x8
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
}
}
#endif
matmul_algos
=
static_cast
<
arm_common
::
MatrixMulImpl
*>
(
matmul_opr
)
->
select_algo_type
({
AlgoDataType
::
INT16X16X32
,
MatmulFormat
::
MK8
});
for
(
auto
&&
algo
:
matmul_algos
)
{
if
(
algo
->
type
()
==
nullptr
)
continue
;
for
(
uint32_t
tile_size
:
{
16
,
8
,
24
,
32
})
{
refhold
.
emplace_back
(
new
AlgoS8WinogradF23_8x8
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
refhold
.
emplace_back
(
new
AlgoS8CF32WinogradF23_4x4_NCHW44
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
winograd_algos
.
emplace_back
(
refhold
.
back
().
get
());
refhold
.
emplace_back
(
new
AlgoS8WinogradF23_8x8_NCHW44
(
static_cast
<
fallback
::
MatrixMulImpl
::
AlgoBase
*>
(
algo
),
tile_size
));
...
...
@@ -240,6 +281,42 @@ bool ConvBiasImpl::is_matmul_quantized_prefer(
return
conv_direct_unusable
;
}
SmallVector
<
AlgoCategory
>
ConvBiasImpl
::
suggest_algo_category_order
(
const
NCBKernSizeParam
&
param
)
const
{
auto
IC
=
param
.
filter_meta
.
icpg
;
auto
OC
=
param
.
filter_meta
.
ocpg
;
auto
FH
=
param
.
filter_meta
.
spatial
[
0
];
auto
FW
=
param
.
filter_meta
.
spatial
[
1
];
//! TODO: now winograd only support fast-run
if
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
)
{
return
{
AlgoCategory
::
WINOGRAD
};
}
//! im2col
bool
im2col_prefer
=
(
IC
>=
32
||
OC
>=
32
);
//! quantized algo use matmul when direct algo is unusable
if
(
param
.
src_type
.
category
()
==
DTypeCategory
::
QUANTIZED
)
{
im2col_prefer
=
is_matmul_quantized_prefer
(
param
);
}
//! conv1x1
im2col_prefer
|=
(
FH
==
1
&&
FW
==
1
);
//! nchw44 and nchw44-dot hybird mode is direct
if
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44_DOT
)
{
if
(
IC
<
4
)
{
im2col_prefer
=
false
;
}
}
if
(
im2col_prefer
)
{
return
{
AlgoCategory
::
IM2COL
,
AlgoCategory
::
DIRECT
,
AlgoCategory
::
NAIVE
};
}
else
{
return
{
AlgoCategory
::
DIRECT
,
AlgoCategory
::
IM2COL
,
AlgoCategory
::
NAIVE
};
}
}
const
char
*
ConvBiasImpl
::
get_algorithm_set_name
()
const
{
// arm common version 0
return
"AC0"
;
...
...
dnn/src/arm_common/conv_bias/opr_impl.h
浏览文件 @
2a3f4d09
...
...
@@ -28,6 +28,9 @@ public:
bool
is_matmul_quantized_prefer
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
ncb_param
)
const
override
;
SmallVector
<
AlgoCategory
>
suggest_algo_category_order
(
const
NCBKernSizeParam
&
param
)
const
override
;
class
AlgoPack
;
protected:
...
...
@@ -90,7 +93,7 @@ private:
class
AlgoF16Direct
;
class
AlgoF16DirectStride1
;
#endif
};
};
}
// namespace arm_common
}
// namespace megdnn
...
...
dnn/src/arm_common/conv_bias/quint8/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -29,6 +29,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QUINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoQU8DirectStride2
final
:
public
AlgoBase
{
...
...
@@ -42,6 +45,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QUINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
#if __ARM_FEATURE_DOTPROD
class
ConvBiasImpl
::
AlgoDotU8DirectStride1
final
:
public
AlgoBase
{
...
...
@@ -56,6 +62,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QUINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
class
ConvBiasImpl
::
AlgoDotU8DirectStride2
final
:
public
AlgoBase
{
...
...
@@ -69,6 +78,9 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
virtual
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QUINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
#endif
}
// namespace arm_common
...
...
dnn/src/arm_common/matrix_mul/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -26,7 +26,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
,
AlgoDataType
::
INT8X8X16
,
DEFAULT
)
};
class
MatrixMulImpl
::
AlgoInt8x8x32Gemv
:
public
AlgoBase
{
...
...
@@ -40,7 +40,7 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
)
};
class
MatrixMulImpl
::
AlgoInt8x8x32GemvMK4
:
public
AlgoBase
{
...
...
@@ -54,7 +54,7 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
QINT8X8X32
,
MK4
)
};
#if __ARM_FEATURE_DOTPROD
...
...
@@ -69,7 +69,7 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
QINT8X8X32
,
MK4_DOT
)
};
#endif
...
...
@@ -87,7 +87,7 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
,
AlgoDataType
::
FLOAT32
,
DEFAULT
)
};
class
MatrixMulImpl
::
AlgoF32GemvMK4
:
public
AlgoBase
{
...
...
@@ -101,7 +101,7 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
1
,
1
,
4
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
1
,
1
,
4
,
AlgoDataType
::
FLOAT32
,
MK4
)
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
...
...
@@ -116,7 +116,7 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
FLOAT16
,
DEFAULT
)
};
#endif
...
...
@@ -131,7 +131,13 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
1
,
1
,
1
,
4
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
1
,
1
,
1
,
4
,
static_cast
<
AlgoDataType
>
(
static_cast
<
uint32_t
>
(
AlgoDataType
::
FLOAT16
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
FLOAT32
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
QINT8X8X32
)),
DEFAULT
)
};
}
// namespace arm_common
...
...
dnn/src/arm_common/matrix_mul/opr_impl.cpp
浏览文件 @
2a3f4d09
...
...
@@ -25,7 +25,7 @@ void* const MatrixMulImpl::sm_arm_common_algo_type =
class
MatrixMulImpl
::
AlgoPack
:
NonCopyableObj
{
AlgoInt8x8x16
int8x8x16
;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16Gemv
f16gemv
;
AlgoF16Gemv
f16gemv
;
#endif
AlgoInt8x8x32Gemv
int8x8x32_gemv
;
AlgoInt8x8x32GemvMK4
int8x8x32_gemv_mk4
;
...
...
@@ -34,10 +34,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#endif
AlgoGevm
gevm
;
AlgoF32GemvMK4
f32_gemv_mk4
;
public:
AlgoPack
()
{
all_algos
.
emplace_back
(
&
int8x8x16
);
#if
__ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos
.
emplace_back
(
&
f16gemv
);
#endif
#if __ARM_FEATURE_DOTPROD
...
...
@@ -47,7 +48,7 @@ public:
all_algos
.
emplace_back
(
&
int8x8x32_gemv_mk4
);
all_algos
.
emplace_back
(
&
f32_gemv_mk4
);
all_algos
.
emplace_back
(
&
gevm
);
}
}
SmallVector
<
AlgoBase
*>
all_algos
;
};
...
...
dnn/src/armv7/conv_bias/int8/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -37,6 +37,9 @@ public:
size_t
group
=
param
.
filter_meta
.
group
;
return
{{
kimpl
,
{
group
,
1
_z
,
1
_z
}}};
}
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
IM2COL
};
}
};
}
// namespace armv7
...
...
dnn/src/armv7/conv_bias/quint8/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -38,6 +38,10 @@ public:
size_t
group
=
param
.
filter_meta
.
group
;
return
{{
kimpl
,
{
group
,
1
_z
,
1
_z
}}};
}
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QUINT8X8X32
,
AlgoCategory
::
IM2COL
};
}
};
}
// namespace armv7
...
...
dnn/src/armv7/matrix_mul/algos.cpp
浏览文件 @
2a3f4d09
...
...
@@ -85,7 +85,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoF32
,
megdnn_armv7_matmul_kern
,
"AlgoF32Impl"
_hash
,
armv7
::
matmul
::
sgemm_4x12
,
float
,
float
);
armv7
::
matmul
::
sgemm_4x12
,
float
,
float
,
AlgoDataType
::
FLOAT32
,
DEFAULT
);
/* ===================== F32 algo mk4 K4x12 ===================== */
...
...
@@ -154,7 +155,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4Pack4x12,
megdnn_armv7_matmul_kern
,
"AlgoF32MK4Pack4x12"
_hash
,
armv7
::
matmul
::
sgemm_mk4_pack_4x12
,
float
,
float
);
float
,
AlgoDataType
::
FLOAT32
,
MK4
);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/* ===================== F16 K4x16x1 algo ===================== */
...
...
@@ -215,7 +216,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoF16K4x16x1
,
megdnn_armv7_matmul_kern
,
"AlgoF16K4x16x1"
_hash
,
armv7
::
matmul
::
hgemm_4x16
,
dt_float16
,
dt_float16
);
dt_float16
,
AlgoDataType
::
FLOAT16
,
DEFAULT
);
#endif
...
...
@@ -280,7 +282,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x2x16,
megdnn_armv7_matmul_kern
,
"AlgoInt8x8x32K4x2x16"
_hash
,
armv7
::
matmul
::
gemm_s8_4x2
,
int8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
);
/* ===================== Int8x8x32 Kernel 4x8x8 algo ===================== */
namespace
{
...
...
@@ -342,7 +345,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x8x8,
megdnn_armv7_matmul_kern
,
"AlgoInt8x8x32K4x8x8"
_hash
,
armv7
::
matmul
::
gemm_s8_4x8
,
int8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
);
/* ===================== Quint8 Kernel 4x8x8 algo ===================== */
namespace
{
...
...
@@ -402,7 +406,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoQuint8K4x8x8
,
megdnn_armv7_matmul_kern
,
"AlgoQuint8K4x8x8"
_hash
,
armv7
::
matmul
::
gemm_u8_4x8
,
uint8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QUINT8X8X32
,
DEFAULT
);
/* ===================== Int8x8x16 Kernel 2x4x16 algo ===================== */
namespace
{
...
...
@@ -468,7 +473,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x2x16,
megdnn_armv7_matmul_kern
,
"AlgoInt8x8x16K4x2x16"
_hash
,
armv7
::
matmul
::
gemm_s8x8x16_4x2
,
int8_t
,
int16_t
);
int16_t
,
AlgoDataType
::
INT8X8X16
,
DEFAULT
);
/* ===================== Int8x8x16 Kernel 4x8x8 algo ===================== */
namespace
{
...
...
@@ -534,7 +539,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8,
megdnn_armv7_matmul_kern
,
"AlgoInt8x8x16K4x8x8"
_hash
,
armv7
::
matmul
::
gemm_s8x8x16_4x8
,
int8_t
,
int16_t
);
int16_t
,
AlgoDataType
::
INT8X8X16
,
DEFAULT
);
/* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/
...
...
@@ -602,7 +607,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16MK4_8x8x4,
megdnn_armv7_matmul_kern
,
"AlgoInt8x8x16MK4_8x8x4"
_hash
,
armv7
::
matmul
::
gemm_s8x8x16_mk4_8x8
,
int8_t
,
int16_t
,
int16_t
);
int8_t
,
int16_t
,
int16_t
,
AlgoDataType
::
INT8X8X16
,
MK4
);
/* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */
...
...
@@ -668,7 +674,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x4x1,
megdnn_armv7_matmul_kern
,
"AlgoInt16x16x32K12x4x1"
_hash
,
armv7
::
matmul
::
gemm_s16x16x32_12x4
,
int16_t
,
int32_t
);
int16_t
,
int32_t
,
AlgoDataType
::
INT16X16X32
,
DEFAULT
);
#if __ARM_FEATURE_DOTPROD
/* ===================== Int8 K6x8x4 algo ===================== */
namespace
{
...
...
@@ -724,7 +731,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K6x8x4,
megdnn_armv7_matmul_kern
,
"AlgoInt8x8x32K6x8x4"
_hash
,
armv7
::
matmul
::
gemm_dots8_6x8
,
int8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
);
/* ===================== Quint8 K4x8x4 algo ===================== */
namespace
{
void
quint8_dot_k4x8x4_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
...
...
@@ -786,7 +794,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4,
megdnn_armv7_matmul_kern
,
"AlgoQuint8DotK4x8x4"
_hash
,
armv7
::
matmul
::
gemm_dot_quint8_4x8
,
uint8_t
,
int32_t
);
uint8_t
,
int32_t
,
AlgoDataType
::
QUINT8X8X32
,
DEFAULT
);
/* ======================== Int8 MK4 8x4x4 dot algo ======================== */
namespace
{
...
...
@@ -854,7 +863,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x4x4DotProd,
megdnn_armv7_matmul_kern
,
"AlgoInt8x8x32MK4_8x4x4DotProd"
_hash
,
armv7
::
matmul
::
gemm_mk4_dots8_8x4
,
int8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QINT8X8X32
,
MK4_DOT
);
#endif
/* ===================== F32 algo K4x8 ===================== */
...
...
@@ -1099,6 +1108,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x2x16,
megdnn_armv7_matmul_kern
,
"AlgoInt8x8x32MK4_4x2x16"
_hash
,
armv7
::
matmul
::
gemm_mk4_s8_4x2
,
int8_t
,
int32_t
);
int32_t
,
AlgoDataType
::
QINT8X8X32
,
MK4
);
// vim: syntax=cpp.doxygen
dnn/src/armv7/matrix_mul/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -50,7 +50,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
8
,
4
,
4
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
8
,
4
,
4
,
AlgoDataType
::
FLOAT32
,
MK4
)
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
...
...
@@ -73,7 +73,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
8
,
8
,
2
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
8
,
8
,
2
,
AlgoDataType
::
FLOAT16
,
MK8
)
};
#endif
#if __ARM_FEATURE_DOTPROD
...
...
@@ -205,7 +205,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
8
,
8
,
2
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
4
,
8
,
8
,
2
,
AlgoDataType
::
INT16X16X32
,
MK8
)
};
class
MatrixMulImpl
::
AlgoInt8x8x32MK4_4x2x16
final
:
public
AlgoBase
{
...
...
dnn/src/armv7/matrix_mul/opr_impl.h
浏览文件 @
2a3f4d09
...
...
@@ -18,7 +18,6 @@ namespace armv7 {
class
MatrixMulImpl
:
public
arm_common
::
MatrixMulImpl
{
public:
using
arm_common
::
MatrixMulImpl
::
MatrixMulImpl
;
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
private:
...
...
dnn/src/common/utils.h
浏览文件 @
2a3f4d09
...
...
@@ -110,6 +110,11 @@ void __log__(LogLevel level, const char* file, const char* func, int line,
} while (0)
#endif // megdnn_ENABLE_LOGGING
template
<
typename
T
>
constexpr
int32_t
cast_int
(
T
data
)
{
return
static_cast
<
int32_t
>
(
data
);
}
/* helper functions */
/**
* \brief Get the next `stride' index lexicographically.
...
...
@@ -187,6 +192,29 @@ std::unique_ptr<T> make_unique(Args&&... args) {
return
std
::
unique_ptr
<
T
>
(
new
T
(
std
::
forward
<
Args
>
(
args
)...));
}
/*!
* \brief check whether the source enum contain the target data type enum
*/
bool
inline
contain_data_type
(
detail
::
AlgoDataType
source
,
detail
::
AlgoDataType
target
)
{
return
static_cast
<
bool
>
(
static_cast
<
uint32_t
>
(
source
)
&
static_cast
<
uint32_t
>
(
target
));
}
/*!
* \brief get the source enum contain the data type number
*/
template
<
typename
T
>
size_t
nr_type_contain
(
T
index
)
{
uint32_t
sr_index
=
static_cast
<
uint32_t
>
(
index
);
size_t
nr_type
=
0
;
while
(
sr_index
!=
0
)
{
nr_type
++
;
sr_index
&=
(
sr_index
-
1
);
}
return
nr_type
;
}
/**
* \brief Aligned workspace bundle.
*
...
...
dnn/src/fallback/conv_bias/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -26,6 +26,16 @@ public:
AlgoSelectionStrategy
algo_selection_strategy
)
const
override
;
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
auto
support_data_type
=
static_cast
<
AlgoDataType
>
(
static_cast
<
uint32_t
>
(
AlgoDataType
::
FLOAT16
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
FLOAT32
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
INT8X8X16
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
QINT8X8X32
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
QUINT8X8X32
));
return
{
support_data_type
,
AlgoCategory
::
NAIVE
};
}
};
class
ConvBiasImpl
::
AlgoWinogradF32
final
:
public
AlgoBase
{
...
...
@@ -46,6 +56,10 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
WINOGRAD
};
}
private:
MatrixMulImpl
::
AlgoBase
*
m_matmul_algo
;
mutable
std
::
string
m_name
;
...
...
@@ -70,6 +84,10 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
WINOGRAD
};
}
private:
MatrixMulImpl
::
AlgoBase
*
m_matmul_algo
;
mutable
std
::
string
m_name
;
...
...
@@ -94,6 +112,10 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
WINOGRAD
};
}
private:
MatrixMulImpl
::
AlgoBase
*
m_matmul_algo
;
mutable
std
::
string
m_name
;
...
...
@@ -118,6 +140,10 @@ public:
size_t
get_workspace
(
const
NCBKernSizeParam
&
param
)
const
override
;
SmallVector
<
NCBKern
>
dispatch_kerns
(
const
NCBKernSizeParam
&
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
WINOGRAD
};
}
private:
MatrixMulImpl
::
AlgoBase
*
m_matmul_algo
;
mutable
std
::
string
m_name
;
...
...
dnn/src/fallback/conv_bias/common.h
浏览文件 @
2a3f4d09
...
...
@@ -140,7 +140,7 @@ using BiasMode = ConvBiasForward::BiasMode;
break; \
}
#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(
)
\
#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(
_algo_data_type)
\
bool is_reproducible() const override { return true; } \
bool usable(const NCBKernSizeParam& param, \
AlgoSelectionStrategy algo_selection_strategy) const override; \
...
...
@@ -153,6 +153,9 @@ using BiasMode = ConvBiasForward::BiasMode;
const override; \
virtual SmallVector<NCBKern> dispatch_preprocess_kerns( \
const NCBKernSizeParam& param) const override; \
ConvAlgoTypePack get_algo_type() const override { \
return {_algo_data_type, AlgoCategory::WINOGRAD}; \
} \
\
private: \
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; \
...
...
dnn/src/fallback/conv_bias/conv1x1/algos.cpp
浏览文件 @
2a3f4d09
...
...
@@ -288,7 +288,8 @@ bool ConvBiasImpl::AlgoConv1x1::is_preferred(
size_t
OH
=
param
.
osz
[
0
];
size_t
OW
=
param
.
osz
[
1
];
if
(
OH
*
OW
!=
1
)
{
return
true
;
return
m_matmul_algo
->
algoset
()
!=
MatrixMulImpl
::
AlgoBase
::
AlgoSet
::
ALGO_TYPE_GEMV
;
}
else
{
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64)
if
(
param
.
src_type
.
enumv
()
==
DTypeEnum
::
Int8
&&
...
...
dnn/src/fallback/conv_bias/conv1x1/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -56,6 +56,11 @@ public:
SmallVector
<
NCBKern
>
dispatch_preprocess_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
m_matmul_algo
->
matmul_description
().
algo_type
.
data_type
,
AlgoCategory
::
IM2COL
};
}
protected:
size_t
get_oc_tile_size_heuristic
(
const
NCBKernSizeParam
&
param
)
const
;
...
...
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h
浏览文件 @
2a3f4d09
...
...
@@ -34,6 +34,16 @@ public:
bool
is_preferred
(
const
NCBKernSizeParam
&
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
auto
support_data_type
=
static_cast
<
AlgoDataType
>
(
static_cast
<
uint32_t
>
(
AlgoDataType
::
FLOAT16
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
FLOAT32
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
INT8X8X16
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
QINT8X8X32
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
QUINT8X8X32
));
return
{
support_data_type
,
AlgoCategory
::
IM2COL
};
}
protected:
size_t
get_oc_tile_size_heuristic
(
const
NCBKernSizeParam
&
param
)
const
;
};
...
...
dnn/src/fallback/conv_bias/im2col/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -48,15 +48,25 @@ public:
SmallVector
<
NCBKern
>
dispatch_preprocess_kerns
(
const
NCBKernSizeParam
&
param
)
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
{
if
(
param
.
src_type
.
category
()
==
DTypeCategory
::
QUANTIZED
)
{
static
CpuOprDelegationStorage
<
1
>
storage
;
auto
conv_bias_opr
=
storage
.
get
<
ConvBias
,
0
>
();
return
static_cast
<
ConvBiasImpl
*>
(
conv_bias_opr
)
->
is_matmul_quantized_prefer
(
param
);
size_t
OH
=
param
.
osz
[
0
];
size_t
OW
=
param
.
osz
[
1
];
//! gemm and oh * ow > 1 is prefer
//! gemv and oh * ow == 1 is prefer
if
((
m_matmul_algo
->
algoset
()
!=
MatrixMulImpl
::
AlgoBase
::
AlgoSet
::
ALGO_TYPE_GEMV
&&
OH
*
OW
>
1
)
||
(
m_matmul_algo
->
algoset
()
==
MatrixMulImpl
::
AlgoBase
::
AlgoSet
::
ALGO_TYPE_GEMV
&&
OH
*
OW
==
1
))
{
return
true
;
}
else
{
return
false
;
}
auto
&&
fm
=
param
.
filter_meta
;
auto
OC
=
fm
.
ocpg
,
IC
=
fm
.
icpg
;
return
OC
>=
32
||
IC
>=
32
;
}
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
m_matmul_algo
->
matmul_description
().
algo_type
.
data_type
,
AlgoCategory
::
IM2COL
};
}
private:
...
...
dnn/src/fallback/conv_bias/opr_impl.cpp
浏览文件 @
2a3f4d09
...
...
@@ -48,11 +48,26 @@ void incr_ptr(T*& dst, ptrdiff_t delta) {
}
// namespace
#if MEGDNN_X86
#define SKIP_GEMV()
//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may
//! fallback to naive implementation, which may cause performance very low, so
//! here we just enable im2col for gemv in x86 backend.
//! FIXME: remove it when we add direct conv support for int8x8x16
#else
#define SKIP_GEMV() \
if (algo->algoset() == MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { \
continue; \
}
#endif
class
ConvBiasImpl
::
AlgoPack
:
NonCopyableObj
{
AlgoNaive
algo_naive
;
SmallVector
<
std
::
unique_ptr
<
AlgoBase
>>
refhold
;
public:
AlgoPack
()
{
refhold
.
emplace_back
(
new
AlgoConv1x1Gemv
());
all_algos
.
emplace_back
(
refhold
.
back
().
get
());
...
...
@@ -110,8 +125,6 @@ public:
all_algos.emplace_back(refhold.back().get());
#endif
}
//! reverse matmul algo, when the algo is_prefer can be selected first
std
::
reverse
(
all_algos
.
begin
(),
all_algos
.
end
());
all_algos
.
emplace_back
(
&
algo_naive
);
}
SmallVector
<
AlgoBase
*>
all_algos
;
...
...
@@ -121,6 +134,22 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static
AlgoPack
sl_algo_pack
;
return
sl_algo_pack
.
all_algos
;
}
SmallVector
<
ConvBiasImpl
::
AlgoBase
*>
ConvBiasImpl
::
select_algo_type
(
ConvAlgoTypePack
target_type
)
{
megdnn_assert
(
nr_type_contain
(
target_type
.
data_type
),
"ConvBias algo selection only support one type"
);
SmallVector
<
ConvBiasImpl
::
AlgoBase
*>
algos
;
for
(
auto
&&
algo
:
algo_pack
())
{
auto
algo_type
=
algo
->
get_algo_type
();
if
(
contain_data_type
(
algo_type
.
data_type
,
target_type
.
data_type
)
&&
algo_type
.
algo_category
==
target_type
.
algo_category
)
{
algos
.
push_back
(
algo
);
}
}
return
algos
;
}
bool
ConvBiasImpl
::
is_naive_algo
(
ConvBiasImpl
::
Algorithm
*
algo
)
{
return
algo
==
nullptr
||
strcmp
(
algo
->
name
(),
"DEFAULT"
)
==
0
;
}
...
...
@@ -248,12 +277,32 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic(
ConvBiasImpl
::
Algorithm
*
ConvBiasImpl
::
get_algorithm_heuristic_with_ncb
(
const
NCBKernSizeParam
&
param
,
size_t
workspace_limit_in_bytes
,
bool
reproducible
)
{
for
(
auto
i
:
get_all_algorithms_with_ncb
(
param
))
{
if
(
static_cast
<
AlgoBase
*>
(
i
)
->
usable_reproducible
(
param
,
AlgoSelectionStrategy
::
HEURISTIC
,
reproducible
)
&&
NCB_ALGO_FUNC
(
get_workspace
,
i
,
param
)
<=
workspace_limit_in_bytes
)
{
return
i
;
auto
algo_data_type
=
param
.
deduce_algo_data_type
();
auto
suggest_category_order
=
suggest_algo_category_order
(
param
);
for
(
auto
category
:
suggest_category_order
)
{
auto
&&
origin_algos
=
select_algo_type
({
algo_data_type
,
category
});
ConvBiasImpl
::
Algorithm
*
heuristic_algo
=
nullptr
;
for
(
auto
i
:
origin_algos
)
{
bool
usable_reproducible
=
static_cast
<
AlgoBase
*>
(
i
)
->
usable_reproducible
(
param
,
AlgoSelectionStrategy
::
HEURISTIC
,
reproducible
);
if
(
usable_reproducible
&&
static_cast
<
AlgoBase
*>
(
i
)
->
get_workspace
(
param
)
<=
workspace_limit_in_bytes
)
{
//! store the first usable algo if no prefer algo, choose it as
//! the target algo
if
(
!
heuristic_algo
)
{
heuristic_algo
=
i
;
}
//! choose the first prefer algo
if
(
i
->
is_preferred
(
param
))
{
return
i
;
}
}
}
if
(
heuristic_algo
)
{
return
heuristic_algo
;
}
}
return
nullptr
;
...
...
@@ -300,9 +349,8 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
sizeof
(
ConvolutionImpl
::
CanonizedFilterMeta
),
"sizeof CanonizedFilterMeta in convolution and conv_bias "
"should be equal"
);
CanonizedFilterMeta
fm
=
check_layout_fwd
(
src
,
filter
,
dst
);
ConvolutionImpl
::
CanonizedFilterMeta
conv_fm
;
conv_fm
.
copy_from
(
fm
);
auto
&&
fm
=
check_layout_fwd
(
src
,
filter
,
dst
);
auto
&
conv_fm
=
reinterpret_cast
<
ConvolutionImpl
::
CanonizedFilterMeta
&>
(
fm
);
param
::
MatrixMul
::
Format
format
=
param
::
MatrixMul
::
Format
::
DEFAULT
;
if
(
param
().
format
==
Param
::
Format
::
NCHW_WINOGRAD
||
...
...
@@ -367,7 +415,7 @@ ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param(
void
ConvBiasImpl
::
exec_with_ncb_kern
(
const
NCBKernParam
&
param
,
ConvBiasImpl
::
Algorithm
*
algo
)
{
auto
ncb_kerns
=
NCB_ALGO_FUNC
(
dispatch_kerns
,
algo
,
param
);
auto
&&
ncb_kerns
=
NCB_ALGO_FUNC
(
dispatch_kerns
,
algo
,
param
);
for
(
auto
&&
kernel
:
ncb_kerns
)
{
auto
run
=
[
kernel
,
param
](
size_t
index
,
size_t
thread_id
)
{
CpuNDRange
ndrange_id
(
kernel
.
global_size
,
index
);
...
...
@@ -380,7 +428,7 @@ void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param,
void
ConvBiasImpl
::
exec_preprocess_with_ncb_kern
(
const
NCBKernParam
&
param
,
ConvBiasImpl
::
Algorithm
*
algo
)
{
auto
ncb_kerns
=
NCB_ALGO_FUNC
(
dispatch_preprocess_kerns
,
algo
,
param
);
auto
&&
ncb_kerns
=
NCB_ALGO_FUNC
(
dispatch_preprocess_kerns
,
algo
,
param
);
for
(
auto
&&
kernel
:
ncb_kerns
)
{
auto
run
=
[
kernel
,
param
](
size_t
index
,
size_t
thread_id
)
{
CpuNDRange
ndrange_id
(
kernel
.
global_size
,
index
);
...
...
@@ -405,7 +453,6 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb(
}
}
}
std
::
reverse
(
prefer_algos
.
begin
(),
prefer_algos
.
end
());
//! Prefer algo inserted from begin
algos
.
insert
(
algos
.
begin
(),
prefer_algos
.
begin
(),
prefer_algos
.
end
());
return
algos
;
...
...
@@ -425,6 +472,35 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
return
m_prev_selected_algo
;
}
SmallVector
<
AlgoCategory
>
ConvBiasImpl
::
suggest_algo_category_order
(
const
NCBKernSizeParam
&
param
)
const
{
auto
IC
=
param
.
filter_meta
.
icpg
;
auto
OC
=
param
.
filter_meta
.
ocpg
;
auto
FH
=
param
.
filter_meta
.
spatial
[
0
];
auto
FW
=
param
.
filter_meta
.
spatial
[
1
];
//! TODO: now winograd only support in fast-run
if
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
)
{
return
{
AlgoCategory
::
WINOGRAD
};
}
//! im2col + matmul
bool
im2col_prefer
=
(
IC
>=
32
||
OC
>=
32
);
//! quantized algo use matmul when direct algo is unusable
if
(
param
.
src_type
.
category
()
==
DTypeCategory
::
QUANTIZED
)
{
im2col_prefer
=
is_matmul_quantized_prefer
(
param
);
}
//! conv1x1
im2col_prefer
|=
(
FH
==
1
&&
FW
==
1
);
if
(
im2col_prefer
)
{
return
{
AlgoCategory
::
IM2COL
,
AlgoCategory
::
DIRECT
,
AlgoCategory
::
NAIVE
};
}
else
{
return
{
AlgoCategory
::
DIRECT
,
AlgoCategory
::
IM2COL
,
AlgoCategory
::
NAIVE
};
}
}
const
char
*
ConvBiasImpl
::
get_algorithm_set_name
()
const
{
// fallback version 0
return
"F0"
;
...
...
dnn/src/fallback/conv_bias/opr_impl.h
浏览文件 @
2a3f4d09
...
...
@@ -18,6 +18,8 @@
#include "src/fallback/matrix_mul/opr_impl.h"
#include "src/naive/conv_bias/opr_impl.h"
#include <unordered_map>
namespace
megdnn
{
namespace
fallback
{
...
...
@@ -44,6 +46,7 @@ class ConvBiasImpl : public naive::ConvBiasForwardImpl {
public:
using
naive
::
ConvBiasForwardImpl
::
ConvBiasForwardImpl
;
using
AlgoSelectionStrategy
=
detail
::
AlgoSelectionStrategy
;
using
AlgoDataType
=
detail
::
AlgoDataType
;
//! implemented by exec_with_ncb_kern()
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
filter
,
...
...
@@ -94,6 +97,8 @@ public:
size_t
workspace_limit_in_bytes
,
bool
reproducible
)
override
;
//! size param for kernels with non-contiguous batch
struct
NCBKernSizeParam
:
ConvolutionImpl
::
NCBKernSizeParam
{
NCBKernSizeParam
()
=
default
;
...
...
@@ -244,6 +249,9 @@ public:
return
(
!
reproducible
||
is_reproducible
())
&&
usable
(
param
,
algo_selection_strategy
);
}
//! get the type of the algo
virtual
ConvAlgoTypePack
get_algo_type
()
const
=
0
;
};
/**
...
...
@@ -251,6 +259,17 @@ public:
*/
virtual
SmallVector
<
AlgoBase
*>
algo_pack
();
/**
* \brief select algo according to input algo type
*/
SmallVector
<
AlgoBase
*>
select_algo_type
(
ConvAlgoTypePack
algo_type
);
/**
* \brief suggest algo category according to the param
*/
virtual
SmallVector
<
AlgoCategory
>
suggest_algo_category_order
(
const
NCBKernSizeParam
&
param
)
const
;
protected:
virtual
void
exec_with_ncb_kern
(
const
NCBKernParam
&
param
,
ConvBiasImpl
::
Algorithm
*
algo
);
...
...
dnn/src/fallback/convolution/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -83,6 +83,10 @@ public:
SmallVector
<
NCBKern
>
dispatch_kern
(
const
NCBKernSizeParam
&
/*param*/
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
NAIVE
};
}
};
class
ConvolutionImpl
::
AlgoNaive
final
:
public
AlgoBase
{
...
...
@@ -96,11 +100,17 @@ public:
SmallVector
<
NCBKern
>
dispatch_kern
(
const
NCBKernSizeParam
&
/*param*/
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
auto
support_data_type
=
static_cast
<
AlgoDataType
>
(
static_cast
<
uint32_t
>
(
AlgoDataType
::
INT8X8X16
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
QINT8X8X32
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
QUINT8X8X32
));
return
{
support_data_type
,
AlgoCategory
::
NAIVE
};
}
};
class
ConvolutionImpl
::
AlgoDefault
final
:
public
AlgoBase
{
static
ConvBiasImpl
::
NCBKernSizeParam
init_conv_bias_param
(
const
NCBKernSizeParam
&
param
);
WorkspaceBundle
get_bundle
(
const
NCBKernSizeParam
&
param
)
const
;
static
SmallVector
<
NCBKern
>
get_kimpl
(
ConvBiasImpl
::
AlgoBase
*
algo
,
const
NCBKernSizeParam
&
param
);
...
...
@@ -136,6 +146,13 @@ public:
//! select matmul to the highest preference
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
static
ConvBiasImpl
::
NCBKernSizeParam
init_conv_bias_param
(
const
NCBKernSizeParam
&
param
);
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
m_algorithm
->
get_algo_type
();
}
private:
std
::
string
m_name
;
ConvBiasImpl
::
AlgoBase
*
m_algorithm
;
...
...
dnn/src/fallback/convolution/opr_impl.cpp
浏览文件 @
2a3f4d09
...
...
@@ -23,6 +23,7 @@
#include "midout.h"
#include <cstring>
#include <unordered_map>
MIDOUT_DECL
(
megdnn_fb_convbwd_float
)
...
...
@@ -75,6 +76,22 @@ SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() {
static
AlgoPack
sl_algo_pack
;
return
sl_algo_pack
.
all_algos
;
}
SmallVector
<
ConvolutionImpl
::
AlgoBase
*>
ConvolutionImpl
::
select_algo_type
(
ConvAlgoTypePack
target_type
)
{
megdnn_assert
(
nr_type_contain
(
target_type
.
data_type
),
"ConvBias algo selection only support one type"
);
SmallVector
<
ConvolutionImpl
::
AlgoBase
*>
algos
;
for
(
auto
&&
algo
:
algo_pack
())
{
auto
algo_type
=
algo
->
get_algo_type
();
if
(
contain_data_type
(
algo_type
.
data_type
,
target_type
.
data_type
)
&&
algo_type
.
algo_category
==
target_type
.
algo_category
)
{
algos
.
push_back
(
algo
);
}
}
return
algos
;
}
bool
ConvolutionImpl
::
is_naive_algo
(
ConvolutionImpl
::
Algorithm
*
algo
)
{
return
algo
==
nullptr
||
strcmp
(
algo
->
name
(),
"DEFAULT"
)
==
0
;
}
...
...
@@ -249,9 +266,9 @@ ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(
void
ConvolutionImpl
::
exec_preprocess_with_ncb_kern
(
const
NCBKernParam
&
param
,
Algorithm
*
algo
)
{
auto
kerns
=
NCB_ALGO_FUNC
(
dispatch_preprocess_kern
,
algo
,
param
);
auto
fallback_handle
=
handle
();
for
(
auto
kernel
:
kerns
)
{
auto
&&
kerns
=
NCB_ALGO_FUNC
(
dispatch_preprocess_kern
,
algo
,
param
);
auto
&&
fallback_handle
=
handle
();
for
(
auto
&&
kernel
:
kerns
)
{
megdnn_assert
(
param
.
filter_meta
.
format
==
Param
::
Format
::
NCHW
||
param
.
filter_meta
.
format
==
Param
::
Format
::
NHWC
||
...
...
@@ -270,9 +287,9 @@ void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param,
void
ConvolutionImpl
::
exec_with_ncb_kern
(
const
NCBKernParam
&
param
,
Algorithm
*
algo
)
{
auto
kerns
=
NCB_ALGO_FUNC
(
dispatch_kern
,
algo
,
param
);
auto
fallback_handle
=
handle
();
for
(
auto
kernel
:
kerns
)
{
auto
&&
kerns
=
NCB_ALGO_FUNC
(
dispatch_kern
,
algo
,
param
);
auto
&&
fallback_handle
=
handle
();
for
(
auto
&&
kernel
:
kerns
)
{
megdnn_assert
(
param
.
filter_meta
.
format
==
Param
::
Format
::
NCHW
||
param
.
filter_meta
.
format
==
Param
::
Format
::
NHWC
||
...
...
@@ -292,13 +309,32 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
ConvolutionImpl
::
Algorithm
*
ConvolutionImpl
::
get_algorithm_heuristic_with_ncb
(
const
NCBKernSizeParam
&
param
,
size_t
workspace_limit_in_bytes
,
bool
reproducible
)
{
for
(
auto
i
:
get_all_algorithms_with_ncb
(
param
))
{
bool
usable_reproducible
=
static_cast
<
AlgoBase
*>
(
i
)
->
usable_reproducible
(
param
,
AlgoSelectionStrategy
::
HEURISTIC
,
reproducible
);
if
(
usable_reproducible
&&
NCB_ALGO_FUNC
(
get_workspace
,
i
,
param
)
<=
workspace_limit_in_bytes
)
{
return
i
;
auto
algo_data_type
=
param
.
deduce_algo_data_type
();
auto
suggest_category_order
=
suggest_algo_category_order
(
param
);
for
(
auto
category
:
suggest_category_order
)
{
auto
&&
origin_algos
=
select_algo_type
({
algo_data_type
,
category
});
ConvolutionImpl
::
Algorithm
*
heuristic_algo
=
nullptr
;
for
(
auto
i
:
origin_algos
)
{
bool
usable_reproducible
=
static_cast
<
AlgoBase
*>
(
i
)
->
usable_reproducible
(
param
,
AlgoSelectionStrategy
::
HEURISTIC
,
reproducible
);
if
(
usable_reproducible
&&
static_cast
<
AlgoBase
*>
(
i
)
->
get_workspace
(
param
)
<=
workspace_limit_in_bytes
)
{
//! store the first usable algo if no prefer algo, choose it as
//! the target algo
if
(
!
heuristic_algo
)
{
heuristic_algo
=
i
;
}
//! choose the first prefer algo
if
(
i
->
is_preferred
(
param
))
{
return
i
;
}
}
}
if
(
heuristic_algo
)
{
return
heuristic_algo
;
}
}
return
nullptr
;
...
...
@@ -317,8 +353,6 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
}
}
}
std
::
reverse
(
prefer_algos
.
begin
(),
prefer_algos
.
end
());
//! Prefer algo inserted from begin
ret
.
insert
(
ret
.
begin
(),
prefer_algos
.
begin
(),
prefer_algos
.
end
());
return
ret
;
}
...
...
@@ -337,11 +371,45 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
return
m_prev_selected_algo
;
}
SmallVector
<
AlgoCategory
>
ConvolutionImpl
::
suggest_algo_category_order
(
const
NCBKernSizeParam
&
param
)
const
{
static
CpuOprDelegationStorage
<
1
>
storage
;
auto
conv_bias_opr
=
storage
.
get
<
ConvBias
,
0
>
();
auto
conv_bias_param
=
ConvolutionImpl
::
AlgoDefault
::
init_conv_bias_param
(
param
);
return
static_cast
<
ConvBiasImpl
*>
(
conv_bias_opr
)
->
suggest_algo_category_order
(
conv_bias_param
);
}
const
char
*
ConvolutionImpl
::
get_algorithm_set_name
()
const
{
// fallback version 0
return
"F0"
;
}
ConvolutionImpl
::
AlgoDataType
ConvolutionImpl
::
NCBKernSizeParam
::
deduce_algo_data_type
()
const
{
if
(
src_type
.
enumv
()
==
DTypeEnum
::
Float32
)
{
return
ConvolutionImpl
::
AlgoDataType
::
FLOAT32
;
#if !MEGDNN_DISABLE_FLOAT16
}
else
if
(
src_type
.
enumv
()
==
DTypeEnum
::
Float16
)
{
return
ConvolutionImpl
::
AlgoDataType
::
FLOAT16
;
#endif
}
else
if
(
src_type
.
enumv
()
==
DTypeEnum
::
Int8
||
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
if
(
dst_type
.
enumv
()
==
DTypeEnum
::
Int16
)
{
return
ConvolutionImpl
::
AlgoDataType
::
INT8X8X16
;
}
else
{
return
ConvolutionImpl
::
AlgoDataType
::
QINT8X8X32
;
}
}
else
if
(
src_type
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
{
return
ConvolutionImpl
::
AlgoDataType
::
QUINT8X8X32
;
}
else
{
megdnn_throw
(
ssprintf
(
"megdnn not support data type of %s * %s -> %s
\n
"
,
src_type
.
name
(),
filter_type
.
name
(),
dst_type
.
name
()));
}
}
/* ===================== ConvolutionBackwardData ===================== */
void
*
const
ConvolutionBackwardDataImpl
::
sm_fallback_deconv_algo_type
=
...
...
dnn/src/fallback/convolution/opr_impl.h
浏览文件 @
2a3f4d09
...
...
@@ -10,11 +10,28 @@
*/
#pragma once
#include "megdnn/oprs/base.h"
#include "src/common/utils.h"
#include "src/fallback/handle.h"
#include "src/naive/convolution/opr_impl.h"
namespace
megdnn
{
/**
* \brief Convolutino algo category
*/
enum
class
AlgoCategory
:
int32_t
{
DIRECT
=
0
,
IM2COL
=
1
,
WINOGRAD
=
2
,
NAIVE
=
3
,
};
struct
ConvAlgoTypePack
{
detail
::
AlgoDataType
data_type
:
32
;
AlgoCategory
algo_category
:
32
;
};
namespace
fallback
{
/*!
...
...
@@ -33,6 +50,7 @@ class ConvolutionImpl : public naive::ConvolutionForwardImpl {
public:
using
naive
::
ConvolutionForwardImpl
::
ConvolutionForwardImpl
;
using
AlgoSelectionStrategy
=
detail
::
AlgoSelectionStrategy
;
using
AlgoDataType
=
detail
::
AlgoDataType
;
//! implemented by exec_with_ncb_kern()
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
filter
,
...
...
@@ -86,6 +104,8 @@ public:
size_t
nr_threads
;
//! weight_preprocess info
const
PreprocessedFilter
*
preprocessed_filter
;
//! get the data type category of the param for select the algo
AlgoDataType
deduce_algo_data_type
()
const
;
};
//! memory param for kernels with non-contiguous batch
...
...
@@ -211,6 +231,9 @@ public:
return
(
!
reproducible
||
is_reproducible
())
&&
usable
(
param
,
algo_selection_strategy
);
}
//! get the type of the algo
virtual
ConvAlgoTypePack
get_algo_type
()
const
=
0
;
};
/**
...
...
@@ -218,6 +241,11 @@ public:
*/
virtual
SmallVector
<
AlgoBase
*>
algo_pack
();
/**
* \brief select algo according to input algo type
*/
SmallVector
<
AlgoBase
*>
select_algo_type
(
ConvAlgoTypePack
algo_type
);
protected:
virtual
void
exec_with_ncb_kern
(
const
NCBKernParam
&
param
,
Algorithm
*
algo
);
...
...
@@ -258,6 +286,9 @@ private:
_megdnn_tensor_out
dst
,
const
PreprocessedFilter
*
preprocessed_filter
,
_megdnn_workspace
workspace
);
SmallVector
<
AlgoCategory
>
suggest_algo_category_order
(
const
NCBKernSizeParam
&
param
)
const
;
};
class
ConvolutionBackwardDataImpl
:
public
naive
::
ConvolutionBackwardDataImpl
{
...
...
dnn/src/fallback/matrix_mul/algos.cpp
浏览文件 @
2a3f4d09
...
...
@@ -76,7 +76,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoF32K8x12x1
,
megdnn_fb_matmul_f32_kern
,
5
,
matmul
::
fallback
::
sgemm_8x12
,
float
,
float
);
float
,
AlgoDataType
::
FLOAT32
,
DEFAULT
);
/* ===================== gemv algo ===================== */
bool
MatrixMulImpl
::
AlgoGemv
::
usable
(
...
...
dnn/src/fallback/matrix_mul/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -37,7 +37,15 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
,
static_cast
<
AlgoDataType
>
(
static_cast
<
uint32_t
>
(
AlgoDataType
::
FLOAT16
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
FLOAT32
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
INT8X8X16
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
QINT8X8X32
)
|
static_cast
<
uint32_t
>
(
AlgoDataType
::
QUINT8X8X32
)),
DEFAULT
)
};
}
// namespace fallback
...
...
dnn/src/fallback/matrix_mul/gemm_common.h
浏览文件 @
2a3f4d09
...
...
@@ -352,13 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
DType dtype_c) \
: A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {}
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size) \
MatmulDescription matmul_description() const override { \
MatmulDescription mdesc; \
mdesc.packmode = packmode(); \
mdesc.innerblocksize = {_m, _n, _k}; \
mdesc.packa_type_size = _packa_type_size; \
return mdesc; \
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size, _data_type, \
_format) \
MatmulDescription matmul_description() const override { \
MatmulDescription mdesc; \
mdesc.packmode = packmode(); \
mdesc.innerblocksize = {_m, _n, _k}; \
mdesc.packa_type_size = _packa_type_size; \
mdesc.algo_type = {_data_type, Param::Format::_format}; \
return mdesc; \
}
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \
...
...
@@ -373,7 +375,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
_packa_type
)
\
_packa_type
, _support_data_type, _format)
\
\
MatrixMulImpl::kern_naked_t MatrixMulImpl::_algo_name::get_kern_naked( \
const KernSizeParam&) const { \
...
...
@@ -474,14 +476,16 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \
_strategy::UNROLL_K}; \
mdesc.packa_type_size = sizeof(_packa_type); \
mdesc.algo_type = {_support_data_type, Param::Format::_format}; \
return mdesc; \
}
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type) \
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(_algo_name, _midout_name, \
_mid_index, _strategy, \
_i_type, _c_type, _i_type)
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
_support_data_type, _format) \
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
_i_type, _support_data_type, _format)
}
// namespace matmul
}
// namespace megdnn
...
...
dnn/src/fallback/matrix_mul/opr_impl.cpp
浏览文件 @
2a3f4d09
...
...
@@ -38,6 +38,22 @@ SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
return
s_algo_pack
.
all_algos
;
}
SmallVector
<
MatrixMulImpl
::
AlgoBase
*>
MatrixMulImpl
::
select_algo_type
(
AlgoTypePack
index
)
{
megdnn_assert
(
nr_type_contain
(
index
.
data_type
),
"Matmul algo selection only support one type"
);
SmallVector
<
MatrixMulImpl
::
AlgoBase
*>
algos
;
for
(
auto
&&
algo
:
algo_pack
())
{
auto
algo_desc
=
algo
->
matmul_description
();
if
(
contain_data_type
(
algo_desc
.
algo_type
.
data_type
,
index
.
data_type
)
&&
algo_desc
.
algo_type
.
format
==
index
.
format
)
{
algos
.
push_back
(
algo
);
}
}
return
algos
;
}
std
::
vector
<
MatrixMul
::
Algorithm
*>
MatrixMulImpl
::
get_all_algorithms
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
const
TensorLayout
&
C
)
{
std
::
vector
<
Algorithm
*>
gemm_algos
,
gemv_algos
;
...
...
@@ -71,17 +87,25 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic(
"require reproducible algorithm, but given algorithm is not "
"reproducible"
);
}
auto
algos
=
get_all_algorithms
(
A
,
B
,
C
);
AlgoTypePack
algo_type
;
algo_type
.
data_type
=
kern_size_param
.
deduce_algo_data_type
();
algo_type
.
format
=
kern_size_param
.
format
;
auto
algos
=
select_algo_type
(
algo_type
);
Algorithm
*
heuristic_algo
=
nullptr
;
for
(
auto
&&
algo
:
algos
)
{
if
(
static_cast
<
AlgoBase
*>
(
algo
)
->
preferred_reproducible
(
if
(
static_cast
<
AlgoBase
*>
(
algo
)
->
usable
(
kern_size_param
)
&&
static_cast
<
AlgoBase
*>
(
algo
)
->
preferred_reproducible
(
kern_size_param
,
reproducible
)
&&
static_cast
<
AlgoBase
*>
(
algo
)
->
get_workspace
(
kern_size_param
)
<=
workspace_limit_in_bytes
)
{
return
algo
;
if
(
algo
->
algoset
()
==
AlgoBase
::
AlgoSet
::
ALGO_TYPE_GEMV
)
{
return
algo
;
}
else
if
(
!
heuristic_algo
)
{
heuristic_algo
=
algo
;
}
}
}
return
nullptr
;
return
heuristic_algo
;
}
MatrixMulImpl
::
KernSizeParam
MatrixMulImpl
::
make_kern_size_param
(
...
...
@@ -150,4 +174,34 @@ void MatrixMulImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
naive
::
MatrixMulForwardImpl
::
exec
(
A
,
B
,
C
,
workspace
);
}
MatrixMulImpl
::
AlgoDataType
MatrixMulImpl
::
KernSizeParam
::
deduce_algo_data_type
()
const
{
megdnn_assert
(
A_type
.
enumv
()
==
B_type
.
enumv
(),
"Matmul A type and B type of different ctype
\n
"
);
if
(
A_type
.
enumv
()
==
DTypeEnum
::
Float32
)
{
return
MatrixMulImpl
::
AlgoDataType
::
FLOAT32
;
#if !MEGDNN_DISABLE_FLOAT16
}
else
if
(
A_type
.
enumv
()
==
DTypeEnum
::
Float16
)
{
return
MatrixMulImpl
::
AlgoDataType
::
FLOAT16
;
#endif
}
else
if
(
A_type
.
enumv
()
==
DTypeEnum
::
Int8
||
A_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)
{
if
(
C_type
.
enumv
()
==
DTypeEnum
::
Int16
)
{
return
MatrixMulImpl
::
AlgoDataType
::
INT8X8X16
;
}
else
{
megdnn_assert
(
C_type
.
enumv
()
==
DTypeEnum
::
Int32
||
C_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
);
return
MatrixMulImpl
::
AlgoDataType
::
QINT8X8X32
;
}
}
else
if
(
A_type
.
enumv
()
==
DTypeEnum
::
Quantized8Asymm
)
{
return
MatrixMulImpl
::
AlgoDataType
::
QUINT8X8X32
;
}
else
if
(
A_type
.
enumv
()
==
DTypeEnum
::
Int16
)
{
return
MatrixMulImpl
::
AlgoDataType
::
INT16X16X32
;
}
else
{
megdnn_throw
(
ssprintf
(
"megdnn matmul not support data type of %s * %s -> %s
\n
"
,
A_type
.
name
(),
B_type
.
name
(),
C_type
.
name
()));
}
}
// vim: syntax=cpp.doxygen
dnn/src/fallback/matrix_mul/opr_impl.h
浏览文件 @
2a3f4d09
...
...
@@ -10,14 +10,23 @@
* implied.
*/
#pragma once
#include "megdnn/opr_param_defs.h"
#include "src/common/utils.h"
#include "src/naive/matrix_mul/opr_impl.h"
#include <unordered_map>
namespace
megdnn
{
namespace
fallback
{
struct
AlgoTypePack
{
detail
::
AlgoDataType
data_type
:
32
;
param
::
MatrixMul
::
Format
format
:
32
;
};
namespace
fallback
{
class
MatrixMulImpl
:
public
naive
::
MatrixMulForwardImpl
{
public:
using
naive
::
MatrixMulForwardImpl
::
MatrixMulForwardImpl
;
using
AlgoDataType
=
detail
::
AlgoDataType
;
bool
is_thread_safe
()
const
override
{
return
true
;
}
...
...
@@ -34,6 +43,8 @@ public:
bool
trA
,
trB
;
Param
::
ComputeMode
compute_mode
;
Param
::
Format
format
;
//! get the data type category of the param for select the algo
AlgoDataType
deduce_algo_data_type
()
const
;
};
struct
KernParam
:
public
KernSizeParam
{
...
...
@@ -110,6 +121,7 @@ public:
struct
MatmulDescription
{
PackMode
packmode
;
InnerBlockSize
innerblocksize
;
AlgoTypePack
algo_type
;
size_t
packa_type_size
;
};
...
...
@@ -146,6 +158,11 @@ public:
*/
virtual
SmallVector
<
AlgoBase
*>
algo_pack
();
/**
* \brief select algo according to input algo type
*/
SmallVector
<
AlgoBase
*>
select_algo_type
(
AlgoTypePack
algo_type
);
protected:
KernSizeParam
make_kern_size_param
(
const
TensorLayout
&
A
,
const
TensorLayout
&
B
,
...
...
dnn/src/x86/conv_bias/f32/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -48,6 +48,10 @@ public:
}
void
*
type
()
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
}
};
/* ===================== direct-stride2 algo ===================== */
...
...
@@ -81,6 +85,10 @@ public:
}
void
*
type
()
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
}
};
/* =========================== winograd ======================== */
class
ConvBiasImpl
::
AlgoFP32WinogradF63_8x8
final
:
public
AlgoBase
{
...
...
@@ -96,7 +104,7 @@ public:
return
m_name
.
c_str
();
}
void
*
type
()
const
override
;
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT32
);
};
class
ConvBiasImpl
::
AlgoFP32WinogradF23_8x8
final
:
public
AlgoBase
{
...
...
@@ -112,7 +120,7 @@ public:
return
m_name
.
c_str
();
}
void
*
type
()
const
override
;
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE
(
AlgoDataType
::
FLOAT32
);
};
/* ===================== matmul algo ===================== */
...
...
@@ -151,6 +159,9 @@ public:
}
void
*
type
()
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
IM2COL
};
}
};
#if MEGDNN_X86_WITH_MKL_DNN
...
...
@@ -192,6 +203,10 @@ public:
return
{{
kern
,
{
1
_z
,
1
_z
,
1
_z
}}};
}
void
*
type
()
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
FLOAT32
,
AlgoCategory
::
DIRECT
};
}
};
#endif
// vim: syntax=cpp.doxygen
dnn/src/x86/conv_bias/int8/algo_usable_preferred.cpp
浏览文件 @
2a3f4d09
...
...
@@ -224,8 +224,6 @@ bool mkldnn_matmul_qint8_preferred(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
)
{
auto
is_preferred
=
true
;
auto
&&
fm
=
param
.
filter_meta
;
megdnn_assert_internal
(
fm
.
group
==
1
&&
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
);
// single channel conv should never use matrix mul
if
(
fm
.
ocpg
==
1
||
fm
.
icpg
==
1
)
...
...
dnn/src/x86/conv_bias/int8/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -34,6 +34,10 @@ public:
}
void
*
type
()
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
/* ===================== avx2 stride2 chanwise algo ===================== */
...
...
@@ -55,6 +59,10 @@ public:
}
void
*
type
()
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
/* ===================== avx2 stride1 direct algo ===================== */
...
...
@@ -76,6 +84,10 @@ public:
}
void
*
type
()
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
/* ================== avx2 int8 direct conv stride2 algo ================== */
...
...
@@ -97,6 +109,10 @@ public:
}
void
*
type
()
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
#if MEGDNN_X86_WITH_MKL_DNN
...
...
@@ -134,6 +150,10 @@ public:
}
void
*
type
()
const
override
;
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
DIRECT
};
}
};
/* ===================== mkldnn qint8 matmul algo ===================== */
class
ConvBiasImpl
::
AlgoMkldnnMatmulQint8
final
:
public
AlgoBase
{
...
...
@@ -160,6 +180,10 @@ public:
bool
is_preferred
(
const
NCBKernSizeParam
&
param
)
const
override
;
void
*
type
()
const
override
;
ConvAlgoTypePack
get_algo_type
()
const
override
{
return
{
AlgoDataType
::
QINT8X8X32
,
AlgoCategory
::
IM2COL
};
}
};
#endif
...
...
dnn/src/x86/conv_bias/opr_impl.cpp
浏览文件 @
2a3f4d09
...
...
@@ -103,10 +103,10 @@ public:
#endif
all_algos
.
emplace_back
(
&
stride1_direct
);
all_algos
.
emplace_back
(
&
stride2_direct
);
all_algos
.
emplace_back
(
&
avx2_stride1_direct_int8
);
all_algos
.
emplace_back
(
&
avx2_stride2_direct
);
all_algos
.
emplace_back
(
&
avx2_stride1_chanwsie_qint8
);
all_algos
.
emplace_back
(
&
avx2_stride2_chanwsie_qint8
);
all_algos
.
emplace_back
(
&
avx2_stride1_direct_int8
);
all_algos
.
emplace_back
(
&
avx2_stride2_direct
);
all_algos
.
emplace_back
(
&
matmul
);
static
CpuOprDelegationStorage
<>
storage
;
...
...
@@ -182,4 +182,41 @@ bool ConvBiasImpl::is_matmul_quantized_prefer(
!
chanwise_avx2_stride2_qint8_usable_preferred
(
param
));
}
SmallVector
<
AlgoCategory
>
ConvBiasImpl
::
suggest_algo_category_order
(
const
NCBKernSizeParam
&
param
)
const
{
auto
IC
=
param
.
filter_meta
.
icpg
;
auto
OC
=
param
.
filter_meta
.
ocpg
;
auto
FH
=
param
.
filter_meta
.
spatial
[
0
];
auto
FW
=
param
.
filter_meta
.
spatial
[
1
];
//! TODO: now winograd only support fast-run
if
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW_WINOGRAD
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW44_WINOGRAD
||
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88_WINOGRAD
)
{
return
{
AlgoCategory
::
WINOGRAD
};
}
//! nchw88 use mkl-dnn which algo is direct
if
(
param
.
filter_meta
.
format
==
param
::
ConvBias
::
Format
::
NCHW88
)
{
return
{
AlgoCategory
::
DIRECT
,
AlgoCategory
::
IM2COL
};
}
//! im2col + matmul
bool
im2col_prefer
=
(
IC
>=
32
||
OC
>=
32
);
//! quantized algo use matmul when direct algo is unusable
if
(
param
.
src_type
.
category
()
==
DTypeCategory
::
QUANTIZED
)
{
im2col_prefer
=
is_matmul_quantized_prefer
(
param
);
}
//! conv1x1
im2col_prefer
|=
(
FH
==
1
&&
FW
==
1
);
//! x86 8x8x16 not optmized, so it will use fallback im2col+matmul
if
(
param
.
deduce_algo_data_type
()
==
AlgoDataType
::
INT8X8X16
)
{
im2col_prefer
=
true
;
}
if
(
im2col_prefer
)
{
return
{
AlgoCategory
::
IM2COL
,
AlgoCategory
::
DIRECT
,
AlgoCategory
::
NAIVE
};
}
else
{
return
{
AlgoCategory
::
DIRECT
,
AlgoCategory
::
IM2COL
,
AlgoCategory
::
NAIVE
};
}
}
// vim: syntax=cpp.doxygen
dnn/src/x86/conv_bias/opr_impl.h
浏览文件 @
2a3f4d09
...
...
@@ -24,6 +24,8 @@ public:
bool
is_thread_safe
()
const
override
{
return
true
;
}
SmallVector
<
AlgoBase
*>
algo_pack
()
override
;
SmallVector
<
AlgoCategory
>
suggest_algo_category_order
(
const
NCBKernSizeParam
&
param
)
const
override
;
class
AlgoDirect
;
class
AlgoDirectStride2
;
...
...
dnn/src/x86/matrix_mul/algos.cpp
浏览文件 @
2a3f4d09
...
...
@@ -184,11 +184,10 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_kern(
return
int8x8x32_kern_vnni
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoInt8x8x32Vnni
,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x32Vnni"
_hash
,
x86
::
matmul
::
gemm_int8_vnni_12x32x4
,
dt_int8
,
dt_int32
,
dt_uint8
);
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoInt8x8x32Vnni
,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x32Vnni"
_hash
,
x86
::
matmul
::
gemm_int8_vnni_12x32x4
,
dt_int8
,
dt_int32
,
dt_uint8AlgoDataType
::
QINT8X8X32
,
DEFAULT
);
#endif
/* ===================== Int8 mkldnn algo ===================== */
...
...
@@ -397,7 +396,8 @@ size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace(
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoInt8x8x16AVX2
,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x16AVX2"
_hash
,
x86
::
matmul
::
gemm_avx2_s8s8s16_4x16x2
,
dt_int8
,
dt_int16
,
dt_int16
);
x86
::
matmul
::
gemm_avx2_s8s8s16_4x16x2
,
dt_int8
,
dt_int16
,
dt_int16
,
AlgoDataType
::
INT8X8X16
,
DEFAULT
);
/*************************AlgoInt8x8x16SSE********************/
void
MatrixMulImpl
::
AlgoInt8x8x16SSE
::
gemm_s8s8s16_sse_4x8x2
(
...
...
@@ -474,7 +474,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16SSE,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x16SSE"
_hash
,
x86
::
matmul
::
gemm_sse_s8s8s16_4x8x2
,
dt_int8
,
dt_int16
,
dt_int16
);
dt_int8
,
dt_int16
,
dt_int16
,
AlgoDataType
::
INT8X8X16
,
DEFAULT
);
/*************************AlgoInt8x8x32AVX2M4N16K2********************/
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x32AVX2M4N16K2
::
get_kern
(
...
...
@@ -516,7 +517,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL
(
AlgoInt8x8x32AVX2M4N16K2
,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x32AVX2M4N16K2"
_hash
,
x86
::
matmul
::
gemm_avx2_s8s8s32_4x16x2
,
dt_int8
,
dt_int32
,
dt_int16
);
dt_int8
,
dt_int32
,
dt_int16
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
);
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x32AVX2M2N4K16
::
get_kern
(
const
KernSizeParam
&
)
const
{
...
...
@@ -556,7 +557,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x32AVX2M2N4K16"
_hash
,
x86
::
matmul
::
gemm_avx2_s8s8s32_2x4x16
,
dt_int8
,
dt_int32
);
dt_int8
,
dt_int32
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
);
/*************************AlgoInt8x8x32SSEM4N8K2********************/
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x32SSEM4N8K2
::
get_kern
(
...
...
@@ -596,7 +598,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2,
megdnn_x86_matmul_kern
,
"AlgoInt8x8x32SSEM4N8K2"
_hash
,
x86
::
matmul
::
gemm_sse_s8s8s32_4x8x2
,
dt_int8
,
dt_int32
,
dt_int16
);
dt_int8
,
dt_int32
,
dt_int16
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
);
/*************************AlgoF32MK8_8x8********************/
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoF32MK8_8x8
::
get_kern
(
...
...
dnn/src/x86/matrix_mul/algos.h
浏览文件 @
2a3f4d09
...
...
@@ -27,7 +27,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
,
AlgoDataType
::
FLOAT32
,
DEFAULT
)
};
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
...
...
@@ -49,7 +49,7 @@ public:
WorkspaceBundle
get_bundle
(
const
KernSizeParam
&
param
)
const
override
;
InnerBlockSize
get_inner_block_size
()
const
override
{
return
{
8
,
16
,
1
};
};
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
4
,
AlgoDataType
::
FLOAT32
,
DEFAULT
)
};
#endif
...
...
@@ -127,7 +127,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
4
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
8
,
8
,
4
,
AlgoDataType
::
FLOAT32
,
MK8
)
};
#if MEGDNN_X86_WITH_VNNI
...
...
@@ -153,7 +153,7 @@ public:
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_x86_algo_type
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
)
MEGDNN_OVERRIDE_MATMUL_DESC
(
8
,
16
,
1
,
2
,
AlgoDataType
::
QINT8X8X32
,
DEFAULT
)
};
#endif
}
// namespace x86
...
...
src/opr/impl/dnn/convolution.cpp
浏览文件 @
2a3f4d09
...
...
@@ -495,8 +495,9 @@ class AlgoChooser {
}
}
mgb_assert
(
found
,
"algo got by heuristic not found in "
"candidate list"
);
"algo %s got by heuristic not found in "
"candidate list"
,
heu
->
name
());
return
std
::
move
(
ret
);
}
...
...
@@ -628,7 +629,7 @@ public:
auto
algo
=
get_algo
(
ctx
);
size_t
workspace
=
ctx
.
get_workspace_size_bytes
(
algo
);
mgb_log_debug
(
"%s:
input shapes (%s %s, %s %s) -> (%s %s):
algo=%s "
"%s:
tensor layouts (%s %s, %s %s)->(%s %s) :
algo=%s "
"workspace=%.2fMiB reproducible=%d"
,
mgb_opr
->
dyn_typeinfo
()
->
name
,
layouts
[
0
].
to_string
().
c_str
(),
...
...
@@ -636,8 +637,7 @@ public:
layouts
[
1
].
to_string
().
c_str
(),
layouts
[
1
].
dtype
.
name
(),
layouts
[
layouts
.
size
()
-
1
].
to_string
().
c_str
(),
layouts
[
layouts
.
size
()
-
1
].
dtype
.
name
(),
algo
->
name
(),
layouts
[
layouts
.
size
()
-
1
].
dtype
.
name
(),
algo
->
name
(),
workspace
/
(
1024
*
1024.0
),
algo
->
is_reproducible
());
megdnn_opr
->
execution_policy
()
=
{
algo
};
return
workspace
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录