提交 3344b580 编写于 作者: M Megvii Engine Team

feat(dnn): add elemwise for nchw88+fp16

GitOrigin-RevId: 63587975f8746bd8cf2443e81d433bfc07122b38
上级 682c74df
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
...@@ -22,7 +23,6 @@ MIDOUT_DECL(arm_common_conv_bias_postprocess_helper) ...@@ -22,7 +23,6 @@ MIDOUT_DECL(arm_common_conv_bias_postprocess_helper)
namespace { namespace {
#define CONCAT_OP(_name) megdnn::arm_common::_name #define CONCAT_OP(_name) megdnn::arm_common::_name
#define CONCAT_NL(_name) megdnn::NonlineMode::_name #define CONCAT_NL(_name) megdnn::NonlineMode::_name
...@@ -57,9 +57,9 @@ namespace { ...@@ -57,9 +57,9 @@ namespace {
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW); dst_type, N, OC, OH* OW);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, \ megdnn::arm_common::OpCallerBinary<_op<ctype>, \
megdnn::arm_common::VEC_BCAST101x4>:: \ megdnn::arm_common::VEC_BCAST101xX>:: \
run(static_cast<ctype*>(conv_dst_ptr), \ run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \ reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
...@@ -86,9 +86,9 @@ namespace { ...@@ -86,9 +86,9 @@ namespace {
if (pack_oc_size == 1) { \ if (pack_oc_size == 1) { \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
} else { \ } else { \
megdnn_assert(pack_oc_size == 4, \ megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \
"Only support nchw44 in ARM"); \ "Only support nchw44/nchw88 in ARM"); \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX); \
} \ } \
} \ } \
MIDOUT_END(); \ MIDOUT_END(); \
...@@ -160,7 +160,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { ...@@ -160,7 +160,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
#undef FOR_NONLINEAR_UNARY #undef FOR_NONLINEAR_UNARY
#undef FOR_NONLINEAR_BINARY_BROADCAST #undef FOR_NONLINEAR_BINARY_BROADCAST
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 #undef FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX
#undef FOR_NONLINEAR_BINARY #undef FOR_NONLINEAR_BINARY
#undef FOR_NONLINEAR_NOBIAS #undef FOR_NONLINEAR_NOBIAS
#undef FOR_NONLINEAR #undef FOR_NONLINEAR
...@@ -183,9 +183,17 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { ...@@ -183,9 +183,17 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW); dst_type, N, OC, OH* OW);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \ #define FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \
megdnn::arm_common::VEC_BCAST101xX>:: \
run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW88(_op) \
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \ megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \
megdnn::arm_common::VEC_BCAST101x4>:: \ megdnn::arm_common::VEC_BCAST101xX>:: \
run(static_cast<opctype*>(conv_dst_ptr), \ run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<const opctype*>(bias_ptr), \ reinterpret_cast<const opctype*>(bias_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
...@@ -220,9 +228,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { ...@@ -220,9 +228,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
if (pack_oc_size == 1) { \ if (pack_oc_size == 1) { \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
} else { \ } else { \
megdnn_assert(pack_oc_size == 4, \ megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \
"Only support nchw44 in ARM"); \ "Only support nchw44/nchw88 in ARM"); \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX); \
} \ } \
break; \ break; \
default: \ default: \
...@@ -230,9 +238,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { ...@@ -230,9 +238,9 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
if (pack_oc_size == 1) { \ if (pack_oc_size == 1) { \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
} else { \ } else { \
megdnn_assert(pack_oc_size == 4, \ megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \
"Only support nchw44 in ARM"); \ "Only support nchw44/nchw88 in ARM"); \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \ FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX); \
} \ } \
break; \ break; \
} \ } \
...@@ -254,7 +262,7 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { ...@@ -254,7 +262,7 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
#undef FOR_NONLINEAR_UNARY #undef FOR_NONLINEAR_UNARY
#undef FOR_NONLINEAR_BINARY_BROADCAST #undef FOR_NONLINEAR_BINARY_BROADCAST
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44 #undef FOR_NONLINEAR_BINARY_BROADCAST_NCHWXX
#undef FOR_NONLINEAR_BINARY #undef FOR_NONLINEAR_BINARY
#undef FOR_NONLINEAR_NOBIAS #undef FOR_NONLINEAR_NOBIAS
#undef FOR_NONLINEAR #undef FOR_NONLINEAR
...@@ -268,9 +276,9 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { ...@@ -268,9 +276,9 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW); dst_type, N, OC, OH* OW);
#define FOR_BINARY_BROADCAST_NCHW44(_op) \ #define FOR_BINARY_BROADCAST_NCHWXX(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, \ megdnn::arm_common::OpCallerBinary<_op<ctype>, \
megdnn::arm_common::VEC_BCAST101x4>:: \ megdnn::arm_common::VEC_BCAST101xX>:: \
run(static_cast<ctype*>(conv_dst_ptr), \ run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \ reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
...@@ -292,9 +300,9 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { ...@@ -292,9 +300,9 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
if (pack_oc_size == 1) { \ if (pack_oc_size == 1) { \
FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \ FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \
} else { \ } else { \
megdnn_assert(pack_oc_size == 4, \ megdnn_assert(pack_oc_size == 4 || pack_oc_size == 8, \
"Only support nchw44 in ARM"); \ "Only support nchw44/nchw88 in ARM"); \
FOR_BINARY_BROADCAST_NCHW44(CONCAT_OP(AddOp)); \ FOR_BINARY_BROADCAST_NCHWXX(CONCAT_OP(AddOp)); \
} \ } \
break; \ break; \
case megdnn::BiasMode::BIAS: \ case megdnn::BiasMode::BIAS: \
...@@ -318,7 +326,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { ...@@ -318,7 +326,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> {
}; };
#undef FOR_BINARY_BROADCAST #undef FOR_BINARY_BROADCAST
#undef FOR_BINARY_BROADCAST_NCHW44 #undef FOR_BINARY_BROADCAST_NCHWXX
#undef FOR_BINARY #undef FOR_BINARY
#undef FOR_BIAS #undef FOR_BIAS
#undef CB #undef CB
......
...@@ -105,25 +105,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available( ...@@ -105,25 +105,21 @@ bool ElemwiseImpl::AlgoBinaryVecBcast101::is_available(
return false; return false;
} }
bool ElemwiseImpl::AlgoBinaryVecBcast101x4::is_available( bool ElemwiseImpl::AlgoBinaryVecBcast101xX::is_available(
const KernParam& kern_param) const { const KernParam& kern_param) const {
if (!is_available_common(kern_param.mode) || if (!is_available_common(kern_param.mode) ||
((BcastType::VEC_BCAST101x4 != kern_param.broad_cast_type) && ((BcastType::VEC_BCAST101xX != kern_param.broad_cast_type) &&
(BcastType::BCAST101x4_VEC != kern_param.broad_cast_type))) (BcastType::BCAST101xX_VEC != kern_param.broad_cast_type)))
return false; return false;
auto& elparam = kern_param.binary_elparam; auto& elparam = kern_param.binary_elparam;
auto& src0 = elparam[0]; auto& src0 = elparam[0];
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
if (DNN_FLOAT16_SELECT(src0.layout.dtype == dtype::Float16{}, false)) {
return false;
}
#endif
DISPATCH_TYPE("AlgoBinaryVecBcast101x::is_available"_hash); DISPATCH_TYPE("AlgoBinaryVecBcast101xX::is_available"_hash);
return false; return false;
} }
#undef DISPATCH_MODE_FLOAT #undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT #undef DISPATCH_MODE_INT
...@@ -334,16 +330,19 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec( ...@@ -334,16 +330,19 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(
return; return;
} }
void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(
const KernParam& kern_param) const { const KernParam& kern_param) const {
auto& elparam = kern_param.binary_elparam; auto& elparam = kern_param.binary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1]; auto &src0 = elparam[0], &src1 = elparam[1];
auto&& dst = *(kern_param.m_dst); auto&& dst = *(kern_param.m_dst);
BroadcastChannelInfo binfo; BroadcastChannelInfo binfo;
// BcastType::VEC + BCAST_101x // BcastType::VEC + BCAST_101X
if (BcastType::VEC_BCAST101x4 == kern_param.broad_cast_type && if (BcastType::VEC_BCAST101xX == kern_param.broad_cast_type) {
is_broadcastedx_channel_like<4>(src1.layout, binfo)) { megdnn_assert(
is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo),
"only nchw44 and nchw88 supported");
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \ case Mode::_mode: \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
...@@ -351,7 +350,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( ...@@ -351,7 +350,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
thin_function<void(const _type*, const _type*, _type*, DType, \ thin_function<void(const _type*, const _type*, _type*, DType, \
DType, DType, size_t, size_t, size_t, size_t)> \ DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<_type, _type>, \ run = OpCallerBinary<_op<_type, _type>, \
BcastType::VEC_BCAST101x4>::run; \ BcastType::VEC_BCAST101xX>::run; \
MEGDNN_DISPATCH_CPU_KERN( \ MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \ static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \ run(static_cast<const _type*>(src0.raw_ptr), \
...@@ -362,17 +361,19 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( ...@@ -362,17 +361,19 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
} \ } \
MIDOUT_END(); \ MIDOUT_END(); \
return return
size_t batch_size = size_t batch_size =
src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH_TYPE("AlgoBinaryVecBcast101x::exec_vec_b"_hash); DISPATCH_TYPE("AlgoBinaryVecBcast101xX::exec_vec_b"_hash);
#undef DISPATCH_BINARY #undef DISPATCH_BINARY
} }
// BCAST_101x + BcastType::VEC // BCAST_101x + BcastType::VEC
if (BcastType::BCAST101x4_VEC == kern_param.broad_cast_type && if (BcastType::BCAST101xX_VEC == kern_param.broad_cast_type) {
is_broadcastedx_channel_like<4>(src0.layout, binfo)) { megdnn_assert(
is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo),
"only nchw44 and nchw88 supported");
#define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \ #define DISPATCH_BINARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \ case Mode::_mode: \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_binary, midout_iv(_case), \
...@@ -380,7 +381,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( ...@@ -380,7 +381,7 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
thin_function<void(const _type*, const _type*, _type*, DType, \ thin_function<void(const _type*, const _type*, _type*, DType, \
DType, DType, size_t, size_t, size_t, size_t)> \ DType, DType, size_t, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<_type, _type>, \ run = OpCallerBinary<_op<_type, _type>, \
BcastType::BCAST101x4_VEC>::run; \ BcastType::BCAST101xX_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \ MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \ static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \ run(static_cast<const _type*>(src0.raw_ptr), \
...@@ -394,12 +395,13 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec( ...@@ -394,12 +395,13 @@ void ElemwiseImpl::AlgoBinaryVecBcast101x4::exec(
size_t batch_size = size_t batch_size =
src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH_TYPE("AlgoBinaryVecBcast101x::exec_b_vec"_hash); DISPATCH_TYPE("AlgoBinaryVecBcast101xX::exec_b_vec"_hash);
#undef DISPATCH_BINARY #undef DISPATCH_BINARY
} }
return; return;
} }
#undef DISPATCH_MODE_FLOAT #undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT #undef DISPATCH_MODE_INT
......
...@@ -34,7 +34,7 @@ namespace arm_common { ...@@ -34,7 +34,7 @@ namespace arm_common {
DECL_CB(VecVec); DECL_CB(VecVec);
DECL_CB(VecScalar); DECL_CB(VecScalar);
DECL_CB(VecBcast101); DECL_CB(VecBcast101);
DECL_CB(VecBcast101x4); DECL_CB(VecBcast101xX);
#undef DECL_CB #undef DECL_CB
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
......
...@@ -27,14 +27,14 @@ class ElemwiseImpl::AlgoPack { ...@@ -27,14 +27,14 @@ class ElemwiseImpl::AlgoPack {
AlgoBinaryVecVec algo_binary_vec_vec; AlgoBinaryVecVec algo_binary_vec_vec;
AlgoBinaryVecScalar algo_binary_vec_sca; AlgoBinaryVecScalar algo_binary_vec_sca;
AlgoBinaryVecBcast101 algo_binary_vec_bcast101; AlgoBinaryVecBcast101 algo_binary_vec_bcast101;
AlgoBinaryVecBcast101x4 algo_binary_VEC_BCAST101x4; AlgoBinaryVecBcast101xX algo_binary_VEC_BCAST101xX;
AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec; AlgoTernaryFma3VecVecVec algo_ternaryfma3_vec_vec_vec;
AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca; AlgoTernaryFma3VecVecScalar algo_ternaryfma3_vec_vecsca;
AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101; AlgoTernaryFma3Bcast101VecBcast101 algo_ternaryfma3_bcast101_vec_bcast101;
AlgoTernaryFma3Bcast101x4VecBcast101x4 AlgoTernaryFma3Bcast101xXVecBcast101xX
algo_ternaryfma3_bcast101x4_vec_bcast101x4; algo_ternaryfma3_bcast101xX_vec_bcast101xX;
AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec; AlgoTernaryFma3VecBcast101Vec algo_ternaryfma3_vec_bcast101_vec;
AlgoTernaryFma3VecBcast101x4Vec algo_ternaryfma3_vec_bcast101x4_vec; AlgoTernaryFma3VecBcast101xXVec algo_ternaryfma3_vec_bcast101xX_vec;
AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec; AlgoTernaryFma3VecScalarVec algo_ternaryfma3_vec_sca_vec;
AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca; AlgoTernaryFma3VecScalarScalar algo_ternaryfma3_vec_sca_sca;
...@@ -44,13 +44,13 @@ public: ...@@ -44,13 +44,13 @@ public:
all_algos.emplace_back(&algo_binary_vec_vec); all_algos.emplace_back(&algo_binary_vec_vec);
all_algos.emplace_back(&algo_binary_vec_sca); all_algos.emplace_back(&algo_binary_vec_sca);
all_algos.emplace_back(&algo_binary_vec_bcast101); all_algos.emplace_back(&algo_binary_vec_bcast101);
all_algos.emplace_back(&algo_binary_VEC_BCAST101x4); all_algos.emplace_back(&algo_binary_VEC_BCAST101xX);
all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_vec_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca); all_algos.emplace_back(&algo_ternaryfma3_vec_vecsca);
all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101); all_algos.emplace_back(&algo_ternaryfma3_bcast101_vec_bcast101);
all_algos.emplace_back(&algo_ternaryfma3_bcast101x4_vec_bcast101x4); all_algos.emplace_back(&algo_ternaryfma3_bcast101xX_vec_bcast101xX);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101x4_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_bcast101xX_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec); all_algos.emplace_back(&algo_ternaryfma3_vec_sca_vec);
all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca); all_algos.emplace_back(&algo_ternaryfma3_vec_sca_sca);
} }
...@@ -118,9 +118,10 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { ...@@ -118,9 +118,10 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
} }
if (is_vector(src1.layout) && if (is_vector(src1.layout) &&
is_broadcastedx_channel_like<4>(src0.layout, binfo) && (is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo)) &&
src0.layout.eq_layout(src2.layout)) { src0.layout.eq_layout(src2.layout)) {
kern_param.broad_cast_type = BcastType::BCAST101x4_VEC_BCAST101x4; kern_param.broad_cast_type = BcastType::BCAST101xX_VEC_BCAST101xX;
return kern_param; return kern_param;
} }
...@@ -131,8 +132,9 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { ...@@ -131,8 +132,9 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
} }
if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) && if (is_vector(src0.layout) && src0.layout.eq_layout(src2.layout) &&
is_broadcastedx_channel_like<4>(src1.layout, binfo)) { (is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
kern_param.broad_cast_type = BcastType::VEC_BCAST101x4_VEC; is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101xX_VEC;
return kern_param; return kern_param;
} }
...@@ -180,17 +182,18 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) { ...@@ -180,17 +182,18 @@ ElemwiseImpl::KernParam ElemwiseImpl::make_kern_param(ElemwiseImpl* opr) {
} }
if (is_vector(src0.layout) && if (is_vector(src0.layout) &&
is_broadcastedx_channel_like<4>(src1.layout, binfo)) { (is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
kern_param.broad_cast_type = BcastType::VEC_BCAST101x4; is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
kern_param.broad_cast_type = BcastType::VEC_BCAST101xX;
return kern_param; return kern_param;
} }
if (is_vector(src1.layout) && if (is_vector(src1.layout) &&
is_broadcastedx_channel_like<4>(src0.layout, binfo)) { (is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
kern_param.broad_cast_type = BcastType::BCAST101x4_VEC; is_broadcastedx_channel_like<8>(src0.layout, binfo))) {
kern_param.broad_cast_type = BcastType::BCAST101xX_VEC;
return kern_param; return kern_param;
} }
} else if (opr->m_src->size() == 1) { } else if (opr->m_src->size() == 1) {
kern_param.broad_cast_type = BcastType::VEC; kern_param.broad_cast_type = BcastType::VEC;
kern_param.unary_elparam = opr->make_elemwise_op_param<1>(); kern_param.unary_elparam = opr->make_elemwise_op_param<1>();
......
...@@ -10,7 +10,9 @@ ...@@ -10,7 +10,9 @@
* implied. * implied.
*/ */
#pragma once #pragma once
#include "src/fallback/elemwise/opr_impl.h" #include "src/fallback/elemwise/opr_impl.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
namespace megdnn { namespace megdnn {
...@@ -37,13 +39,13 @@ private: ...@@ -37,13 +39,13 @@ private:
class AlgoBinaryVecVec; class AlgoBinaryVecVec;
class AlgoBinaryVecScalar; class AlgoBinaryVecScalar;
class AlgoBinaryVecBcast101; class AlgoBinaryVecBcast101;
class AlgoBinaryVecBcast101x4; class AlgoBinaryVecBcast101xX;
class AlgoTernaryFma3VecVecVec; class AlgoTernaryFma3VecVecVec;
class AlgoTernaryFma3VecVecScalar; class AlgoTernaryFma3VecVecScalar;
class AlgoTernaryFma3Bcast101VecBcast101; class AlgoTernaryFma3Bcast101VecBcast101;
class AlgoTernaryFma3Bcast101x4VecBcast101x4; class AlgoTernaryFma3Bcast101xXVecBcast101xX;
class AlgoTernaryFma3VecBcast101Vec; class AlgoTernaryFma3VecBcast101Vec;
class AlgoTernaryFma3VecBcast101x4Vec; class AlgoTernaryFma3VecBcast101xXVec;
class AlgoTernaryFma3VecScalarVec; class AlgoTernaryFma3VecScalarVec;
class AlgoTernaryFma3VecScalarScalar; class AlgoTernaryFma3VecScalarScalar;
class AlgoPack; class AlgoPack;
......
...@@ -42,9 +42,9 @@ using namespace arm_common; ...@@ -42,9 +42,9 @@ using namespace arm_common;
DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC); DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC);
DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR); DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR);
DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101); DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101);
DECL_AVAILABLE(Bcast101x4VecBcast101x4, BcastType::BCAST101x4_VEC_BCAST101x4); DECL_AVAILABLE(Bcast101xXVecBcast101xX, BcastType::BCAST101xX_VEC_BCAST101xX);
DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC); DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC);
DECL_AVAILABLE(VecBcast101x4Vec, BcastType::VEC_BCAST101x4_VEC); DECL_AVAILABLE(VecBcast101xXVec, BcastType::VEC_BCAST101xX_VEC);
DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC); DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC);
DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR); DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR);
#undef DECL_CB #undef DECL_CB
...@@ -161,13 +161,15 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( ...@@ -161,13 +161,15 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec(
return; return;
} }
void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec( void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec(
const KernParam& kern_param) const { const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam; auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
BroadcastChannelInfo binfo; BroadcastChannelInfo binfo;
is_broadcastedx_channel_like<4>(src0.layout, binfo); megdnn_assert(is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo),
"only nchw44 and nchw88 supported");
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \ case Mode::_mode: \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
...@@ -177,7 +179,7 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec( ...@@ -177,7 +179,7 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec(
size_t, size_t, size_t)> \ size_t, size_t, size_t)> \
run = OpCallerTernary< \ run = OpCallerTernary< \
_op<_type, _type>, \ _op<_type, _type>, \
BcastType::BCAST101x4_VEC_BCAST101x4>::run; \ BcastType::BCAST101xX_VEC_BCAST101xX>::run; \
MEGDNN_DISPATCH_CPU_KERN( \ MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \ static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \ run(static_cast<const _type*>(src0.raw_ptr), \
...@@ -193,19 +195,21 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec( ...@@ -193,19 +195,21 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101x4VecBcast101x4::exec(
size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
auto&& dst = *(kern_param.m_dst); auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3Bcast101x4VecBcast101x4::exec"_hash); DISPATCH_TYPE("AlgoTernaryFma3Bcast101xXVecBcast101xX::exec"_hash);
#undef DISPATCH_TERNARY #undef DISPATCH_TERNARY
return; return;
} }
void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec( void ElemwiseImpl::AlgoTernaryFma3VecBcast101xXVec::exec(
const KernParam& kern_param) const { const KernParam& kern_param) const {
auto& elparam = kern_param.ternary_elparam; auto& elparam = kern_param.ternary_elparam;
auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2]; auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];
BroadcastChannelInfo binfo; BroadcastChannelInfo binfo;
is_broadcastedx_channel_like<4>(src1.layout, binfo); megdnn_assert(is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo),
"only nchw44 and nchw88 supported");
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \ #define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op) \
case Mode::_mode: \ case Mode::_mode: \
MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \ MIDOUT_BEGIN(megdnn_arm_common_elemwise_ternary, midout_iv(_case), \
...@@ -214,7 +218,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec( ...@@ -214,7 +218,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec(
_type*, DType, DType, DType, DType, size_t, \ _type*, DType, DType, DType, DType, size_t, \
size_t, size_t, size_t)> \ size_t, size_t, size_t)> \
run = OpCallerTernary<_op<_type, _type>, \ run = OpCallerTernary<_op<_type, _type>, \
BcastType::VEC_BCAST101x4_VEC>::run; \ BcastType::VEC_BCAST101xX_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN( \ MEGDNN_DISPATCH_CPU_KERN( \
static_cast<naive::HandleImpl*>(kern_param.handle), \ static_cast<naive::HandleImpl*>(kern_param.handle), \
run(static_cast<const _type*>(src0.raw_ptr), \ run(static_cast<const _type*>(src0.raw_ptr), \
...@@ -230,7 +234,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec( ...@@ -230,7 +234,7 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101x4Vec::exec(
size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
auto&& dst = *(kern_param.m_dst); auto&& dst = *(kern_param.m_dst);
DISPATCH_TYPE("AlgoTernaryFma3VecBcast101x4Vec::exec"_hash); DISPATCH_TYPE("AlgoTernaryFma3VecBcast101xXVec::exec"_hash);
#undef DISPATCH_TERNARY #undef DISPATCH_TERNARY
return; return;
......
...@@ -34,9 +34,9 @@ namespace arm_common { ...@@ -34,9 +34,9 @@ namespace arm_common {
DECL_CB(VecVecVec); DECL_CB(VecVecVec);
DECL_CB(VecVecScalar); DECL_CB(VecVecScalar);
DECL_CB(Bcast101VecBcast101); DECL_CB(Bcast101VecBcast101);
DECL_CB(Bcast101x4VecBcast101x4); DECL_CB(Bcast101xXVecBcast101xX);
DECL_CB(VecBcast101Vec); DECL_CB(VecBcast101Vec);
DECL_CB(VecBcast101x4Vec); DECL_CB(VecBcast101xXVec);
DECL_CB(VecScalarVec); DECL_CB(VecScalarVec);
DECL_CB(VecScalarScalar); DECL_CB(VecScalarScalar);
#undef DECL_CB #undef DECL_CB
......
...@@ -644,7 +644,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, ...@@ -644,7 +644,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param,
{ {
BroadcastChannelInfo binfo; BroadcastChannelInfo binfo;
if (is_vector(src0.layout) && if (is_vector(src0.layout) &&
is_broadcastedx_channel_like<4>(src1.layout, binfo)) { (is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo))) {
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \ case _mode: { \
using src_ctype = typename DTypeTrait<_src_dt>::ctype; \ using src_ctype = typename DTypeTrait<_src_dt>::ctype; \
...@@ -653,14 +654,13 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, ...@@ -653,14 +654,13 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param,
DType, DType, DType, size_t, size_t, size_t, \ DType, DType, DType, size_t, size_t, size_t, \
size_t)> \ size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, \
VEC_BCAST101x4>::run; \ VEC_BCAST101xX>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} }
size_t batch_size = size_t batch_size =
src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z); src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH() DISPATCH()
...@@ -679,14 +679,13 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param, ...@@ -679,14 +679,13 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<2>& param,
DType, DType, DType, size_t, size_t, size_t, \ DType, DType, DType, size_t, size_t, size_t, \
size_t)> \ size_t)> \
run = OpCallerBinary<_op<src_ctype, dst_ctype>, \ run = OpCallerBinary<_op<src_ctype, dst_ctype>, \
BCAST101x4_VEC>::run; \ BCAST101xX_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR(run( \ MEGDNN_DISPATCH_CPU_KERN_OPR(run( \
src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \ dst.ptr<dst_ctype>(), src0.layout.dtype, src1.layout.dtype, \
dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \ dst.layout.dtype, batch_size, binfo.x, binfo.y, binfo.z)); \
return; \ return; \
} }
size_t batch_size = size_t batch_size =
src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z); src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
DISPATCH() DISPATCH()
...@@ -818,7 +817,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, ...@@ -818,7 +817,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
{ {
BroadcastChannelInfo binfo; BroadcastChannelInfo binfo;
if (is_vector(src0.layout) && if (is_vector(src0.layout) &&
is_broadcastedx_channel_like<4>(src1.layout, binfo) && (is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
is_broadcastedx_channel_like<8>(src1.layout, binfo)) &&
src0.layout.eq_shape(src2.layout)) { src0.layout.eq_shape(src2.layout)) {
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \ case _mode: { \
...@@ -828,7 +828,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, ...@@ -828,7 +828,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype*, dst_ctype*, DType, DType, DType, \
DType, size_t, size_t, size_t, size_t)> \ DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, \ run = OpCallerTernary<_op<src_ctype, dst_ctype>, \
VEC_BCAST101x4_VEC>::run; \ VEC_BCAST101xX_VEC>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \ src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
...@@ -846,7 +846,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, ...@@ -846,7 +846,8 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
//! BCAST101x + VEC +BCAST101x //! BCAST101x + VEC +BCAST101x
if (is_vector(src1.layout) && if (is_vector(src1.layout) &&
is_broadcastedx_channel_like<4>(src0.layout, binfo) && (is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
is_broadcastedx_channel_like<8>(src0.layout, binfo)) &&
src0.layout.eq_shape(src2.layout)) { src0.layout.eq_shape(src2.layout)) {
#define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \ #define DISPATCH_SINGLE_MODE(_src_dt, _dst_dt, _mode, _op) \
case _mode: { \ case _mode: { \
...@@ -856,7 +857,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param, ...@@ -856,7 +857,7 @@ void ElemwiseMultiTypeImpl::on_quantized_mode(const ElemwiseOpParamN<3>& param,
const src_ctype*, dst_ctype*, DType, DType, DType, \ const src_ctype*, dst_ctype*, DType, DType, DType, \
DType, size_t, size_t, size_t, size_t)> \ DType, size_t, size_t, size_t, size_t)> \
run = OpCallerTernary<_op<src_ctype, dst_ctype>, \ run = OpCallerTernary<_op<src_ctype, dst_ctype>, \
BCAST101x4_VEC_BCAST101x4>::run; \ BCAST101xX_VEC_BCAST101xX>::run; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \ MEGDNN_DISPATCH_CPU_KERN_OPR( \
run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \ run(src0.ptr<src_ctype>(), src1.ptr<src_ctype>(), \
src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \ src2.ptr<src_ctype>(), dst.ptr<dst_ctype>(), \
......
...@@ -89,6 +89,21 @@ cb(dt_float32, float32_t, float32x4_t, f32); ...@@ -89,6 +89,21 @@ cb(dt_float32, float32_t, float32x4_t, f32);
cb(dt_int32, int32_t, int32x4_t, s32); cb(dt_int32, int32_t, int32x4_t, s32);
#undef cb #undef cb
template <typename ctype>
struct ParamElemVisitorBcast101x8;
#define cb(_ctype, _inner_ctype, _neon_type, _fun_suffix) \
template <> \
struct ParamElemVisitorBcast101x8<_ctype> { \
_neon_type operator()(const _ctype* src) const { \
return vld1q_##_fun_suffix( \
reinterpret_cast<const _inner_ctype*>(src)); \
} \
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
cb(__fp16, __fp16, float16x8_t, f16);
#endif
#undef cb
/*! /*!
* \brief broadcast type * \brief broadcast type
* BCAST_x[0]x[1]...: x[i] == !stride[i] * BCAST_x[0]x[1]...: x[i] == !stride[i]
...@@ -97,17 +112,17 @@ enum BcastType { ...@@ -97,17 +112,17 @@ enum BcastType {
VEC, VEC,
VEC_VEC, VEC_VEC,
VEC_BCAST101, VEC_BCAST101,
VEC_BCAST101x4, VEC_BCAST101xX,
VEC_SCALAR, VEC_SCALAR,
SCALAR_VEC, SCALAR_VEC,
BCAST101_VEC, BCAST101_VEC,
BCAST101x4_VEC, BCAST101xX_VEC,
VEC_VEC_VEC, VEC_VEC_VEC,
VEC_VEC_SCALAR, VEC_VEC_SCALAR,
BCAST101_VEC_BCAST101, BCAST101_VEC_BCAST101,
BCAST101x4_VEC_BCAST101x4, BCAST101xX_VEC_BCAST101xX,
VEC_BCAST101_VEC, VEC_BCAST101_VEC,
VEC_BCAST101x4_VEC, VEC_BCAST101xX_VEC,
VEC_SCALAR_VEC, VEC_SCALAR_VEC,
VEC_SCALAR_SCALAR, VEC_SCALAR_SCALAR,
UNKNOWN_BCAST_TYPE UNKNOWN_BCAST_TYPE
...@@ -334,7 +349,7 @@ struct OpCallerBinary<Op, VEC_BCAST101> { ...@@ -334,7 +349,7 @@ struct OpCallerBinary<Op, VEC_BCAST101> {
}; };
template <typename ctype> template <typename ctype>
struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101x4_VEC> { struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101xX_VEC> {
using Op = PowOp<ctype, ctype>; using Op = PowOp<ctype, ctype>;
static void run(const typename Op::src_ctype* src0, static void run(const typename Op::src_ctype* src0,
const typename Op::src_ctype* src1, const typename Op::src_ctype* src1,
...@@ -360,18 +375,37 @@ struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101x4_VEC> { ...@@ -360,18 +375,37 @@ struct OpCallerBinary<PowOp<ctype, ctype>, BCAST101x4_VEC> {
} }
}; };
template <typename Op> template <typename src_ctype, size_t channel_block_dim>
struct OpCallerBinary<Op, BCAST101x4_VEC> { struct OpCallerBinaryBcast101xXVec {
static void run(const typename Op::src_ctype* src0, template <typename Op>
const typename Op::src_ctype* src1, static void run(const src_ctype* src0, const src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, typename Op::dst_ctype* dst, const Op& op, size_t batch,
DType src1_dtype, DType dst_dtype, size_t batch, size_t nr_channel_blocks, size_t channel_stride) {
size_t nr_channel_blocks, size_t channel_stride, for (size_t b = 0; b < batch; b++) {
size_t channel_block_dim) { auto src0_ptr = src0;
megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
Op op(src0_dtype, src1_dtype, dst_dtype); auto src0_block_ptr = src0_ptr + cb * channel_block_dim;
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis0; for (size_t img_index = 0; img_index < channel_stride;
ParamElemVisitor<typename Op::src_ctype> vis1; img_index++) {
for (size_t c_iter = 0; c_iter < channel_block_dim;
c_iter++) {
op(*(src0_block_ptr + c_iter), *src1, dst);
src1++;
dst++;
}
}
}
}
}
};
template <typename src_ctype, size_t channel_block_dim>
struct OpCallerBinaryBcast101xDVec {
template <typename Op, typename Vis0, typename Vis1>
static void run(const src_ctype* src0, const src_ctype* src1,
typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0,
const Vis1& vis1, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
for (size_t b = 0; b < batch; b++) { for (size_t b = 0; b < batch; b++) {
auto src0_ptr = src0; auto src0_ptr = src0;
for (size_t cb = 0; cb < nr_channel_blocks; cb++) { for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
...@@ -400,8 +434,63 @@ struct OpCallerBinary<Op, BCAST101x4_VEC> { ...@@ -400,8 +434,63 @@ struct OpCallerBinary<Op, BCAST101x4_VEC> {
} }
}; };
template <typename src_ctype>
struct OpCallerBinaryBcast101xXVec<src_ctype, 4> {
template <typename Op>
static void run(const src_ctype* src0, const src_ctype* src1,
typename Op::dst_ctype* dst, const Op& op, size_t batch,
size_t nr_channel_blocks, size_t channel_stride) {
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis0;
ParamElemVisitor<typename Op::src_ctype> vis1;
OpCallerBinaryBcast101xDVec<src_ctype, 4>::run(
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks,
channel_stride);
}
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
struct OpCallerBinaryBcast101xXVec<__fp16, 8> {
using src_ctype = __fp16;
template <typename Op>
static void run(const src_ctype* src0, const src_ctype* src1,
typename Op::dst_ctype* dst, const Op& op, size_t batch,
size_t nr_channel_blocks, size_t channel_stride) {
ParamElemVisitorBcast101x8<src_ctype> vis0;
ParamElemVisitor<src_ctype> vis1;
OpCallerBinaryBcast101xDVec<src_ctype, 8>::run(
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks,
channel_stride);
}
};
#endif
template <typename Op>
struct OpCallerBinary<Op, BCAST101xX_VEC> {
static void run(const typename Op::src_ctype* src0,
const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype,
DType src1_dtype, DType dst_dtype, size_t batch,
size_t nr_channel_blocks, size_t channel_stride,
size_t channel_block_dim) {
megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8,
"only imp for nchw44/nchw88");
Op op(src0_dtype, src1_dtype, dst_dtype);
if (channel_block_dim == 4) {
OpCallerBinaryBcast101xXVec<typename Op::src_ctype, 4>::run(
src0, src1, dst, op, batch, nr_channel_blocks,
channel_stride);
} else {
OpCallerBinaryBcast101xXVec<typename Op::src_ctype, 8>::run(
src0, src1, dst, op, batch, nr_channel_blocks,
channel_stride);
}
}
};
template <typename ctype> template <typename ctype>
struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101x4> { struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101xX> {
using Op = PowOp<ctype, ctype>; using Op = PowOp<ctype, ctype>;
static void run(const typename Op::src_ctype* src0, static void run(const typename Op::src_ctype* src0,
const typename Op::src_ctype* src1, const typename Op::src_ctype* src1,
...@@ -427,18 +516,37 @@ struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101x4> { ...@@ -427,18 +516,37 @@ struct OpCallerBinary<PowOp<ctype, ctype>, VEC_BCAST101x4> {
} }
}; };
template <typename Op> template <typename src_ctype, size_t channel_block_dim>
struct OpCallerBinary<Op, VEC_BCAST101x4> { struct OpCallerBinaryVecBcast101xX {
static void run(const typename Op::src_ctype* src0, template <typename Op>
const typename Op::src_ctype* src1, static void run(const src_ctype* src0, const src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype, typename Op::dst_ctype* dst, const Op& op, size_t batch,
DType src1_dtype, DType dst_dtype, size_t batch, size_t nr_channel_blocks, size_t channel_stride) {
size_t nr_channel_blocks, size_t channel_stride, for (size_t b = 0; b < batch; b++) {
size_t channel_block_dim) { auto src1_ptr = src1;
megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
Op op(src0_dtype, src1_dtype, dst_dtype); auto src1_block_ptr = src1_ptr + cb * channel_block_dim;
ParamElemVisitor<typename Op::src_ctype> vis0; for (size_t img_index = 0; img_index < channel_stride;
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis1; img_index++) {
for (size_t c_iter = 0; c_iter < channel_block_dim;
c_iter++) {
op(*src0, *(src1_block_ptr + c_iter), dst);
src0++;
dst++;
}
}
}
}
}
};
template <typename src_ctype, size_t channel_block_dim>
struct OpCallerBinaryVecBcast101xD {
template <typename Op, typename Vis0, typename Vis1>
static void run(const src_ctype* src0, const src_ctype* src1,
typename Op::dst_ctype* dst, const Op& op, const Vis0& vis0,
const Vis1& vis1, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
for (size_t b = 0; b < batch; b++) { for (size_t b = 0; b < batch; b++) {
auto src1_ptr = src1; auto src1_ptr = src1;
for (size_t cb = 0; cb < nr_channel_blocks; cb++) { for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
...@@ -467,6 +575,60 @@ struct OpCallerBinary<Op, VEC_BCAST101x4> { ...@@ -467,6 +575,60 @@ struct OpCallerBinary<Op, VEC_BCAST101x4> {
} }
}; };
template <typename src_ctype>
struct OpCallerBinaryVecBcast101xX<src_ctype, 4> {
template <typename Op>
static void run(const src_ctype* src0, const src_ctype* src1,
typename Op::dst_ctype* dst, const Op& op, size_t batch,
size_t nr_channel_blocks, size_t channel_stride) {
ParamElemVisitor<src_ctype> vis0;
ParamElemVisitorBcast101x4<src_ctype> vis1;
OpCallerBinaryVecBcast101xD<src_ctype, 4>::run(
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks,
channel_stride);
}
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
struct OpCallerBinaryVecBcast101xX<__fp16, 8> {
using src_ctype = __fp16;
template <typename Op>
static void run(const src_ctype* src0, const src_ctype* src1,
typename Op::dst_ctype* dst, const Op& op, size_t batch,
size_t nr_channel_blocks, size_t channel_stride) {
ParamElemVisitor<src_ctype> vis0;
ParamElemVisitorBcast101x8<src_ctype> vis1;
OpCallerBinaryVecBcast101xD<src_ctype, 8>::run(
src0, src1, dst, op, vis0, vis1, batch, nr_channel_blocks,
channel_stride);
}
};
#endif
template <typename Op>
struct OpCallerBinary<Op, VEC_BCAST101xX> {
static void run(const typename Op::src_ctype* src0,
const typename Op::src_ctype* src1,
typename Op::dst_ctype* dst, DType src0_dtype,
DType src1_dtype, DType dst_dtype, size_t batch,
size_t nr_channel_blocks, size_t channel_stride,
size_t channel_block_dim) {
megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8,
"only imp for nchw44/nchw88");
Op op(src0_dtype, src1_dtype, dst_dtype);
if (channel_block_dim == 4) {
OpCallerBinaryVecBcast101xX<typename Op::src_ctype, 4>::run(
src0, src1, dst, op, batch, nr_channel_blocks,
channel_stride);
} else {
OpCallerBinaryVecBcast101xX<typename Op::src_ctype, 8>::run(
src0, src1, dst, op, batch, nr_channel_blocks,
channel_stride);
}
}
};
template <typename Op> template <typename Op>
struct OpCallerBinary<Op, VEC_SCALAR> { struct OpCallerBinary<Op, VEC_SCALAR> {
static void run(const typename Op::src_ctype* src0, static void run(const typename Op::src_ctype* src0,
...@@ -683,21 +845,42 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> { ...@@ -683,21 +845,42 @@ struct OpCallerTernary<Op, BCAST101_VEC_BCAST101> {
} }
}; };
//! src0: CHW44, src1: vector, src2: CHW44 template <typename src_ctype, size_t channel_block_dim>
template <typename Op> struct OpCallerTernaryBcast101xXVecBcast101xX {
struct OpCallerTernary<Op, BCAST101x4_VEC_BCAST101x4> { template <typename Op>
static void run(const typename Op::src_ctype* src0, static void run(const src_ctype* src0, const src_ctype* src1,
const typename Op::src_ctype* src1, const src_ctype* src2, typename Op::dst_ctype* dst,
const typename Op::src_ctype* src2, const Op& op, size_t batch, size_t nr_channel_blocks,
typename Op::dst_ctype* dst, DType src0_dtype, size_t channel_stride) {
DType src1_dtype, DType src2_dtype, DType dst_dtype, for (size_t b = 0; b < batch; b++) {
size_t batch, size_t nr_channel_blocks, auto src0_ptr = src0;
size_t channel_stride, size_t channel_block_dim) { auto src2_ptr = src2;
megdnn_assert(channel_block_dim == 4, "only imp for nchw44"); for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype); auto src0_block_ptr = src0_ptr + cb * channel_block_dim;
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis0; auto src2_block_ptr = src2_ptr + cb * channel_block_dim;
ParamElemVisitor<typename Op::src_ctype> vis1; for (size_t img_index = 0; img_index < channel_stride;
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis2; img_index++) {
for (size_t c_iter = 0; c_iter < channel_block_dim;
c_iter++) {
op(*(src0_block_ptr + c_iter), *src1,
*(src2_block_ptr + c_iter), dst);
src1++;
dst++;
}
}
}
}
}
};
template <typename src_ctype, size_t channel_block_dim>
struct OpCallerTernaryBcast101xDVecBcast101xD {
template <typename Op, typename Vis0, typename Vis1, typename Vis2>
static void run(const src_ctype* src0, const src_ctype* src1,
const src_ctype* src2, typename Op::dst_ctype* dst,
const Op& op, const Vis0& vis0, const Vis1& vis1,
const Vis2& vis2, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
for (size_t b = 0; b < batch; b++) { for (size_t b = 0; b < batch; b++) {
auto src0_ptr = src0; auto src0_ptr = src0;
auto src2_ptr = src2; auto src2_ptr = src2;
...@@ -731,6 +914,70 @@ struct OpCallerTernary<Op, BCAST101x4_VEC_BCAST101x4> { ...@@ -731,6 +914,70 @@ struct OpCallerTernary<Op, BCAST101x4_VEC_BCAST101x4> {
} }
}; };
//! src0: CHW44, src1: vector, src2: CHW44
template <typename src_ctype>
struct OpCallerTernaryBcast101xXVecBcast101xX<src_ctype, 4> {
template <typename Op>
static void run(const src_ctype* src0, const src_ctype* src1,
const src_ctype* src2, typename Op::dst_ctype* dst,
const Op& op, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
ParamElemVisitorBcast101x4<src_ctype> vis0;
ParamElemVisitor<src_ctype> vis1;
ParamElemVisitorBcast101x4<src_ctype> vis2;
OpCallerTernaryBcast101xDVecBcast101xD<src_ctype, 4>::run(
src0, src1, src2, dst, op, vis0, vis1, vis2, batch,
nr_channel_blocks, channel_stride);
}
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
struct OpCallerTernaryBcast101xXVecBcast101xX<__fp16, 8> {
using src_ctype = __fp16;
template <typename Op>
static void run(const src_ctype* src0, const src_ctype* src1,
const src_ctype* src2, typename Op::dst_ctype* dst,
const Op& op, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
ParamElemVisitorBcast101x8<src_ctype> vis0;
ParamElemVisitor<src_ctype> vis1;
ParamElemVisitorBcast101x8<src_ctype> vis2;
OpCallerTernaryBcast101xDVecBcast101xD<src_ctype, 8>::run(
src0, src1, src2, dst, op, vis0, vis1, vis2, batch,
nr_channel_blocks, channel_stride);
}
};
#endif
template <typename Op>
struct OpCallerTernary<Op, BCAST101xX_VEC_BCAST101xX> {
static void run(const typename Op::src_ctype* src0,
const typename Op::src_ctype* src1,
const typename Op::src_ctype* src2,
typename Op::dst_ctype* dst, DType src0_dtype,
DType src1_dtype, DType src2_dtype, DType dst_dtype,
size_t batch, size_t nr_channel_blocks,
size_t channel_stride, size_t channel_block_dim) {
megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8,
"only imp for nchw44/nchw88");
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
if (channel_block_dim == 4) {
OpCallerTernaryBcast101xXVecBcast101xX<typename Op::src_ctype,
4>::run(src0, src1, src2,
dst, op, batch,
nr_channel_blocks,
channel_stride);
} else {
OpCallerTernaryBcast101xXVecBcast101xX<typename Op::src_ctype,
8>::run(src0, src1, src2,
dst, op, batch,
nr_channel_blocks,
channel_stride);
}
}
};
//! src1: 1C11, src0 and src2 are contig //! src1: 1C11, src0 and src2 are contig
template <typename Op> template <typename Op>
struct OpCallerTernary<Op, VEC_BCAST101_VEC> { struct OpCallerTernary<Op, VEC_BCAST101_VEC> {
...@@ -775,21 +1022,41 @@ struct OpCallerTernary<Op, VEC_BCAST101_VEC> { ...@@ -775,21 +1022,41 @@ struct OpCallerTernary<Op, VEC_BCAST101_VEC> {
} }
}; };
template <typename src_ctype, size_t channel_block_dim>
struct OpCallerTernaryVecBcast101xXVec {
template <typename Op>
static void run(const src_ctype* src0, const src_ctype* src1,
const src_ctype* src2, typename Op::dst_ctype* dst,
const Op& op, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
for (size_t b = 0; b < batch; b++) {
auto src1_ptr = src1;
for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
auto src1_block_ptr = src1_ptr + cb * channel_block_dim;
for (size_t img_index = 0; img_index < channel_stride;
img_index++) {
for (size_t c_iter = 0; c_iter < channel_block_dim;
c_iter++) {
op(*src0, *(src1_block_ptr + c_iter), *src2, dst);
src0++;
src2++;
dst++;
}
}
}
}
}
};
//! src1: CHW44, src0 and src2 are contig //! src1: CHW44, src0 and src2 are contig
template <typename Op> template <typename src_ctype, size_t channel_block_dim>
struct OpCallerTernary<Op, VEC_BCAST101x4_VEC> { struct OpCallerTernaryVecBcast101xDVec {
static void run(const typename Op::src_ctype* src0, template <typename Op, typename Vis0, typename Vis1, typename Vis2>
const typename Op::src_ctype* src1, static void run(const src_ctype* src0, const src_ctype* src1,
const typename Op::src_ctype* src2, const src_ctype* src2, typename Op::dst_ctype* dst,
typename Op::dst_ctype* dst, DType src0_dtype, const Op& op, const Vis0& vis0, const Vis1& vis1,
DType src1_dtype, DType src2_dtype, DType dst_dtype, const Vis2& vis2, size_t batch, size_t nr_channel_blocks,
size_t batch, size_t nr_channel_blocks, size_t channel_stride) {
size_t channel_stride, size_t channel_block_dim) {
megdnn_assert(channel_block_dim == 4, "only imp for nchw44");
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
ParamElemVisitor<typename Op::src_ctype> vis0;
ParamElemVisitorBcast101x4<typename Op::src_ctype> vis1;
ParamElemVisitor<typename Op::src_ctype> vis2;
for (size_t b = 0; b < batch; b++) { for (size_t b = 0; b < batch; b++) {
auto src1_ptr = src1; auto src1_ptr = src1;
for (size_t cb = 0; cb < nr_channel_blocks; cb++) { for (size_t cb = 0; cb < nr_channel_blocks; cb++) {
...@@ -821,6 +1088,66 @@ struct OpCallerTernary<Op, VEC_BCAST101x4_VEC> { ...@@ -821,6 +1088,66 @@ struct OpCallerTernary<Op, VEC_BCAST101x4_VEC> {
} }
}; };
template <typename src_ctype>
struct OpCallerTernaryVecBcast101xXVec<src_ctype, 4> {
template <typename Op>
static void run(const src_ctype* src0, const src_ctype* src1,
const src_ctype* src2, typename Op::dst_ctype* dst,
const Op& op, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
ParamElemVisitor<src_ctype> vis0;
ParamElemVisitorBcast101x4<src_ctype> vis1;
ParamElemVisitor<src_ctype> vis2;
OpCallerTernaryVecBcast101xDVec<src_ctype, 4>::run(
src0, src1, src2, dst, op, vis0, vis1, vis2, batch,
nr_channel_blocks, channel_stride);
}
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
template <>
struct OpCallerTernaryVecBcast101xXVec<__fp16, 8> {
using src_ctype = __fp16;
template <typename Op>
static void run(const src_ctype* src0, const src_ctype* src1,
const src_ctype* src2, typename Op::dst_ctype* dst,
const Op& op, size_t batch, size_t nr_channel_blocks,
size_t channel_stride) {
ParamElemVisitor<src_ctype> vis0;
ParamElemVisitorBcast101x8<src_ctype> vis1;
ParamElemVisitor<src_ctype> vis2;
OpCallerTernaryVecBcast101xDVec<src_ctype, 8>::run(
src0, src1, src2, dst, op, vis0, vis1, vis2, batch,
nr_channel_blocks, channel_stride);
}
};
#endif
template <typename Op>
struct OpCallerTernary<Op, VEC_BCAST101xX_VEC> {
static void run(const typename Op::src_ctype* src0,
const typename Op::src_ctype* src1,
const typename Op::src_ctype* src2,
typename Op::dst_ctype* dst, DType src0_dtype,
DType src1_dtype, DType src2_dtype, DType dst_dtype,
size_t batch, size_t nr_channel_blocks,
size_t channel_stride, size_t channel_block_dim) {
megdnn_assert(channel_block_dim == 4 || channel_block_dim == 8,
"only imp for nchw44/nchw88");
Op op(src0_dtype, src1_dtype, src2_dtype, dst_dtype);
if (channel_block_dim == 4) {
OpCallerTernaryVecBcast101xXVec<typename Op::src_ctype, 4>::run(
src0, src1, src2, dst, op, batch, nr_channel_blocks,
channel_stride);
} else {
OpCallerTernaryVecBcast101xXVec<typename Op::src_ctype, 8>::run(
src0, src1, src2, dst, op, batch, nr_channel_blocks,
channel_stride);
}
}
};
//! src1: scalar, src0 and src2 has the same shape //! src1: scalar, src0 and src2 has the same shape
template <typename Op> template <typename Op>
struct OpCallerTernary<Op, VEC_SCALAR_VEC> { struct OpCallerTernary<Op, VEC_SCALAR_VEC> {
......
...@@ -53,6 +53,20 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) { ...@@ -53,6 +53,20 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_TERNARY) {
checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}}); checker.execs({{3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {3, 4, 5, 7, 4}, {}});
checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}}); checker.execs({{1, 2, 5, 7, 4}, {1, 2, 1, 1, 4}, {1, 2, 5, 7, 4}, {}});
//! nchw88
checker.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
checker.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
//! nchw88
checker.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
checker.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
checker.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
checker.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
checker.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}}); checker.execs({{3, 4, 7}, {3, 4, 7}, {3, 4, 7}, {}});
checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}}); checker.execs({{1, 4, 1, 1}, {3, 4, 5, 7}, {1, 4, 1, 1}, {}});
checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}}); checker.execs({{1, 4, 1}, {3, 4, 7}, {1, 4, 1}, {}});
...@@ -227,6 +241,78 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) { ...@@ -227,6 +241,78 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) {
run(Mode::POW); run(Mode::POW);
} }
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW88_FP) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
checker.set_param(Mode::FUSE_ADD_RELU)
.execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
auto run = [&](Mode mode) {
// VEC_BCAST101x
checker.set_param(mode).execs({{1, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.set_param(mode).execs({{2, 3, 2, 2, 8}, {1, 3, 1, 1, 8}, {}});
checker.set_param(mode).execs({{3, 8, 5, 3, 8}, {1, 8, 1, 1, 8}, {}});
checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
checker.set_param(mode).execs({{1, 2, 5, 7, 8}, {1, 2, 1, 1, 8}, {}});
// BCAST101x_VEC not powOp
checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {1, 3, 2, 2, 8}, {}});
checker.set_param(mode).execs({{1, 3, 1, 1, 8}, {2, 3, 2, 2, 8}, {}});
checker.set_param(mode).execs({{1, 8, 1, 1, 8}, {3, 8, 5, 3, 8}, {}});
checker.set_param(mode).execs({{3, 4, 5, 7, 8}, {3, 4, 5, 7, 8}, {}});
checker.set_param(mode).execs({{1, 2, 1, 1, 8}, {1, 2, 5, 7, 8}, {}});
};
auto run_all = [&]() {
run(Mode::ADD);
run(Mode::FUSE_ADD_H_SWISH);
run(Mode::FUSE_ADD_RELU);
run(Mode::MAX);
run(Mode::MIN);
run(Mode::MUL);
run(Mode::SUB);
run(Mode::TRUE_DIV);
run(Mode::POW);
};
{
UniformFloatRNG rng(1e-5, 7e1);
checker.set_rng(0, &rng);
checker.set_epsilon(1e-5);
checker.set_dtype(0, dtype::Float32());
checker.set_dtype(1, dtype::Float32());
run_all();
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
{
UniformFloatRNG rng(1, 2);
checker.set_rng(0, &rng);
checker.set_epsilon(3e-3);
checker.set_dtype(0, dtype::Float16());
checker.set_dtype(1, dtype::Float16());
run_all();
}
#endif
}
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
namespace { namespace {
void run_elemwise_benchmark(const TensorShapeArray& shapes, void run_elemwise_benchmark(const TensorShapeArray& shapes,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册