diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.cpp b/dnn/src/arm_common/conv_bias/fp32/algos.cpp index 20f11b81009ecc9350ae2279b3f26ed8a946f1ee..fef97afec66c25c6a7415f4565a9b7a8a2f9b166 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/algos.cpp @@ -331,6 +331,51 @@ MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF63_4x4_NCHW44, megdnn_arm_common_winograd_fp32, param::MatrixMul::Format::MK4); +/* =================== AlgoFP32WinogradF73_4x4_NCHW44 ===================== */ + +bool ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44::usable( + const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, + midout_iv("AlgoFP32WinogradF73_4x4_NCHW44"_hash)) { + if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) + return false; + using Strategy = winograd::winograd_F73_mk4_f_nchw44; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + m_matmul_algo->packmode() == + fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK && + (param.filter_meta.format == param::ConvBias::Format::NCHW44 || + (param.filter_meta.format == + param::ConvBias::Format::NCHW44_WINOGRAD && + param.output_block_size == 7 && + param.winograd_matmul_format == + param::MatrixMul::Format::MK4)) && + !param.filter_meta.should_flip && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float32; + } + MIDOUT_END(); + return false; +} + +MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF73_4x4_NCHW44, + winograd::winograd_F73_mk4_f_nchw44, + megdnn_arm_common_winograd_fp32, + param::MatrixMul::Format::MK4); + /* ===================== direct algo ===================== */ MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.h b/dnn/src/arm_common/conv_bias/fp32/algos.h index c04d009def0427aa06146de46cac7b461f520b40..9cf5fb99fa5851c6dad0e090e09866f93b65baf7 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.h +++ b/dnn/src/arm_common/conv_bias/fp32/algos.h @@ -124,6 +124,22 @@ public: } MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); }; + +class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { +public: + AlgoFP32WinogradF73_4x4_NCHW44( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {4, 7, m_tile_size}, + param::ConvBias::Format::NCHW44); + } + return m_name.c_str(); + } + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); +}; // ================================================================= // class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy.h b/dnn/src/arm_common/conv_bias/fp32/strategy.h index 1f68a3e741c41a7f8e2b74607e16c3737d92e4ee..4b7fc9c5919f9224e588dd4aafe1ddddb211fd62 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy.h +++ b/dnn/src/arm_common/conv_bias/fp32/strategy.h @@ -38,6 +38,9 @@ MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4, MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, winograd_F63_mk4_f_nchw44) + +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 7, 3, 4, 4, + winograd_F73_mk4_f_nchw44) } // namespace winograd } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp index b861a23dcdfb82662bcdca04c0277ad3ceda5dcd..a435d727842badb4c1826df8316efff627035d4d 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp @@ -1,6 +1,5 @@ /** - * \file dnn/src/arm_common/conv_bias/fp32/strategy_f - * 63_mk4_nchw44.cpp + * \file dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a988a3be96888e5d98aeb9cd05483153575977b6 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp @@ -0,0 +1,587 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/strategy_f73_mk4_nchw44.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/fp32/filter_transform.h" +#include "src/arm_common/conv_bias/fp32/helper.h" +#include "src/arm_common/conv_bias/fp32/strategy.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/common/winograd/winograd_helper.h" +#include "src/fallback/conv_bias/winograd/winograd.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F73_mk4) + +using namespace megdnn; +using namespace arm_common; + +namespace { + +constexpr size_t alpha = 7 + 3 - 1; +constexpr size_t pack_size = 4; +constexpr float input_parameters[28] = { + 1.5f, 0.75f, 3.0f, 7.875f, 0.5f, 2.5f, 0.125f, + 0.875f, 4.0f, 8.0f, 5.25f, 7.375f, 5.375f, 3.5f, + 7.75f, 0.25f, 2.125f, 10.625f, 0.625f, 4.375f, 5.0f, + 10.0f, 5.75f, 2.75f, 4.25f, 1.75f, 2.0f, 0.0f}; + +struct InputTransformF73_NCHW44 { + template + static void prepare(const float* input, float* patch, float* patchT, + int ih_start, int iw_start, size_t IH, size_t IW, + size_t ic, size_t IC) { + MEGDNN_MARK_USED_VAR(patch); + size_t IW4 = IW * pack_size; + size_t iw4_start = iw_start * pack_size; + size_t icb = ic / pack_size; + if (!(inner && ic + pack_size < IC)) { + memset(patchT, 0, sizeof(float) * pack_size * alpha * alpha); + } + if (inner) { + const float* input_ptr = + input + icb * IH * IW4 + ih_start * IW4 + iw4_start; + for (size_t ih = 0; ih < alpha; ih++) { +#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i); + UNROLL_CALL_NOWRAPPER(9, cb); +#undef cb + +#define cb(i) vst1q_f32(patchT + ih * pack_size * alpha + i * pack_size, v##i); + UNROLL_CALL_NOWRAPPER(9, cb); +#undef cb + input_ptr += IW4; + } + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + const float* input_ptr = input + icb * IH * IW4; + // partial copy + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); + vst1q_f32( + patchT + iho * pack_size * alpha + iwo * pack_size, + src); + } + } + } + } + + static void transform(const float* patchT, float* input_transform_buf, + size_t unit_idx, size_t nr_units_in_tile, size_t ic, + size_t IC) { + // BT * d * B + + size_t ICB = IC / pack_size; + size_t icb = ic / pack_size; + + float32x4_t d0, d1, d2, d3, d4, d5, d6, d7, d8; + float32x4_t v0 = vld1q_f32(input_parameters + 0); + float32x4_t v1 = vld1q_f32(input_parameters + 4); + float32x4_t v2 = vld1q_f32(input_parameters + 8); + float32x4_t v3 = vld1q_f32(input_parameters + 12); + float32x4_t v4 = vld1q_f32(input_parameters + 16); + float32x4_t v5 = vld1q_f32(input_parameters + 20); + float32x4_t v6 = vld1q_f32(input_parameters + 24); + + //! B + //! 1.5 0 0 0 0 0 0 0 0 + //! -1 -1.5 1.5 -0.75 0.75 -3 3 -1 1.5 + //! -7.875 -0.5 -2.5 0.125 -0.875 -4 -8 0 -1 + //! 5.25 7.375 -5.375 4 -3.5 7.75 0.25 5.25 -7.875 + //! 7.875 2.125 10.625 -0.625 4.375 5 10 0 5.25 + //! -5.25 -5.75 -2.75 -4.25 1.75 -5.75 -4.25 -5.25 7.875 + //! -1.5 -0.5 -2.5 0.5 -3.5 -1 -2 0 -5.25 + //! 1 1 1 1 1 1 1 1 -1.5 + //! 0 0 0 0 0 0 0 0 1 + + // 1.5f, 0.75f, 3.0f, 7.875f, v0 + // 0.5f, 2.5f, 0.125f, 0.875f, v1 + // 4.0f, 8.0f, 5.25f, 7.375f, v2 + // 5.375f, 3.5f, 7.75f, 0.25f, v3 + // 2.125f, 10.625f, 0.625f, 4.375f, v4 + // 5.0f, 10.0f, 5.75f, 2.75f, v5 + // 4.25f, 1.75f, 2.0f, 0.0f, v6 + +#define cb(i) \ + d0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \ + d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \ + d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \ + d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \ + d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \ + d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \ + d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \ + d7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \ + auto t##i##8 = vld1q_f32(patchT + i * alpha * pack_size + 8 * pack_size); \ + auto t##i##0 = d7; \ + auto t##i##1 = d7; \ + auto t##i##2 = d7; \ + auto t##i##3 = d7; \ + auto t##i##4 = d7; \ + auto t##i##5 = d7; \ + auto t##i##6 = d7; \ + auto t##i##7 = d7; \ + t##i##8 = vfmsq_laneq_f32(t##i##8, d7, v0, 0); \ + t##i##0 = t##i##0 - d1; \ + t##i##1 = vfmsq_laneq_f32(t##i##1, d1, v0, 0); \ + t##i##2 = vfmaq_laneq_f32(t##i##2, d1, v0, 0); \ + t##i##3 = vfmsq_laneq_f32(t##i##3, d1, v0, 1); \ + t##i##4 = vfmaq_laneq_f32(t##i##4, d1, v0, 1); \ + t##i##5 = vfmsq_laneq_f32(t##i##5, d1, v0, 2); \ + t##i##6 = vfmaq_laneq_f32(t##i##6, d1, v0, 2); \ + t##i##7 = t##i##7 - d1; \ + t##i##8 = vfmaq_laneq_f32(t##i##8, d1, v0, 0); \ + t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 3); \ + t##i##1 = vfmsq_laneq_f32(t##i##1, d2, v1, 0); \ + t##i##2 = vfmsq_laneq_f32(t##i##2, d2, v1, 1); \ + t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v1, 2); \ + t##i##4 = vfmsq_laneq_f32(t##i##4, d2, v1, 3); \ + t##i##5 = vfmsq_laneq_f32(t##i##5, d2, v2, 0); \ + t##i##6 = vfmsq_laneq_f32(t##i##6, d2, v2, 1); \ + t##i##8 = t##i##8 - d2; \ + t##i##0 = vfmaq_laneq_f32(t##i##0, d3, v2, 2); \ + t##i##1 = vfmaq_laneq_f32(t##i##1, d3, v2, 3); \ + t##i##2 = vfmsq_laneq_f32(t##i##2, d3, v3, 0); \ + t##i##3 = vfmaq_laneq_f32(t##i##3, d3, v2, 0); \ + t##i##4 = vfmsq_laneq_f32(t##i##4, d3, v3, 1); \ + t##i##5 = vfmaq_laneq_f32(t##i##5, d3, v3, 2); \ + t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v3, 3); \ + t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v2, 2); \ + t##i##8 = vfmsq_laneq_f32(t##i##8, d3, v0, 3); \ + t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 3); \ + t##i##1 = vfmaq_laneq_f32(t##i##1, d4, v4, 0); \ + t##i##2 = vfmaq_laneq_f32(t##i##2, d4, v4, 1); \ + t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v4, 2); \ + t##i##4 = vfmaq_laneq_f32(t##i##4, d4, v4, 3); \ + t##i##5 = vfmaq_laneq_f32(t##i##5, d4, v5, 0); \ + t##i##6 = vfmaq_laneq_f32(t##i##6, d4, v5, 1); \ + t##i##8 = vfmaq_laneq_f32(t##i##8, d4, v2, 2); \ + t##i##0 = vfmsq_laneq_f32(t##i##0, d5, v2, 2); \ + t##i##1 = vfmsq_laneq_f32(t##i##1, d5, v5, 2); \ + t##i##2 = vfmsq_laneq_f32(t##i##2, d5, v5, 3); \ + t##i##3 = vfmsq_laneq_f32(t##i##3, d5, v6, 0); \ + t##i##4 = vfmaq_laneq_f32(t##i##4, d5, v6, 1); \ + t##i##5 = vfmsq_laneq_f32(t##i##5, d5, v5, 2); \ + t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v6, 0); \ + t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v2, 2); \ + t##i##8 = vfmaq_laneq_f32(t##i##8, d5, v0, 3); \ + t##i##0 = vfmsq_laneq_f32(t##i##0, d6, v0, 0); \ + t##i##1 = vfmsq_laneq_f32(t##i##1, d6, v1, 0); \ + t##i##2 = vfmsq_laneq_f32(t##i##2, d6, v1, 1); \ + t##i##3 = vfmaq_laneq_f32(t##i##3, d6, v1, 0); \ + t##i##4 = vfmsq_laneq_f32(t##i##4, d6, v3, 1); \ + t##i##5 = t##i##5 - d6; \ + t##i##6 = vfmsq_laneq_f32(t##i##6, d6, v6, 2); \ + t##i##8 = vfmsq_laneq_f32(t##i##8, d6, v2, 2); \ + t##i##0 = vfmaq_laneq_f32(t##i##0, d0, v0, 0); + + UNROLL_CALL_RAW(9, cb); +#undef cb + +#define cb(i) \ + d8 = t8##i; \ + d0 = t7##i; \ + d1 = t7##i; \ + d2 = t7##i; \ + d3 = t7##i; \ + d4 = t7##i; \ + d5 = t7##i; \ + d6 = t7##i; \ + d7 = t7##i; \ + d8 = vfmsq_laneq_f32(d8, t7##i, v0, 0); \ + d0 = d0 - t1##i; \ + d1 = vfmsq_laneq_f32(d1, t1##i, v0, 0); \ + d2 = vfmaq_laneq_f32(d2, t1##i, v0, 0); \ + d3 = vfmsq_laneq_f32(d3, t1##i, v0, 1); \ + d4 = vfmaq_laneq_f32(d4, t1##i, v0, 1); \ + d5 = vfmsq_laneq_f32(d5, t1##i, v0, 2); \ + d6 = vfmaq_laneq_f32(d6, t1##i, v0, 2); \ + d7 = d7 - t1##i; \ + d8 = vfmaq_laneq_f32(d8, t1##i, v0, 0); \ + d0 = vfmsq_laneq_f32(d0, t2##i, v0, 3); \ + d1 = vfmsq_laneq_f32(d1, t2##i, v1, 0); \ + d2 = vfmsq_laneq_f32(d2, t2##i, v1, 1); \ + d3 = vfmaq_laneq_f32(d3, t2##i, v1, 2); \ + d4 = vfmsq_laneq_f32(d4, t2##i, v1, 3); \ + d5 = vfmsq_laneq_f32(d5, t2##i, v2, 0); \ + d6 = vfmsq_laneq_f32(d6, t2##i, v2, 1); \ + d8 = d8 - t2##i; \ + d0 = vfmaq_laneq_f32(d0, t3##i, v2, 2); \ + d1 = vfmaq_laneq_f32(d1, t3##i, v2, 3); \ + d2 = vfmsq_laneq_f32(d2, t3##i, v3, 0); \ + d3 = vfmaq_laneq_f32(d3, t3##i, v2, 0); \ + d4 = vfmsq_laneq_f32(d4, t3##i, v3, 1); \ + d5 = vfmaq_laneq_f32(d5, t3##i, v3, 2); \ + d6 = vfmaq_laneq_f32(d6, t3##i, v3, 3); \ + d7 = vfmaq_laneq_f32(d7, t3##i, v2, 2); \ + d8 = vfmsq_laneq_f32(d8, t3##i, v0, 3); \ + d0 = vfmaq_laneq_f32(d0, t4##i, v0, 3); \ + d1 = vfmaq_laneq_f32(d1, t4##i, v4, 0); \ + d2 = vfmaq_laneq_f32(d2, t4##i, v4, 1); \ + d3 = vfmsq_laneq_f32(d3, t4##i, v4, 2); \ + d4 = vfmaq_laneq_f32(d4, t4##i, v4, 3); \ + d5 = vfmaq_laneq_f32(d5, t4##i, v5, 0); \ + d6 = vfmaq_laneq_f32(d6, t4##i, v5, 1); \ + d8 = vfmaq_laneq_f32(d8, t4##i, v2, 2); \ + d0 = vfmsq_laneq_f32(d0, t5##i, v2, 2); \ + d1 = vfmsq_laneq_f32(d1, t5##i, v5, 2); \ + d2 = vfmsq_laneq_f32(d2, t5##i, v5, 3); \ + d3 = vfmsq_laneq_f32(d3, t5##i, v6, 0); \ + d4 = vfmaq_laneq_f32(d4, t5##i, v6, 1); \ + d5 = vfmsq_laneq_f32(d5, t5##i, v5, 2); \ + d6 = vfmsq_laneq_f32(d6, t5##i, v6, 0); \ + d7 = vfmsq_laneq_f32(d7, t5##i, v2, 2); \ + d8 = vfmaq_laneq_f32(d8, t5##i, v0, 3); \ + d0 = vfmsq_laneq_f32(d0, t6##i, v0, 0); \ + d1 = vfmsq_laneq_f32(d1, t6##i, v1, 0); \ + d2 = vfmsq_laneq_f32(d2, t6##i, v1, 1); \ + d3 = vfmaq_laneq_f32(d3, t6##i, v1, 0); \ + d4 = vfmsq_laneq_f32(d4, t6##i, v3, 1); \ + d5 = d5 - t6##i; \ + d6 = vfmsq_laneq_f32(d6, t6##i, v6, 2); \ + d8 = vfmsq_laneq_f32(d8, t6##i, v2, 2); \ + d0 = vfmaq_laneq_f32(d0, t0##i, v0, 0); \ + vst1q_f32(input_transform_buf + \ + (0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d0); \ + vst1q_f32(input_transform_buf + \ + (1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d1); \ + vst1q_f32(input_transform_buf + \ + (2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d2); \ + vst1q_f32(input_transform_buf + \ + (3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d3); \ + vst1q_f32(input_transform_buf + \ + (4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d4); \ + vst1q_f32(input_transform_buf + \ + (5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d5); \ + vst1q_f32(input_transform_buf + \ + (6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d6); \ + vst1q_f32(input_transform_buf + \ + (7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d7); \ + vst1q_f32(input_transform_buf + \ + (8 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d8); + + UNROLL_CALL_RAW(9, cb); +#undef cb + } +}; + +template +struct OutputTransformF73_NCHW44 { + static void transform(const float* output_transform_buf, const float* bias, + float* output, float* transform_mid_buf, + size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, + const DType& dst_dtype) { + MEGDNN_MARK_USED_VAR(transform_mid_buf); + Op op(src_dtype, dst_dtype); + //! AT * m * A + + size_t oc = oc_start + oc_index; + size_t OCB = (oc_end - oc_start) / pack_size; + size_t ocb = oc_index / pack_size; + +#define cb(m, n) \ + auto v##m##n = Vector::load( \ + output_transform_buf + \ + (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ + ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); + + UNROLL_CALL_NOWRAPPER_D2(9, 9, cb); +#undef cb + + /** + * A + * + * 1 0 0 0 0 0 0 + * 1 1 1 1 1 1 1 + * 1 -1 1 -1 1 -1 1 + * 1 2 4 8 16 32 64 + * 1 -2 4 -8 16 -32 64 + * 1 0.5 0.25 0.125 0.0625 0.03125 0.015625 + * 1 -0.5 0.25 -0.125 0.0625 -0.03125 0.015625 + * 1 1.5 2.25 3.375 5.0625 7.59375 11.390625 + * 0 0 0 0 0 0 1 + */ + + Vector v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; +#define cb(m) \ + v1addv2 = v1##m + v2##m; \ + v1subv2 = v1##m - v2##m; \ + v3addv4 = v3##m + v4##m; \ + v3subv4 = v3##m - v4##m; \ + v5addv6 = v5##m + v6##m; \ + v5subv6 = v5##m - v6##m; \ + auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6 + v7##m; \ + auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + v7##m * 1.5f; \ + auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + v7##m * 2.25f; \ + auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + v7##m * 3.375f; \ + auto t4##m = \ + v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + v7##m * 5.0625f; \ + auto t5##m = \ + v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m * 7.59375f; \ + auto t6##m = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + \ + v7##m * 11.390625f + v8##m; + + UNROLL_CALL_NOWRAPPER(9, cb); +#undef cb + +#define cb(m) \ + v1addv2 = t##m##1 + t##m##2; \ + v1subv2 = t##m##1 - t##m##2; \ + v3addv4 = t##m##3 + t##m##4; \ + v3subv4 = t##m##3 - t##m##4; \ + v5addv6 = t##m##5 + t##m##6; \ + v5subv6 = t##m##5 - t##m##6; \ + v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6 + t##m##7; \ + v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f + t##m##7 * 1.5f; \ + v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f + t##m##7 * 2.25f; \ + v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f + t##m##7 * 3.375; \ + v##m##4 = \ + v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f + t##m##7 * 5.0625f; \ + v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + \ + t##m##7 * 7.59375f; \ + v##m##6 = v1addv2 + v3addv4 * 64.f + v5addv6 * 0.015625f + \ + t##m##7 * 11.390625f + t##m##8; + + UNROLL_CALL_NOWRAPPER(7, cb); +#undef cb + + Vector vbias; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias = Vector::load(bias + oc); + +#define cb(m, n) v##m##n += vbias; + UNROLL_CALL_RAW_D2(7, 7, cb); +#undef cb + } + if (bmode != BiasMode::BIAS) { +#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); + UNROLL_CALL_RAW_D2(7, 7, cb); +#undef cb + } +#define out_save(oho, owo) \ + do { \ + size_t oh = oh_start + oho; \ + size_t ow = ow_start + owo; \ + if (oh < OH && ow < OW) { \ + if (bmode == BiasMode::BIAS) { \ + v##oho##owo += Vector::load(bias + oc * OH * OW + \ + oh * OW * pack_size + \ + ow * pack_size); \ + v##oho##owo = op(v##oho##owo.value); \ + } \ + v##oho##owo.save(output + oc * OH * OW + oh * OW * pack_size + \ + ow * pack_size); \ + } \ + } while (0); + UNROLL_CALL_RAW_D2(7, 7, out_save); + } +#undef out_save +}; +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F73_mk4_f_nchw44) + +void winograd_F73_mk4_f_nchw44::filter(const float* filter, + float* filter_transform_buf, + float* transform_mid_buf, size_t OC, + size_t IC, size_t oc_start, + size_t oc_end) { + constexpr size_t pack_size = 4; + // Gg * GT + // G + // 0.6666667 0.0000000 0.0000000 + // 0.4444444 0.4444444 0.4444444 + // 0.0888889 -0.0888889 0.0888889 + // 0.0222222 0.0444444 0.0888889 + //-0.0031746 0.0063492 -0.0126984 + //-0.7111111 -0.3555556 -0.1777778 + //-0.3555556 0.1777778 -0.0888889 + //-0.1523810 -0.2285714 -0.3428572 + // 0.0000000 0.0000000 1.0000000 + MEGDNN_MARK_USED_VAR(transform_mid_buf); + megdnn_assert((oc_end - oc_start) % pack_size == 0 && + oc_start % pack_size == 0 && + oc_end % pack_size == 0 && IC % pack_size == 0 && + OC % pack_size == 0, + "NCHW44 Winograd filter transform requires both OC and IC " + "are times of 4"); + + size_t ICB = IC / pack_size; + + for (size_t ocb = oc_start / pack_size; ocb < oc_end / pack_size; ocb++) { + for (size_t icb = 0; icb < ICB; icb++) { + for (size_t ic_inner = 0; ic_inner < pack_size; ic_inner++) { + const float* fptr = filter + + (ocb * ICB + icb) * KERNEL_SIZE * + KERNEL_SIZE * pack_size * + pack_size + + ic_inner * pack_size; + +#define cb(m, n) \ + Vector g##m##n = Vector::load( \ + fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); + UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) +#undef cb + +#define FILTER_TRANSFORM(n, wd, g) \ + auto wd##n##0 = g##0##n * 0.6666667f; \ + auto wd##n##1 = (g##0##n + g##1##n + g##2##n) * 0.4444444f; \ + auto wd##n##2 = (g##0##n - g##1##n + g##2##n) * 0.0888889f; \ + auto wd##n##3 = g##0##n * 0.0222222f + g##1##n * 0.0444444f + \ + g##2##n * 0.0888889f; \ + auto wd##n##4 = g##0##n * -0.0031746f + g##1##n * 0.0063492f + \ + g##2##n * -0.0126984f; \ + auto wd##n##5 = g##0##n * -0.7111111f + g##1##n * -0.3555556f + \ + g##2##n * -0.1777778f; \ + auto wd##n##6 = g##0##n * -0.3555556f + g##1##n * 0.1777778f + \ + g##2##n * -0.0888889f; \ + auto wd##n##7 = g##0##n * -0.1523810f + g##1##n * -0.2285714f + \ + g##2##n * -0.3428572f; \ + auto wd##n##8 = g##2##n; + UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); + UNROLL_CALL_RAW(9, FILTER_TRANSFORM, ret, wd); +#undef FILTER_TRANSFORM +#define cb_save(m, n) \ + ret##m##n.save(filter_transform_buf + (m * alpha + n) * OC * IC + \ + ocb * IC * pack_size + icb * pack_size * pack_size + \ + ic_inner * pack_size); + UNROLL_CALL_NOWRAPPER_D2(9, 9, cb_save) +#undef cb_save + } + } + } +} + +void winograd_F73_mk4_f_nchw44::input(const float* input, + float* input_transform_buf, + float* transform_mid_buf, size_t IH, + size_t IW, size_t IC, size_t PH, + size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { + constexpr size_t pack_size = 4; + megdnn_assert(IC % pack_size == 0); + constexpr int alpha = 3 + 7 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = + div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + float* patch = transform_mid_buf; + float* patchT = transform_mid_buf + pack_size * alpha * alpha; + + for (size_t ic = 0; ic < IC; ic += pack_size) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransformF73_NCHW44::prepare(input, patch, patchT, + ih_start, iw_start, IH, + IW, ic, IC); + InputTransformF73_NCHW44::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, + ic, IC); + + } else { + InputTransformF73_NCHW44::prepare(input, patch, patchT, + ih_start, iw_start, IH, + IW, ic, IC); + InputTransformF73_NCHW44::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, + ic, IC); + } + } + } +} + +void winograd_F73_mk4_f_nchw44::output(const float* output_transform_buf, + const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, + size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \ + size_t oc_index = oc - oc_start; \ + rep(unit_idx, nr_units_in_tile) { \ + size_t index = unit_start_idx + unit_idx; \ + auto nh = index / units_w; \ + auto nw = index % units_w; \ + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ + OutputTransformF73_NCHW44<_bmode MEGDNN_COMMA _nonline_op>:: \ + transform(output_transform_buf, bias, output, \ + transform_mid_buf, oh_start, ow_start, OH, OW, \ + oc_start, oc_end, oc_index, unit_idx, \ + nr_units_in_tile, src_dtype, dst_dtype); \ + } \ + } + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + constexpr size_t pack_size = 4; + + size_t OC = oc_end - oc_start; + megdnn_assert(OC % pack_size == 0 && oc_start % pack_size == 0 && + oc_end % pack_size == 0, + "NCHW44 Winograd filter transform requires OC is times of 4"); + + DISPATCH_CONV_WINOGRAD_BIAS(megdnn_arm_common_winograd_fp32_F73_mk4, cb, + float, float, bmode, nonline_mode); +#undef cb +} + +} // namespace winograd +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index 33a00b66a7f0742ff0b3cc37361dc95087fac2b2..515166f6d240be07ffc32e94709ad975237a5489 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -151,6 +151,13 @@ public: static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); +//! uncomment this when low precision mode is done +#if 0 + refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); +#endif #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC refhold.emplace_back(new AlgoFP16WinogradF23( static_cast(algo), diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index f50762e449be7840cc492905608bf676b06a025c..6a6b1d609938443efd4092cd6c2ddd225e35e9e6 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -50,6 +50,7 @@ private: class AlgoFP32WinogradF23_4x4_NCHW44; class AlgoFP32WinogradF63_4x4_NCHW44; + class AlgoFP32WinogradF73_4x4_NCHW44; class AlgoS8ChanWiseStride1NCHW44; class AlgoS8ChanWiseStride2NCHW44; diff --git a/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp b/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp index 47d570c6adbca9740746b83a01a298f4597932bb..82e7af0eca61c10f262d33bb1419ab3074f818f3 100644 --- a/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp +++ b/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp @@ -94,6 +94,10 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, DISPATCH(winograd_F63_mk4_f_nchw44, param::Winograd::Format::MK4, 0, 6); } + } else if (m == 7) { + megdnn_assert(pack_c_size == 4, "WINOGRAD F(7,3) Only Supports NCHW44"); + DISPATCH(winograd_F73_mk4_f_nchw44, + param::Winograd::Format::MK4, 0, 7); } } else if (FW == 4) { if (m == 5) { diff --git a/dnn/src/common/unroll_macro.h b/dnn/src/common/unroll_macro.h index d0d0f1c61130e11d436e9878816f7b00942fdcc6..38fbd4657a0778a0ffb49cd7d695491640bc5832 100644 --- a/dnn/src/common/unroll_macro.h +++ b/dnn/src/common/unroll_macro.h @@ -122,6 +122,23 @@ cb(5, 0, ##a) cb(5, 1, ##a) cb(5, 2, ##a) cb(5, 3, ##a) \ cb(5, 4, ##a) cb(5, 5, ##a) \ +#define UNROLL_RAW_7x7(cb, v0, a...) \ + cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) \ + cb(0, 4, ##a) cb(0, 5, ##a) cb(0, 6, ##a) \ + cb(1, 0, ##a) cb(1, 1, ##a) cb(1, 2, ##a) cb(1, 3, ##a) \ + cb(1, 4, ##a) cb(1, 5, ##a) cb(1, 6, ##a) \ + cb(2, 0, ##a) cb(2, 1, ##a) cb(2, 2, ##a) cb(2, 3, ##a) \ + cb(2, 4, ##a) cb(2, 5, ##a) cb(2, 6, ##a) \ + cb(3, 0, ##a) cb(3, 1, ##a) cb(3, 2, ##a) cb(3, 3, ##a) \ + cb(3, 4, ##a) cb(3, 5, ##a) cb(3, 6, ##a) \ + cb(4, 0, ##a) cb(4, 1, ##a) cb(4, 2, ##a) cb(4, 3, ##a) \ + cb(4, 4, ##a) cb(4, 5, ##a) cb(4, 6, ##a) \ + cb(5, 0, ##a) cb(5, 1, ##a) cb(5, 2, ##a) cb(5, 3, ##a) \ + cb(5, 4, ##a) cb(5, 5, ##a) cb(5, 6, ##a) \ + cb(6, 0, ##a) cb(6, 1, ##a) cb(6, 2, ##a) cb(6, 3, ##a) \ + cb(6, 4, ##a) cb(6, 5, ##a) cb(6, 6, ##a) \ + + #define UNROLL_RAW_8x8(cb, v0, a...) \ cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) \ cb(0, 4, ##a) cb(0, 5, ##a) cb(0, 6, ##a) cb(0, 7, ##a) \ @@ -140,6 +157,26 @@ cb(7, 0, ##a) cb(7, 1, ##a) cb(7, 2, ##a) cb(7, 3, ##a) \ cb(7, 4, ##a) cb(7, 5, ##a) cb(7, 6, ##a) cb(7, 7, ##a) +#define UNROLL_RAW_9x9(cb, v0, a...) \ + cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) \ + cb(0, 4, ##a) cb(0, 5, ##a) cb(0, 6, ##a) cb(0, 7, ##a) cb(0, 8, ##a) \ + cb(1, 0, ##a) cb(1, 1, ##a) cb(1, 2, ##a) cb(1, 3, ##a) \ + cb(1, 4, ##a) cb(1, 5, ##a) cb(1, 6, ##a) cb(1, 7, ##a) cb(1, 8, ##a) \ + cb(2, 0, ##a) cb(2, 1, ##a) cb(2, 2, ##a) cb(2, 3, ##a) \ + cb(2, 4, ##a) cb(2, 5, ##a) cb(2, 6, ##a) cb(2, 7, ##a) cb(2, 8, ##a) \ + cb(3, 0, ##a) cb(3, 1, ##a) cb(3, 2, ##a) cb(3, 3, ##a) \ + cb(3, 4, ##a) cb(3, 5, ##a) cb(3, 6, ##a) cb(3, 7, ##a) cb(3, 8, ##a) \ + cb(4, 0, ##a) cb(4, 1, ##a) cb(4, 2, ##a) cb(4, 3, ##a) \ + cb(4, 4, ##a) cb(4, 5, ##a) cb(4, 6, ##a) cb(4, 7, ##a) cb(4, 8, ##a) \ + cb(5, 0, ##a) cb(5, 1, ##a) cb(5, 2, ##a) cb(5, 3, ##a) \ + cb(5, 4, ##a) cb(5, 5, ##a) cb(5, 6, ##a) cb(5, 7, ##a) cb(5, 8, ##a) \ + cb(6, 0, ##a) cb(6, 1, ##a) cb(6, 2, ##a) cb(6, 3, ##a) \ + cb(6, 4, ##a) cb(6, 5, ##a) cb(6, 6, ##a) cb(6, 7, ##a) cb(6, 8, ##a) \ + cb(7, 0, ##a) cb(7, 1, ##a) cb(7, 2, ##a) cb(7, 3, ##a) \ + cb(7, 4, ##a) cb(7, 5, ##a) cb(7, 6, ##a) cb(7, 7, ##a) cb(7, 8, ##a) \ + cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a) \ + cb(8, 4, ##a) cb(8, 5, ##a) cb(8, 6, ##a) cb(8, 7, ##a) cb(8, 8, ##a) + #define UNROLL_CALL0_D2(step, step2, cb, v...) \ UNROLL_RAW_##step##x##step2(cb, 0, ##v) #define UNROLL_CALL1_D2(step, step2, cb, v...) \ diff --git a/dnn/src/naive/winograd_filter_preprocess/opr_impl.cpp b/dnn/src/naive/winograd_filter_preprocess/opr_impl.cpp index 26490bf153d7164240907a30ad131e660811a2ee..148440e424ef0601b40f66cb2999295475a30c3f 100644 --- a/dnn/src/naive/winograd_filter_preprocess/opr_impl.cpp +++ b/dnn/src/naive/winograd_filter_preprocess/opr_impl.cpp @@ -25,7 +25,7 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); - + //! nchw88 group conv size_t flt_start = 0; size_t pack_c_size = 1; @@ -212,6 +212,10 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, std::vector interp_points = {0, 1, -1, 2, -2, 0.5, -0.5}; DISPATCH_DTYPE(7); + } else if (m == 7) { + std::vector interp_points = {0, 1, -1, 2, + -2, 0.5, -0.5, 1.5}; + DISPATCH_DTYPE(8); } } #undef cb @@ -221,6 +225,7 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, #undef DISPATCH_DTYPE } } + megdnn_assert(execed, "Unsupport winograd filter preprocess. m: %zu src: %s", m, src.layout.to_string().c_str()); diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 1b3094e56bc3c982fec22d52d6e0fbca3731f025..ca33ed8a94aacc9bfeae876ed82a56b33e51b29b 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -777,7 +777,8 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F23_8x8) { } #endif -void benchmark_winograd_nchw_vs_nchw44(const char* algo_name, Handle* handle) { +void benchmark_winograd_nchw_vs_nchw44(const char* algo_name0, + const char* algo_name1, Handle* handle) { using namespace conv_bias; using NLMode = param::ConvBias::NonlineMode; std::vector args_nchw44; @@ -846,9 +847,9 @@ void benchmark_winograd_nchw_vs_nchw44(const char* algo_name, Handle* handle) { benchmark_winograd_nchw44.set_display(false); benchmark_winograd_nchw44.set_times(RUN); - std::string winograd_nchw_algo_name = ssprintf("WINOGRAD:%s", algo_name); + std::string winograd_nchw_algo_name = ssprintf("WINOGRAD:%s", algo_name0); std::string winograd_nchw44_algo_name = - ssprintf("WINOGRAD_NCHW44:%s", algo_name); + ssprintf("WINOGRAD_NCHW44:%s", algo_name1); for (size_t i = 0; i < args_nchw.size(); ++i) { auto arg_nchw = args_nchw[i]; @@ -892,17 +893,31 @@ void benchmark_winograd_nchw_vs_nchw44(const char* algo_name, Handle* handle) { TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_MK4_NCHW_VS_NCHW44) { #if MEGDNN_AARCH64 - benchmark_winograd_nchw_vs_nchw44("AARCH64_F32_MK4_4x16:4:2", handle()); + benchmark_winograd_nchw_vs_nchw44("AARCH64_F32_MK4_4x16:4:2", + "AARCH64_F32_MK4_4x16:4:2", handle()); #else - benchmark_winograd_nchw_vs_nchw44("ARMV7_F32_MK4_4x8:4:2", handle()); + benchmark_winograd_nchw_vs_nchw44("ARMV7_F32_MK4_4x8:4:2", + "ARMV7_F32_MK4_4x8:4:2", handle()); #endif } TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63_MK4_NCHW_VS_NCHW44) { #if MEGDNN_AARCH64 - benchmark_winograd_nchw_vs_nchw44("AARCH64_F32_MK4_4x16:4:6", handle()); + benchmark_winograd_nchw_vs_nchw44("AARCH64_F32_MK4_4x16:4:6", + "AARCH64_F32_MK4_4x16:4:6", handle()); #else - benchmark_winograd_nchw_vs_nchw44("ARMV7_F32_MK4_4x8:4:6", handle()); + benchmark_winograd_nchw_vs_nchw44("ARMV7_F32_MK4_4x8:4:6", + "ARMV7_F32_MK4_4x8:4:6", handle()); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F73_MK4_NCHW_VS_NCHW44) { +#if MEGDNN_AARCH64 + benchmark_winograd_nchw_vs_nchw44("AARCH64_F32_MK4_4x16:4:6", + "ARM_COMMON_F32_GEMV_MK4:4:7", handle()); +#else + benchmark_winograd_nchw_vs_nchw44("ARMV7_F32_MK4_4x8:4:6", + "ARMV7_F32_MK4_4x8:4:7", handle()); #endif } diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 4ce397d67d107bf0d469f3aa15761574139a373b..9e99311148d2ed74980783a893fd8b7e1d886fd7 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -750,6 +750,26 @@ TEST_F(ARM_COMMON_MULTI_THREADS, param::ConvBias::Format::NCHW44); } +//! uncomment it when low precision mode is ok +#if 0 +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) { + using namespace conv_bias; + std::vector args = get_nchw44_conv_bias_args({3}, 1); + Checker checker(handle()); + check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4, + param::ConvBias::Format::NCHW44); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCESS) { + using namespace conv_bias; + std::vector args = get_nchw44_conv_bias_args({3}, 1); + Checker> checker( + handle()); + check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4, + param::ConvBias::Format::NCHW44); +} +#endif + TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) { using namespace conv_bias; std::vector args = get_winograd_args(4); @@ -923,6 +943,12 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) { } } }; + + //! uncomment this when low precision mode is ok + // run(handle(), nchw44_args, {2, 6, 7}, dtype::Float32(), dtype::Float32(), + // dtype::Float32(), dtype::Float32(), 1e-2f); + + //! remove this when low precision mode is ok run(handle(), nchw44_args, {2, 6}, dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32(), 1e-3f); } diff --git a/dnn/test/arm_common/matrix_mul.cpp b/dnn/test/arm_common/matrix_mul.cpp index 582b05223e9d548bb4c08058302d33e5afdaa331..25a84f965dd99e94ce4c652cb787d93ea7c2a078 100644 --- a/dnn/test/arm_common/matrix_mul.cpp +++ b/dnn/test/arm_common/matrix_mul.cpp @@ -399,7 +399,7 @@ TEST_F(ARM_COMMON, BENCHMARK_SGEMV_MK4) { .set_param(param); auto run = [&](size_t M, size_t K) { - printf("SGEMV_MK4: (%zu, %zu)\n", M, K); + printf("SGEMV_MK4: (%zu, %zu, 1)\n", M, K); TensorShape A, B; A = TensorShape{M / 4, K / 4, 4, 4}; B = TensorShape{K / 4, 1, 4};