Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5885b137
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5885b137
编写于
10月 08, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/arm): support layout like NHWC channel like broadcast on arm
GitOrigin-RevId: fb4300004c4e1920d3cd1be40ca33bde822e4c72
上级
565466c2
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
723 addition
and
7 deletion
+723
-7
dnn/src/arm_common/elemwise/binary/algo.cpp
dnn/src/arm_common/elemwise/binary/algo.cpp
+81
-0
dnn/src/arm_common/elemwise/binary/algo.h
dnn/src/arm_common/elemwise/binary/algo.h
+1
-0
dnn/src/arm_common/elemwise/opr_impl.cpp
dnn/src/arm_common/elemwise/opr_impl.cpp
+40
-0
dnn/src/arm_common/elemwise/opr_impl.h
dnn/src/arm_common/elemwise/opr_impl.h
+3
-0
dnn/src/arm_common/elemwise/ternary/algo.cpp
dnn/src/arm_common/elemwise/ternary/algo.cpp
+80
-0
dnn/src/arm_common/elemwise/ternary/algo.h
dnn/src/arm_common/elemwise/ternary/algo.h
+2
-0
dnn/src/arm_common/elemwise_op.h
dnn/src/arm_common/elemwise_op.h
+229
-0
dnn/src/arm_common/quantized_converter.h
dnn/src/arm_common/quantized_converter.h
+14
-0
dnn/src/arm_common/type_cvt/opr_impl.cpp
dnn/src/arm_common/type_cvt/opr_impl.cpp
+106
-1
dnn/src/common/elemwise/opr_impl_helper.cpp
dnn/src/common/elemwise/opr_impl_helper.cpp
+13
-0
dnn/src/common/elemwise/opr_impl_helper.h
dnn/src/common/elemwise/opr_impl_helper.h
+10
-0
dnn/src/fallback/type_cvt/opr_impl.cpp
dnn/src/fallback/type_cvt/opr_impl.cpp
+4
-2
dnn/src/naive/type_cvt/opr_impl.cpp
dnn/src/naive/type_cvt/opr_impl.cpp
+2
-2
dnn/test/arm_common/elemwise.cpp
dnn/test/arm_common/elemwise.cpp
+91
-0
dnn/test/arm_common/type_cvt.cpp
dnn/test/arm_common/type_cvt.cpp
+20
-0
dnn/test/common/checker.cpp
dnn/test/common/checker.cpp
+2
-1
dnn/test/common/rng.cpp
dnn/test/common/rng.cpp
+3
-0
dnn/test/cuda/type_cvt.cpp
dnn/test/cuda/type_cvt.cpp
+5
-0
dnn/test/x86/type_cvt.cpp
dnn/test/x86/type_cvt.cpp
+14
-0
src/opr/impl/basic_arith.cpp
src/opr/impl/basic_arith.cpp
+3
-1
未找到文件。
dnn/src/arm_common/elemwise/binary/algo.cpp
浏览文件 @
5885b137
...
...
@@ -104,6 +104,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available(
return
false
;
}
bool
ElemwiseImpl
::
AlgoBinaryVecBcast111C
::
is_available
(
const
KernParam
&
kern_param
)
const
{
if
(
!
is_available_common
(
kern_param
.
mode
)
||
((
BcastType
::
VEC_BCAST111C
!=
kern_param
.
broad_cast_type
)
&&
(
BcastType
::
BCAST111C_VEC
!=
kern_param
.
broad_cast_type
)))
return
false
;
auto
&
elparam
=
kern_param
.
binary_elparam
;
auto
&
src0
=
elparam
[
0
];
DISPATCH_TYPE
(
"AlgoBinaryVecBcast111C::is_available"
_hash
);
return
false
;
}
bool
ElemwiseImpl
::
AlgoBinaryVecBcast101xX
::
is_available
(
const
KernParam
&
kern_param
)
const
{
if
(
!
is_available_common
(
kern_param
.
mode
)
||
...
...
@@ -333,6 +348,72 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons
return
;
}
void
ElemwiseImpl
::
AlgoBinaryVecBcast111C
::
exec
(
const
KernParam
&
kern_param
)
const
{
auto
&
elparam
=
kern_param
.
binary_elparam
;
auto
&
src0
=
elparam
[
0
],
&
src1
=
elparam
[
1
];
auto
&&
dst
=
*
(
kern_param
.
m_dst
);
BroadcastChannelInfo
binfo
;
// Case extra: BcastType::VEC + BCAST_111C
if
(
BcastType
::
VEC_BCAST111C
==
kern_param
.
broad_cast_type
&&
is_NHWC_broadcasted_channel_like
(
src1
.
layout
,
binfo
))
{
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::VEC_BCAST111C>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return
DISPATCH_TYPE
(
"AlgoBinaryVecBcast111C::exec_vec_b"
_hash
);
#undef DISPATCH_BINARY
}
// BCAST_111C + BcastType::VEC
if
(
BcastType
::
BCAST111C_VEC
==
kern_param
.
broad_cast_type
&&
is_NHWC_broadcasted_channel_like
(
src0
.
layout
,
binfo
))
{
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::BCAST111C_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return
DISPATCH_TYPE
(
"AlgoBinaryVecBcast111C::exec_b_vec"
_hash
);
#undef DISPATCH_BINARY
}
return
;
}
void
ElemwiseImpl
::
AlgoBinaryVecBcast101xX
::
exec
(
const
KernParam
&
kern_param
)
const
{
auto
&
elparam
=
kern_param
.
binary_elparam
;
auto
&
src0
=
elparam
[
0
],
&
src1
=
elparam
[
1
];
...
...
dnn/src/arm_common/elemwise/binary/algo.h
浏览文件 @
5885b137
...
...
@@ -33,6 +33,7 @@ namespace arm_common {
DECL_CB
(
VecVec
);
DECL_CB
(
VecScalar
);
DECL_CB
(
VecBcast101
);
DECL_CB
(
VecBcast111C
);
DECL_CB
(
VecBcast101xX
);
#undef DECL_CB
}
// namespace arm_common
...
...
dnn/src/arm_common/elemwise/opr_impl.cpp
浏览文件 @
5885b137
...
...
@@ -27,12 +27,15 @@ class ElemwiseImpl::AlgoPack {
AlgoBinaryVecVec
algo_binary_vec_vec
;
AlgoBinaryVecScalar
algo_binary_vec_sca
;
AlgoBinaryVecBcast101
algo_binary_vec_bcast101
;
AlgoBinaryVecBcast111C
algo_binary_vec_bcast110
;
AlgoBinaryVecBcast101xX
algo_binary_VEC_BCAST101xX
;
AlgoTernaryFma3VecVecVec
algo_ternaryfma3_vec_vec_vec
;
AlgoTernaryFma3VecVecScalar
algo_ternaryfma3_vec_vecsca
;
AlgoTernaryFma3Bcast101VecBcast101
algo_ternaryfma3_bcast101_vec_bcast101
;
AlgoTernaryFma3Bcast111CVecBcast111C
algo_ternaryfma3_bcast110_vec_bcast110
;
AlgoTernaryFma3Bcast101xXVecBcast101xX
algo_ternaryfma3_bcast101xX_vec_bcast101xX
;
AlgoTernaryFma3VecBcast101Vec
algo_ternaryfma3_vec_bcast101_vec
;
AlgoTernaryFma3VecBcast111CVec
algo_ternaryfma3_vec_bcast110_vec
;
AlgoTernaryFma3VecBcast101xXVec
algo_ternaryfma3_vec_bcast101xX_vec
;
AlgoTernaryFma3VecScalarVec
algo_ternaryfma3_vec_sca_vec
;
AlgoTernaryFma3VecScalarScalar
algo_ternaryfma3_vec_sca_sca
;
...
...
@@ -43,12 +46,15 @@ public:
all_algos
.
emplace_back
(
&
algo_binary_vec_vec
);
all_algos
.
emplace_back
(
&
algo_binary_vec_sca
);
all_algos
.
emplace_back
(
&
algo_binary_vec_bcast101
);
all_algos
.
emplace_back
(
&
algo_binary_vec_bcast110
);
all_algos
.
emplace_back
(
&
algo_binary_VEC_BCAST101xX
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_vec_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_vecsca
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_bcast101_vec_bcast101
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_bcast110_vec_bcast110
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_bcast101xX_vec_bcast101xX
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_bcast101_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_bcast110_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_bcast101xX_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_sca_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_sca_sca
);
...
...
@@ -87,6 +93,14 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
kern_param
.
mode
=
opr
->
param
().
mode
;
kern_param
.
handle
=
opr
->
handle
();
auto
is_legal_layout_for_nhwc
=
[](
const
TensorLayout
&
l
)
{
if
(
is_vector
(
l
))
return
true
;
if
(
l
.
ndim
==
2
&&
l
.
stride
[
1
]
==
1
)
return
true
;
return
false
;
};
if
((
opr
->
m_src
->
size
()
==
3
)
&&
(
opr
->
param
().
mode
==
Mode
::
FUSE_MUL_ADD3
))
{
kern_param
.
ternary_elparam
=
opr
->
make_elemwise_op_param
<
3
>
();
bool
c_is_scalar
;
...
...
@@ -127,6 +141,20 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
return
kern_param
;
}
if
(
is_legal_layout_for_nhwc
(
src1
.
layout
)
&&
is_NHWC_broadcasted_channel_like
(
src0
.
layout
,
binfo
)
&&
src0
.
layout
.
eq_layout
(
src2
.
layout
))
{
kern_param
.
broad_cast_type
=
BcastType
::
BCAST111C_VEC_BCAST111C
;
return
kern_param
;
}
if
(
is_legal_layout_for_nhwc
(
src0
.
layout
)
&&
src2
.
layout
.
eq_layout
(
src0
.
layout
)
&&
is_NHWC_broadcasted_channel_like
(
src1
.
layout
,
binfo
))
{
kern_param
.
broad_cast_type
=
BcastType
::
VEC_BCAST111C_VEC
;
return
kern_param
;
}
if
(
is_vector
(
src0
.
layout
)
&&
src0
.
layout
.
eq_layout
(
src2
.
layout
)
&&
(
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
)
||
is_broadcastedx_channel_like
<
8
>
(
src1
.
layout
,
binfo
)))
{
...
...
@@ -174,6 +202,18 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
return
kern_param
;
}
if
(
is_legal_layout_for_nhwc
(
src1
.
layout
)
&&
is_NHWC_broadcasted_channel_like
(
src0
.
layout
,
binfo
))
{
kern_param
.
broad_cast_type
=
BcastType
::
BCAST111C_VEC
;
return
kern_param
;
}
if
(
is_legal_layout_for_nhwc
(
src0
.
layout
)
&&
is_NHWC_broadcasted_channel_like
(
src1
.
layout
,
binfo
))
{
kern_param
.
broad_cast_type
=
BcastType
::
VEC_BCAST111C
;
return
kern_param
;
}
if
(
is_vector
(
src0
.
layout
)
&&
(
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
)
||
is_broadcastedx_channel_like
<
8
>
(
src1
.
layout
,
binfo
)))
{
...
...
dnn/src/arm_common/elemwise/opr_impl.h
浏览文件 @
5885b137
...
...
@@ -38,12 +38,15 @@ private:
class
AlgoBinaryVecVec
;
class
AlgoBinaryVecScalar
;
class
AlgoBinaryVecBcast101
;
class
AlgoBinaryVecBcast111C
;
class
AlgoBinaryVecBcast101xX
;
class
AlgoTernaryFma3VecVecVec
;
class
AlgoTernaryFma3VecVecScalar
;
class
AlgoTernaryFma3Bcast101VecBcast101
;
class
AlgoTernaryFma3Bcast111CVecBcast111C
;
class
AlgoTernaryFma3Bcast101xXVecBcast101xX
;
class
AlgoTernaryFma3VecBcast101Vec
;
class
AlgoTernaryFma3VecBcast111CVec
;
class
AlgoTernaryFma3VecBcast101xXVec
;
class
AlgoTernaryFma3VecScalarVec
;
class
AlgoTernaryFma3VecScalarScalar
;
...
...
dnn/src/arm_common/elemwise/ternary/algo.cpp
浏览文件 @
5885b137
...
...
@@ -42,8 +42,10 @@ using namespace arm_common;
DECL_AVAILABLE
(
VecVecVec
,
BcastType
::
VEC_VEC_VEC
);
DECL_AVAILABLE
(
VecVecScalar
,
BcastType
::
VEC_VEC_SCALAR
);
DECL_AVAILABLE
(
Bcast101VecBcast101
,
BcastType
::
BCAST101_VEC_BCAST101
);
DECL_AVAILABLE
(
Bcast111CVecBcast111C
,
BcastType
::
BCAST111C_VEC_BCAST111C
);
DECL_AVAILABLE
(
Bcast101xXVecBcast101xX
,
BcastType
::
BCAST101xX_VEC_BCAST101xX
);
DECL_AVAILABLE
(
VecBcast101Vec
,
BcastType
::
VEC_BCAST101_VEC
);
DECL_AVAILABLE
(
VecBcast111CVec
,
BcastType
::
VEC_BCAST111C_VEC
);
DECL_AVAILABLE
(
VecBcast101xXVec
,
BcastType
::
VEC_BCAST101xX_VEC
);
DECL_AVAILABLE
(
VecScalarVec
,
BcastType
::
VEC_SCALAR_VEC
);
DECL_AVAILABLE
(
VecScalarScalar
,
BcastType
::
VEC_SCALAR_SCALAR
);
...
...
@@ -164,6 +166,45 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec(
return
;
}
void
ElemwiseImpl
::
AlgoTernaryFma3Bcast111CVecBcast111C
::
exec
(
const
KernParam
&
kern_param
)
const
{
auto
&
elparam
=
kern_param
.
ternary_elparam
;
auto
&
src0
=
elparam
[
0
],
&
src1
=
elparam
[
1
],
&
src2
=
elparam
[
2
];
// Case 3: shape of src0 and src2 is {1, 1, 1, C}
BroadcastChannelInfo
binfo
;
is_NHWC_broadcasted_channel_like
(
src0
.
layout
,
binfo
);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, size_t, const _type*, _type*, DType, \
DType, DType, DType, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, \
BcastType::BCAST111C_VEC_BCAST111C>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \
static_cast<const _type*>(src2.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
binfo.x, binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return
auto
&&
dst
=
*
(
kern_param
.
m_dst
);
DISPATCH_TYPE
(
"AlgoTernaryFma3Bcast111CVecBcast111C::exec"
_hash
);
#undef DISPATCH_TERNARY
return
;
}
void
ElemwiseImpl
::
AlgoTernaryFma3Bcast101xXVecBcast101xX
::
exec
(
const
KernParam
&
kern_param
)
const
{
auto
&
elparam
=
kern_param
.
ternary_elparam
;
...
...
@@ -282,6 +323,45 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec(
return
;
}
void
ElemwiseImpl
::
AlgoTernaryFma3VecBcast111CVec
::
exec
(
const
KernParam
&
kern_param
)
const
{
auto
&
elparam
=
kern_param
.
ternary_elparam
;
auto
&
src0
=
elparam
[
0
],
&
src1
=
elparam
[
1
],
&
src2
=
elparam
[
2
];
// Case 4: shape of src1 is {1, 1, 1, C}, and src0 and src2 are contig
BroadcastChannelInfo
binfo
;
is_NHWC_broadcasted_channel_like
(
src1
.
layout
,
binfo
);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, size_t, const _type*, const _type*, size_t, _type*, \
DType, DType, DType, DType, size_t, size_t, size_t)> \
run = OpCallerTernary< \
_op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<const _type*>(src2.raw_ptr), \
is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \
binfo.x, binfo.y, binfo.z)); \
} \
MIDOUT_END(); \
return
auto
&&
dst
=
*
(
kern_param
.
m_dst
);
DISPATCH_TYPE
(
"AlgoTernaryFma3VecBcast111CVec::exec"
_hash
);
#undef DISPATCH_TERNARY
return
;
}
void
ElemwiseImpl
::
AlgoTernaryFma3VecScalarVec
::
exec
(
const
KernParam
&
kern_param
)
const
{
auto
&
elparam
=
kern_param
.
ternary_elparam
;
...
...
dnn/src/arm_common/elemwise/ternary/algo.h
浏览文件 @
5885b137
...
...
@@ -33,8 +33,10 @@ namespace arm_common {
DECL_CB
(
VecVecVec
);
DECL_CB
(
VecVecScalar
);
DECL_CB
(
Bcast101VecBcast101
);
DECL_CB
(
Bcast111CVecBcast111C
);
DECL_CB
(
Bcast101xXVecBcast101xX
);
DECL_CB
(
VecBcast101Vec
);
DECL_CB
(
VecBcast111CVec
);
DECL_CB
(
VecBcast101xXVec
);
DECL_CB
(
VecScalarVec
);
DECL_CB
(
VecScalarScalar
);
...
...
dnn/src/arm_common/elemwise_op.h
浏览文件 @
5885b137
...
...
@@ -107,16 +107,20 @@ enum BcastType {
VEC
,
VEC_VEC
,
VEC_BCAST101
,
VEC_BCAST111C
,
VEC_BCAST101xX
,
VEC_SCALAR
,
SCALAR_VEC
,
BCAST101_VEC
,
BCAST111C_VEC
,
BCAST101xX_VEC
,
VEC_VEC_VEC
,
VEC_VEC_SCALAR
,
BCAST101_VEC_BCAST101
,
BCAST111C_VEC_BCAST111C
,
BCAST101xX_VEC_BCAST101xX
,
VEC_BCAST101_VEC
,
VEC_BCAST111C_VEC
,
VEC_BCAST101xX_VEC
,
VEC_SCALAR_VEC
,
VEC_SCALAR_SCALAR
,
...
...
@@ -226,6 +230,60 @@ struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101> {
}
};
template
<
typename
ctype
>
struct
OpCallerBinary
<
PowOp
<
ctype
,
ctype
>
,
VEC_BCAST111C
>
{
using
Op
=
PowOp
<
ctype
,
ctype
>
;
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
const
typename
Op
::
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
dst_dtype
,
size_t
batch
,
size_t
channel
,
size_t
channel_stride
)
{
Op
op
(
src0_dtype
,
src1_dtype
,
dst_dtype
);
for
(
size_t
b
=
0
;
b
<
batch
;
b
++
)
{
for
(
size_t
c
=
0
;
c
<
channel
;
c
++
)
{
size_t
i
=
0
;
const
typename
Op
::
src_ctype
*
src1_ptr
=
src1
;
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for
(;
i
<
channel_stride
;
i
++
)
{
op
(
*
src0
,
*
src1_ptr
,
dst
);
src0
++
;
src1_ptr
++
;
dst
++
;
}
}
}
}
};
template
<
typename
ctype
>
struct
OpCallerBinary
<
PowOp
<
ctype
,
ctype
>
,
BCAST111C_VEC
>
{
using
Op
=
PowOp
<
ctype
,
ctype
>
;
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
const
typename
Op
::
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
dst_dtype
,
size_t
batch
,
size_t
channel
,
size_t
channel_stride
)
{
Op
op
(
src0_dtype
,
src1_dtype
,
dst_dtype
);
for
(
size_t
b
=
0
;
b
<
batch
;
b
++
)
{
for
(
size_t
c
=
0
;
c
<
channel
;
c
++
)
{
size_t
i
=
0
;
const
typename
Op
::
src_ctype
*
src0_ptr
=
src0
;
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for
(;
i
<
channel_stride
;
i
++
)
{
op
(
*
src0_ptr
,
*
src1
,
dst
);
src0_ptr
++
;
src1
++
;
dst
++
;
}
}
}
}
};
template
<
typename
ctype
>
struct
OpCallerBinary
<
PowOp
<
ctype
,
ctype
>
,
SCALAR_VEC
>
{
using
Op
=
PowOp
<
ctype
,
ctype
>
;
...
...
@@ -340,6 +398,84 @@ struct OpCallerBinary<Op, VEC_BCAST101> {
}
};
template
<
typename
Op
>
struct
OpCallerBinary
<
Op
,
VEC_BCAST111C
>
{
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
const
typename
Op
::
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
dst_dtype
,
size_t
batch
,
size_t
channel
,
size_t
channel_stride
)
{
Op
op
(
src0_dtype
,
src1_dtype
,
dst_dtype
);
ParamElemVisitor
<
typename
Op
::
src_ctype
>
vis
;
for
(
size_t
b
=
0
;
b
<
batch
;
b
++
)
{
for
(
size_t
c
=
0
;
c
<
channel
;
c
++
)
{
size_t
rest
=
channel_stride
;
const
typename
Op
::
src_ctype
*
src1_ptr
=
src1
;
while
(
rest
>=
Op
::
SIMD_WIDTH
*
2
)
{
auto
src0_neon0
=
vis
(
src0
);
auto
src0_neon1
=
vis
(
src0
+
Op
::
SIMD_WIDTH
);
auto
src1_neon0
=
vis
(
src1_ptr
);
auto
src1_neon1
=
vis
(
src1_ptr
+
Op
::
SIMD_WIDTH
);
src0
+=
Op
::
SIMD_WIDTH
*
2
;
src1_ptr
+=
Op
::
SIMD_WIDTH
*
2
;
op
({{
src0_neon0
,
src0_neon1
}},
{{
src1_neon0
,
src1_neon1
}},
dst
);
dst
+=
Op
::
SIMD_WIDTH
*
2
;
rest
-=
Op
::
SIMD_WIDTH
*
2
;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
while
(
rest
>
0
)
{
op
(
*
src0
,
*
src1_ptr
,
dst
);
dst
++
;
src0
++
;
src1_ptr
++
;
rest
--
;
}
}
}
}
};
template
<
typename
Op
>
struct
OpCallerBinary
<
Op
,
BCAST111C_VEC
>
{
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
const
typename
Op
::
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
dst_dtype
,
size_t
batch
,
size_t
channel
,
size_t
channel_stride
)
{
Op
op
(
src0_dtype
,
src1_dtype
,
dst_dtype
);
ParamElemVisitor
<
typename
Op
::
src_ctype
>
vis
;
for
(
size_t
b
=
0
;
b
<
batch
;
b
++
)
{
for
(
size_t
c
=
0
;
c
<
channel
;
c
++
)
{
size_t
rest
=
channel_stride
;
const
typename
Op
::
src_ctype
*
src0_ptr
=
src0
;
while
(
rest
>=
Op
::
SIMD_WIDTH
*
2
)
{
auto
src0_neon0
=
vis
(
src0_ptr
);
auto
src0_neon1
=
vis
(
src0_ptr
+
Op
::
SIMD_WIDTH
);
auto
src1_neon0
=
vis
(
src1
);
auto
src1_neon1
=
vis
(
src1
+
Op
::
SIMD_WIDTH
);
src0_ptr
+=
Op
::
SIMD_WIDTH
*
2
;
src1
+=
Op
::
SIMD_WIDTH
*
2
;
op
({{
src0_neon0
,
src0_neon1
}},
{{
src1_neon0
,
src1_neon1
}},
dst
);
dst
+=
Op
::
SIMD_WIDTH
*
2
;
rest
-=
Op
::
SIMD_WIDTH
*
2
;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
while
(
rest
>
0
)
{
op
(
*
src0_ptr
,
*
src1
,
dst
);
dst
++
;
src0_ptr
++
;
src1
++
;
rest
--
;
}
}
}
}
};
template
<
typename
ctype
>
struct
OpCallerBinary
<
PowOp
<
ctype
,
ctype
>
,
BCAST101xX_VEC
>
{
using
Op
=
PowOp
<
ctype
,
ctype
>
;
...
...
@@ -824,6 +960,54 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
}
};
//! src0: 111C, src1: vector, src2: 111C, src1 may not be contig
template
<
typename
Op
>
struct
OpCallerTernary
<
Op
,
BCAST111C_VEC_BCAST111C
>
{
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
const
typename
Op
::
src_ctype
*
src1
,
size_t
src1_offset
,
const
typename
Op
::
src_ctype
*
src2
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
src2_dtype
,
DType
dst_dtype
,
size_t
batch_size
,
size_t
channel_size
,
size_t
channel_stride
)
{
Op
op
(
src0_dtype
,
src1_dtype
,
src2_dtype
,
dst_dtype
);
ParamElemVisitor
<
typename
Op
::
src_ctype
>
vis
;
for
(
size_t
batch
=
0
;
batch
<
batch_size
;
batch
++
)
{
for
(
size_t
channel
=
0
;
channel
<
channel_size
;
channel
++
)
{
auto
src0_ptr
=
src0
;
auto
src2_ptr
=
src2
;
size_t
i
=
0
;
for
(;
i
+
Op
::
SIMD_WIDTH
*
2
<=
channel_stride
;
i
+=
Op
::
SIMD_WIDTH
*
2
)
{
auto
src0_neon0
=
vis
(
src0_ptr
);
auto
src0_neon1
=
vis
(
src0_ptr
+
Op
::
SIMD_WIDTH
);
auto
src1_neon0
=
vis
(
src1
);
auto
src1_neon1
=
vis
(
src1
+
Op
::
SIMD_WIDTH
);
auto
src2_neon0
=
vis
(
src2_ptr
);
auto
src2_neon1
=
vis
(
src2_ptr
+
Op
::
SIMD_WIDTH
);
op
({{
src0_neon0
,
src0_neon1
}},
{{
src1_neon0
,
src1_neon1
}},
{{
src2_neon0
,
src2_neon1
}},
dst
);
src0_ptr
+=
Op
::
SIMD_WIDTH
*
2
;
src1
+=
Op
::
SIMD_WIDTH
*
2
;
src2_ptr
+=
Op
::
SIMD_WIDTH
*
2
;
dst
+=
Op
::
SIMD_WIDTH
*
2
;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for
(;
i
<
channel_stride
;
i
++
)
{
op
(
*
src0_ptr
,
*
src1
,
*
src2_ptr
,
dst
);
src0_ptr
++
;
src1
++
;
src2_ptr
++
;
dst
++
;
}
src1
+=
src1_offset
;
}
}
}
};
template
<
typename
src_ctype
,
size_t
channel_block_dim
>
struct
OpCallerTernaryBcast101xXVecBcast101xX
{
template
<
typename
Op
>
...
...
@@ -992,6 +1176,51 @@ struct OpCallerTernary<Op, VEC_BCAST101_VEC> {
}
};
//! src1: 111C, src0 and src2 may not be contig
template
<
typename
Op
>
struct
OpCallerTernary
<
Op
,
VEC_BCAST111C_VEC
>
{
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
size_t
src0_offset
,
const
typename
Op
::
src_ctype
*
src1
,
const
typename
Op
::
src_ctype
*
src2
,
size_t
src2_offset
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
src2_dtype
,
DType
dst_dtype
,
size_t
batch_size
,
size_t
channel_size
,
size_t
channel_stride
)
{
Op
op
(
src0_dtype
,
src1_dtype
,
src2_dtype
,
dst_dtype
);
ParamElemVisitor
<
typename
Op
::
src_ctype
>
vis0
;
ParamElemVisitor
<
typename
Op
::
src_ctype
>
vis1
;
ParamElemVisitor
<
typename
Op
::
src_ctype
>
vis2
;
for
(
size_t
batch
=
0
;
batch
<
batch_size
;
batch
++
)
{
for
(
size_t
channel
=
0
;
channel
<
channel_size
;
channel
++
)
{
auto
src1_ptr
=
src1
;
size_t
i
=
0
;
for
(;
i
+
Op
::
SIMD_WIDTH
*
2
<=
channel_stride
;
i
+=
Op
::
SIMD_WIDTH
*
2
)
{
op
({{
vis0
(
src0
),
vis0
(
src0
+
Op
::
SIMD_WIDTH
)}},
{{
vis1
(
src1_ptr
),
vis1
(
src1_ptr
+
Op
::
SIMD_WIDTH
)}},
{{
vis2
(
src2
),
vis2
(
src2
+
Op
::
SIMD_WIDTH
)}},
dst
);
src0
+=
Op
::
SIMD_WIDTH
*
2
;
src1_ptr
+=
Op
::
SIMD_WIDTH
*
2
;
src2
+=
Op
::
SIMD_WIDTH
*
2
;
dst
+=
Op
::
SIMD_WIDTH
*
2
;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for
(;
i
<
channel_stride
;
i
++
)
{
op
(
*
src0
,
*
src1_ptr
,
*
src2
,
dst
);
src0
++
;
src1_ptr
++
;
src2
++
;
dst
++
;
}
src0
+=
src0_offset
;
src2
+=
src2_offset
;
}
}
}
};
template
<
typename
src_ctype
,
size_t
channel_block_dim
>
struct
OpCallerTernaryVecBcast101xXVec
{
template
<
typename
Op
>
...
...
dnn/src/arm_common/quantized_converter.h
浏览文件 @
5885b137
...
...
@@ -50,6 +50,20 @@ inline dt_qint32 QConverter::convert(const float& src) {
saturate
<
int32_t
,
float
>
(
std
::
round
(
src
),
-
2147483648
,
2147483647
));
}
template
<
>
inline
float32x4x2_t
QConverter
::
convert
(
const
int16x8_t
&
vsrc
)
{
int32x4_t
vhi
=
vmovl_s16
(
vget_high_s16
(
vsrc
));
int32x4_t
vlo
=
vmovl_s16
(
vget_low_s16
(
vsrc
));
return
{{
vcvtq_f32_s32
(
vlo
),
vcvtq_f32_s32
(
vhi
)}};
}
template
<
>
inline
float32x4x2_t
QConverter
::
convert
(
const
uint16x8_t
&
vsrc
)
{
uint32x4_t
vhi
=
vmovl_u16
(
vget_high_u16
(
vsrc
));
uint32x4_t
vlo
=
vmovl_u16
(
vget_low_u16
(
vsrc
));
return
{{
vcvtq_f32_u32
(
vlo
),
vcvtq_f32_u32
(
vhi
)}};
}
#if __ARM_ARCH >= 8
template
<
>
inline
int8x8_t
QConverter
::
convert
(
const
float32x4x2_t
&
vsrc
)
{
...
...
dnn/src/arm_common/type_cvt/opr_impl.cpp
浏览文件 @
5885b137
...
...
@@ -17,6 +17,7 @@
#include "src/common/utils.h"
#include "src/naive/handle.h"
MIDOUT_DECL
(
megdnn_arm_typecvt_fix2float
)
MIDOUT_DECL
(
megdnn_arm_typecvt_quantized
)
MIDOUT_DECL
(
megdnn_arm_typecvt_float
)
...
...
@@ -325,6 +326,48 @@ struct FloatTypeCvter<float, __fp16> {
};
#endif
template
<
typename
ctype
,
typename
dtype
>
struct
Fix2FloatTypeCvter
;
template
<
>
struct
Fix2FloatTypeCvter
<
int16_t
,
float
>
{
using
stype
=
int16_t
;
using
dst_type
=
float
;
static
constexpr
size_t
SIMD_WIDTH
=
8
;
Fix2FloatTypeCvter
(
DType
src_dtype
,
DType
dst_dtype
)
{
MEGDNN_MARK_USED_VAR
(
src_dtype
);
MEGDNN_MARK_USED_VAR
(
dst_dtype
);
}
void
cvt
(
const
int16_t
*
src
,
float
*
dst
)
{
int16x8_t
vitem
=
vld1q_s16
(
src
);
auto
vres
=
QConverter
::
convert
<
float32x4x2_t
,
int16x8_t
>
(
vitem
);
vst1q_f32_x2
(
dst
,
vres
);
}
void
cvt_remain
(
const
int16_t
*
src
,
float
*
dst
)
{
*
dst
=
*
src
;
}
};
template
<
>
struct
Fix2FloatTypeCvter
<
uint16_t
,
float
>
{
using
stype
=
uint16_t
;
using
dst_type
=
float
;
static
constexpr
size_t
SIMD_WIDTH
=
8
;
Fix2FloatTypeCvter
(
DType
src_dtype
,
DType
dst_dtype
)
{
MEGDNN_MARK_USED_VAR
(
src_dtype
);
MEGDNN_MARK_USED_VAR
(
dst_dtype
);
}
void
cvt
(
const
uint16_t
*
src
,
float
*
dst
)
{
uint16x8_t
vitem
=
vld1q_u16
(
src
);
auto
vres
=
QConverter
::
convert
<
float32x4x2_t
,
uint16x8_t
>
(
vitem
);
vst1q_f32_x2
(
dst
,
vres
);
}
void
cvt_remain
(
const
uint16_t
*
src
,
float
*
dst
)
{
*
dst
=
*
src
;
}
};
template
<
typename
TypeCvter
>
void
do_typecvt
(
const
typename
TypeCvter
::
stype
*
src
,
typename
TypeCvter
::
dst_type
*
dst
,
...
...
@@ -347,6 +390,43 @@ void do_typecvt(
}
}
template
<
typename
TypeCvter
>
void
do_typecvt
(
const
typename
TypeCvter
::
stype
*
src
,
typename
TypeCvter
::
dst_type
*
dst
,
DType
src_dtype
,
DType
dst_dtype
,
const
TensorLayout
&
src_layout
)
{
TypeCvter
typecvt
(
src_dtype
,
dst_dtype
);
size_t
calc_num
=
1
;
size_t
nr_elems
=
src_layout
.
total_nr_elems
();
size_t
src_stride
=
nr_elems
;
//! adjust calc_num nr_elems and src_stride according to src_collapse_layout
auto
src_collapse_layout
=
src_layout
.
collapse_contiguous
();
if
(
src_collapse_layout
.
ndim
==
2
)
{
calc_num
=
src_collapse_layout
.
shape
[
0
];
nr_elems
=
src_collapse_layout
.
shape
[
1
];
src_stride
=
src_collapse_layout
.
stride
[
0
];
}
for
(
size_t
c
=
0
;
c
<
calc_num
;
++
c
)
{
size_t
i
=
0
;
for
(;
i
+
TypeCvter
::
SIMD_WIDTH
<=
nr_elems
;
i
+=
TypeCvter
::
SIMD_WIDTH
)
{
typecvt
.
cvt
(
src
,
dst
);
src
+=
TypeCvter
::
SIMD_WIDTH
;
dst
+=
TypeCvter
::
SIMD_WIDTH
;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for
(;
i
<
nr_elems
;
i
++
)
{
typecvt
.
cvt_remain
(
src
,
dst
);
src
++
;
dst
++
;
}
src
+=
src_stride
-
nr_elems
;
}
}
}
// anonymous namespace
void
TypeCvtImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
)
{
...
...
@@ -354,7 +434,30 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
DType
dst_dtype
=
dst
.
layout
.
dtype
;
size_t
nr_elems
=
src
.
layout
.
total_nr_elems
();
bool
execed
=
false
;
if
(
src
.
layout
.
is_contiguous
())
{
auto
src_collapse_layout
=
src
.
layout
.
collapse_contiguous
();
bool
has_int16_special_impl
=
(
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Int16
||
src
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint16
)
&&
(
src
.
layout
.
is_contiguous
()
||
src_collapse_layout
.
ndim
==
2
)
&&
dst
.
layout
.
is_contiguous
();
if
(
has_int16_special_impl
)
{
using
namespace
dtype
;
#define DISPATCH_FIX2FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \
MIDOUT_BEGIN(megdnn_arm_typecvt_fix2float, midout_iv(_midout_iv)) { \
using _TypeCvter = Fix2FloatTypeCvter<_stype, _dtype>; \
MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \
src.compatible_ptr<_stype>(), dst.compatible_ptr<_dtype>(), \
src_dtype, dst_dtype, src.layout)); \
execed = true; \
} \
MIDOUT_END(); \
}
DISPATCH_FIX2FLOAT
(
Int16
,
int16_t
,
Float32
,
float
,
0
);
DISPATCH_FIX2FLOAT
(
Uint16
,
uint16_t
,
Float32
,
float
,
1
);
#undef DISPATCH_FIX2FLOAT
}
else
if
(
src
.
layout
.
is_contiguous
())
{
using
namespace
dtype
;
#define DISPATCH_QUANTIZED(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \
...
...
@@ -377,6 +480,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
DISPATCH_QUANTIZED
(
QuantizedS32
,
int32_t
,
QuantizedS32
,
int32_t
,
5
);
DISPATCH_QUANTIZED
(
float
,
float
,
QuantizedS8
,
int8_t
,
6
);
DISPATCH_QUANTIZED
(
float
,
float
,
Quantized8Asymm
,
uint8_t
,
7
);
#undef DISPATCH_QUANTIZED
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \
...
...
@@ -394,6 +498,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
}
DISPATCH_FLOAT
(
dt_float16
,
__fp16
,
float
,
float
,
0
);
DISPATCH_FLOAT
(
float
,
float
,
dt_float16
,
__fp16
,
1
);
#undef DISPATCH_FLOAT
#endif
}
if
(
!
execed
)
{
...
...
dnn/src/common/elemwise/opr_impl_helper.cpp
浏览文件 @
5885b137
...
...
@@ -150,6 +150,19 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like(
return
false
;
}
bool
ElemwiseLayoutHelper
::
is_NHWC_broadcasted_channel_like
(
const
TensorLayout
&
layout
,
BroadcastChannelInfo
&
info
)
{
if
(
layout
.
format
.
type
()
==
TensorFormat
::
Type
::
DEFAULT
)
{
if
(
layout
.
ndim
==
2
&&
layout
.
stride
[
1
]
==
1
&&
layout
.
stride
[
0
]
==
0
)
{
info
.
x
=
1
;
info
.
y
=
layout
.
shape
[
0
];
info
.
z
=
layout
.
shape
[
1
];
return
true
;
}
}
return
false
;
}
bool
ElemwiseLayoutHelper
::
is_broadcasted_1x
(
const
TensorLayout
&
layout
,
Broadcast1xInfo
&
binfo
)
{
if
(
layout
.
ndim
==
2
&&
layout
.
stride
[
0
]
==
0
&&
layout
.
stride
[
1
]
==
1
)
{
...
...
dnn/src/common/elemwise/opr_impl_helper.h
浏览文件 @
5885b137
...
...
@@ -80,6 +80,16 @@ public:
static
bool
is_broadcasted_channel_like
(
const
TensorLayout
&
layout
,
BroadcastChannelInfo
&
info
);
/*!
* \brief check whether layout matches BroadcastChannelInfo under NHWC
* layout
*
* Note that Input must be 2-dimensional, and must be [1, y] broadacsted
* into [z, y] and x would be set to 1.
*/
static
bool
is_NHWC_broadcasted_channel_like
(
const
TensorLayout
&
layout
,
BroadcastChannelInfo
&
info
);
/*!
* \brief check whether layout matches BroadcastChannelInfo
*
...
...
dnn/src/fallback/type_cvt/opr_impl.cpp
浏览文件 @
5885b137
...
...
@@ -309,7 +309,8 @@ void on_dest_ctype(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
break; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Bool
)
case
DTypeEnum
::
QuantizedS8
cb
(
::
megdnn
::
dtype
::
Bool
)
cb
(
::
megdnn
::
dtype
::
Uint16
)
case
DTypeEnum
::
QuantizedS8
:
MIDOUT_BEGIN
(
megdnn_fb_typecvt_src_dtype
,
midout_iv
(
DTypeEnum
::
QuantizedS8
))
{
...
...
@@ -467,7 +468,8 @@ void run_contiguous(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
}
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Bool
)
case
DTypeEnum
::
QuantizedS8
cb
(
::
megdnn
::
dtype
::
Bool
)
cb
(
::
megdnn
::
dtype
::
Uint16
)
case
DTypeEnum
::
QuantizedS8
:
MIDOUT_BEGIN
(
megdnn_fb_typecvt_dst_dtype
,
midout_iv
(
DTypeEnum
::
QuantizedS8
))
{
...
...
dnn/src/naive/type_cvt/opr_impl.cpp
浏览文件 @
5885b137
...
...
@@ -78,7 +78,7 @@ void on_dest_ctype(HandleImpl* handle, const TensorND& dest, const TensorND& src
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Bool
)
cb
(
::
megdnn
::
dtype
::
Bool
)
cb
(
::
megdnn
::
dtype
::
Uint16
)
#undef cb
default
:
megdnn_throw
(
"bad dtype"
);
}
...
...
@@ -99,7 +99,7 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_DTYPE
(
cb
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Bool
)
cb
(
::
megdnn
::
dtype
::
Bool
)
cb
(
::
megdnn
::
dtype
::
Uint16
)
#undef cb
default
:
megdnn_throw
(
"bad dtype"
);
}
...
...
dnn/test/arm_common/elemwise.cpp
浏览文件 @
5885b137
...
...
@@ -14,6 +14,7 @@
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/general.h"
using
namespace
megdnn
;
...
...
@@ -298,6 +299,63 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) {
#endif
}
TEST_F
(
ARM_COMMON
,
ELEMWISE_FORWARD_NHWC_FP32_BCAST
)
{
using
Mode
=
ElemwiseForward
::
Param
::
Mode
;
Checker
<
ElemwiseForward
>
checker
(
handle
());
UniformFloatRNG
rng
(
1e-5
,
7e1
);
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_epsilon
(
1e-5
);
checker
.
set_dtype
(
0
,
dtype
::
Float32
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
//! 2 dim
auto
run
=
[
&
](
Mode
mode
)
{
// VEC_BCAST111C
checker
.
set_param
(
mode
).
execs
({{
1
,
2
,
2
,
12
},
{
1
,
1
,
1
,
12
},
{}});
checker
.
set_param
(
mode
).
execs
({{
2
,
5
,
3
,
28
},
{
1
,
1
,
1
,
28
},
{}});
checker
.
set_param
(
mode
).
execs
({{
3
,
5
,
8
,
32
},
{
1
,
1
,
1
,
32
},
{}});
// BCAST111C_VEC
checker
.
set_param
(
mode
).
execs
({{
1
,
1
,
1
,
12
},
{
1
,
2
,
2
,
12
},
{}});
checker
.
set_param
(
mode
).
execs
({{
1
,
1
,
1
,
28
},
{
2
,
5
,
3
,
28
},
{}});
checker
.
set_param
(
mode
).
execs
({{
1
,
1
,
1
,
32
},
{
3
,
5
,
8
,
32
},
{}});
};
run
(
Mode
::
ADD
);
run
(
Mode
::
MUL
);
run
(
Mode
::
SUB
);
//! 3 dim contig
auto
run_3d_contig
=
[
&
](
Mode
mode
)
{
// BCAST111C_VEC_BCAST111C
checker
.
set_param
(
mode
).
execs
(
{{
1
,
1
,
1
,
12
},
{
1
,
2
,
2
,
12
},
{
1
,
1
,
1
,
12
},
{}});
checker
.
set_param
(
mode
).
execs
(
{{
1
,
1
,
1
,
28
},
{
2
,
5
,
3
,
28
},
{
1
,
1
,
1
,
28
},
{}});
checker
.
set_param
(
mode
).
execs
(
{{
1
,
1
,
1
,
32
},
{
3
,
5
,
8
,
32
},
{
1
,
1
,
1
,
32
},
{}});
// VEC_BCAST111C_VEC
checker
.
set_param
(
mode
).
execs
(
{{
1
,
2
,
2
,
12
},
{
1
,
1
,
1
,
12
},
{
1
,
2
,
2
,
12
},
{}});
checker
.
set_param
(
mode
).
execs
(
{{
2
,
5
,
3
,
28
},
{
1
,
1
,
1
,
28
},
{
2
,
5
,
3
,
28
},
{}});
checker
.
set_param
(
mode
).
execs
(
{{
3
,
5
,
8
,
32
},
{
1
,
1
,
1
,
32
},
{
3
,
5
,
8
,
32
},
{}});
};
run_3d_contig
(
Mode
::
FUSE_MUL_ADD3
);
//! 3 dim incontig
auto
run_3d_incontig
=
[
&
](
Mode
mode
)
{
megdnn
::
TensorLayout
src0
({
1
,
1
,
1
,
12
},
dtype
::
Float32
());
megdnn
::
TensorLayout
src1
({
1
,
2
,
2
,
12
},
{
80
,
40
,
20
,
1
},
dtype
::
Float32
());
// BCAST111C_VEC_BCAST111C
checker
.
set_param
(
mode
).
execl
({
src0
,
src1
,
src0
,
{}});
// VEC_BCAST111C_VEC
checker
.
set_param
(
mode
).
execl
({
src1
,
src0
,
src1
,
{}});
};
run_3d_incontig
(
Mode
::
FUSE_MUL_ADD3
);
}
#if MEGDNN_WITH_BENCHMARK
namespace
{
void
run_elemwise_benchmark
(
...
...
@@ -354,6 +412,39 @@ void run_elemwise_benchmark(
}
}
// namespace
TEST_F
(
ARM_COMMON
,
BENCHMARK_NCHW_VS_NHWC
)
{
Benchmarker
<
Elemwise
>
benchmarker
(
handle
());
constexpr
size_t
RUN
=
50
;
benchmarker
.
set_times
(
RUN
).
set_display
(
false
);
auto
run
=
[
&
](
size_t
N
,
size_t
C
,
size_t
H
,
size_t
W
,
param
::
Elemwise
::
Mode
mode
,
const
char
*
mode_name
)
{
megdnn
::
param
::
Elemwise
param
;
param
.
mode
=
mode
;
benchmarker
.
set_param
(
param
);
megdnn
::
TensorShape
nhwc_src0
{
N
,
H
,
W
,
C
};
megdnn
::
TensorShape
nhwc_src1
{
1
,
1
,
1
,
C
};
megdnn
::
TensorShape
nchw_src0
{
N
,
C
,
H
,
W
};
megdnn
::
TensorShape
nchw_src1
{
1
,
C
,
1
,
1
};
float
computations
=
N
*
C
*
H
*
W
;
auto
nhwc_time
=
benchmarker
.
execs
({
nhwc_src1
,
nhwc_src0
,
{}})
/
RUN
;
auto
nchw_time
=
benchmarker
.
execs
({
nchw_src1
,
nchw_src0
,
{}})
/
RUN
;
auto
perf_nhwc
=
computations
/
nhwc_time
/
1e6
;
auto
perf_nchw
=
computations
/
nchw_time
/
1e6
;
printf
(
"Elemwise Mode : %s
\n
NHWC : %fms %fGflops
\n
NCHW : %fms "
"%fGflops
\n
"
,
mode_name
,
nhwc_time
,
perf_nhwc
,
nchw_time
,
perf_nchw
);
};
run
(
1
,
120
,
16
,
24
,
param
::
Elemwise
::
Mode
::
ADD
,
"ADD"
);
run
(
1
,
120
,
16
,
24
,
param
::
Elemwise
::
Mode
::
MUL
,
"MUL"
);
run
(
1
,
120
,
32
,
48
,
param
::
Elemwise
::
Mode
::
ADD
,
"ADD"
);
run
(
1
,
120
,
32
,
48
,
param
::
Elemwise
::
Mode
::
MUL
,
"MUL"
);
run
(
1
,
120
,
64
,
96
,
param
::
Elemwise
::
Mode
::
ADD
,
"ADD"
);
run
(
1
,
120
,
64
,
96
,
param
::
Elemwise
::
Mode
::
MUL
,
"MUL"
);
}
#define INT_RUN(shape, mode) \
run_elemwise_benchmark(shape, mode, #mode, dtype::Int8{}, handle()); \
run_elemwise_benchmark(shape, mode, #mode, dtype::Int16{}, handle()); \
...
...
dnn/test/arm_common/type_cvt.cpp
浏览文件 @
5885b137
...
...
@@ -88,6 +88,26 @@ TEST_F(ARM_COMMON, TYPE_CVT) {
.
execs
({{
1
,
32
,
24
,
128
},
{
1
,
32
,
24
,
128
}});
}
TEST_F
(
ARM_COMMON
,
TYPE_CVT_16_F32
)
{
Checker
<
TypeCvt
>
checker
(
handle
());
UniformIntRNG
rng
{
INT16_MIN
>>
1
,
INT16_MAX
>>
1
};
for
(
size_t
size
:
{
3
,
7
,
15
,
33
,
10000
})
{
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_dtype
(
0
,
dtype
::
Int16
()).
execs
({{
size
},
{
size
}});
checker
.
set_dtype
(
0
,
dtype
::
Uint16
()).
execs
({{
size
},
{
size
}});
}
TensorLayout
src_int16
{
{
1
,
96
,
64
,
120
},
{
128
*
64
*
96
,
128
*
64
,
128
,
1
},
dtype
::
Int16
()};
TensorLayout
dst_int16
{{
1
,
96
,
64
,
120
},
dtype
::
Float32
()};
checker
.
execl
({
src_int16
,
dst_int16
});
TensorLayout
src_uint16
{
{
1
,
96
,
64
,
120
},
{
128
*
64
*
96
,
128
*
64
,
128
,
1
},
dtype
::
Uint16
()};
TensorLayout
dst_uint16
{{
1
,
96
,
64
,
120
},
dtype
::
Float32
()};
checker
.
execl
({
src_uint16
,
dst_uint16
});
}
#if MEGDNN_WITH_BENCHMARK
TEST_F
(
ARM_COMMON
,
BENCHMARK_TYPE_CVT
)
{
auto
run
=
[
&
](
const
TensorShapeArray
&
shapes
)
{
...
...
dnn/test/common/checker.cpp
浏览文件 @
5885b137
...
...
@@ -158,8 +158,9 @@ void copy_tensors(
//! In order to avoid an unnecessary increase in binary size, we just
//! use QuantizedS16 dtype in winograd_filter_preprocess now.
cb
(
::
megdnn
::
dtype
::
QuantizedS16
)
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE
(
cb
)
cb
(
::
megdnn
::
dtype
::
Uint16
)
#undef cb
default
:
megdnn_trap
();
default
:
megdnn_trap
();
}
}
...
...
dnn/test/common/rng.cpp
浏览文件 @
5885b137
...
...
@@ -202,6 +202,9 @@ void IIDRNG::gen(const TensorND& tensor) {
memset
(
tensor
.
raw_ptr
,
0
,
tensor
.
layout
.
access_bytes
());
return
;
}
if
(
tensor
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Uint16
)
{
return
;
}
megdnn_assert
(
0
,
"IIDRNG does not know how to generate value for DType %s"
,
tensor
.
layout
.
dtype
.
name
());
...
...
dnn/test/cuda/type_cvt.cpp
浏览文件 @
5885b137
...
...
@@ -25,6 +25,11 @@ TEST_F(CUDA, TYPE_CVT) {
TensorLayout
src
({
10
,
10
},
sdtype
),
dst
({
10
,
10
},
ddtype
);
Checker
<
TypeCvt
>
checker
(
handle_cuda
());
checker
.
set_rng
(
0
,
&
init
).
exec
(
TensorLayoutArray
{
src
,
dst
});
TensorLayout
non_contig_src
(
{
1
,
96
,
64
,
120
},
{
96
*
64
*
128
,
64
*
128
,
128
,
1
},
sdtype
);
TensorLayout
non_contig_dst
({
1
,
96
,
64
,
120
},
ddtype
);
checker
.
exec
(
TensorLayoutArray
{
non_contig_src
,
non_contig_dst
});
}
}
...
...
dnn/test/x86/type_cvt.cpp
浏览文件 @
5885b137
...
...
@@ -37,8 +37,22 @@ TEST_F(X86, TYPE_CVT) {
for
(
auto
ddtype
:
dtypes
)
{
checker
.
set_dtype
(
0
,
sdtype
).
set_dtype
(
1
,
ddtype
).
execs
(
{{
size
},
{
size
}});
TensorLayout
non_contig_src
(
{
1
,
10
,
10
,
12
},
{
10
*
10
*
18
,
10
*
18
,
18
,
1
},
sdtype
);
TensorLayout
non_contig_dst
({
1
,
10
,
10
,
12
},
ddtype
);
checker
.
exec
(
TensorLayoutArray
{
non_contig_src
,
non_contig_dst
});
}
}
for
(
size_t
size
:
{
1
,
7
,
15
,
33
})
{
checker
.
set_dtype
(
0
,
dtype
::
Uint16
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
execs
({{
size
},
{
size
}});
}
TensorLayout
non_contig_src
(
{
1
,
10
,
10
,
12
},
{
10
*
10
*
18
,
10
*
18
,
18
,
1
},
dtype
::
Uint16
());
TensorLayout
non_contig_dst
({
1
,
10
,
10
,
12
},
dtype
::
Float32
());
checker
.
exec
(
TensorLayoutArray
{
non_contig_src
,
non_contig_dst
});
}
TEST_F
(
X86
,
TYPE_CVT_NO_CONTIGUOUS
)
{
...
...
src/opr/impl/basic_arith.cpp
浏览文件 @
5885b137
...
...
@@ -772,8 +772,10 @@ void TypeCvt::perform(
}
void
TypeCvt
::
add_input_layout_constraint
()
{
//! Because the implementation of typecvt on arm/x86/cuda/opencl support
//! non-contiguous memory. So we change constraint of typecvt to monotone
for
(
auto
i
:
input
())
{
i
->
add_layout_constraint_
contiguous
();
i
->
add_layout_constraint_
monotone
();
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录