Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b778d225
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b778d225
编写于
8月 07, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
8月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb/fallback): add conv1x1_gemv, conv1x1 and im2col 8x8x16/8x8x32 support bias
GitOrigin-RevId: 3d97fedc8f33d0b41f94680d6710c56bc32b62e7
上级
c357db01
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
294 addition
and
118 deletion
+294
-118
dnn/src/arm_common/conv_bias/postprocess_helper.h
dnn/src/arm_common/conv_bias/postprocess_helper.h
+60
-1
dnn/src/fallback/conv_bias/common.h
dnn/src/fallback/conv_bias/common.h
+5
-3
dnn/src/fallback/conv_bias/conv1x1/algos.cpp
dnn/src/fallback/conv_bias/conv1x1/algos.cpp
+1
-2
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
+27
-14
dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp
dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp
+29
-15
dnn/src/fallback/conv_bias/im2col/algos.cpp
dnn/src/fallback/conv_bias/im2col/algos.cpp
+1
-2
dnn/src/fallback/conv_bias/im2col/factory.h
dnn/src/fallback/conv_bias/im2col/factory.h
+35
-18
dnn/src/fallback/conv_bias/im2col/strategy_default.cpp
dnn/src/fallback/conv_bias/im2col/strategy_default.cpp
+4
-6
dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp
...src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp
+4
-7
dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp
dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp
+2
-2
dnn/src/x86/conv_bias/postprocess_helper.h
dnn/src/x86/conv_bias/postprocess_helper.h
+67
-0
dnn/src/x86/elemwise_helper/kimpl/add.h
dnn/src/x86/elemwise_helper/kimpl/add.h
+2
-0
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+48
-46
dnn/test/x86/conv_bias.cpp
dnn/test/x86/conv_bias.cpp
+9
-2
未找到文件。
dnn/src/arm_common/conv_bias/postprocess_helper.h
浏览文件 @
b778d225
...
...
@@ -100,7 +100,6 @@ namespace {
MIDOUT_END(); \
break; \
default: \
megdnn_throw("no quantized unsupported biasmode"); \
break; \
}
...
...
@@ -258,6 +257,66 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
#undef FOR_NONLINEAR_NOBIAS
#undef FOR_NONLINEAR
#undef FOR_BIAS
#define FOR_BINARY_BROADCAST(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW);
#define FOR_BINARY_BROADCAST_NCHW44(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, \
megdnn::arm_common::VEC_BCAST101x4>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
#define FOR_BINARY(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N* OC* OH* OW* pack_oc_size);
#define FOR_BIAS(_bias_mode, OH, OW) \
switch (_bias_mode) { \
case megdnn::BiasMode::NO_BIAS: \
break; \
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \
if (pack_oc_size == 1) { \
FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \
} else { \
megdnn_assert(pack_oc_size == 4, \
"Only support nchw44 in ARM"); \
FOR_BINARY_BROADCAST_NCHW44(CONCAT_OP(AddOp)); \
} \
break; \
case megdnn::BiasMode::BIAS: \
FOR_BINARY(CONCAT_OP(AddOp)); \
break; \
default: \
break; \
}
template
<
typename
ctype
,
typename
dtype
>
struct
PostProcess
<
ctype
,
dtype
,
megdnn
::
PostprocessMode
::
ADD_BIAS
>
{
static
void
run
(
void
*
conv_dst_ptr
,
void
*
bias_ptr
,
void
*
dst_ptr
,
megdnn
::
BiasMode
bias_mode
,
megdnn
::
NonlineMode
nonlineMode
,
megdnn
::
DType
bias_type
,
megdnn
::
DType
dst_type
,
size_t
N
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
pack_oc_size
=
1
)
{
megdnn_assert
(
nonlineMode
==
megdnn
::
NonlineMode
::
IDENTITY
);
FOR_BIAS
(
bias_mode
,
OH
,
OW
);
}
};
#undef FOR_BINARY_BROADCAST
#undef FOR_BINARY_BROADCAST_NCHW44
#undef FOR_BINARY
#undef FOR_BIAS
#undef CB
#undef CONCAT_OP
#undef CONCAT_NL
...
...
dnn/src/fallback/conv_bias/common.h
浏览文件 @
b778d225
...
...
@@ -159,8 +159,10 @@ private: \
enum
class
PostprocessMode
:
uint8_t
{
FLOAT
=
0
,
///< support all biasmode and no_nonlinemode
NO_PROCESS
,
///<support non bias and identity
QUANTIZED
,
///<support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish identify nonline mode
NO_PROCESS
,
///< support non bias and identity
QUANTIZED
,
///< support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish
///< identify nonline mode
ADD_BIAS
,
///< only add bias
};
}
// namespace megdnn
...
...
dnn/src/fallback/conv_bias/conv1x1/algos.cpp
浏览文件 @
b778d225
...
...
@@ -227,8 +227,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param,
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS16
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int32
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
if
(
param
.
bias_mode
!=
megdnn
::
BiasMode
::
NO_BIAS
||
param
.
nonlineMode
!=
megdnn
::
NonlineMode
::
IDENTITY
)
{
if
(
param
.
nonlineMode
!=
megdnn
::
NonlineMode
::
IDENTITY
)
{
return
false
;
}
}
...
...
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
浏览文件 @
b778d225
...
...
@@ -310,6 +310,19 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
} \
} \
MIDOUT_END()
#define cb3(_format, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == param.src_type.enumv() && \
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
conv1x1_gemv_worker = \
Conv1x1GemvWorker<_src_ctype, _bias_ctype, _dst_ctype, \
_bias_ctype, _dst_ctype, \
_postprocess_mode, _format>::exec; \
} \
} \
MIDOUT_END()
switch
(
param
.
filter_meta
.
format
)
{
case
param
::
ConvBias
::
Format
::
NCHW
:
...
...
@@ -324,23 +337,23 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
PostprocessMode
::
NO_PROCESS
,
"NCHW::GEMV::FLOAT16_FLOAT16"
_hash
);
#endif
#endif
cb
2
(
param
::
ConvBias
::
Format
::
NCHW
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
cb
3
(
param
::
ConvBias
::
Format
::
NCHW
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
ADD_BIA
S
,
"NCHW::GEMV::INT8x8x32_INT32"
_hash
);
cb
2
(
param
::
ConvBias
::
Format
::
NCHW
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
NO_PROCES
S
,
cb
3
(
param
::
ConvBias
::
Format
::
NCHW
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
ADD_BIA
S
,
"NCHW::GEMV::INT8x8x16_INT16"
_hash
);
cb
2
(
param
::
ConvBias
::
Format
::
NCHW
,
dtype
::
QuantizedS8
,
cb
3
(
param
::
ConvBias
::
Format
::
NCHW
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
dt_int32
,
PostprocessMode
::
ADD_BIA
S
,
"NCHW::GEMV::QINT8x8x32_QINT32"
_hash
);
cb2
(
param
::
ConvBias
::
Format
::
NCHW
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS8
,
dt_int8
,
dt_int32
,
dt_int8
,
PostprocessMode
::
QUANTIZED
,
"NCHW::GEMV::QINT8x8x32_QINT8"
_hash
);
cb
2
(
param
::
ConvBias
::
Format
::
NCHW
,
dtype
::
Quantized8Asymm
,
cb
3
(
param
::
ConvBias
::
Format
::
NCHW
,
dtype
::
Quantized8Asymm
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS32
,
dt_uint8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
dt_int32
,
PostprocessMode
::
ADD_BIA
S
,
"NCHW::GEMV::QUINT8x8x32_QINT32"
_hash
);
cb2
(
param
::
ConvBias
::
Format
::
NCHW
,
dtype
::
Quantized8Asymm
,
dtype
::
QuantizedS32
,
dtype
::
Quantized8Asymm
,
dt_uint8
,
dt_int32
,
...
...
@@ -365,13 +378,13 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
break
;
case
param
::
ConvBias
::
Format
::
NCHW44_DOT
:
cb
2
(
param
::
ConvBias
::
Format
::
NCHW44_DOT
,
dt_int8
,
dt_int32
,
cb
3
(
param
::
ConvBias
::
Format
::
NCHW44_DOT
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
PostprocessMode
::
ADD_BIA
S
,
"NCHW44_DOT::GEMV::INT8x8x32_INT32"
_hash
);
cb
2
(
param
::
ConvBias
::
Format
::
NCHW44_DOT
,
dtype
::
QuantizedS8
,
cb
3
(
param
::
ConvBias
::
Format
::
NCHW44_DOT
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
dt_int32
,
PostprocessMode
::
ADD_BIA
S
,
"NCHW44_DOT::GEMV::QINT8x8x32_QINT32"
_hash
);
cb2
(
param
::
ConvBias
::
Format
::
NCHW44_DOT
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS8
,
dt_int8
,
dt_int32
,
...
...
@@ -385,6 +398,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
}
#undef cb1
#undef cb2
#undef cb3
megdnn_assert
(
conv1x1_gemv_worker
,
"No suitable gemv worker"
);
...
...
@@ -448,8 +462,7 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(const NCBKernSizeParam& param,
if
(
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int16
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int32
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
if
(
param
.
bias_mode
!=
megdnn
::
BiasMode
::
NO_BIAS
||
param
.
nonlineMode
!=
megdnn
::
NonlineMode
::
IDENTITY
)
{
if
(
param
.
nonlineMode
!=
megdnn
::
NonlineMode
::
IDENTITY
)
{
return
false
;
}
}
...
...
dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.cpp
浏览文件 @
b778d225
...
...
@@ -56,6 +56,19 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
} \
} \
MIDOUT_END()
#define cb3(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == param.src_type.enumv() && \
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
return std::make_unique<Conv1x1Strategy< \
_src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \
_dst_ctype, _postprocess_mode, _packmode>>(pack_c_size); \
} \
} \
MIDOUT_END()
switch
(
pack_mode
)
{
case
MatrixMulImpl
::
AlgoBase
::
PackMode
::
DEFAULT
:
...
...
@@ -71,26 +84,26 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
"Default::FLOAT16_FLOAT16"
_hash
);
#endif
#endif
cb
2
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
DEFAULT
,
dt_int8
,
dt_int32
,
cb
3
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
DEFAULT
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
"Default::INT8x8x32_INT32"
_hash
);
cb
2
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
DEFAULT
,
dt_int8
,
dt_int16
,
PostprocessMode
::
ADD_BIA
S
,
"Default::INT8x8x32_INT32"
_hash
);
cb
3
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
DEFAULT
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
NO_PROCES
S
,
"Default::INT8x8x16_INT16"
_hash
);
PostprocessMode
::
ADD_BIA
S
,
"Default::INT8x8x16_INT16"
_hash
);
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
cb
2
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
DEFAULT
,
cb
3
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
DEFAULT
,
dtype
::
Quantized8Asymm
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS32
,
dt_uint8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
PostprocessMode
::
ADD_BIA
S
,
"Default::QUINT8x8x32_QINT32"
_hash
);
cb2
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
DEFAULT
,
dtype
::
Quantized8Asymm
,
dtype
::
QuantizedS32
,
dtype
::
Quantized8Asymm
,
dt_uint8
,
dt_int32
,
dt_uint8
,
PostprocessMode
::
QUANTIZED
,
"Default::QUINT8x8x32_QUINT8"
_hash
);
#endif
cb
2
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
DEFAULT
,
dtype
::
QuantizedS8
,
cb
3
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
DEFAULT
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
dt_int32
,
PostprocessMode
::
ADD_BIA
S
,
"Default::QINT8x8x32_QINT32"
_hash
);
cb2
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
DEFAULT
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS8
,
dt_int8
,
dt_int32
,
...
...
@@ -107,17 +120,17 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
cb1
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
NO_PACK
,
dt_float32
,
dt_float32
,
PostprocessMode
::
FLOAT
,
"NoPack::FLOAT"
_hash
);
cb
2
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
NO_PACK
,
dt_int8
,
dt_int16
,
cb
3
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
NO_PACK
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
NO_PROCES
S
,
"NoPack::INT8x8x16_INT16"
_hash
);
PostprocessMode
::
ADD_BIA
S
,
"NoPack::INT8x8x16_INT16"
_hash
);
cb
2
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
NO_PACK
,
dt_int8
,
dt_int32
,
cb
3
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
NO_PACK
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
"NoPack::INT8x8x32_INT32"
_hash
);
PostprocessMode
::
ADD_BIA
S
,
"NoPack::INT8x8x32_INT32"
_hash
);
cb
2
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
NO_PACK
,
dtype
::
QuantizedS8
,
cb
3
(
MatrixMulImpl
::
AlgoBase
::
PackMode
::
NO_PACK
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
dt_int32
,
PostprocessMode
::
ADD_BIA
S
,
"NoPack::QINT8x8x32_QINT32"
_hash
);
break
;
...
...
@@ -127,6 +140,7 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
}
#undef cb1
#undef cb2
#undef cb3
megdnn_throw
(
"Invalid Data Type"
);
return
nullptr
;
}
...
...
dnn/src/fallback/conv_bias/im2col/algos.cpp
浏览文件 @
b778d225
...
...
@@ -746,8 +746,7 @@ bool ConvBiasImpl::AlgoIm2col::usable(
if
(
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int16
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
Int32
||
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS32
)
{
if
(
param
.
bias_mode
!=
megdnn
::
BiasMode
::
NO_BIAS
||
param
.
nonlineMode
!=
megdnn
::
NonlineMode
::
IDENTITY
)
{
if
(
param
.
nonlineMode
!=
megdnn
::
NonlineMode
::
IDENTITY
)
{
return
false
;
}
}
...
...
dnn/src/fallback/conv_bias/im2col/factory.h
浏览文件 @
b778d225
...
...
@@ -213,6 +213,22 @@ public:
} \
MIDOUT_END(); \
return {};
#define cb3(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \
_src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode, \
_midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == param.src_type.enumv() && \
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
return std::make_unique< \
Strategy<_src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \
_dst_ctype, _postprocess_mode, \
PackMode::_packmode, FormatMode::_format>>(); \
} \
} \
MIDOUT_END(); \
return {};
static
std
::
unique_ptr
<
StrategyBase
>
make_default_strategy
(
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
...
...
@@ -279,13 +295,13 @@ public:
#endif
case
StrategyType
::
INT8x8x32
:
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW
)
{
cb
2
(
NCHW
,
DEFAULT
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
cb
3
(
NCHW
,
DEFAULT
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
ADD_BIA
S
,
"DefaultStrategyType::INT8x8x32"
_hash
);
}
else
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW44
||
format
==
param
::
ConvBias
::
Format
::
NCHW44_DOT
)
{
cb
2
(
NCHW44
,
DEFAULT
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
cb
3
(
NCHW44
,
DEFAULT
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
ADD_BIA
S
,
"DefaultStrategyType::INT8x8x32"
_hash
);
}
else
{
megdnn_throw
(
...
...
@@ -299,12 +315,12 @@ public:
case
StrategyType
::
INT8x8x16
:
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW
)
{
cb
2
(
NCHW
,
DEFAULT
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
NO_PROCES
S
,
cb
3
(
NCHW
,
DEFAULT
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
ADD_BIA
S
,
"DefaultStrategyType::INT8x8x16"
_hash
);
}
else
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW44
)
{
cb
2
(
NCHW44
,
DEFAULT
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
NO_PROCES
S
,
cb
3
(
NCHW44
,
DEFAULT
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
ADD_BIA
S
,
"DefaultStrategyType::INT8x8x16"
_hash
);
}
else
{
megdnn_throw
(
...
...
@@ -316,9 +332,9 @@ public:
break
;
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
case
StrategyType
::
QUINT8x8x32
:
cb
2
(
NCHW
,
DEFAULT
,
dtype
::
Quantized8Asymm
,
dtype
::
QuantizedS32
,
cb
3
(
NCHW
,
DEFAULT
,
dtype
::
Quantized8Asymm
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS32
,
dt_uint8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
PostprocessMode
::
ADD_BIA
S
,
"DefaultStrategyType::QUINT8x8x32"
_hash
);
break
;
...
...
@@ -331,15 +347,15 @@ public:
#endif
case
StrategyType
::
QINT8x8x32
:
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW
)
{
cb
2
(
NCHW
,
DEFAULT
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
cb
3
(
NCHW
,
DEFAULT
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
PostprocessMode
::
ADD_BIA
S
,
"DefaultStrategyTypeNCHW::QINT8x8x32"
_hash
);
}
else
if
(
format
==
param
::
ConvBias
::
Format
::
NCHW44
||
format
==
param
::
ConvBias
::
Format
::
NCHW44_DOT
)
{
cb
2
(
NCHW44
,
DEFAULT
,
dtype
::
QuantizedS8
,
cb
3
(
NCHW44
,
DEFAULT
,
dtype
::
QuantizedS8
,
dtype
::
QuantizedS32
,
dtype
::
QuantizedS32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
dt_int32
,
dt_int32
,
PostprocessMode
::
ADD_BIA
S
,
"DefaultStrategyTypeHCHW44::QINT8x8x32"
_hash
);
}
else
{
megdnn_throw
(
...
...
@@ -467,13 +483,13 @@ public:
#endif
#endif
case
StrategyType
::
INT8x8x16
:
cb
2
(
NCHW
,
NO_PACK
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
NO_PROCES
S
,
cb
3
(
NCHW
,
NO_PACK
,
dt_int8
,
dt_int16
,
dt_int16
,
dt_int8
,
dt_int16
,
dt_int16
,
PostprocessMode
::
ADD_BIA
S
,
"NoPackStrategyType::INT8x8x16"
_hash
);
break
;
case
StrategyType
::
INT8x8x32
:
cb
2
(
NCHW
,
NO_PACK
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
NO_PROCES
S
,
cb
3
(
NCHW
,
NO_PACK
,
dt_int8
,
dt_int32
,
dt_int32
,
dt_int8
,
dt_int32
,
dt_int32
,
PostprocessMode
::
ADD_BIA
S
,
"NoPackStrategyType::INT8x8x32"
_hash
);
break
;
default:
...
...
@@ -509,6 +525,7 @@ public:
#undef cb1
#undef cb2
#undef cb3
static
std
::
unique_ptr
<
StrategyBase
>
make_strategy
(
fallback
::
MatrixMulImpl
::
AlgoBase
*
matmul_algo
,
...
...
dnn/src/fallback/conv_bias/im2col/strategy_default.cpp
浏览文件 @
b778d225
...
...
@@ -203,18 +203,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS
(
dt_uint8
,
dt_int32
,
dt_uint8
,
dt_qint32
,
dt_quint8
,
megdnn
::
PostprocessMode
::
QUANTIZED
)
INSTANTIAL_CLASS
(
dt_uint8
,
dt_int32
,
dt_int32
,
dt_
qint32
,
dt_q
int32
,
megdnn
::
PostprocessMode
::
NO_PROCES
S
)
INSTANTIAL_CLASS
(
dt_uint8
,
dt_int32
,
dt_int32
,
dt_
int32
,
dt_
int32
,
megdnn
::
PostprocessMode
::
ADD_BIA
S
)
#endif
INSTANTIAL_CLASS
(
dt_int8
,
dt_int32
,
dt_int8
,
dt_qint32
,
dt_qint8
,
megdnn
::
PostprocessMode
::
QUANTIZED
)
INSTANTIAL_CLASS
(
dt_int8
,
dt_int32
,
dt_int32
,
dt_int32
,
dt_int32
,
megdnn
::
PostprocessMode
::
NO_PROCES
S
)
megdnn
::
PostprocessMode
::
ADD_BIA
S
)
INSTANTIAL_CLASS
(
dt_int8
,
dt_int16
,
dt_int16
,
dt_int16
,
dt_int16
,
megdnn
::
PostprocessMode
::
NO_PROCESS
)
INSTANTIAL_CLASS
(
dt_int8
,
dt_int32
,
dt_int32
,
dt_qint32
,
dt_qint32
,
megdnn
::
PostprocessMode
::
NO_PROCESS
)
megdnn
::
PostprocessMode
::
ADD_BIAS
)
#undef INSTANTIAL_CLASS
}
// namespace megdnn
...
...
dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp
浏览文件 @
b778d225
...
...
@@ -119,19 +119,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS
(
dt_uint8
,
dt_int32
,
dt_uint8
,
dt_qint32
,
dt_quint8
,
megdnn
::
PostprocessMode
::
QUANTIZED
)
INSTANTIAL_CLASS
(
dt_uint8
,
dt_int32
,
dt_int32
,
dt_
qint32
,
dt_q
int32
,
megdnn
::
PostprocessMode
::
NO_PROCES
S
)
INSTANTIAL_CLASS
(
dt_uint8
,
dt_int32
,
dt_int32
,
dt_
int32
,
dt_
int32
,
megdnn
::
PostprocessMode
::
ADD_BIA
S
)
#endif
INSTANTIAL_CLASS
(
dt_int8
,
dt_int32
,
dt_int8
,
dt_qint32
,
dt_qint8
,
megdnn
::
PostprocessMode
::
QUANTIZED
)
INSTANTIAL_CLASS
(
dt_int8
,
dt_int32
,
dt_int32
,
dt_int32
,
dt_int32
,
megdnn
::
PostprocessMode
::
NO_PROCES
S
)
megdnn
::
PostprocessMode
::
ADD_BIA
S
)
INSTANTIAL_CLASS
(
dt_int8
,
dt_int16
,
dt_int16
,
dt_int16
,
dt_int16
,
megdnn
::
PostprocessMode
::
NO_PROCESS
)
INSTANTIAL_CLASS
(
dt_int8
,
dt_int32
,
dt_int32
,
dt_qint32
,
dt_qint32
,
megdnn
::
PostprocessMode
::
NO_PROCESS
)
megdnn
::
PostprocessMode
::
ADD_BIAS
)
#undef INSTANTIAL_CLASS
}
// namespace megdnn
...
...
dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp
浏览文件 @
b778d225
...
...
@@ -162,9 +162,9 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
INSTANTIAL_CLASS
(
dt_float32
,
dt_float32
,
dt_float32
,
dt_float32
,
dt_float32
,
megdnn
::
PostprocessMode
::
FLOAT
)
INSTANTIAL_CLASS
(
dt_int8
,
dt_int16
,
dt_int16
,
dt_int16
,
dt_int16
,
megdnn
::
PostprocessMode
::
NO_PROCES
S
)
megdnn
::
PostprocessMode
::
ADD_BIA
S
)
INSTANTIAL_CLASS
(
dt_int8
,
dt_int32
,
dt_int32
,
dt_int32
,
dt_int32
,
megdnn
::
PostprocessMode
::
NO_PROCES
S
)
megdnn
::
PostprocessMode
::
ADD_BIA
S
)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#else
#if !MEGDNN_DISABLE_FLOAT16
...
...
dnn/src/x86/conv_bias/postprocess_helper.h
浏览文件 @
b778d225
...
...
@@ -294,6 +294,73 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> {
#undef FOR_BIAS
}
};
#undef CALL_BINARY
#undef CALL_BINARY_BROADCAST
#define CALL_BINARY(_op, _simd_type) \
thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \
DType, size_t)> \
run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \
megdnn::x86::BcastType::VEC_VEC>::run; \
run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \
reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N* OC* OH* OW);
#define CALL_BINARY_BROADCAST(_op, _simd_type) \
thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \
DType, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \
megdnn::x86::BcastType::VEC_BCAST101>::run; \
run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \
reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, N, \
OC, OH* OW);
#define FOR_SIMD(CALLER) \
if (is_supported(SIMDType::AVX2)) { \
CALLER(AddOp, SIMDType::AVX2) \
} else if (is_supported(SIMDType::SSE4_2)) { \
CALLER(AddOp, SIMDType::SSE4_2) \
} else { \
CALLER(AddOp, SIMDType::NONE) \
}
#define FOR_BIAS(bias_mode) \
switch (bias_mode) { \
case BiasMode::BIAS: \
FOR_SIMD(CALL_BINARY); \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
FOR_SIMD(CALL_BINARY_BROADCAST); \
break; \
default: \
break; \
}
template
<
typename
ctype
,
typename
dtype
>
struct
PostProcess
<
ctype
,
dtype
,
megdnn
::
PostprocessMode
::
ADD_BIAS
>
{
static
void
run
(
void
*
conv_dst_ptr
,
void
*
bias_ptr
,
void
*
dst_ptr
,
megdnn
::
ConvBiasForward
::
BiasMode
bias_mode
,
megdnn
::
param
::
ConvBiasV0
::
NonlineMode
nonlineMode
,
DType
bias_type
,
DType
dst_type
,
size_t
N
,
size_t
OC
,
size_t
OH
,
size_t
OW
,
size_t
pack_oc_size
=
1
)
{
MEGDNN_MARK_USED_VAR
(
pack_oc_size
);
megdnn_assert
(
pack_oc_size
==
1
,
"PostProcess only support nchw in x86"
);
megdnn_assert
(
nonlineMode
==
megdnn
::
param
::
ConvBiasV0
::
NonlineMode
::
IDENTITY
,
"Add bias PostProcess only support IDENTITY"
);
if
(
bias_mode
==
megdnn
::
ConvBiasForward
::
BiasMode
::
NO_BIAS
)
{
return
;
}
FOR_BIAS
(
bias_mode
);
#undef CALL_BINARY
#undef CALL_BINARY_BROADCAST
#undef FOR_SIMD
#undef FOR_BIAS
}
};
#undef cb_unary
#undef cb_binary
#undef BIAS_CASE
...
...
dnn/src/x86/elemwise_helper/kimpl/add.h
浏览文件 @
b778d225
...
...
@@ -92,6 +92,8 @@ OP(dt_int8, SIMDType::AVX2, "avx2", __m256i, __m256ix2, __m256i, mm256, epi8,
using AddOpBase::operator(); \
};
OP
(
dt_int32
,
SIMDType
::
NONE
);
OP
(
dt_int16
,
SIMDType
::
NONE
);
OP
(
dt_float32
,
SIMDType
::
NONE
);
#undef OP
}
// namespace x86
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
b778d225
...
...
@@ -1992,13 +1992,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) {
#define cb(name) \
checker_conv_bias( \
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false,
false,
\
true, false, true,
false, false, tru
e), \
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false,
true,
\
true, false, true,
true, false, fals
e), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \
false, false, tru
e), \
true, false, fals
e), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name);
...
...
@@ -2041,13 +2041,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) {
#define cb(name) \
checker_conv_bias( \
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false,
false,
\
true, false, true,
false, false, tru
e), \
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false,
true,
\
true, false, true,
true, false, fals
e), \
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
dtype::Int32(), {}, name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \
false, false, tru
e), \
true, false, fals
e), \
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
dtype::Int32(), {}, name);
...
...
@@ -2118,7 +2118,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) {
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM
)
{
NormalRNG
rng
(
128.
f
);
#define cb(name) \
checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
false, true, true), \
...
...
@@ -2189,14 +2188,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
UniformIntRNG
rng
{
-
50
,
50
};
float
epsilon
=
0.001
;
#define cb(name) \
checker_conv_bias(
\
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true),
\
checker_conv_bias(
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false,
\
true, true, false),
\
handle(), &rng, epsilon, \
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
dtype::QuantizedS32(1.2 * 1.3), {}, name); \
checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
&rng, epsilon, \
checker_conv_bias( \
get_conv_bias_args({1}, 2, false, false, true, true, false), \
handle(), &rng, epsilon, \
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
dtype::QuantizedS32(1.2 * 1.3), {}, name);
...
...
@@ -2252,18 +2252,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
UniformIntRNG
rng
{
-
50
,
50
};
float
epsilon
=
0.001
;
std
::
vector
<
conv_bias
::
TestArg
>
args_nchw44
=
get_nchw44_conv_bias_args
({
2
,
3
,
4
,
5
,
6
,
7
},
1
,
true
,
tru
e
,
true
,
false
,
false
,
false
,
false
,
tru
e
);
get_nchw44_conv_bias_args
({
2
,
3
,
4
,
5
,
6
,
7
},
1
,
true
,
fals
e
,
true
,
false
,
false
,
true
,
false
,
fals
e
);
std
::
vector
<
conv_bias
::
TestArg
>
args_nchw44_1x1s2
=
get_nchw44_conv_bias_args
({
1
},
2
,
true
,
tru
e
,
true
,
false
,
false
,
false
,
false
,
tru
e
);
get_nchw44_conv_bias_args
({
1
},
2
,
true
,
fals
e
,
true
,
false
,
false
,
true
,
false
,
fals
e
);
#define cb(name) \
checker_conv_bias( \
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false,
true, true),
\
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false,
false, true),
\
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
dtype::Int16{}, dtype::Int16{}, name); \
checker_conv_bias(get_conv_bias_args({1}, 2, false,
true, true), handle(),
\
&rng, epsilon, dtype::Int8{}, dtype::Int8{},
\
checker_conv_bias(get_conv_bias_args({1}, 2, false,
false, true),
\
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{},
\
dtype::Int16{}, dtype::Int16{}, name);
#define cb_nchw44(name) \
...
...
@@ -2316,10 +2316,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPR
float
epsilon
=
0.001
;
#define cb(name) \
check_conv_bias_preprocess( \
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false,
tru
e, true), \
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false,
fals
e, true), \
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
dtype::Int16{}, dtype::Int16{}, name); \
check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false,
tru
e, true), \
check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false,
fals
e, true), \
handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \
name);
...
...
@@ -2406,7 +2406,7 @@ void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args,
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
2.5
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
2.5
f
))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
6.25
f
))
.
set_dtype
(
4
,
{}
)
.
set_dtype
(
4
,
dtype
::
QuantizedS32
(
6.25
f
)
)
.
set_rng
(
0
,
&
rng
)
.
set_rng
(
1
,
&
rng
)
.
set_rng
(
2
,
&
rng
)
...
...
@@ -2436,7 +2436,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args
checker
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
2.5
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
2.5
f
))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
6.25
f
))
.
set_dtype
(
4
,
{}
)
.
set_dtype
(
4
,
dtype
::
QuantizedS32
(
6.25
f
)
)
.
set_rng
(
0
,
&
rng
)
.
set_rng
(
1
,
&
rng
)
.
set_rng
(
2
,
&
rng
)
...
...
@@ -2450,7 +2450,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2
)
{
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_nchw44_conv_bias_args
({
2
,
5
,
7
},
2
,
false
,
tru
e
,
true
);
get_nchw44_conv_bias_args
({
2
,
5
,
7
},
2
,
false
,
fals
e
,
true
);
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
#if MEGDNN_AARCH64
...
...
@@ -2464,7 +2464,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) {
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS
)
{
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_nchw44_conv_bias_args
({
2
,
5
,
7
},
2
,
false
,
tru
e
,
true
);
get_nchw44_conv_bias_args
({
2
,
5
,
7
},
2
,
false
,
fals
e
,
true
);
#define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name);
#if MEGDNN_AARCH64
...
...
@@ -2478,7 +2478,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPR
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1
)
{
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_nchw44_conv_bias_args
({
3
,
4
,
6
},
1
,
false
,
tru
e
,
true
);
get_nchw44_conv_bias_args
({
3
,
4
,
6
},
1
,
false
,
fals
e
,
true
);
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
#if MEGDNN_AARCH64
...
...
@@ -3080,9 +3080,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) {
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONVBIAS_1X1_S1_INT8x8x16
)
{
UniformIntRNG
rng
{
-
50
,
50
};
float
epsilon
=
0.001
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_conv_bias_1x1_args
(
true
,
true
);
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_conv_bias_1x1_args
(
false
,
true
,
false
,
false
);
std
::
vector
<
conv_bias
::
TestArg
>
args_nchw44
=
get_nchw44_conv_bias_args
(
{
1
},
1
,
true
,
true
,
true
,
false
,
false
,
false
,
false
,
tru
e
);
{
1
},
1
,
true
,
true
,
true
,
false
,
false
,
true
,
false
,
fals
e
);
#define cb(name) \
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name);
...
...
@@ -3140,7 +3141,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) {
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_1X1_S1_INT8x8x32
)
{
using
namespace
conv_bias
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_conv_bias_1x1_args
(
true
,
true
);
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_conv_bias_1x1_args
(
false
,
true
,
false
,
false
);
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
...
...
dnn/test/x86/conv_bias.cpp
浏览文件 @
b778d225
...
...
@@ -834,6 +834,13 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X32) {
//! no bias
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
,
h
,
w
},
TensorShape
{
oc
,
ic
,
kernel
,
kernel
},
TensorShape
{});
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
,
h
,
w
},
TensorShape
{
oc
,
ic
,
kernel
,
kernel
},
TensorShape
{
1
,
oc
,
1
,
1
});
args
.
emplace_back
(
param
,
TensorShape
{
1
,
ic
,
h
,
w
},
TensorShape
{
oc
,
ic
,
kernel
,
kernel
},
TensorShape
{
1
,
oc
,
(
h
+
2
*
p
-
kernel
)
+
1
,
(
h
+
2
*
p
-
kernel
)
+
1
});
};
for
(
size_t
kernel
:
{
2
,
3
,
4
,
5
,
6
,
7
})
...
...
@@ -1384,7 +1391,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) {
using
namespace
conv_bias
;
UniformIntRNG
rng
{
-
50
,
50
};
float
epsilon
=
0.001
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_conv_bias_1x1_args
(
tru
e
,
true
);
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_conv_bias_1x1_args
(
fals
e
,
true
);
#if MEGDNN_X86_WITH_MKL_DNN
if
(
x86
::
is_supported
(
x86
::
SIMDType
::
VNNI
))
{
checker_conv_bias
(
args
,
handle
(),
&
rng
,
epsilon
,
dtype
::
Int8
{},
...
...
@@ -1422,7 +1429,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32_PREPROCESS) {
using
namespace
conv_bias
;
UniformIntRNG
rng
{
-
50
,
50
};
float
epsilon
=
0.001
;
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_conv_bias_1x1_args
(
tru
e
,
true
);
std
::
vector
<
conv_bias
::
TestArg
>
args
=
get_conv_bias_1x1_args
(
fals
e
,
true
);
#if MEGDNN_X86_WITH_VNNI
if
(
x86
::
is_supported
(
x86
::
SIMDType
::
VNNI
))
{
checker_conv_bias_preprocess
(
args
,
handle
(),
&
rng
,
epsilon
,
dtype
::
Int8
{},
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录