Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3344b580
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3344b580
编写于
8月 25, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): add elemwise for nchw88+fp16
GitOrigin-RevId: 63587975f8746bd8cf2443e81d433bfc07122b38
上级
682c74df
变更
10
展开全部
隐藏空白更改
内联
并排
Showing
10 changed file
with
595 addition
and
162 deletion
+595
-162
dnn/src/arm_common/conv_bias/postprocess_helper.h
dnn/src/arm_common/conv_bias/postprocess_helper.h
+50
-42
dnn/src/arm_common/elemwise/binary/algo.cpp
dnn/src/arm_common/elemwise/binary/algo.cpp
+22
-20
dnn/src/arm_common/elemwise/binary/algo.h
dnn/src/arm_common/elemwise/binary/algo.h
+1
-1
dnn/src/arm_common/elemwise/opr_impl.cpp
dnn/src/arm_common/elemwise/opr_impl.cpp
+19
-16
dnn/src/arm_common/elemwise/opr_impl.h
dnn/src/arm_common/elemwise/opr_impl.h
+5
-3
dnn/src/arm_common/elemwise/ternary/algo.cpp
dnn/src/arm_common/elemwise/ternary/algo.cpp
+14
-10
dnn/src/arm_common/elemwise/ternary/algo.h
dnn/src/arm_common/elemwise/ternary/algo.h
+2
-2
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
+10
-9
dnn/src/arm_common/elemwise_op.h
dnn/src/arm_common/elemwise_op.h
+386
-59
dnn/test/arm_common/elemwise.cpp
dnn/test/arm_common/elemwise.cpp
+86
-0
未找到文件。
dnn/src/arm_common/conv_bias/postprocess_helper.h
浏览文件 @
3344b580
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
...
@@ -22,7 +23,6 @@ MIDOUT_DECL(arm_common_conv_bias_postprocess_helper)
...
@@ -22,7 +23,6 @@ MIDOUT_DECL(arm_common_conv_bias_postprocess_helper)
namespace
{
namespace
{
#define CONCAT_OP(_name) megdnn::arm_common::_name
#define CONCAT_OP(_name) megdnn::arm_common::_name
#define CONCAT_NL(_name) megdnn::NonlineMode::_name
#define CONCAT_NL(_name) megdnn::NonlineMode::_name
...
@@ -57,9 +57,9 @@ namespace {
...
@@ -57,9 +57,9 @@ namespace {
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW);
dst_type, N, OC, OH* OW);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW
44
(_op) \
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW
XX
(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, \
megdnn::arm_common::OpCallerBinary<_op<ctype>, \
megdnn::arm_common::VEC_BCAST101x
4
>:: \
megdnn::arm_common::VEC_BCAST101x
X
>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
...
@@ -86,9 +86,9 @@ namespace {
...
@@ -86,9 +86,9 @@ namespace {
if (pack_oc_size == 1) { \
if (pack_oc_size == 1) { \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
} else { \
} else { \
megdnn_assert(pack_oc_size == 4
,
\
megdnn_assert(pack_oc_size == 4
|| pack_oc_size == 8,
\
"Only support nchw44
in ARM");
\
"Only support nchw44
/nchw88 in ARM");
\
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW
44
); \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW
XX
); \
} \
} \
} \
} \
MIDOUT_END(); \
MIDOUT_END(); \
...
@@ -100,7 +100,7 @@ namespace {
...
@@ -100,7 +100,7 @@ namespace {
MIDOUT_END(); \
MIDOUT_END(); \
break; \
break; \
default: \
default: \
megdnn_throw("unknow biasmode"); \
megdnn_throw("unknow biasmode");
\
break; \
break; \
}
}
...
@@ -160,7 +160,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
...
@@ -160,7 +160,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
#undef FOR_NONLINEAR_UNARY
#undef FOR_NONLINEAR_UNARY
#undef FOR_NONLINEAR_BINARY_BROADCAST
#undef FOR_NONLINEAR_BINARY_BROADCAST
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW
44
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW
XX
#undef FOR_NONLINEAR_BINARY
#undef FOR_NONLINEAR_BINARY
#undef FOR_NONLINEAR_NOBIAS
#undef FOR_NONLINEAR_NOBIAS
#undef FOR_NONLINEAR
#undef FOR_NONLINEAR
...
@@ -183,16 +183,24 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
...
@@ -183,16 +183,24 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW);
dst_type, N, OC, OH* OW);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \
megdnn::arm_common::VEC_BCAST101xX>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \
megdnn::arm_common::VEC_BCAST101x
4
>:: \
megdnn::arm_common::VEC_BCAST101x
X
>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
dst_type, N, OC, OH* OW, pack_oc_size);
#define HANDLE_IDENTITY(_caller, _op)
\
#define HANDLE_IDENTITY(_caller, _op) \
case megdnn::NonlineMode::IDENTITY:
\
case megdnn::NonlineMode::IDENTITY: \
_caller(_op) break;
_caller(_op) break;
#define FOR_NONLINEAR(_caller) \
#define FOR_NONLINEAR(_caller) \
...
@@ -220,9 +228,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
...
@@ -220,9 +228,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
if (pack_oc_size == 1) { \
if (pack_oc_size == 1) { \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
} else { \
} else { \
megdnn_assert(pack_oc_size == 4
,
\
megdnn_assert(pack_oc_size == 4
|| pack_oc_size == 8,
\
"Only support nchw44
in ARM");
\
"Only support nchw44
/nchw88 in ARM");
\
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW
44
); \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW
XX
); \
} \
} \
break; \
break; \
default: \
default: \
...
@@ -230,9 +238,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
...
@@ -230,9 +238,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
if (pack_oc_size == 1) { \
if (pack_oc_size == 1) { \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
} else { \
} else { \
megdnn_assert(pack_oc_size == 4
,
\
megdnn_assert(pack_oc_size == 4
|| pack_oc_size == 8,
\
"Only support nchw44
in ARM");
\
"Only support nchw44
/nchw88 in ARM");
\
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW
44
); \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW
XX
); \
} \
} \
break; \
break; \
} \
} \
...
@@ -254,7 +262,7 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
...
@@ -254,7 +262,7 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
#undef FOR_NONLINEAR_UNARY
#undef FOR_NONLINEAR_UNARY
#undef FOR_NONLINEAR_BINARY_BROADCAST
#undef FOR_NONLINEAR_BINARY_BROADCAST
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW
44
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW
XX
#undef FOR_NONLINEAR_BINARY
#undef FOR_NONLINEAR_BINARY
#undef FOR_NONLINEAR_NOBIAS
#undef FOR_NONLINEAR_NOBIAS
#undef FOR_NONLINEAR
#undef FOR_NONLINEAR
...
@@ -268,9 +276,9 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
...
@@ -268,9 +276,9 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW);
dst_type, N, OC, OH* OW);
#define FOR_BINARY_BROADCAST_NCHW
44
(_op) \
#define FOR_BINARY_BROADCAST_NCHW
XX
(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, \
megdnn::arm_common::OpCallerBinary<_op<ctype>, \
megdnn::arm_common::VEC_BCAST101x
4
>:: \
megdnn::arm_common::VEC_BCAST101x
X
>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
...
@@ -284,25 +292,25 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
...
@@ -284,25 +292,25 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N* OC* OH* OW* pack_oc_size);
dst_type, N* OC* OH* OW* pack_oc_size);
#define FOR_BIAS(_bias_mode, OH, OW) \
#define FOR_BIAS(_bias_mode, OH, OW)
\
switch (_bias_mode) { \
switch (_bias_mode) {
\
case megdnn::BiasMode::NO_BIAS: \
case megdnn::BiasMode::NO_BIAS:
\
break; \
break;
\
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS:
\
if (pack_oc_size == 1) { \
if (pack_oc_size == 1) {
\
FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \
FOR_BINARY_BROADCAST(CONCAT_OP(AddOp));
\
} else { \
} else {
\
megdnn_assert(pack_oc_size == 4
,
\
megdnn_assert(pack_oc_size == 4
|| pack_oc_size == 8,
\
"Only support nchw44 in ARM"); \
"Only support nchw44
/nchw88
in ARM"); \
FOR_BINARY_BROADCAST_NCHW
44(CONCAT_OP(AddOp));
\
FOR_BINARY_BROADCAST_NCHW
XX(CONCAT_OP(AddOp));
\
} \
}
\
break; \
break;
\
case megdnn::BiasMode::BIAS: \
case megdnn::BiasMode::BIAS:
\
FOR_BINARY(CONCAT_OP(AddOp)); \
FOR_BINARY(CONCAT_OP(AddOp));
\
break; \
break;
\
default: \
default:
\
megdnn_throw("unknow biasmode"); \
megdnn_throw("unknow biasmode");
\
break; \
break;
\
}
}
template
<
typename
ctype
,
typename
dtype
>
template
<
typename
ctype
,
typename
dtype
>
...
@@ -318,7 +326,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> {
...
@@ -318,7 +326,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> {
};
};
#undef FOR_BINARY_BROADCAST
#undef FOR_BINARY_BROADCAST
#undef FOR_BINARY_BROADCAST_NCHW
44
#undef FOR_BINARY_BROADCAST_NCHW
XX
#undef FOR_BINARY
#undef FOR_BINARY
#undef FOR_BIAS
#undef FOR_BIAS
#undef CB
#undef CB
...
...
dnn/src/arm_common/elemwise/binary/algo.cpp
浏览文件 @
3344b580
...
@@ -105,25 +105,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available(
...
@@ -105,25 +105,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available(
return
false
;
return
false
;
}
}
bool
ElemwiseImpl
::
AlgoBinaryVecBcast101x
4
::
is_available
(
bool
ElemwiseImpl
::
AlgoBinaryVecBcast101x
X
::
is_available
(
const
KernParam
&
kern_param
)
const
{
const
KernParam
&
kern_param
)
const
{
if
(
!
is_available_common
(
kern_param
.
mode
)
||
if
(
!
is_available_common
(
kern_param
.
mode
)
||
((
BcastType
::
VEC_BCAST101x
4
!=
kern_param
.
broad_cast_type
)
&&
((
BcastType
::
VEC_BCAST101x
X
!=
kern_param
.
broad_cast_type
)
&&
(
BcastType
::
BCAST101x
4
_VEC
!=
kern_param
.
broad_cast_type
)))
(
BcastType
::
BCAST101x
X
_VEC
!=
kern_param
.
broad_cast_type
)))
return
false
;
return
false
;
auto
&
elparam
=
kern_param
.
binary_elparam
;
auto
&
elparam
=
kern_param
.
binary_elparam
;
auto
&
src0
=
elparam
[
0
];
auto
&
src0
=
elparam
[
0
];
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
if
(
DNN_FLOAT16_SELECT
(
src0
.
layout
.
dtype
==
dtype
::
Float16
{},
false
))
{
return
false
;
}
#endif
DISPATCH_TYPE
(
"AlgoBinaryVecBcast101x::is_available"
_hash
);
DISPATCH_TYPE
(
"AlgoBinaryVecBcast101x
X
::is_available"
_hash
);
return
false
;
return
false
;
}
}
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT
#undef DISPATCH_MODE_INT
...
@@ -334,16 +330,19 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(
...
@@ -334,16 +330,19 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(
return
;
return
;
}
}
void
ElemwiseImpl
::
AlgoBinaryVecBcast101x
4
::
exec
(
void
ElemwiseImpl
::
AlgoBinaryVecBcast101x
X
::
exec
(
const
KernParam
&
kern_param
)
const
{
const
KernParam
&
kern_param
)
const
{
auto
&
elparam
=
kern_param
.
binary_elparam
;
auto
&
elparam
=
kern_param
.
binary_elparam
;
auto
&
src0
=
elparam
[
0
],
&
src1
=
elparam
[
1
];
auto
&
src0
=
elparam
[
0
],
&
src1
=
elparam
[
1
];
auto
&&
dst
=
*
(
kern_param
.
m_dst
);
auto
&&
dst
=
*
(
kern_param
.
m_dst
);
BroadcastChannelInfo
binfo
;
BroadcastChannelInfo
binfo
;
// BcastType::VEC + BCAST_101x
// BcastType::VEC + BCAST_101X
if
(
BcastType
::
VEC_BCAST101x4
==
kern_param
.
broad_cast_type
&&
if
(
BcastType
::
VEC_BCAST101xX
==
kern_param
.
broad_cast_type
)
{
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
))
{
megdnn_assert
(
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
)
||
is_broadcastedx_channel_like
<
8
>
(
src1
.
layout
,
binfo
),
"only nchw44 and nchw88 supported"
);
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
case Mode::_mode: \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
...
@@ -351,7 +350,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
...
@@ -351,7 +350,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
thin_function<void(const _type*, const _type*, _type*, DType, \
thin_function<void(const _type*, const _type*, _type*, DType, \
DType, DType, size_t, size_t, size_t, size_t)> \
DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<_type, _type>, \
run = OpCallerBinary<_op<_type, _type>, \
BcastType::VEC_BCAST101x
4
>::run; \
BcastType::VEC_BCAST101x
X
>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
run(static_cast<const _type*>(src0.raw_ptr), \
...
@@ -362,17 +361,19 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
...
@@ -362,17 +361,19 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
} \
} \
MIDOUT_END(); \
MIDOUT_END(); \
return
return
size_t
batch_size
=
size_t
batch_size
=
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
DISPATCH_TYPE
(
"AlgoBinaryVecBcast101x::exec_vec_b"
_hash
);
DISPATCH_TYPE
(
"AlgoBinaryVecBcast101x
X
::exec_vec_b"
_hash
);
#undef DISPATCH_BINARY
#undef DISPATCH_BINARY
}
}
// BCAST_101x + BcastType::VEC
// BCAST_101x + BcastType::VEC
if
(
BcastType
::
BCAST101x4_VEC
==
kern_param
.
broad_cast_type
&&
if
(
BcastType
::
BCAST101xX_VEC
==
kern_param
.
broad_cast_type
)
{
is_broadcastedx_channel_like
<
4
>
(
src0
.
layout
,
binfo
))
{
megdnn_assert
(
is_broadcastedx_channel_like
<
4
>
(
src0
.
layout
,
binfo
)
||
is_broadcastedx_channel_like
<
8
>
(
src0
.
layout
,
binfo
),
"only nchw44 and nchw88 supported"
);
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
case Mode::_mode: \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
...
@@ -380,7 +381,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
...
@@ -380,7 +381,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
thin_function<void(const _type*, const _type*, _type*, DType, \
thin_function<void(const _type*, const _type*, _type*, DType, \
DType, DType, size_t, size_t, size_t, size_t)> \
DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<_type, _type>, \
run = OpCallerBinary<_op<_type, _type>, \
BcastType::BCAST101x
4
_VEC>::run; \
BcastType::BCAST101x
X
_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
run(static_cast<const _type*>(src0.raw_ptr), \
...
@@ -394,12 +395,13 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
...
@@ -394,12 +395,13 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
size_t
batch_size
=
size_t
batch_size
=
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
DISPATCH_TYPE
(
"AlgoBinaryVecBcast101x::exec_b_vec"
_hash
);
DISPATCH_TYPE
(
"AlgoBinaryVecBcast101x
X
::exec_b_vec"
_hash
);
#undef DISPATCH_BINARY
#undef DISPATCH_BINARY
}
}
return
;
return
;
}
}
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT
#undef DISPATCH_MODE_INT
...
...
dnn/src/arm_common/elemwise/binary/algo.h
浏览文件 @
3344b580
...
@@ -34,7 +34,7 @@ namespace arm_common {
...
@@ -34,7 +34,7 @@ namespace arm_common {
DECL_CB
(
VecVec
);
DECL_CB
(
VecVec
);
DECL_CB
(
VecScalar
);
DECL_CB
(
VecScalar
);
DECL_CB
(
VecBcast101
);
DECL_CB
(
VecBcast101
);
DECL_CB
(
VecBcast101x
4
);
DECL_CB
(
VecBcast101x
X
);
#undef DECL_CB
#undef DECL_CB
}
// namespace arm_common
}
// namespace arm_common
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/arm_common/elemwise/opr_impl.cpp
浏览文件 @
3344b580
...
@@ -27,14 +27,14 @@ class ElemwiseImpl::AlgoPack {
...
@@ -27,14 +27,14 @@ class ElemwiseImpl::AlgoPack {
AlgoBinaryVecVec
algo_binary_vec_vec
;
AlgoBinaryVecVec
algo_binary_vec_vec
;
AlgoBinaryVecScalar
algo_binary_vec_sca
;
AlgoBinaryVecScalar
algo_binary_vec_sca
;
AlgoBinaryVecBcast101
algo_binary_vec_bcast101
;
AlgoBinaryVecBcast101
algo_binary_vec_bcast101
;
AlgoBinaryVecBcast101x
4
algo_binary_VEC_BCAST101x4
;
AlgoBinaryVecBcast101x
X
algo_binary_VEC_BCAST101xX
;
AlgoTernaryFma3VecVecVec
algo_ternaryfma3_vec_vec_vec
;
AlgoTernaryFma3VecVecVec
algo_ternaryfma3_vec_vec_vec
;
AlgoTernaryFma3VecVecScalar
algo_ternaryfma3_vec_vecsca
;
AlgoTernaryFma3VecVecScalar
algo_ternaryfma3_vec_vecsca
;
AlgoTernaryFma3Bcast101VecBcast101
algo_ternaryfma3_bcast101_vec_bcast101
;
AlgoTernaryFma3Bcast101VecBcast101
algo_ternaryfma3_bcast101_vec_bcast101
;
AlgoTernaryFma3Bcast101x
4VecBcast101x4
AlgoTernaryFma3Bcast101x
XVecBcast101xX
algo_ternaryfma3_bcast101x
4_vec_bcast101x4
;
algo_ternaryfma3_bcast101x
X_vec_bcast101xX
;
AlgoTernaryFma3VecBcast101Vec
algo_ternaryfma3_vec_bcast101_vec
;
AlgoTernaryFma3VecBcast101Vec
algo_ternaryfma3_vec_bcast101_vec
;
AlgoTernaryFma3VecBcast101x
4Vec
algo_ternaryfma3_vec_bcast101x4
_vec
;
AlgoTernaryFma3VecBcast101x
XVec
algo_ternaryfma3_vec_bcast101xX
_vec
;
AlgoTernaryFma3VecScalarVec
algo_ternaryfma3_vec_sca_vec
;
AlgoTernaryFma3VecScalarVec
algo_ternaryfma3_vec_sca_vec
;
AlgoTernaryFma3VecScalarScalar
algo_ternaryfma3_vec_sca_sca
;
AlgoTernaryFma3VecScalarScalar
algo_ternaryfma3_vec_sca_sca
;
...
@@ -44,13 +44,13 @@ public:
...
@@ -44,13 +44,13 @@ public:
all_algos
.
emplace_back
(
&
algo_binary_vec_vec
);
all_algos
.
emplace_back
(
&
algo_binary_vec_vec
);
all_algos
.
emplace_back
(
&
algo_binary_vec_sca
);
all_algos
.
emplace_back
(
&
algo_binary_vec_sca
);
all_algos
.
emplace_back
(
&
algo_binary_vec_bcast101
);
all_algos
.
emplace_back
(
&
algo_binary_vec_bcast101
);
all_algos
.
emplace_back
(
&
algo_binary_VEC_BCAST101x
4
);
all_algos
.
emplace_back
(
&
algo_binary_VEC_BCAST101x
X
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_vec_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_vec_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_vecsca
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_vecsca
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_bcast101_vec_bcast101
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_bcast101_vec_bcast101
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_bcast101x
4_vec_bcast101x4
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_bcast101x
X_vec_bcast101xX
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_bcast101_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_bcast101_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_bcast101x
4
_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_bcast101x
X
_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_sca_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_sca_vec
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_sca_sca
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_sca_sca
);
}
}
...
@@ -118,9 +118,10 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
...
@@ -118,9 +118,10 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
}
}
if
(
is_vector
(
src1
.
layout
)
&&
if
(
is_vector
(
src1
.
layout
)
&&
is_broadcastedx_channel_like
<
4
>
(
src0
.
layout
,
binfo
)
&&
(
is_broadcastedx_channel_like
<
4
>
(
src0
.
layout
,
binfo
)
||
is_broadcastedx_channel_like
<
8
>
(
src0
.
layout
,
binfo
))
&&
src0
.
layout
.
eq_layout
(
src2
.
layout
))
{
src0
.
layout
.
eq_layout
(
src2
.
layout
))
{
kern_param
.
broad_cast_type
=
BcastType
::
BCAST101x
4_VEC_BCAST101x4
;
kern_param
.
broad_cast_type
=
BcastType
::
BCAST101x
X_VEC_BCAST101xX
;
return
kern_param
;
return
kern_param
;
}
}
...
@@ -131,8 +132,9 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
...
@@ -131,8 +132,9 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
}
}
if
(
is_vector
(
src0
.
layout
)
&&
src0
.
layout
.
eq_layout
(
src2
.
layout
)
&&
if
(
is_vector
(
src0
.
layout
)
&&
src0
.
layout
.
eq_layout
(
src2
.
layout
)
&&
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
))
{
(
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
)
||
kern_param
.
broad_cast_type
=
BcastType
::
VEC_BCAST101x4_VEC
;
is_broadcastedx_channel_like
<
8
>
(
src1
.
layout
,
binfo
)))
{
kern_param
.
broad_cast_type
=
BcastType
::
VEC_BCAST101xX_VEC
;
return
kern_param
;
return
kern_param
;
}
}
...
@@ -180,17 +182,18 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
...
@@ -180,17 +182,18 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
}
}
if
(
is_vector
(
src0
.
layout
)
&&
if
(
is_vector
(
src0
.
layout
)
&&
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
))
{
(
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
)
||
kern_param
.
broad_cast_type
=
BcastType
::
VEC_BCAST101x4
;
is_broadcastedx_channel_like
<
8
>
(
src1
.
layout
,
binfo
)))
{
kern_param
.
broad_cast_type
=
BcastType
::
VEC_BCAST101xX
;
return
kern_param
;
return
kern_param
;
}
}
if
(
is_vector
(
src1
.
layout
)
&&
if
(
is_vector
(
src1
.
layout
)
&&
is_broadcastedx_channel_like
<
4
>
(
src0
.
layout
,
binfo
))
{
(
is_broadcastedx_channel_like
<
4
>
(
src0
.
layout
,
binfo
)
||
kern_param
.
broad_cast_type
=
BcastType
::
BCAST101x4_VEC
;
is_broadcastedx_channel_like
<
8
>
(
src0
.
layout
,
binfo
)))
{
kern_param
.
broad_cast_type
=
BcastType
::
BCAST101xX_VEC
;
return
kern_param
;
return
kern_param
;
}
}
}
else
if
(
opr
->
m_src
->
size
()
==
1
)
{
}
else
if
(
opr
->
m_src
->
size
()
==
1
)
{
kern_param
.
broad_cast_type
=
BcastType
::
VEC
;
kern_param
.
broad_cast_type
=
BcastType
::
VEC
;
kern_param
.
unary_elparam
=
opr
->
make_elemwise_op_param
<
1
>
();
kern_param
.
unary_elparam
=
opr
->
make_elemwise_op_param
<
1
>
();
...
...
dnn/src/arm_common/elemwise/opr_impl.h
浏览文件 @
3344b580
...
@@ -10,7 +10,9 @@
...
@@ -10,7 +10,9 @@
* implied.
* implied.
*/
*/
#pragma once
#pragma once
#include "src/fallback/elemwise/opr_impl.h"
#include "src/fallback/elemwise/opr_impl.h"
#include "src/arm_common/elemwise_op.h"
#include "src/arm_common/elemwise_op.h"
namespace
megdnn
{
namespace
megdnn
{
...
@@ -37,13 +39,13 @@ private:
...
@@ -37,13 +39,13 @@ private:
class
AlgoBinaryVecVec
;
class
AlgoBinaryVecVec
;
class
AlgoBinaryVecScalar
;
class
AlgoBinaryVecScalar
;
class
AlgoBinaryVecBcast101
;
class
AlgoBinaryVecBcast101
;
class
AlgoBinaryVecBcast101x
4
;
class
AlgoBinaryVecBcast101x
X
;
class
AlgoTernaryFma3VecVecVec
;
class
AlgoTernaryFma3VecVecVec
;
class
AlgoTernaryFma3VecVecScalar
;
class
AlgoTernaryFma3VecVecScalar
;
class
AlgoTernaryFma3Bcast101VecBcast101
;
class
AlgoTernaryFma3Bcast101VecBcast101
;
class
AlgoTernaryFma3Bcast101x
4VecBcast101x4
;
class
AlgoTernaryFma3Bcast101x
XVecBcast101xX
;
class
AlgoTernaryFma3VecBcast101Vec
;
class
AlgoTernaryFma3VecBcast101Vec
;
class
AlgoTernaryFma3VecBcast101x
4
Vec
;
class
AlgoTernaryFma3VecBcast101x
X
Vec
;
class
AlgoTernaryFma3VecScalarVec
;
class
AlgoTernaryFma3VecScalarVec
;
class
AlgoTernaryFma3VecScalarScalar
;
class
AlgoTernaryFma3VecScalarScalar
;
class
AlgoPack
;
class
AlgoPack
;
...
...
dnn/src/arm_common/elemwise/ternary/algo.cpp
浏览文件 @
3344b580
...
@@ -42,9 +42,9 @@ using namespace arm_common;
...
@@ -42,9 +42,9 @@ using namespace arm_common;
DECL_AVAILABLE
(
VecVecVec
,
BcastType
::
VEC_VEC_VEC
);
DECL_AVAILABLE
(
VecVecVec
,
BcastType
::
VEC_VEC_VEC
);
DECL_AVAILABLE
(
VecVecScalar
,
BcastType
::
VEC_VEC_SCALAR
);
DECL_AVAILABLE
(
VecVecScalar
,
BcastType
::
VEC_VEC_SCALAR
);
DECL_AVAILABLE
(
Bcast101VecBcast101
,
BcastType
::
BCAST101_VEC_BCAST101
);
DECL_AVAILABLE
(
Bcast101VecBcast101
,
BcastType
::
BCAST101_VEC_BCAST101
);
DECL_AVAILABLE
(
Bcast101x
4VecBcast101x4
,
BcastType
::
BCAST101x4_VEC_BCAST101x4
);
DECL_AVAILABLE
(
Bcast101x
XVecBcast101xX
,
BcastType
::
BCAST101xX_VEC_BCAST101xX
);
DECL_AVAILABLE
(
VecBcast101Vec
,
BcastType
::
VEC_BCAST101_VEC
);
DECL_AVAILABLE
(
VecBcast101Vec
,
BcastType
::
VEC_BCAST101_VEC
);
DECL_AVAILABLE
(
VecBcast101x
4Vec
,
BcastType
::
VEC_BCAST101x4
_VEC
);
DECL_AVAILABLE
(
VecBcast101x
XVec
,
BcastType
::
VEC_BCAST101xX
_VEC
);
DECL_AVAILABLE
(
VecScalarVec
,
BcastType
::
VEC_SCALAR_VEC
);
DECL_AVAILABLE
(
VecScalarVec
,
BcastType
::
VEC_SCALAR_VEC
);
DECL_AVAILABLE
(
VecScalarScalar
,
BcastType
::
VEC_SCALAR_SCALAR
);
DECL_AVAILABLE
(
VecScalarScalar
,
BcastType
::
VEC_SCALAR_SCALAR
);
#undef DECL_CB
#undef DECL_CB
...
@@ -161,13 +161,15 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec(
...
@@ -161,13 +161,15 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec(
return
;
return
;
}
}
void
ElemwiseImpl
::
AlgoTernaryFma3Bcast101x
4VecBcast101x4
::
exec
(
void
ElemwiseImpl
::
AlgoTernaryFma3Bcast101x
XVecBcast101xX
::
exec
(
const
KernParam
&
kern_param
)
const
{
const
KernParam
&
kern_param
)
const
{
auto
&
elparam
=
kern_param
.
ternary_elparam
;
auto
&
elparam
=
kern_param
.
ternary_elparam
;
auto
&
src0
=
elparam
[
0
],
&
src1
=
elparam
[
1
],
&
src2
=
elparam
[
2
];
auto
&
src0
=
elparam
[
0
],
&
src1
=
elparam
[
1
],
&
src2
=
elparam
[
2
];
BroadcastChannelInfo
binfo
;
BroadcastChannelInfo
binfo
;
is_broadcastedx_channel_like
<
4
>
(
src0
.
layout
,
binfo
);
megdnn_assert
(
is_broadcastedx_channel_like
<
4
>
(
src0
.
layout
,
binfo
)
||
is_broadcastedx_channel_like
<
8
>
(
src0
.
layout
,
binfo
),
"only nchw44 and nchw88 supported"
);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
case Mode::_mode: \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
...
@@ -177,7 +179,7 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec(
...
@@ -177,7 +179,7 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec(
size_t, size_t, size_t)> \
size_t, size_t, size_t)> \
run = OpCallerTernary< \
run = OpCallerTernary< \
_op<_type, _type>, \
_op<_type, _type>, \
BcastType::BCAST101x
4_VEC_BCAST101x4
>::run; \
BcastType::BCAST101x
X_VEC_BCAST101xX
>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
run(static_cast<const _type*>(src0.raw_ptr), \
...
@@ -193,19 +195,21 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec(
...
@@ -193,19 +195,21 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec(
size_t
batch_size
=
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
size_t
batch_size
=
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
auto
&&
dst
=
*
(
kern_param
.
m_dst
);
auto
&&
dst
=
*
(
kern_param
.
m_dst
);
DISPATCH_TYPE
(
"AlgoTernaryFma3Bcast101x
4VecBcast101x4
::exec"
_hash
);
DISPATCH_TYPE
(
"AlgoTernaryFma3Bcast101x
XVecBcast101xX
::exec"
_hash
);
#undef DISPATCH_TERNARY
#undef DISPATCH_TERNARY
return
;
return
;
}
}
void
ElemwiseImpl
::
AlgoTernaryFma3VecBcast101x
4
Vec
::
exec
(
void
ElemwiseImpl
::
AlgoTernaryFma3VecBcast101x
X
Vec
::
exec
(
const
KernParam
&
kern_param
)
const
{
const
KernParam
&
kern_param
)
const
{
auto
&
elparam
=
kern_param
.
ternary_elparam
;
auto
&
elparam
=
kern_param
.
ternary_elparam
;
auto
&
src0
=
elparam
[
0
],
&
src1
=
elparam
[
1
],
&
src2
=
elparam
[
2
];
auto
&
src0
=
elparam
[
0
],
&
src1
=
elparam
[
1
],
&
src2
=
elparam
[
2
];
BroadcastChannelInfo
binfo
;
BroadcastChannelInfo
binfo
;
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
);
megdnn_assert
(
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
)
||
is_broadcastedx_channel_like
<
8
>
(
src1
.
layout
,
binfo
),
"only nchw44 and nchw88 supported"
);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
case Mode::_mode: \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
...
@@ -214,7 +218,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec(
...
@@ -214,7 +218,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec(
_type*, DType, DType, DType, DType, size_t, \
_type*, DType, DType, DType, DType, size_t, \
size_t, size_t, size_t)> \
size_t, size_t, size_t)> \
run = OpCallerTernary<_op<_type, _type>, \
run = OpCallerTernary<_op<_type, _type>, \
BcastType::VEC_BCAST101x
4
_VEC>::run; \
BcastType::VEC_BCAST101x
X
_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
run(static_cast<const _type*>(src0.raw_ptr), \
...
@@ -230,7 +234,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec(
...
@@ -230,7 +234,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec(
size_t
batch_size
=
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
size_t
batch_size
=
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
auto
&&
dst
=
*
(
kern_param
.
m_dst
);
auto
&&
dst
=
*
(
kern_param
.
m_dst
);
DISPATCH_TYPE
(
"AlgoTernaryFma3VecBcast101x
4
Vec::exec"
_hash
);
DISPATCH_TYPE
(
"AlgoTernaryFma3VecBcast101x
X
Vec::exec"
_hash
);
#undef DISPATCH_TERNARY
#undef DISPATCH_TERNARY
return
;
return
;
...
...
dnn/src/arm_common/elemwise/ternary/algo.h
浏览文件 @
3344b580
...
@@ -34,9 +34,9 @@ namespace arm_common {
...
@@ -34,9 +34,9 @@ namespace arm_common {
DECL_CB
(
VecVecVec
);
DECL_CB
(
VecVecVec
);
DECL_CB
(
VecVecScalar
);
DECL_CB
(
VecVecScalar
);
DECL_CB
(
Bcast101VecBcast101
);
DECL_CB
(
Bcast101VecBcast101
);
DECL_CB
(
Bcast101x
4VecBcast101x4
);
DECL_CB
(
Bcast101x
XVecBcast101xX
);
DECL_CB
(
VecBcast101Vec
);
DECL_CB
(
VecBcast101Vec
);
DECL_CB
(
VecBcast101x
4
Vec
);
DECL_CB
(
VecBcast101x
X
Vec
);
DECL_CB
(
VecScalarVec
);
DECL_CB
(
VecScalarVec
);
DECL_CB
(
VecScalarScalar
);
DECL_CB
(
VecScalarScalar
);
#undef DECL_CB
#undef DECL_CB
...
...
dnn/src/arm_common/elemwise_multi_type/opr_impl.cpp
浏览文件 @
3344b580
...
@@ -644,7 +644,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param,
...
@@ -644,7 +644,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param,
{
{
BroadcastChannelInfo
binfo
;
BroadcastChannelInfo
binfo
;
if
(
is_vector
(
src0
.
layout
)
&&
if
(
is_vector
(
src0
.
layout
)
&&
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
))
{
(
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
)
||
is_broadcastedx_channel_like
<
8
>
(
src1
.
layout
,
binfo
)))
{
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
case _mode: { \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
...
@@ -653,14 +654,13 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param,
...
@@ -653,14 +654,13 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param,
DType, DType, DType, size_t, size_t, size_t, \
DType, DType, DType, size_t, size_t, size_t, \
size_t)> \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, \
VEC_BCAST101x
4
>::run; \
VEC_BCAST101x
X
>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
return; \
}
}
size_t
batch_size
=
size_t
batch_size
=
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
src0
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
DISPATCH
()
DISPATCH
()
...
@@ -679,14 +679,13 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param,
...
@@ -679,14 +679,13 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param,
DType, DType, DType, size_t, size_t, size_t, \
DType, DType, DType, size_t, size_t, size_t, \
size_t)> \
size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, \
BCAST101x
4
_VEC>::run; \
BCAST101x
X
_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \
return; \
}
}
size_t
batch_size
=
size_t
batch_size
=
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
src1
.
layout
.
shape
[
0
]
/
(
binfo
.
x
*
binfo
.
y
*
binfo
.
z
);
DISPATCH
()
DISPATCH
()
...
@@ -818,7 +817,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
...
@@ -818,7 +817,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
{
{
BroadcastChannelInfo
binfo
;
BroadcastChannelInfo
binfo
;
if
(
is_vector
(
src0
.
layout
)
&&
if
(
is_vector
(
src0
.
layout
)
&&
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
)
&&
(
is_broadcastedx_channel_like
<
4
>
(
src1
.
layout
,
binfo
)
||
is_broadcastedx_channel_like
<
8
>
(
src1
.
layout
,
binfo
))
&&
src0
.
layout
.
eq_shape
(
src2
.
layout
))
{
src0
.
layout
.
eq_shape
(
src2
.
layout
))
{
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
case _mode: { \
...
@@ -828,7 +828,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
...
@@ -828,7 +828,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype*, dst_ctype*, DType, DType, DType, \
DType, size_t, size_t, size_t, size_t)> \
DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, \
VEC_BCAST101x
4
_VEC>::run; \
VEC_BCAST101x
X
_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
...
@@ -846,7 +846,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
...
@@ -846,7 +846,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
//! BCAST101x + VEC +BCAST101x
//! BCAST101x + VEC +BCAST101x
if
(
is_vector
(
src1
.
layout
)
&&
if
(
is_vector
(
src1
.
layout
)
&&
is_broadcastedx_channel_like
<
4
>
(
src0
.
layout
,
binfo
)
&&
(
is_broadcastedx_channel_like
<
4
>
(
src0
.
layout
,
binfo
)
||
is_broadcastedx_channel_like
<
8
>
(
src0
.
layout
,
binfo
))
&&
src0
.
layout
.
eq_shape
(
src2
.
layout
))
{
src0
.
layout
.
eq_shape
(
src2
.
layout
))
{
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \
case _mode: { \
...
@@ -856,7 +857,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
...
@@ -856,7 +857,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
const src_ctype*, dst_ctype*, DType, DType, DType, \
const src_ctype*, dst_ctype*, DType, DType, DType, \
DType, size_t, size_t, size_t, size_t)> \
DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, \
BCAST101x
4_VEC_BCAST101x4
>::run; \
BCAST101x
X_VEC_BCAST101xX
>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
...
...
dnn/src/arm_common/elemwise_op.h
浏览文件 @
3344b580
此差异已折叠。
点击以展开。
dnn/test/arm_common/elemwise.cpp
浏览文件 @
3344b580
...
@@ -53,6 +53,20 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) {
...
@@ -53,6 +53,20 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) {
checker
.
execs
({{
3
,
4
,
5
,
7
,
4
},
{
3
,
4
,
5
,
7
,
4
},
{
3
,
4
,
5
,
7
,
4
},
{}});
checker
.
execs
({{
3
,
4
,
5
,
7
,
4
},
{
3
,
4
,
5
,
7
,
4
},
{
3
,
4
,
5
,
7
,
4
},
{}});
checker
.
execs
({{
1
,
2
,
5
,
7
,
4
},
{
1
,
2
,
1
,
1
,
4
},
{
1
,
2
,
5
,
7
,
4
},
{}});
checker
.
execs
({{
1
,
2
,
5
,
7
,
4
},
{
1
,
2
,
1
,
1
,
4
},
{
1
,
2
,
5
,
7
,
4
},
{}});
//! nchw88
checker
.
execs
({{
1
,
3
,
1
,
1
,
8
},
{
1
,
3
,
2
,
2
,
8
},
{
1
,
3
,
1
,
1
,
8
},
{}});
checker
.
execs
({{
1
,
3
,
1
,
1
,
8
},
{
2
,
3
,
2
,
2
,
8
},
{
1
,
3
,
1
,
1
,
8
},
{}});
checker
.
execs
({{
1
,
8
,
1
,
1
,
8
},
{
3
,
8
,
5
,
3
,
8
},
{
1
,
8
,
1
,
1
,
8
},
{}});
checker
.
execs
({{
3
,
4
,
5
,
7
,
8
},
{
3
,
4
,
5
,
7
,
8
},
{
3
,
4
,
5
,
7
,
8
},
{}});
checker
.
execs
({{
1
,
2
,
1
,
1
,
8
},
{
1
,
2
,
5
,
7
,
8
},
{
1
,
2
,
1
,
1
,
8
},
{}});
//! nchw88
checker
.
execs
({{
1
,
3
,
2
,
2
,
8
},
{
1
,
3
,
1
,
1
,
8
},
{
1
,
3
,
2
,
2
,
8
},
{}});
checker
.
execs
({{
2
,
3
,
2
,
2
,
8
},
{
1
,
3
,
1
,
1
,
8
},
{
2
,
3
,
2
,
2
,
8
},
{}});
checker
.
execs
({{
3
,
8
,
5
,
3
,
8
},
{
1
,
8
,
1
,
1
,
8
},
{
3
,
8
,
5
,
3
,
8
},
{}});
checker
.
execs
({{
3
,
4
,
5
,
7
,
8
},
{
3
,
4
,
5
,
7
,
8
},
{
3
,
4
,
5
,
7
,
8
},
{}});
checker
.
execs
({{
1
,
2
,
5
,
7
,
8
},
{
1
,
2
,
1
,
1
,
8
},
{
1
,
2
,
5
,
7
,
8
},
{}});
checker
.
execs
({{
3
,
4
,
7
},
{
3
,
4
,
7
},
{
3
,
4
,
7
},
{}});
checker
.
execs
({{
3
,
4
,
7
},
{
3
,
4
,
7
},
{
3
,
4
,
7
},
{}});
checker
.
execs
({{
1
,
4
,
1
,
1
},
{
3
,
4
,
5
,
7
},
{
1
,
4
,
1
,
1
},
{}});
checker
.
execs
({{
1
,
4
,
1
,
1
},
{
3
,
4
,
5
,
7
},
{
1
,
4
,
1
,
1
},
{}});
checker
.
execs
({{
1
,
4
,
1
},
{
3
,
4
,
7
},
{
1
,
4
,
1
},
{}});
checker
.
execs
({{
1
,
4
,
1
},
{
3
,
4
,
7
},
{
1
,
4
,
1
},
{}});
...
@@ -227,6 +241,78 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) {
...
@@ -227,6 +241,78 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) {
run
(
Mode
::
POW
);
run
(
Mode
::
POW
);
}
}
TEST_F
(
ARM_COMMON
,
ELEMWISE_FORWARD_NCHW88_FP
)
{
using
Mode
=
ElemwiseForward
::
Param
::
Mode
;
Checker
<
ElemwiseForward
>
checker
(
handle
());
checker
.
set_param
(
Mode
::
FUSE_ADD_RELU
)
.
execs
({{
1
,
3
,
1
,
1
,
8
},
{
1
,
3
,
2
,
2
,
8
},
{}});
checker
.
set_param
(
Mode
::
FUSE_ADD_RELU
)
.
execs
({{
1
,
3
,
1
,
1
,
8
},
{
2
,
3
,
2
,
2
,
8
},
{}});
checker
.
set_param
(
Mode
::
FUSE_ADD_RELU
)
.
execs
({{
1
,
8
,
1
,
1
,
8
},
{
3
,
8
,
5
,
3
,
8
},
{}});
checker
.
set_param
(
Mode
::
FUSE_ADD_RELU
)
.
execs
({{
3
,
4
,
5
,
7
,
8
},
{
3
,
4
,
5
,
7
,
8
},
{}});
checker
.
set_param
(
Mode
::
FUSE_ADD_RELU
)
.
execs
({{
1
,
2
,
1
,
1
,
8
},
{
1
,
2
,
5
,
7
,
8
},
{}});
checker
.
set_param
(
Mode
::
FUSE_ADD_RELU
)
.
execs
({{
1
,
3
,
2
,
2
,
8
},
{
1
,
3
,
1
,
1
,
8
},
{}});
checker
.
set_param
(
Mode
::
FUSE_ADD_RELU
)
.
execs
({{
2
,
3
,
2
,
2
,
8
},
{
1
,
3
,
1
,
1
,
8
},
{}});
checker
.
set_param
(
Mode
::
FUSE_ADD_RELU
)
.
execs
({{
3
,
8
,
5
,
3
,
8
},
{
1
,
8
,
1
,
1
,
8
},
{}});
checker
.
set_param
(
Mode
::
FUSE_ADD_RELU
)
.
execs
({{
3
,
4
,
5
,
7
,
8
},
{
3
,
4
,
5
,
7
,
8
},
{}});
checker
.
set_param
(
Mode
::
FUSE_ADD_RELU
)
.
execs
({{
1
,
2
,
5
,
7
,
8
},
{
1
,
2
,
1
,
1
,
8
},
{}});
auto
run
=
[
&
](
Mode
mode
)
{
// VEC_BCAST101x
checker
.
set_param
(
mode
).
execs
({{
1
,
3
,
2
,
2
,
8
},
{
1
,
3
,
1
,
1
,
8
},
{}});
checker
.
set_param
(
mode
).
execs
({{
2
,
3
,
2
,
2
,
8
},
{
1
,
3
,
1
,
1
,
8
},
{}});
checker
.
set_param
(
mode
).
execs
({{
3
,
8
,
5
,
3
,
8
},
{
1
,
8
,
1
,
1
,
8
},
{}});
checker
.
set_param
(
mode
).
execs
({{
3
,
4
,
5
,
7
,
8
},
{
3
,
4
,
5
,
7
,
8
},
{}});
checker
.
set_param
(
mode
).
execs
({{
1
,
2
,
5
,
7
,
8
},
{
1
,
2
,
1
,
1
,
8
},
{}});
// BCAST101x_VEC not powOp
checker
.
set_param
(
mode
).
execs
({{
1
,
3
,
1
,
1
,
8
},
{
1
,
3
,
2
,
2
,
8
},
{}});
checker
.
set_param
(
mode
).
execs
({{
1
,
3
,
1
,
1
,
8
},
{
2
,
3
,
2
,
2
,
8
},
{}});
checker
.
set_param
(
mode
).
execs
({{
1
,
8
,
1
,
1
,
8
},
{
3
,
8
,
5
,
3
,
8
},
{}});
checker
.
set_param
(
mode
).
execs
({{
3
,
4
,
5
,
7
,
8
},
{
3
,
4
,
5
,
7
,
8
},
{}});
checker
.
set_param
(
mode
).
execs
({{
1
,
2
,
1
,
1
,
8
},
{
1
,
2
,
5
,
7
,
8
},
{}});
};
auto
run_all
=
[
&
]()
{
run
(
Mode
::
ADD
);
run
(
Mode
::
FUSE_ADD_H_SWISH
);
run
(
Mode
::
FUSE_ADD_RELU
);
run
(
Mode
::
MAX
);
run
(
Mode
::
MIN
);
run
(
Mode
::
MUL
);
run
(
Mode
::
SUB
);
run
(
Mode
::
TRUE_DIV
);
run
(
Mode
::
POW
);
};
{
UniformFloatRNG
rng
(
1e-5
,
7e1
);
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_epsilon
(
1e-5
);
checker
.
set_dtype
(
0
,
dtype
::
Float32
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
run_all
();
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
{
UniformFloatRNG
rng
(
1
,
2
);
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_epsilon
(
3e-3
);
checker
.
set_dtype
(
0
,
dtype
::
Float16
());
checker
.
set_dtype
(
1
,
dtype
::
Float16
());
run_all
();
}
#endif
}
#if MEGDNN_WITH_BENCHMARK
#if MEGDNN_WITH_BENCHMARK
namespace
{
namespace
{
void
run_elemwise_benchmark
(
const
TensorShapeArray
&
shapes
,
void
run_elemwise_benchmark
(
const
TensorShapeArray
&
shapes
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录