Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c48d58da
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c48d58da
编写于
10月 28, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/arm_common): add N1HW like elemwise broadcast mode
GitOrigin-RevId: 28951358012c2d085f68260fd723797f943138ca
上级
669c3cda
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
277 addition
and
0 deletion
+277
-0
dnn/src/arm_common/elemwise/binary/algo.cpp
dnn/src/arm_common/elemwise/binary/algo.cpp
+81
-0
dnn/src/arm_common/elemwise/binary/algo.h
dnn/src/arm_common/elemwise/binary/algo.h
+1
-0
dnn/src/arm_common/elemwise/opr_impl.cpp
dnn/src/arm_common/elemwise/opr_impl.cpp
+12
-0
dnn/src/arm_common/elemwise/opr_impl.h
dnn/src/arm_common/elemwise/opr_impl.h
+1
-0
dnn/src/arm_common/elemwise_op.h
dnn/src/arm_common/elemwise_op.h
+136
-0
dnn/src/common/elemwise/opr_impl_helper.cpp
dnn/src/common/elemwise/opr_impl_helper.cpp
+14
-0
dnn/src/common/elemwise/opr_impl_helper.h
dnn/src/common/elemwise/opr_impl_helper.h
+8
-0
dnn/test/arm_common/elemwise.cpp
dnn/test/arm_common/elemwise.cpp
+24
-0
未找到文件。
dnn/src/arm_common/elemwise/binary/algo.cpp
浏览文件 @
c48d58da
...
...
@@ -104,6 +104,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available(
return
false
;
}
bool
ElemwiseImpl
::
AlgoBinaryVecBcastX0X
::
is_available
(
const
KernParam
&
kern_param
)
const
{
if
(
!
is_available_common
(
kern_param
.
mode
)
||
((
BcastType
::
VEC_BCASTX0X
!=
kern_param
.
broad_cast_type
)
&&
(
BcastType
::
BCASTX0X_VEC
!=
kern_param
.
broad_cast_type
)))
return
false
;
auto
&
elparam
=
kern_param
.
binary_elparam
;
auto
&
src0
=
elparam
[
0
];
DISPATCH_TYPE
(
"AlgoBinaryVecBcastX0X::is_available"
_hash
);
return
false
;
}
bool
ElemwiseImpl
::
AlgoBinaryVecBcast111C
::
is_available
(
const
KernParam
&
kern_param
)
const
{
if
(
!
is_available_common
(
kern_param
.
mode
)
||
...
...
@@ -348,6 +363,72 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons
return
;
}
void
ElemwiseImpl
::
AlgoBinaryVecBcastX0X
::
exec
(
const
KernParam
&
kern_param
)
const
{
auto
&
elparam
=
kern_param
.
binary_elparam
;
auto
&
src0
=
elparam
[
0
],
&
src1
=
elparam
[
1
];
auto
&&
dst
=
*
(
kern_param
.
m_dst
);
BroadcastChannelInfo
binfo
;
// Case: BcastType::VEC + BCAST_X0X
if
(
BcastType
::
VEC_BCASTX0X
==
kern_param
.
broad_cast_type
&&
is_broadcasted_3dim_like
(
src1
.
layout
,
binfo
))
{
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return
DISPATCH_TYPE
(
"AlgoBinaryVecBcastX0X::exec_vec_b"
_hash
);
#undef DISPATCH_BINARY
}
// BCAST_X0X + BcastType::VEC
if
(
BcastType
::
BCASTX0X_VEC
==
kern_param
.
broad_cast_type
&&
is_broadcasted_3dim_like
(
src0
.
layout
,
binfo
))
{
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \
MIDOUT_BEGIN( \
megdnn_arm_common_elemwise_binary, midout_iv(_case), \
midout_iv(Mode::_mode), _type_midout_id) { \
thin_function<void( \
const _type*, const _type*, _type*, DType, DType, DType, size_t, \
size_t, size_t)> \
run = OpCallerBinary< \
_op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \
static_cast<const _type*>(src1.raw_ptr), \
static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \
src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \
binfo.z)); \
} \
MIDOUT_END(); \
return
DISPATCH_TYPE
(
"AlgoBinaryVecBcastX0X::exec_b_vec"
_hash
);
#undef DISPATCH_BINARY
}
return
;
}
void
ElemwiseImpl
::
AlgoBinaryVecBcast111C
::
exec
(
const
KernParam
&
kern_param
)
const
{
auto
&
elparam
=
kern_param
.
binary_elparam
;
auto
&
src0
=
elparam
[
0
],
&
src1
=
elparam
[
1
];
...
...
dnn/src/arm_common/elemwise/binary/algo.h
浏览文件 @
c48d58da
...
...
@@ -33,6 +33,7 @@ namespace arm_common {
DECL_CB
(
VecVec
);
DECL_CB
(
VecScalar
);
DECL_CB
(
VecBcast101
);
DECL_CB
(
VecBcastX0X
);
DECL_CB
(
VecBcast111C
);
DECL_CB
(
VecBcast101xX
);
#undef DECL_CB
...
...
dnn/src/arm_common/elemwise/opr_impl.cpp
浏览文件 @
c48d58da
...
...
@@ -27,6 +27,7 @@ class ElemwiseImpl::AlgoPack {
AlgoBinaryVecVec
algo_binary_vec_vec
;
AlgoBinaryVecScalar
algo_binary_vec_sca
;
AlgoBinaryVecBcast101
algo_binary_vec_bcast101
;
AlgoBinaryVecBcastX0X
algo_binary_vec_bcastX0X
;
AlgoBinaryVecBcast111C
algo_binary_vec_bcast110
;
AlgoBinaryVecBcast101xX
algo_binary_VEC_BCAST101xX
;
AlgoTernaryFma3VecVecVec
algo_ternaryfma3_vec_vec_vec
;
...
...
@@ -46,6 +47,7 @@ public:
all_algos
.
emplace_back
(
&
algo_binary_vec_vec
);
all_algos
.
emplace_back
(
&
algo_binary_vec_sca
);
all_algos
.
emplace_back
(
&
algo_binary_vec_bcast101
);
all_algos
.
emplace_back
(
&
algo_binary_vec_bcastX0X
);
all_algos
.
emplace_back
(
&
algo_binary_vec_bcast110
);
all_algos
.
emplace_back
(
&
algo_binary_VEC_BCAST101xX
);
all_algos
.
emplace_back
(
&
algo_ternaryfma3_vec_vec_vec
);
...
...
@@ -202,6 +204,16 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
return
kern_param
;
}
if
(
is_vector
(
src0
.
layout
)
&&
is_broadcasted_3dim_like
(
src1
.
layout
,
binfo
))
{
kern_param
.
broad_cast_type
=
BcastType
::
VEC_BCASTX0X
;
return
kern_param
;
}
if
(
is_vector
(
src1
.
layout
)
&&
is_broadcasted_3dim_like
(
src0
.
layout
,
binfo
))
{
kern_param
.
broad_cast_type
=
BcastType
::
BCASTX0X_VEC
;
return
kern_param
;
}
if
(
is_legal_layout_for_nhwc
(
src1
.
layout
)
&&
is_NHWC_broadcasted_channel_like
(
src0
.
layout
,
binfo
))
{
kern_param
.
broad_cast_type
=
BcastType
::
BCAST111C_VEC
;
...
...
dnn/src/arm_common/elemwise/opr_impl.h
浏览文件 @
c48d58da
...
...
@@ -38,6 +38,7 @@ private:
class
AlgoBinaryVecVec
;
class
AlgoBinaryVecScalar
;
class
AlgoBinaryVecBcast101
;
class
AlgoBinaryVecBcastX0X
;
class
AlgoBinaryVecBcast111C
;
class
AlgoBinaryVecBcast101xX
;
class
AlgoTernaryFma3VecVecVec
;
...
...
dnn/src/arm_common/elemwise_op.h
浏览文件 @
c48d58da
...
...
@@ -107,11 +107,13 @@ enum BcastType {
VEC
,
VEC_VEC
,
VEC_BCAST101
,
VEC_BCASTX0X
,
VEC_BCAST111C
,
VEC_BCAST101xX
,
VEC_SCALAR
,
SCALAR_VEC
,
BCAST101_VEC
,
BCASTX0X_VEC
,
BCAST111C_VEC
,
BCAST101xX_VEC
,
VEC_VEC_VEC
,
...
...
@@ -230,6 +232,34 @@ struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101> {
}
};
template
<
typename
ctype
>
struct
OpCallerBinary
<
PowOp
<
ctype
,
ctype
>
,
VEC_BCASTX0X
>
{
using
Op
=
PowOp
<
ctype
,
ctype
>
;
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
const
typename
Op
::
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
dst_dtype
,
size_t
batch
,
size_t
channel
,
size_t
channel_stride
)
{
Op
op
(
src0_dtype
,
src1_dtype
,
dst_dtype
);
for
(
size_t
b
=
0
;
b
<
batch
;
b
++
)
{
const
typename
Op
::
src_ctype
*
src1_ptr_base
=
src1
+
b
*
channel_stride
;
for
(
size_t
c
=
0
;
c
<
channel
;
c
++
)
{
size_t
i
=
0
;
auto
src1_ptr
=
src1_ptr_base
;
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for
(;
i
<
channel_stride
;
i
++
)
{
op
(
*
src0
,
*
src1_ptr
,
dst
);
src0
++
;
src1_ptr
++
;
dst
++
;
}
}
}
}
};
template
<
typename
ctype
>
struct
OpCallerBinary
<
PowOp
<
ctype
,
ctype
>
,
VEC_BCAST111C
>
{
using
Op
=
PowOp
<
ctype
,
ctype
>
;
...
...
@@ -332,6 +362,34 @@ struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101_VEC> {
}
};
template
<
typename
ctype
>
struct
OpCallerBinary
<
PowOp
<
ctype
,
ctype
>
,
BCASTX0X_VEC
>
{
using
Op
=
PowOp
<
ctype
,
ctype
>
;
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
const
typename
Op
::
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
dst_dtype
,
size_t
batch
,
size_t
channel
,
size_t
channel_stride
)
{
Op
op
(
src0_dtype
,
src1_dtype
,
dst_dtype
);
for
(
size_t
b
=
0
;
b
<
batch
;
b
++
)
{
auto
src0_ptr_base
=
src0
+
b
*
channel_stride
;
for
(
size_t
c
=
0
;
c
<
channel
;
c
++
)
{
size_t
i
=
0
;
auto
src0_ptr
=
src0_ptr_base
;
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for
(;
i
<
channel_stride
;
i
++
)
{
op
(
*
src0_ptr
,
*
src1
,
dst
);
src0_ptr
++
;
src1
++
;
dst
++
;
}
}
}
}
};
template
<
typename
Op
>
struct
OpCallerBinary
<
Op
,
VEC_VEC
>
{
static
void
run
(
...
...
@@ -398,6 +456,45 @@ struct OpCallerBinary<Op, VEC_BCAST101> {
}
};
template
<
typename
Op
>
struct
OpCallerBinary
<
Op
,
VEC_BCASTX0X
>
{
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
const
typename
Op
::
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
dst_dtype
,
size_t
batch
,
size_t
channel
,
size_t
channel_stride
)
{
Op
op
(
src0_dtype
,
src1_dtype
,
dst_dtype
);
ParamElemVisitor
<
typename
Op
::
src_ctype
>
vis
;
for
(
size_t
b
=
0
;
b
<
batch
;
b
++
)
{
const
typename
Op
::
src_ctype
*
src1_ptr_base
=
src1
+
b
*
channel_stride
;
for
(
size_t
c
=
0
;
c
<
channel
;
c
++
)
{
size_t
i
=
0
;
auto
src1_ptr
=
src1_ptr_base
;
for
(;
i
+
Op
::
SIMD_WIDTH
*
2
<=
channel_stride
;
i
+=
Op
::
SIMD_WIDTH
*
2
)
{
auto
src0_neon0
=
vis
(
src0
);
auto
src0_neon1
=
vis
(
src0
+
Op
::
SIMD_WIDTH
);
auto
src1_neon0
=
vis
(
src1_ptr
);
auto
src1_neon1
=
vis
(
src1_ptr
+
Op
::
SIMD_WIDTH
);
op
({{
src0_neon0
,
src0_neon1
}},
{{
src1_neon0
,
src1_neon1
}},
dst
);
src0
+=
Op
::
SIMD_WIDTH
*
2
;
src1_ptr
+=
Op
::
SIMD_WIDTH
*
2
;
dst
+=
Op
::
SIMD_WIDTH
*
2
;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for
(;
i
<
channel_stride
;
i
++
)
{
op
(
*
src0
,
*
src1_ptr
,
dst
);
src0
++
;
src1_ptr
++
;
dst
++
;
}
}
}
}
};
template
<
typename
Op
>
struct
OpCallerBinary
<
Op
,
VEC_BCAST111C
>
{
static
void
run
(
...
...
@@ -844,6 +941,45 @@ struct OpCallerBinary<Op, BCAST101_VEC> {
}
};
template
<
typename
Op
>
struct
OpCallerBinary
<
Op
,
BCASTX0X_VEC
>
{
static
void
run
(
const
typename
Op
::
src_ctype
*
src0
,
const
typename
Op
::
src_ctype
*
src1
,
typename
Op
::
dst_ctype
*
dst
,
DType
src0_dtype
,
DType
src1_dtype
,
DType
dst_dtype
,
size_t
batch
,
size_t
channel
,
size_t
channel_stride
)
{
Op
op
(
src0_dtype
,
src1_dtype
,
dst_dtype
);
ParamElemVisitor
<
typename
Op
::
src_ctype
>
vis
;
for
(
size_t
b
=
0
;
b
<
batch
;
b
++
)
{
auto
src0_ptr_base
=
src0
+
b
*
channel_stride
;
for
(
size_t
c
=
0
;
c
<
channel
;
c
++
)
{
auto
src0_ptr
=
src0_ptr_base
;
size_t
i
=
0
;
for
(;
i
+
Op
::
SIMD_WIDTH
*
2
<=
channel_stride
;
i
+=
Op
::
SIMD_WIDTH
*
2
)
{
auto
src0_neon0
=
vis
(
src0_ptr
);
auto
src0_neon1
=
vis
(
src0_ptr
+
Op
::
SIMD_WIDTH
);
auto
src1_neon0
=
vis
(
src1
);
auto
src1_neon1
=
vis
(
src1
+
Op
::
SIMD_WIDTH
);
op
({{
src0_neon0
,
src0_neon1
}},
{{
src1_neon0
,
src1_neon1
}},
dst
);
src0_ptr
+=
Op
::
SIMD_WIDTH
*
2
;
src1
+=
Op
::
SIMD_WIDTH
*
2
;
dst
+=
Op
::
SIMD_WIDTH
*
2
;
}
#if MEGDNN_FIX_AARCH32_BUG
// FIXME: as llvm may cause cannot select error if enable vectorize
#pragma clang loop vectorize(disable)
#endif
for
(;
i
<
channel_stride
;
i
++
)
{
op
(
*
src0_ptr
,
*
src1
,
dst
);
src0_ptr
++
;
src1
++
;
dst
++
;
}
}
}
}
};
template
<
typename
Op
,
BcastType
bcast_type
>
struct
OpCallerTernary
;
...
...
dnn/src/common/elemwise/opr_impl_helper.cpp
浏览文件 @
c48d58da
...
...
@@ -150,6 +150,20 @@ bool ElemwiseLayoutHelper::is_broadcasted_channel_like(
return
false
;
}
bool
ElemwiseLayoutHelper
::
is_broadcasted_3dim_like
(
const
TensorLayout
&
layout
,
BroadcastChannelInfo
&
info
)
{
if
(
layout
.
format
.
type
()
==
TensorFormat
::
Type
::
DEFAULT
)
{
if
(
layout
.
ndim
==
3
&&
(
layout
.
stride
[
0
]
-
layout
.
shape
[
2
])
==
0
&&
layout
.
stride
[
1
]
==
0
&&
layout
.
stride
[
2
]
==
1
)
{
info
.
x
=
layout
.
shape
[
0
];
info
.
y
=
layout
.
shape
[
1
];
info
.
z
=
layout
.
shape
[
2
];
return
true
;
}
}
return
false
;
}
bool
ElemwiseLayoutHelper
::
is_NHWC_broadcasted_channel_like
(
const
TensorLayout
&
layout
,
BroadcastChannelInfo
&
info
)
{
if
(
layout
.
format
.
type
()
==
TensorFormat
::
Type
::
DEFAULT
)
{
...
...
dnn/src/common/elemwise/opr_impl_helper.h
浏览文件 @
c48d58da
...
...
@@ -80,6 +80,14 @@ public:
static
bool
is_broadcasted_channel_like
(
const
TensorLayout
&
layout
,
BroadcastChannelInfo
&
info
);
/*!
* \brief check whether layout matches BroadcastChannelInfo like N1HW
*
* Note layout should be [N, 1, H*W] like
*/
static
bool
is_broadcasted_3dim_like
(
const
TensorLayout
&
layout
,
BroadcastChannelInfo
&
info
);
/*!
* \brief check whether layout matches BroadcastChannelInfo under NHWC
* layout
...
...
dnn/test/arm_common/elemwise.cpp
浏览文件 @
c48d58da
...
...
@@ -356,6 +356,30 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NHWC_FP32_BCAST) {
run_3d_incontig
(
Mode
::
FUSE_MUL_ADD3
);
}
TEST_F
(
ARM_COMMON
,
ELEMWISE_FORWARD_N1HW_FP32_BCAST
)
{
using
Mode
=
ElemwiseForward
::
Param
::
Mode
;
Checker
<
ElemwiseForward
>
checker
(
handle
());
UniformFloatRNG
rng
(
1e-5
,
7e1
);
checker
.
set_rng
(
0
,
&
rng
);
checker
.
set_epsilon
(
1e-5
);
checker
.
set_dtype
(
0
,
dtype
::
Float32
());
checker
.
set_dtype
(
1
,
dtype
::
Float32
());
//! 2 dim
auto
run
=
[
&
](
Mode
mode
)
{
// VEC_BCASTX0X
checker
.
set_param
(
mode
).
execs
({{
2
,
8
,
4
,
4
},
{
2
,
1
,
4
,
4
},
{}});
checker
.
set_param
(
mode
).
execs
({{
4
,
21
,
78
},
{
4
,
1
,
78
},
{}});
// BCASTX0X_VEC
checker
.
set_param
(
mode
).
execs
({{
2
,
1
,
4
,
4
},
{
2
,
8
,
4
,
4
},
{}});
checker
.
set_param
(
mode
).
execs
({{
4
,
1
,
78
},
{
4
,
21
,
78
},
{}});
};
run
(
Mode
::
ADD
);
run
(
Mode
::
MUL
);
run
(
Mode
::
SUB
);
}
#if MEGDNN_WITH_BENCHMARK
namespace
{
void
run_elemwise_benchmark
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录