From c59be192cdf5bb2940f3b44271d716497aa15f9f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 17 Apr 2020 15:13:17 +0800 Subject: [PATCH] feat(dnn/arm_common/elemwise): add arm_common support chw44 elemwise GitOrigin-RevId: aba44e01233107168b1a2c5e8dbc0c0ef1e71687 --- dnn/src/common/elemwise/opr_impl_helper.cpp | 16 ++++++++++++---- dnn/src/common/elemwise/opr_impl_helper.h | 4 +++- dnn/src/x86/elemwise/opr_impl.cpp | 10 ++++++---- dnn/src/x86/elemwise_op.h | 8 ++++---- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/dnn/src/common/elemwise/opr_impl_helper.cpp b/dnn/src/common/elemwise/opr_impl_helper.cpp index 04a9de1f0..bb4ca3438 100644 --- a/dnn/src/common/elemwise/opr_impl_helper.cpp +++ b/dnn/src/common/elemwise/opr_impl_helper.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "./opr_impl_helper.h" @@ -79,18 +80,19 @@ bool ElemwiseLayoutHelper::is_broadcasted_scalar(const TensorLayout& layout) { } return true; } +template bool ElemwiseLayoutHelper::is_broadcastedx_channel_like( const TensorLayout& layout, BroadcastChannelInfo& info) { if (layout.format.type() == TensorFormat::Type::DEFAULT && - layout.ndim == 3 && layout.stride[0] == 8 && layout.stride[1] == 0 && - layout.stride[2] == 1) { + layout.ndim == 3 && layout.stride[0] == slice_size && + layout.stride[1] == 0 && layout.stride[2] == 1) { info.x = layout.shape[0]; info.y = layout.shape[1]; info.z = layout.shape[2]; return true; } else if (layout.format.type() == TensorFormat::Type::DEFAULT && layout.ndim == 4 && layout.stride[0] == 0 && - layout.stride[1] == 8 && layout.stride[2] == 0 && + layout.stride[1] == slice_size && layout.stride[2] == 0 && layout.stride[3] == 1) { info.x = layout.shape[1]; info.y = layout.shape[2]; @@ -99,6 +101,12 @@ bool ElemwiseLayoutHelper::is_broadcastedx_channel_like( } return false; } +#define INST(n) \ + template bool ElemwiseLayoutHelper::is_broadcastedx_channel_like( \ + const TensorLayout& layout, BroadcastChannelInfo& info) +INST(4); +INST(8); +#undef INST bool ElemwiseLayoutHelper::is_broadcasted_channel_like( const TensorLayout& layout, BroadcastChannelInfo& info) { diff --git a/dnn/src/common/elemwise/opr_impl_helper.h b/dnn/src/common/elemwise/opr_impl_helper.h index eb31983fc..e3e474957 100644 --- a/dnn/src/common/elemwise/opr_impl_helper.h +++ b/dnn/src/common/elemwise/opr_impl_helper.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -87,6 +88,7 @@ public: * Note that Input can also be 3-dimensional, and must be [x, 1, z] * broadacsted into [x, y, z] */ + template static bool is_broadcastedx_channel_like(const TensorLayout& layout, BroadcastChannelInfo& info); }; diff --git a/dnn/src/x86/elemwise/opr_impl.cpp b/dnn/src/x86/elemwise/opr_impl.cpp index 60eac1be9..28e9a189c 100644 --- a/dnn/src/x86/elemwise/opr_impl.cpp +++ b/dnn/src/x86/elemwise/opr_impl.cpp @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include "src/x86/elemwise/opr_impl.h" #include "src/x86/elemwise_op.h" @@ -360,13 +361,14 @@ bool ElemwiseImpl::exec_binary() { return true; \ } { - bool normal_case = is_vector(src1.layout) && - is_broadcastedx_channel_like(src0.layout, binfo); + bool normal_case = + is_vector(src1.layout) && + is_broadcastedx_channel_like<8>(src0.layout, binfo); bool swap_case = false; bool commutable = mode_trait().commutable; if (!normal_case && commutable) { swap_case = is_vector(src0.layout) && - is_broadcastedx_channel_like(src1.layout, binfo); + is_broadcastedx_channel_like<8>(src1.layout, binfo); } if ((swap_case || normal_case) && diff --git a/dnn/src/x86/elemwise_op.h b/dnn/src/x86/elemwise_op.h index a4a883054..39c59cce3 100644 --- a/dnn/src/x86/elemwise_op.h +++ b/dnn/src/x86/elemwise_op.h @@ -414,7 +414,7 @@ struct OpCallerBinary { const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, - size_t nr_blocks_in_channel, size_t channel_stride, + size_t nr_channel_blocks, size_t channel_stride, size_t channel_block_dim) { megdnn_assert(channel_block_dim == 8, "avx2 only support nchw88"); Op op(src0_dtype, src1_dtype, dst_dtype); @@ -422,7 +422,7 @@ struct OpCallerBinary { ParamElemVisitor vis1; for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; - for (size_t cb = 0; cb < nr_blocks_in_channel; cb++) { + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src0_block_ptr = src0_ptr + cb * channel_block_dim; auto channel_block_vec = vis0(src0_block_ptr); size_t img_index = 0; @@ -451,12 +451,12 @@ struct OpCallerBinary { const typename Op::src_ctype* src1, typename Op::dst_ctype* dst, DType src0_dtype, DType src1_dtype, DType dst_dtype, size_t batch, - size_t nr_blocks_in_channel, size_t channel_stride, + size_t nr_channel_blocks, size_t channel_stride, size_t channel_block_dim) { Op op(src0_dtype, src1_dtype, dst_dtype); for (size_t b = 0; b < batch; b++) { auto src0_ptr = src0; - for (size_t cb = 0; cb < nr_blocks_in_channel; cb++) { + for (size_t cb = 0; cb < nr_channel_blocks; cb++) { auto src0_block_ptr = src0_ptr + cb * channel_block_dim; for (size_t i = 0; i < channel_stride; i++) { for (size_t c_iter = 0; c_iter < channel_block_dim; -- GitLab