diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index 78afc2eafdd584daa092b634d0d5a13eb1ff4874..e3c6e0461e19e92134e33cc7da63cfd880a2b3fd 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -387,7 +387,9 @@ public: //! get algo name, the format is ParamTrait::category:base:p.to_string() //! \warning: base must not contain :. template - static std::string algo_name(const std::string& base, const T& p); + static std::string algo_name( + const std::string& base, const T& p, + param::ConvBias::Format format = param::ConvBias::Format::NCHW); /*! * \brief parse algo_name and get WinogradParam from algo name. * diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.cpp b/dnn/src/arm_common/conv_bias/fp32/algos.cpp index af3de8433eb5f04126b86db2138f54cdae7cbd06..a142b9a6e4f46e586daa8384c66805ce5b613b2f 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/algos.cpp @@ -388,6 +388,166 @@ ConvBiasImpl::AlgoFP32WinogradF63_4x4::dispatch_kerns( return {}; } +/* =================== AlgoFP32WinogradF23_4x4_NCHW44 =================== */ + +bool ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::usable( + fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(opr); + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, + midout_iv("AlgoFP32WinogradF23_4x4_NCHW44"_hash)) { + if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) + return false; + using Strategy = winograd::winograd_F23_mk4_f_nchw44; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + m_matmul_algo->packmode() == + fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK && + (opr->param().format == param::ConvBias::Format::NCHW44 || + (opr->param().format == + param::ConvBias::Format::NCHW44_WINOGRAD && + opr->param().output_block_size == 2 && + param.winograd_matmul_format == + param::MatrixMul::Format::MK4)) && + opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float32; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, + midout_iv("AlgoFP32WinogradF23_4x4_NCHW44"_hash)) { + winograd::winograd_F23_mk4_f_nchw44 strategy( + param.src_type, param.filter_type, param.dst_type); + return megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_workspace_size(param, m_matmul_algo); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, + midout_iv("AlgoFP32WinogradF23_4x4_NCHW44"_hash)) { + winograd::winograd_F23_mk4_f_nchw44 strategy( + param.src_type, param.filter_type, param.dst_type); + auto winograd_impl = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg); + return winograd_impl.get_kerns(param, m_matmul_algo); + } + MIDOUT_END(); + return {}; +} + +/* =================== AlgoFP32WinogradF63_4x4_NCHW44 ===================== */ + +bool ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::usable( + fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy /*algo_selection_strategy*/) const { + MEGDNN_MARK_USED_VAR(param); + MEGDNN_MARK_USED_VAR(opr); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, + midout_iv("AlgoFP32WinogradF63_4x4_NCHW44"_hash)) { + if (param.filter_meta.icpg % 4 != 0 || param.filter_meta.ocpg % 4 != 0) + return false; + using Strategy = winograd::winograd_F63_mk4_f_nchw44; + Strategy strategy(param.src_type, param.filter_type, param.dst_type); + auto&& matmul_param = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_matmul_kern_param(param); + return m_matmul_algo->usable(matmul_param) && + m_matmul_algo->packmode() == + fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK && + (opr->param().format == param::ConvBias::Format::NCHW44 || + (opr->param().format == + param::ConvBias::Format::NCHW44_WINOGRAD && + opr->param().output_block_size == 6 && + param.winograd_matmul_format == + param::MatrixMul::Format::MK4)) && + opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION && + (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] && + param.filter_meta.spatial[0] == 3) && + (param.filter_meta.stride[0] == param.filter_meta.stride[1] && + param.filter_meta.stride[0] == 1) && + (param.filter_meta.dilation[0] == + param.filter_meta.dilation[1] && + param.filter_meta.dilation[0] == 1) && + param.compute_mode == param::ConvBias::ComputeMode::DEFAULT && + param.src_type.enumv() == DTypeEnum::Float32 && + param.filter_meta.icpg % 4 == 0 && + param.filter_meta.ocpg % 4 == 0; + } + MIDOUT_END(); + return false; +} + +size_t ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::get_workspace( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, + midout_iv("AlgoFP32WinogradF63_4x4_NCHW44"_hash)) { + winograd::winograd_F63_mk4_f_nchw44 strategy( + param.src_type, param.filter_type, param.dst_type); + return megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg) + .get_workspace_size(param, m_matmul_algo); + } + MIDOUT_END(); + return 0; +} + +SmallVector +ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44::dispatch_kerns( + fallback::ConvBiasImpl*, const NCBKernSizeParam& param) const { + MEGDNN_MARK_USED_VAR(param); + MIDOUT_BEGIN(megdnn_arm_common_winograd_fp32, + midout_iv("AlgoFP32WinogradF63_4x4_NCHW44"_hash)) { + winograd::winograd_F63_mk4_f_nchw44 strategy( + param.src_type, param.filter_type, param.dst_type); + auto winograd_impl = + megdnn::winograd::ConvBias( + strategy, m_tile_size, param.nr_threads, param.osz[0], + param.osz[1], param.filter_meta.ocpg); + return winograd_impl.get_kerns(param, m_matmul_algo); + } + MIDOUT_END(); + return {}; +} + /* ===================== direct algo ===================== */ MIDOUT_DECL(megdnn_arm_common_conv_bias_f32_kimpl); diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.h b/dnn/src/arm_common/conv_bias/fp32/algos.h index 4573ece243a57b5aad1a270eee79e993838d0062..0229898cd7fee51f680006f6b373c1574b553779 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.h +++ b/dnn/src/arm_common/conv_bias/fp32/algos.h @@ -157,6 +157,64 @@ private: uint32_t m_tile_size; }; +//===================== NCHW44 Winograd Support =====================// +class ConvBiasImpl::AlgoFP32WinogradF23_4x4_NCHW44 final : public AlgoBase { +public: + AlgoFP32WinogradF23_4x4_NCHW44( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {4, 2, m_tile_size}, + param::ConvBias::Format::NCHW44); + } + return m_name.c_str(); + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + +private: + fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; + mutable std::string m_name; + uint32_t m_tile_size; +}; + +class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { +public: + AlgoFP32WinogradF63_4x4_NCHW44( + fallback::MatrixMulImpl::AlgoBase* matmul_algo, uint32_t tile_size) + : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {} + bool is_reproducible() const override { return true; } + const char* name() const override { + if (m_name.empty()) { + m_name = ConvBiasImpl::algo_name( + m_matmul_algo->name(), {4, 6, m_tile_size}, + param::ConvBias::Format::NCHW44); + } + return m_name.c_str(); + } + bool usable(fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param, + AlgoSelectionStrategy algo_selection_strategy) const override; + size_t get_workspace(fallback::ConvBiasImpl*, + const NCBKernSizeParam& param) const override; + virtual SmallVector dispatch_kerns( + fallback::ConvBiasImpl* opr, + const NCBKernSizeParam& param) const override; + +private: + fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; + mutable std::string m_name; + uint32_t m_tile_size; +}; +// ================================================================= // + class ConvBiasImpl::AlgoF32Direct final : public AlgoBase { SmallVector get_kimpls(const NCBKernSizeParam& param) const; bool m_large_group; diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy.h b/dnn/src/arm_common/conv_bias/fp32/strategy.h index 43b109e96348efd9819fea6fb12db636f0a6d01f..1f68a3e741c41a7f8e2b74607e16c3737d92e4ee 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy.h +++ b/dnn/src/arm_common/conv_bias/fp32/strategy.h @@ -32,6 +32,12 @@ MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 5, 4, 1, 1, MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 4, 5, 1, 1, winograd_4x5_1x1_f) + +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 2, 3, 4, 4, + winograd_F23_mk4_f_nchw44) + +MEGDNN_REG_WINOGRAD_STRATEGY(float, float, float, float, 6, 3, 4, 4, + winograd_F63_mk4_f_nchw44) } // namespace winograd } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b00b5dd0c5f77c249e1e9a0c5bdd1a27234ee66 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp @@ -0,0 +1,349 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/fp32/strategy.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/fallback/conv_bias/winograd/winograd.h" + +#include "src/naive/matrix_mul/matrix_mul_helper.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/arm_common/conv_bias/fp32/helper.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_nchw44_fp32_F23_mk4) + +using namespace megdnn; +using namespace arm_common; +namespace { + +constexpr size_t alpha = 2 + 3 - 1; +constexpr size_t pack_size = 4; + +struct InputTransformF23_NCHW44 { + template + static void prepare(const float* input, float* patch, float* patchT, + int ih_start, int iw_start, size_t IH, size_t IW, + size_t ic, size_t IC) { + MEGDNN_MARK_USED_VAR(patch); + size_t IW4 = IW * pack_size; + size_t iw4_start = iw_start * pack_size; + size_t icb = ic / pack_size; + if (!(inner && ic + pack_size < IC)) { + memset(patchT, 0, sizeof(float) * pack_size * alpha * alpha); + } + if (inner) { + const float* input_ptr = + input + icb * IH * IW4 + ih_start * IW4 + iw4_start; + for (size_t ih = 0; ih < alpha; ih++) { +#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i); + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(i) vst1q_f32(patchT + ih * alpha * pack_size + i * pack_size, v##i); + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + input_ptr += IW4; + } + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + const float* input_ptr = input + icb * IH * IW4; + // partial copy + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); + vst1q_f32( + patchT + iho * alpha * pack_size + iwo * pack_size, + src); + } + } + } + } + + static void transform(const float* patchT, float* input_transform_buf, + size_t unit_idx, size_t nr_units_in_tile, size_t ic, + size_t IC) { + // BT * d * B +#define cb(m, n) \ + Vector d##m##n = Vector::load( \ + patchT + m * alpha * pack_size + n * pack_size); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + + //! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 + //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 + //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 + //! 0 -1 0 1 d30 d31 d32 d33 0 0 0 1 +#define cb(m) \ + auto t0##m = d0##m - d2##m; \ + auto t1##m = d1##m + d2##m; \ + auto t2##m = d2##m - d1##m; \ + auto t3##m = d3##m - d1##m; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(m) \ + d##m##0 = t##m##0 - t##m##2; \ + d##m##1 = t##m##1 + t##m##2; \ + d##m##2 = t##m##2 - t##m##1; \ + d##m##3 = t##m##3 - t##m##1; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + + size_t ICB = IC / 4; + size_t icb = ic / 4; +#define cb(m, n) \ + d##m##n.save(input_transform_buf + \ + (m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb) +#undef cb + } +}; + +#define CONCAT(a, idx) a##idx +template +struct OutputTransformF23_NCHW44 { + static void transform(const float* output_transform_buf, const float* bias, + float* output, float* transform_mid_buf, + size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, + const DType& dst_dtype) { + MEGDNN_MARK_USED_VAR(transform_mid_buf); + Op op(src_dtype, dst_dtype); + //! AT * m * A + size_t OCB = (oc_end - oc_start) / pack_size; + size_t oc = oc_start + oc_index; + size_t ocb = oc_index / pack_size; + +#define cb(m, n) \ + auto v##m##n = Vector::load( \ + output_transform_buf + \ + (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ + ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + + //! 1 1 1 0 v00 v01 v02 v03 1 0 + //! 0 1 -1 1 v10 v11 v12 v13 1 1 + //! v20 v21 v22 v23 1 -1 + //! v30 v31 v32 v33 0 1 + +#define cb(m) \ + auto t0##m = v0##m + v1##m + v2##m; \ + auto t1##m = v1##m - v2##m + v3##m; + + UNROLL_CALL_NOWRAPPER(4, cb); +#undef cb + +#define cb(m) \ + v##m##0 = t##m##0 + t##m##1 + t##m##2; \ + v##m##1 = t##m##1 - t##m##2 + t##m##3; + + UNROLL_CALL_NOWRAPPER(2, cb); +#undef cb + + Vector vbias; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias = Vector::load(bias + oc); + +#define cb(m, n) v##m##n += vbias; + UNROLL_CALL_RAW_D2(2, 2, cb); +#undef cb + } + if (bmode != BiasMode::BIAS) { +#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); + UNROLL_CALL_RAW_D2(2, 2, cb); +#undef cb + } +#define out_save(oho, owo) \ + do { \ + size_t oh = oh_start + oho; \ + size_t ow = ow_start + owo; \ + if (oh < OH && ow < OW) { \ + if (bmode == BiasMode::BIAS) { \ + v##oho##owo += Vector::load(bias + oc * OH * OW + \ + oh * OW * pack_size + \ + ow * pack_size); \ + v##oho##owo = op(v##oho##owo.value); \ + } \ + v##oho##owo.save(output + oc * OH * OW + oh * OW * pack_size + \ + ow * pack_size); \ + } \ + } while (0); + UNROLL_CALL_RAW_D2(2, 2, out_save); +#undef out_save + } +}; +#undef CONCAT +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F23_mk4_f_nchw44) +void winograd_F23_mk4_f_nchw44::filter(const float* filter, + float* filter_transform_buf, + float* transform_mid_buf, size_t OC, + size_t IC, size_t oc_start, + size_t oc_end) { + //! 1 0 0 v00 v01 v02 1 0.5 0.5 0 + //! 0.5 0.5 0.5 v10 v11 v12 0 0.5 -0.5 0 + //! 0.5 -0.5 0.5 v20 v21 v22 0 0.5 0.5 1 + //! 0 0 1 + + constexpr size_t pack_size = 4; + + MEGDNN_MARK_USED_VAR(transform_mid_buf); + megdnn_assert((oc_end - oc_start) % pack_size == 0 && + oc_start % pack_size == 0 && + oc_end % pack_size == 0 && IC % pack_size == 0 && + OC % pack_size == 0, + "NCHW44 Winograd filter transform requires both OC and IC " + "are times of 4"); + size_t OCB = OC / pack_size; + size_t ICB = IC / pack_size; + + for (size_t ocb = oc_start / pack_size; ocb < oc_end / pack_size; ocb++) { + for (size_t icb = 0; icb < ICB; icb++) { + for (size_t ic_inner = 0; ic_inner < pack_size; ic_inner++) { + const float* fptr = filter + + (ocb * ICB + icb) * KERNEL_SIZE * + KERNEL_SIZE * pack_size * + pack_size + + ic_inner * pack_size; + +#define cb(m, n) \ + Vector g##m##n = Vector::load( \ + fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); + UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) +#undef cb + +#define FILTER_TRANSFORM(n, wd, g) \ + auto wd##n##0 = g##0##n; \ + tmp0 = (g##0##n + g##2##n) * 0.5; \ + tmp1 = g##1##n * 0.5; \ + auto wd##n##1 = tmp0 + tmp1; \ + auto wd##n##2 = tmp0 - tmp1; \ + auto wd##n##3 = g##2##n; + Vector tmp0, tmp1; + UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); + UNROLL_CALL_RAW(4, FILTER_TRANSFORM, ret, wd); +#undef FILTER_TRANSFORM +#define cb_save(m, n) \ + ret##m##n.save(filter_transform_buf + \ + (m * ALPHA + n) * OCB * ICB * pack_size * pack_size + \ + ocb * ICB * pack_size * pack_size + \ + icb * pack_size * pack_size + ic_inner * pack_size); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb_save) +#undef cb_save + } + } + } +} + +void winograd_F23_mk4_f_nchw44::input(const float* input, + float* input_transform_buf, + float* transform_mid_buf, size_t IH, + size_t IW, size_t IC, size_t PH, + size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { + megdnn_assert(IC % 4 == 0); + constexpr int alpha = 3 + 2 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = + div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + float* patch = transform_mid_buf; + float* patchT = transform_mid_buf + 4 * alpha * alpha; + + for (size_t ic = 0; ic < IC; ic += 4) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransformF23_NCHW44::prepare(input, patch, patchT, + ih_start, iw_start, IH, + IW, ic, IC); + InputTransformF23_NCHW44::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, + ic, IC); + + } else { + InputTransformF23_NCHW44::prepare(input, patch, patchT, + ih_start, iw_start, IH, + IW, ic, IC); + InputTransformF23_NCHW44::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, + ic, IC); + } + } + } +} + +void winograd_F23_mk4_f_nchw44::output(const float* output_transform_buf, + const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, + size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { +#define cb(_bmode, _nonline_op, ...) \ + OutputTransformF23_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \ + __VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + constexpr size_t pack_size = 4; + + size_t OC = oc_end - oc_start; + megdnn_assert(OC % pack_size == 0 && oc_start % pack_size == 0 && + oc_end % pack_size == 0, + "NCHW44 Winograd filter transform requires OC is times of 4"); + + for (size_t oc = oc_start; oc < oc_end; oc += 4) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_nchw44_fp32_F23_mk4, cb, float, + float, bmode, nonline_mode, output_transform_buf, bias, + output, transform_mid_buf, oh_start, ow_start, OH, OW, + oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, + src_dtype, dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp new file mode 100644 index 0000000000000000000000000000000000000000..df5aa713ef7680e3ce0df37968536a559751f751 --- /dev/null +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp @@ -0,0 +1,410 @@ +/** + * \file dnn/src/arm_common/conv_bias/fp32/strategy_f + * 63_mk4_nchw44.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "src/arm_common/conv_bias/fp32/filter_transform.h" +#include "src/arm_common/conv_bias/fp32/helper.h" +#include "src/arm_common/conv_bias/fp32/strategy.h" +#include "src/arm_common/elemwise_helper/op_unary.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/arm_common/utils.h" +#include "src/common/unroll_macro.h" +#include "src/common/utils.h" +#include "src/common/winograd/winograd_helper.h" +#include "src/fallback/conv_bias/winograd/winograd.h" + +#include "midout.h" +MIDOUT_DECL(megdnn_arm_common_winograd_fp32_F63_mk4) + +using namespace megdnn; +using namespace arm_common; + +namespace { + +constexpr size_t alpha = 6 + 3 - 1; +constexpr size_t pack_size = 4; + +struct InputTransformF63_NCHW44 { + template + static void prepare(const float* input, float* patch, float* patchT, + int ih_start, int iw_start, size_t IH, size_t IW, + size_t ic, size_t IC) { + MEGDNN_MARK_USED_VAR(patch); + size_t IW4 = IW * pack_size; + size_t iw4_start = iw_start * pack_size; + size_t icb = ic / pack_size; + if (!(inner && ic + pack_size < IC)) { + memset(patchT, 0, sizeof(float) * pack_size * alpha * alpha); + } + if (inner) { + const float* input_ptr = + input + icb * IH * IW4 + ih_start * IW4 + iw4_start; + for (size_t ih = 0; ih < alpha; ih++) { +#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(i) vst1q_f32(patchT + ih * pack_size * alpha + i * pack_size, v##i); + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + input_ptr += IW4; + } + } else { + int ih0_act = std::max(ih_start, 0), + ih1_act = std::min(ih_start + alpha, IH), + iw0_act = std::max(iw_start, 0), + iw1_act = std::min(iw_start + alpha, IW); + const float* input_ptr = input + icb * IH * IW4; + // partial copy + for (int ih = ih0_act; ih < ih1_act; ++ih) { + for (int iw = iw0_act; iw < iw1_act; ++iw) { + size_t iho = ih - ih_start, iwo = iw - iw_start; + auto src = vld1q_f32(input_ptr + ih * IW4 + iw * pack_size); + vst1q_f32( + patchT + iho * pack_size * alpha + iwo * pack_size, + src); + } + } + } + } + + static void transform(const float* patchT, float* input_transform_buf, + size_t unit_idx, size_t nr_units_in_tile, size_t ic, + size_t IC) { + // BT * d * B +#define cb(m, n) \ + Vector d##m##n = Vector::load( \ + patchT + m * alpha * pack_size + n * pack_size); + + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); +#undef cb + + //! B + //! 1 0 0 0 0 0 0 0 + //! 0 1 -1 0.5 -0.5 2 -2 -1 + //! -5.25 1 1 0.25 0.25 4 4 0 + //! 0 -4.25 4.25 -2.5 2.5 -2.5 2.5 5.25 + //! 5.25 -4.25 -4.25 -1.25 -1.25 -5 -5 0 + //! 0 1 -1 2 -2 0.5 -0.5 -5.25 + //! -1 1 1 1 1 1 1 0 + //! 0 0 0 0 0 0 0 1 +#define cb(m) \ + auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \ + auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \ + auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \ + auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \ + d5##m * 2.f + d6##m; \ + auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - \ + d4##m * 1.25f - d5##m * 2.f + d6##m; \ + auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \ + d5##m * 0.5f + d6##m; \ + auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \ + d5##m * 0.5f + d6##m; \ + auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; + + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(m) \ + d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \ + d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - \ + (t##m##3 + t##m##4) * 4.25f; \ + d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + \ + (t##m##3 - t##m##4) * 4.25f; \ + d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - \ + t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; \ + d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - \ + t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; \ + d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \ + t##m##5 * 0.5f + t##m##6; \ + d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - \ + t##m##4 * 5.f - t##m##5 * 0.5f + t##m##6; \ + d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; + + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + + size_t ICB = IC / pack_size; + size_t icb = ic / pack_size; +#define cb(m, n) \ + d##m##n.save(input_transform_buf + \ + (m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + unit_idx * pack_size); + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) +#undef cb + } +}; + +template +struct OutputTransformF63_NCHW44 { + static void transform(const float* output_transform_buf, const float* bias, + float* output, float* transform_mid_buf, + size_t oh_start, size_t ow_start, size_t OH, + size_t OW, size_t oc_start, size_t oc_end, + size_t oc_index, size_t unit_idx, + size_t nr_units_in_tile, const DType& src_dtype, + const DType& dst_dtype) { + MEGDNN_MARK_USED_VAR(transform_mid_buf); + Op op(src_dtype, dst_dtype); + //! AT * m * A + + size_t oc = oc_start + oc_index; + size_t OCB = (oc_end - oc_start) / pack_size; + size_t ocb = oc_index / pack_size; + +#define cb(m, n) \ + auto v##m##n = Vector::load( \ + output_transform_buf + \ + (m * alpha + n) * OCB * nr_units_in_tile * pack_size + \ + ocb * nr_units_in_tile * pack_size + unit_idx * pack_size); + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); +#undef cb + + /** + * A + * + * 1 0 0 0 0 0 + * 1 1 1 1 1 1 + * 1 -1 1 -1 1 -1 + * 1 2 4 8 16 32 + * 1 -2 4 -8 16 -32 + * 1 0.5 0.25 0.125 0.0625 0.03125 + * 1 -0.5 0.25 -0.125 0.0625 -0.03125 + * 0 0.0 0 0 0 1 + */ + + Vector v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; +#define cb(m) \ + v1addv2 = v1##m + v2##m; \ + v1subv2 = v1##m - v2##m; \ + v3addv4 = v3##m + v4##m; \ + v3subv4 = v3##m - v4##m; \ + v5addv6 = v5##m + v6##m; \ + v5subv6 = v5##m - v6##m; \ + auto t0##m = v0##m + v1addv2 + v3addv4 + v5addv6; \ + auto t1##m = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ + auto t2##m = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ + auto t3##m = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ + auto t4##m = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ + auto t5##m = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + v7##m; + + UNROLL_CALL_NOWRAPPER(8, cb); +#undef cb + +#define cb(m) \ + v1addv2 = t##m##1 + t##m##2; \ + v1subv2 = t##m##1 - t##m##2; \ + v3addv4 = t##m##3 + t##m##4; \ + v3subv4 = t##m##3 - t##m##4; \ + v5addv6 = t##m##5 + t##m##6; \ + v5subv6 = t##m##5 - t##m##6; \ + v##m##0 = t##m##0 + v1addv2 + v3addv4 + v5addv6; \ + v##m##1 = v1subv2 + v3subv4 * 2.f + v5subv6 * 0.5f; \ + v##m##2 = v1addv2 + v3addv4 * 4.f + v5addv6 * 0.25f; \ + v##m##3 = v1subv2 + v3subv4 * 8.f + v5subv6 * 0.125f; \ + v##m##4 = v1addv2 + v3addv4 * 16.f + v5addv6 * 0.0625f; \ + v##m##5 = v1subv2 + v3subv4 * 32.f + v5subv6 * 0.03125f + t##m##7; + + UNROLL_CALL_NOWRAPPER(6, cb); +#undef cb + + Vector vbias; + if (bmode == BiasMode::BROADCAST_CHANNEL_BIAS) { + vbias = Vector::load(bias + oc); + +#define cb(m, n) v##m##n += vbias; + UNROLL_CALL_RAW_D2(6, 6, cb); +#undef cb + } + if (bmode != BiasMode::BIAS) { +#define cb(m, n) v##m##n = op(CONCAT(v##m, n).value); + UNROLL_CALL_RAW_D2(6, 6, cb); +#undef cb + } +#define out_save(oho, owo) \ + do { \ + size_t oh = oh_start + oho; \ + size_t ow = ow_start + owo; \ + if (oh < OH && ow < OW) { \ + if (bmode == BiasMode::BIAS) { \ + v##oho##owo += Vector::load(bias + oc * OH * OW + \ + oh * OW * pack_size + \ + ow * pack_size); \ + v##oho##owo = op(v##oho##owo.value); \ + } \ + v##oho##owo.save(output + oc * OH * OW + oh * OW * pack_size + \ + ow * pack_size); \ + } \ + } while (0); + UNROLL_CALL_RAW_D2(6, 6, out_save); + } +#undef out_save +}; +} // namespace + +namespace megdnn { +namespace arm_common { +namespace winograd { + +MEGDNN_REG_WINOGRAD_STRATEGY_IMPL(winograd_F63_mk4_f_nchw44) + +void winograd_F63_mk4_f_nchw44::filter(const float* filter, + float* filter_transform_buf, + float* transform_mid_buf, size_t OC, + size_t IC, size_t oc_start, + size_t oc_end) { + constexpr size_t pack_size = 4; + // Gg * GT + // G + // 1.0000000 0.0000000 0.0000000 + // -0.2222222 -0.2222222 -0.2222222 + // -0.2222222 0.2222222 -0.2222222 + // 0.0111111 0.0222222 0.0444444 + // 0.0111111 -0.0222222 0.0444444 + // 0.7111111 0.3555556 0.1777778 + // 0.7111111 -0.3555556 0.1777778 + // 0.0000000 0.0000000 1.0000000 + MEGDNN_MARK_USED_VAR(transform_mid_buf); + megdnn_assert((oc_end - oc_start) % pack_size == 0 && + oc_start % pack_size == 0 && + oc_end % pack_size == 0 && IC % pack_size == 0 && + OC % pack_size == 0, + "NCHW44 Winograd filter transform requires both OC and IC " + "are times of 4"); + + size_t ICB = IC / pack_size; + + for (size_t ocb = oc_start / pack_size; ocb < oc_end / pack_size; ocb++) { + for (size_t icb = 0; icb < ICB; icb++) { + for (size_t ic_inner = 0; ic_inner < pack_size; ic_inner++) { + const float* fptr = filter + + (ocb * ICB + icb) * KERNEL_SIZE * + KERNEL_SIZE * pack_size * + pack_size + + ic_inner * pack_size; + +#define cb(m, n) \ + Vector g##m##n = Vector::load( \ + fptr + (m * KERNEL_SIZE + n) * pack_size * pack_size); + UNROLL_CALL_NOWRAPPER_D2(3, 3, cb) +#undef cb + +#define FILTER_TRANSFORM(n, wd, g) \ + auto wd##n##0 = g##0##n; \ + tmp0 = (g##0##n + g##2##n) * -0.2222222f; \ + tmp1 = g##1##n * -0.2222222f; \ + auto wd##n##1 = tmp0 + tmp1; \ + auto wd##n##2 = tmp0 - tmp1; \ + tmp0 = g##0##n * 0.0111111f + g##2##n * 0.0444444f; \ + tmp1 = g##1##n * 0.0222222f; \ + auto wd##n##3 = tmp0 + tmp1; \ + auto wd##n##4 = tmp0 - tmp1; \ + tmp0 = g##0##n * 0.7111111f + g##2##n * 0.1777778f; \ + tmp1 = g##1##n * 0.3555556f; \ + auto wd##n##5 = tmp0 + tmp1; \ + auto wd##n##6 = tmp0 - tmp1; \ + auto wd##n##7 = g##2##n; + Vector tmp0, tmp1; + UNROLL_CALL_RAW(3, FILTER_TRANSFORM, wd, g); + UNROLL_CALL_RAW(8, FILTER_TRANSFORM, ret, wd); +#undef FILTER_TRANSFORM +#define cb_save(m, n) \ + ret##m##n.save(filter_transform_buf + (m * alpha + n) * OC * IC + \ + ocb * IC * pack_size + icb * pack_size * pack_size + \ + ic_inner * pack_size); + UNROLL_CALL_NOWRAPPER_D2(8, 8, cb_save) +#undef cb_save + } + } + } +} + +void winograd_F63_mk4_f_nchw44::input(const float* input, + float* input_transform_buf, + float* transform_mid_buf, size_t IH, + size_t IW, size_t IC, size_t PH, + size_t PW, size_t unit_start_idx, + size_t nr_units_in_tile) { + constexpr size_t pack_size = 4; + megdnn_assert(IC % pack_size == 0); + constexpr int alpha = 3 + 6 - 1; + + // OW = IW + 2 * PW - KERNEL_SIZE + 1 + auto units_w = + div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); + float* patch = transform_mid_buf; + float* patchT = transform_mid_buf + pack_size * alpha * alpha; + + for (size_t ic = 0; ic < IC; ic += pack_size) { + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + size_t nh = index / units_w; + size_t nw = index % units_w; + int ih_start = nh * OUTPUT_BLOCK_SIZE - PH; + int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; + if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && + iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { + InputTransformF63_NCHW44::prepare(input, patch, patchT, + ih_start, iw_start, IH, + IW, ic, IC); + InputTransformF63_NCHW44::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, + ic, IC); + + } else { + InputTransformF63_NCHW44::prepare(input, patch, patchT, + ih_start, iw_start, IH, + IW, ic, IC); + InputTransformF63_NCHW44::transform(patchT, input_transform_buf, + unit_idx, nr_units_in_tile, + ic, IC); + } + } + } +} + +void winograd_F63_mk4_f_nchw44::output(const float* output_transform_buf, + const float* bias, float* output, + float* transform_mid_buf, BiasMode bmode, + NonlineMode nonline_mode, size_t OH, + size_t OW, size_t oc_start, + size_t oc_end, size_t unit_start_idx, + size_t nr_units_in_tile) { + constexpr size_t pack_size = 4; +#define cb(_bmode, _nonline_op, ...) \ + OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \ + __VA_ARGS__); + + auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); + for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { + size_t oc_index = oc - oc_start; + rep(unit_idx, nr_units_in_tile) { + size_t index = unit_start_idx + unit_idx; + auto nh = index / units_w; + auto nw = index % units_w; + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; + DISPATCH_CONV_WINOGRAD_BIAS( + megdnn_arm_common_winograd_fp32_F63_mk4, cb, float, float, + bmode, nonline_mode, output_transform_buf, bias, output, + transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, + oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, + dst_dtype); + } + } +#undef cb +} + +} // namespace winograd +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index e3af5f4f42f3f92b9de533af5c537eec864fbc43..b902a8720adec8eb233316ae5f312c2579e64eea 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -169,6 +169,14 @@ public: static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); + refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC refhold.emplace_back(new AlgoFP16WinogradF23( static_cast(algo), diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index dc185aaa3983b54abbe521f158949b01e906a03e..58db42a72374c76ed9c25877cf0ec0d59f5c93bd 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -49,6 +49,9 @@ private: class AlgoFP32WinogradF54; class AlgoFP32WinogradF45; + class AlgoFP32WinogradF23_4x4_NCHW44; + class AlgoFP32WinogradF63_4x4_NCHW44; + class AlgoS8ChanWiseStride1NCHW44; class AlgoS8ChanWiseStride2NCHW44; diff --git a/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp b/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp index ce6e4a0c7fb5be9a5e37008175108ebe4c1fa11a..3842a65013bcb8c4af79fe51e25bb349a192478e 100644 --- a/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp +++ b/dnn/src/arm_common/winograd_filter_preprocess/opr_impl.cpp @@ -28,13 +28,22 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, using namespace winograd; check_exec(src.layout, dst.layout, workspace.size); + //! NCHW44 group conv or NCHW group conv or both dense conv size_t flt_start = 0; + size_t pack_c_size = 1; size_t group = 1; - if (src.layout.ndim == 5) { + if (src.layout.ndim == 5) { //! {g, OC, IC, FH, FW} flt_start = 1; group = src.layout[0]; + } else if (src.layout.ndim == 6) { //! {OC/4, IC/4, FH, FW, 4, 4} + pack_c_size = src.layout[5]; + } else if (src.layout.ndim == 7) { //! {g, OC/4, IC/4, FH, FW, 4, 4} + flt_start = 1; + group = src.layout[0]; + pack_c_size = src.layout[6]; } - size_t OC = src.layout[flt_start], IC = src.layout[flt_start + 1], + size_t OC = src.layout[flt_start] * pack_c_size, + IC = src.layout[flt_start + 1] * pack_c_size, FW = src.layout[flt_start + 3]; size_t m = param().output_block_size; @@ -68,13 +77,23 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, float* workspace_ptr = workspace.ptr(); if (FW == 3) { if (m == 2) { - DISPATCH(winograd_2x3_4x4_f, param::Winograd::Format::MK4, 0, - 0); + if (pack_c_size == 1) { + DISPATCH(winograd_2x3_4x4_f, param::Winograd::Format::MK4, + 0, 0); + } else if (pack_c_size == 4) { + DISPATCH(winograd_F23_mk4_f_nchw44, + param::Winograd::Format::MK4, 0, 5); + } } else if (m == 6) { DISPATCH(winograd_6x3_1x1_f, param::Winograd::Format::DEFAULT, 0, 1); - DISPATCH(winograd_6x3_4x4_f, param::Winograd::Format::MK4, 0, - 2); + if (pack_c_size == 1) { + DISPATCH(winograd_6x3_4x4_f, param::Winograd::Format::MK4, + 0, 2); + } else if (pack_c_size == 4) { + DISPATCH(winograd_F63_mk4_f_nchw44, + param::Winograd::Format::MK4, 0, 6); + } } } else if (FW == 4) { if (m == 5) { diff --git a/dnn/src/common/conv_bias.cpp b/dnn/src/common/conv_bias.cpp index ee834fa8100b9e3123edb909bb9b0b224511e267..e82aaef5ac80fc1761454b6e0c1429dc636d9183 100644 --- a/dnn/src/common/conv_bias.cpp +++ b/dnn/src/common/conv_bias.cpp @@ -158,7 +158,10 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( } template -struct ParamTrait; +struct NCHWParamTrait; + +template +struct NCHW44ParamTrait; std::string ConvBias::WinogradParam::to_string() const { return ssprintf("%u:%u:%u", channel_block_size, output_block_size, @@ -166,32 +169,51 @@ std::string ConvBias::WinogradParam::to_string() const { } template -std::string ConvBias::algo_name(const std::string& base, const T& p) { - return ssprintf("%s:%s:%s", ParamTrait::category.c_str(), base.c_str(), - p.to_string().c_str()); +std::string ConvBias::algo_name(const std::string& base, const T& p, + param::ConvBias::Format format) { + if (format == param::ConvBias::Format::NCHW) { + return ssprintf("%s:%s:%s", NCHWParamTrait::category.c_str(), + base.c_str(), p.to_string().c_str()); + } else if (format == param::ConvBias::Format::NCHW44) { + return ssprintf("%s:%s:%s", NCHW44ParamTrait::category.c_str(), + base.c_str(), p.to_string().c_str()); + } + megdnn_throw("Invalid format"); + return ""; } #define FOREACH_CONV_BIAS_PARAM(cb) \ cb(WinogradParam) cb(DirectParam) cb(MatmulParam) cb(DefaultParam) -#define cb(pt) \ - template <> \ - struct ParamTrait { \ - static const std::string category; \ +#define cb(pt) \ + template <> \ + struct NCHWParamTrait { \ + static const std::string category; \ + }; \ + template <> \ + struct NCHW44ParamTrait { \ + static const std::string category; \ }; FOREACH_CONV_BIAS_PARAM(cb) #undef cb -#define cb(pt, ct) const std::string ParamTrait::category = ct -cb(WinogradParam, "WINOGRAD"); +#define cb(pt, ct) \ + const std::string NCHWParamTrait::category = ct; \ + const std::string NCHW44ParamTrait::category = ct cb(DirectParam, "DIRECT"); cb(MatmulParam, "MATMUL"); cb(DefaultParam, "DEFAULT"); #undef cb +const std::string NCHWParamTrait::category = + "WINOGRAD"; +const std::string NCHW44ParamTrait::category = + "WINOGRAD_NCHW44"; + #define cb(t) \ template std::string ConvBias::algo_name( \ - const std::string& base, const ConvBias::t& p); + const std::string& base, const ConvBias::t& p, \ + param::ConvBias::Format format); FOREACH_CONV_BIAS_PARAM(cb) #undef cb @@ -199,17 +221,37 @@ ConvBias::WinogradParam ConvBias::parse_winograd_name( const std::string& algo_name) { ConvBias::WinogradParam ret = INVALID_WINOGRAD_PARAM; char base[128]; - sscanf(algo_name.c_str(), "WINOGRAD:%[^:]:%u:%u:%u", base, - &(ret.channel_block_size), &(ret.output_block_size), - &(ret.tile_size)); - if (ret.tile_size == 0 || ret.output_block_size == 0 || - ret.channel_block_size == 0) { - megdnn_log_warn("the algo name %s is not suitable for winograd", - algo_name.c_str()); - return INVALID_WINOGRAD_PARAM; + char name[128]; + + auto parse = [&](const std::string& algo_name, + const std::string& pre) -> auto { + memset(name, 0, 128); + sscanf(algo_name.c_str(), "%[^:]:%[^:]:%u:%u:%u", name, base, + &(ret.channel_block_size), &(ret.output_block_size), + &(ret.tile_size)); + if (strcmp(name, pre.c_str())) { + megdnn_log_warn("algo %s is not %s algo", name, pre.c_str()); + ret = INVALID_WINOGRAD_PARAM; + return false; + } + if (ret.tile_size == 0 || ret.output_block_size == 0 || + ret.channel_block_size == 0) { + megdnn_log_warn("the algo name %s is not suitable for %s", + algo_name.c_str(), pre.c_str()); + ret = INVALID_WINOGRAD_PARAM; + return false; + } + return true; + }; + + if (parse(algo_name, "WINOGRAD_NCHW44")) { + return ret; + } else { + parse(algo_name, "WINOGRAD"); + return ret; } - return ret; } + constexpr ConvBias::WinogradParam ConvBias::INVALID_WINOGRAD_PARAM; void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args, diff --git a/dnn/src/common/convolution.cpp b/dnn/src/common/convolution.cpp index 9b8140d7f3dc42240b19906f2e4b337c5cecc64e..c80b5b02621cff4eba01ab60ab3c63ed3f11af9d 100644 --- a/dnn/src/common/convolution.cpp +++ b/dnn/src/common/convolution.cpp @@ -299,6 +299,7 @@ void make_canonized_filter_meta_nchwxx( megdnn_assert(param.format == Param::Format::NCHW88 || param.format == Param::Format::NCHW44 || + param.format == Param::Format::NCHW44_WINOGRAD || param.format == Param::Format::NCHW88_WINOGRAD); size_t img_ndim = 2; size_t flt_start = 0; @@ -663,6 +664,7 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, param().format == Param::Format::NCHW32 || param().format == Param::Format::NCHW88 || param().format == Param::Format::NCHW88_WINOGRAD || + param().format == Param::Format::NCHW44_WINOGRAD || param().format == Param::Format::CHWN4); img_dim = src.ndim - 3; if ((param().format == Param::Format::NCHW88 || diff --git a/dnn/src/common/winograd/winograd_helper.cpp b/dnn/src/common/winograd/winograd_helper.cpp index b9a779c33f2c7eca36e0d50fa58b5daad3718e3d..10ba9af1541a625b3dcf01b98d181861a1d949cd 100644 --- a/dnn/src/common/winograd/winograd_helper.cpp +++ b/dnn/src/common/winograd/winograd_helper.cpp @@ -68,6 +68,7 @@ constexpr size_t layout_pack_size(param::ConvBias::Format layout) { case param::ConvBias::Format::NHWCD4: return 4; case param::ConvBias::Format::NCHW4: + case param::ConvBias::Format::NCHW44: return 4; case param::ConvBias::Format::NCHW32: return 32; @@ -363,6 +364,7 @@ INST(uint8_t, uint8_t, int16_t, int) _ctype, _dst_type, _input_filter_compute_type, \ _output_compute_type, layout, param::MatrixMul::Format::MK4>; INST(float, float, float, float, param::ConvBias::Format::NCHW) +INST(float, float, float, float, param::ConvBias::Format::NCHW44) #undef INST #define INST(_ctype, _dst_type, _input_filter_compute_type, \ diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index f2466b77e51bf41c58ced3e64828a70305c35570..3f68c07fa77aa9556ef1b3e54d0a6ca28e6dbae0 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -425,6 +425,7 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id, break; } case ConvBiasImpl::Param::Format::NCHW_WINOGRAD: + case ConvBiasImpl::Param::Format::NCHW44_WINOGRAD: case ConvBiasImpl::Param::Format::NCHW88_WINOGRAD: { //! four format of weight layout //! 1. {g, alpha, alpha, ocpg/8, icpg/8, 8, 8} diff --git a/dnn/src/fallback/convolution/opr_impl.cpp b/dnn/src/fallback/convolution/opr_impl.cpp index 707623cb0808ce609374edce2b8f9304b6abdf44..322ceeaba33efcd9be430cbd0ed4ec5774a98ba8 100644 --- a/dnn/src/fallback/convolution/opr_impl.cpp +++ b/dnn/src/fallback/convolution/opr_impl.cpp @@ -198,7 +198,8 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, for (auto kernel : kerns) { megdnn_assert(param.filter_meta.format == Param::Format::NCHW || param.filter_meta.format == Param::Format::NHWC || - param.filter_meta.format == Param::Format::NCHW88, + param.filter_meta.format == Param::Format::NCHW88 || + param.filter_meta.format == Param::Format::NCHW44, "invalid conv format"); auto run = [param, kernel](size_t index, size_t thread_id) { CpuNDRange ndrange_id(kernel.global_size, index); diff --git a/dnn/src/naive/winograd_filter_preprocess/opr_impl.cpp b/dnn/src/naive/winograd_filter_preprocess/opr_impl.cpp index b5db83e7f3375c5d56ccb51178bab6b6458dc268..0ebf578d698d60a78a480a8e3b76787fa2297c48 100644 --- a/dnn/src/naive/winograd_filter_preprocess/opr_impl.cpp +++ b/dnn/src/naive/winograd_filter_preprocess/opr_impl.cpp @@ -137,6 +137,7 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, #undef DISPATCH_FORMAT_MK8 #undef DISPATCH_DTYPE } else { + megdnn_assert(src.layout.ndim == 6 || src.layout.ndim == 7); #define cb(_ctype, _dst_type, _input_filter_compute_type, \ _output_compute_type, _format, rescale) \ if (param().format == _format) { \ @@ -158,20 +159,58 @@ void WinogradFilterPreprocessImpl::exec(_megdnn_tensor_in src, DISPATCH_KERNEL(dt_float32, dt_float32, dt_float32, dt_float32, \ DISPATCH_FORMAT_MK8, 1.0f, _midout_tag, 0); \ } - megdnn_assert(src.layout.ndim == 6 || src.layout.ndim == 7); - if (FW == 3) { - if (m == 2) { - std::vector interp_points = {0, 1, -1}; - DISPATCH_DTYPE(4); - } else if (m == 6) { - std::vector interp_points = {0, 1, -1, 2, -2, 0.5, -0.5}; - DISPATCH_DTYPE(5); + if (pack_c_size == 8) { //! NCHW88 + if (FW == 3) { + if (m == 2) { + std::vector interp_points = {0, 1, -1}; + DISPATCH_DTYPE(4); + } else if (m == 6) { + std::vector interp_points = {0, 1, -1, 2, + -2, 0.5, -0.5}; + DISPATCH_DTYPE(5); + } } +#undef cb +#undef DISPATCH_FORMAT_MK8 +#undef DISPATCH_DTYPE } + else if (pack_c_size == 4) { //! NCHW44 +#define cb(_ctype, _dst_type, _input_filter_compute_type, \ + _output_compute_type, _format, rescale) \ + if (param().format == _format) { \ + return winograd::StrategyHelper< \ + _ctype, _dst_type, _input_filter_compute_type, \ + _output_compute_type, param::ConvBias::Format::NCHW44, \ + _format>::filter(src_ptr, dst_ptr, workspace_ptr, OC, IC, 0, \ + OC, m, FW, interp_points, src.layout.dtype, \ + rescale); \ + } + +#define DISPATCH_FORMAT_MK4(_ctype, _dst_type, _input_filter_compute_type, \ + _output_compute_type, _rescale) \ + cb(_ctype, _dst_type, _input_filter_compute_type, _output_compute_type, \ + param::Winograd::Format::MK4, _rescale); + +#define DISPATCH_DTYPE(_midout_tag) \ + if (src.layout.dtype.enumv() == DTypeEnum::Float32) { \ + DISPATCH_KERNEL(dt_float32, dt_float32, dt_float32, dt_float32, \ + DISPATCH_FORMAT_MK4, 1.0f, _midout_tag, 0); \ + } + if (FW == 3) { + if (m == 2) { + std::vector interp_points = {0, 1, -1}; + DISPATCH_DTYPE(6); + } else if (m == 6) { + std::vector interp_points = {0, 1, -1, 2, + -2, 0.5, -0.5}; + DISPATCH_DTYPE(7); + } + } #undef cb #undef DISPATCH_FORMAT_MK8 #undef DISPATCH_KERNEL #undef DISPATCH_DTYPE + } } megdnn_assert(execed, "Unsupport winograd filter preprocess. m: %zu src: %s", m, diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index f9807c8392bf5af389b5f8159e1a3f8d74ea0e04..00f720f622a21a05587ae84f712fcfa0539ddcbc 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -699,6 +699,135 @@ TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F16_F23_8x8) { } #endif +void benchmark_winograd_nchw_vs_nchw44(const char* algo_name, Handle* handle) { + using namespace conv_bias; + using NLMode = param::ConvBias::NonlineMode; + std::vector args_nchw44; + std::vector args_nchw; + + auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, + size_t group, NLMode nlmode) { + param::ConvBias param; + param.format = param::ConvBias::Format::NCHW44; + param.stride_h = 1; + param.stride_w = 1; + param.pad_h = 1; + param.pad_w = 1; + param.nonlineMode = nlmode; + + if (group == 1) { + param.sparse = param::ConvBias::Sparse::DENSE; + args_nchw44.emplace_back(param, TensorShape{n, ic / 4, h, w, 4}, + TensorShape{oc / 4, ic / 4, 3, 3, 4, 4}, + TensorShape{}); + param.format = param::ConvBias::Format::NCHW; + args_nchw.emplace_back(param, TensorShape{n, ic, h, w}, + TensorShape{oc, ic, 3, 3}, TensorShape{}); + } else { + auto oc_per_group = oc / group; + auto ic_per_group = ic / group; + param.sparse = param::ConvBias::Sparse::GROUP; + args_nchw44.emplace_back(param, + TensorShape{n, ic_per_group / 4, h, w, 4}, + TensorShape{group, oc_per_group / 4, + ic_per_group / 4, 3, 3, 4, 4}, + TensorShape{}); + param.format = param::ConvBias::Format::NCHW; + args_nchw.emplace_back( + param, TensorShape{n, ic, h, w}, + TensorShape{group, oc_per_group, ic_per_group, 3, 3}, + TensorShape{}); + } + }; + + std::vector nonlinemode = {NLMode::IDENTITY}; + for (auto nlmode : nonlinemode) + for (size_t n : {1, 2}) + for (size_t group = 1; group <= 2; ++group) { + pack(n, 512, 512, 15, 15, group, nlmode); + pack(n, 512, 256, 15, 15, group, nlmode); + pack(n, 256, 256, 29, 29, group, nlmode); + pack(n, 256, 128, 29, 29, group, nlmode); + pack(n, 128, 128, 57, 57, group, nlmode); + pack(n, 128, 64, 57, 57, group, nlmode); + pack(n, 24, 24, 224, 224, group, nlmode); + pack(n, 64, 24, 123, 123, group, nlmode); + pack(n, 64, 64, 56, 56, group, nlmode); + pack(n, 128, 128, 28, 28, group, nlmode); + pack(n, 256, 256, 14, 14, group, nlmode); + pack(n, 512, 512, 7, 7, group, nlmode); + } + + using namespace conv_bias; + constexpr size_t RUN = 10; + Benchmarker benchmark_winograd_nchw(handle); + benchmark_winograd_nchw.set_display(false); + benchmark_winograd_nchw.set_times(RUN); + + Benchmarker benchmark_winograd_nchw44(handle); + benchmark_winograd_nchw44.set_display(false); + benchmark_winograd_nchw44.set_times(RUN); + + std::string winograd_nchw_algo_name = ssprintf("WINOGRAD:%s", algo_name); + std::string winograd_nchw44_algo_name = + ssprintf("WINOGRAD_NCHW44:%s", algo_name); + + for (size_t i = 0; i < args_nchw.size(); ++i) { + auto arg_nchw = args_nchw[i]; + auto arg_nchw44 = args_nchw44[i]; + + TensorLayout dst_layout; + auto opr = handle->create_operator(); + opr->param() = arg_nchw.param; + opr->deduce_layout({arg_nchw.src, dtype::Float32()}, + {arg_nchw.filter, dtype::Float32()}, + {arg_nchw.bias, dtype::Float32()}, {}, dst_layout); + //! dst.nr_elems * IC * FH * FW * 2 + float computations = dst_layout.total_nr_elems() * arg_nchw.filter[1] * + arg_nchw.filter[2] * arg_nchw.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; + + benchmark_winograd_nchw.set_param(arg_nchw.param); + auto nchw_used = algo_benchmark( + benchmark_winograd_nchw, + {arg_nchw.src, arg_nchw.filter, {}, {}, {}}, + winograd_nchw_algo_name.c_str()) / + RUN; + + benchmark_winograd_nchw44.set_param(arg_nchw44.param); + auto nchw44_used = + algo_benchmark( + benchmark_winograd_nchw44, + {arg_nchw44.src, arg_nchw44.filter, {}, {}, {}}, + winograd_nchw44_algo_name.c_str()) / + RUN; + + printf("%s %s: nchw: %f ms %f Gflops nchw44: %f ms %f GFlops " + "speedup: " + "%f\n", + arg_nchw.src.to_string().c_str(), + arg_nchw.filter.to_string().c_str(), nchw_used, + computations / nchw_used, nchw44_used, + computations / nchw44_used, nchw_used / nchw44_used); + } +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_MK4_NCHW_VS_NCHW44) { +#if MEGDNN_AARCH64 + benchmark_winograd_nchw_vs_nchw44("AARCH64_F32_MK4_4x16:4:2", handle()); +#else + benchmark_winograd_nchw_vs_nchw44("ARMV7_F32_MK4_4x8:4:2", handle()); +#endif +} + +TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F63_MK4_NCHW_VS_NCHW44) { +#if MEGDNN_AARCH64 + benchmark_winograd_nchw_vs_nchw44("AARCH64_F32_MK4_4x16:4:6", handle()); +#else + benchmark_winograd_nchw_vs_nchw44("ARMV7_F32_MK4_4x8:4:6", handle()); +#endif +} + TEST_F(ARM_COMMON, BENCHMARK_CONVBIAS_WINOGRAD_F23_8x8) { auto benchmark_winograd_quantized = [](const char* algo_name_fp32, const char* algo_name_quantized, diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 1f4dc9f4ebec3692232b32219b55ea09dd5d2dfc..553547eb1a55b1bd94848e354701fbd9a86830a9 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -654,6 +654,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) { check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4); } + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44) { + using namespace conv_bias; + std::vector args = get_nchw44_conv_bias_args({3}, 1); + Checker checker(handle()); + check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4, + param::ConvBias::Format::NCHW44); +} + TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) { using namespace conv_bias; std::vector args = get_winograd_args(3); @@ -667,7 +676,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) { std::vector args = get_winograd_mk_packed_args(); Checker checker(handle()); - check_winograd("4:6:32", checker, args, param::MatrixMul::Format::MK4); + check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4); +} + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) { + using namespace conv_bias; + std::vector args = get_nchw44_conv_bias_args({3}, 1); + Checker checker(handle()); + check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4, + param::ConvBias::Format::NCHW44); } TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) { @@ -761,6 +778,75 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) { #endif } +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_PREPROCESS_NCHW44) { + using namespace conv_bias; + std::vector nchw44_args = get_nchw44_conv_bias_args({3}, 1); + + Checker checker(handle()); + + auto extra_impl = [](const TensorNDArray& tensors, uint32_t m, + param::ConvBias param, Handle* handle) { + megdnn_assert(param.format == param::ConvBias::Format::NCHW44); + auto winograd_preprocess_opr = + handle->create_operator(); + winograd_preprocess_opr->param().output_block_size = m; + winograd_preprocess_opr->param().format = param::MatrixMul::Format::MK4; + TensorLayout filter_transform_layout; + winograd_preprocess_opr->deduce_layout(tensors[1].layout, + filter_transform_layout); + size_t winograd_preprocess_workspace_in_bytes = + winograd_preprocess_opr->get_workspace_in_bytes( + tensors[1].layout, filter_transform_layout); + + auto conv_bias_opr = handle->create_operator(); + conv_bias_opr->param() = param; + conv_bias_opr->param().format = param::ConvBias::Format::NCHW44_WINOGRAD; + conv_bias_opr->param().output_block_size = m; + size_t conv_bias_workspace_in_bytes = + conv_bias_opr->get_workspace_in_bytes( + tensors[0].layout, filter_transform_layout, + tensors[2].layout, tensors[3].layout, + tensors[4].layout, nullptr); + + WorkspaceBundle wb(nullptr, {filter_transform_layout.span().dist_byte(), + conv_bias_workspace_in_bytes, + winograd_preprocess_workspace_in_bytes}); + wb.set(malloc(wb.total_size_in_bytes())); + + TensorND filter_transform_tensor(wb.get(0), + std::move(filter_transform_layout)); + winograd_preprocess_opr->exec(tensors[1], filter_transform_tensor, + wb.get_workspace(2)); + conv_bias_opr->exec(tensors[0], filter_transform_tensor, tensors[2], + tensors[3], tensors[4], nullptr, + wb.get_workspace(1)); + free(wb.ptr()); + }; + + auto run = [&checker, &extra_impl]( + Handle* handle, const std::vector& args, + const std::vector& out_size, DType A_dtype, + DType B_dtype, DType C_dtype, DType D_dtype, + const float eps) { + for (auto&& arg : args) { + for (uint32_t m : out_size) { + checker.set_extra_opr_impl(std::bind(extra_impl, + std::placeholders::_1, m, + arg.param, handle)); + checker.set_dtype(0, A_dtype) + .set_dtype(1, B_dtype) + .set_dtype(2, C_dtype) + .set_dtype(4, D_dtype) + .set_epsilon(eps) + .set_param(arg.param) + .execs({arg.src, arg.filter, arg.bias, {}, {}}); + } + } + }; + run(handle(), nchw44_args, {2, 6}, dtype::Float32(), dtype::Float32(), + dtype::Float32(), dtype::Float32(), 1e-3f); +} + TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) { using namespace conv_bias; diff --git a/src/gopt/impl/weights_preprocess.cpp b/src/gopt/impl/weights_preprocess.cpp index 136ec52b4453eb6194cb22ea2b78c1c6910e924c..3f8c84f3f379f6f94fea6003061934918f6e231d 100644 --- a/src/gopt/impl/weights_preprocess.cpp +++ b/src/gopt/impl/weights_preprocess.cpp @@ -72,7 +72,7 @@ void WinogradTransformReplacePass::apply(OptState& opt) const { auto&& inputs = conv_bias_opr.input(); VarNodeArray new_inp; new_inp.reserve(inputs.size()); - for (auto i: inputs) { + for (auto i : inputs) { new_inp.push_back(rewriter.get_var(i)); } @@ -86,11 +86,15 @@ void WinogradTransformReplacePass::apply(OptState& opt) const { megdnn::ConvBias::parse_winograd_name(algo_name); if (winograd_param == megdnn::ConvBias::INVALID_WINOGRAD_PARAM) break; - mgb_assert(conv_bias_opr.param().format == - megdnn::ConvBias::Param::Format::NCHW || - conv_bias_opr.param().format == - megdnn::ConvBias::Param::Format::NCHW88, - "currently winograd only suppport NCHW and nchw88"); + mgb_assert( + conv_bias_opr.param().format == + megdnn::ConvBias::Param::Format::NCHW || + conv_bias_opr.param().format == + megdnn::ConvBias::Param::Format::NCHW88 || + conv_bias_opr.param().format == + megdnn::ConvBias::Param::Format::NCHW44, + "currently winograd only suppport NCHW and NCHW44 and " + "NCHW88"); opr::ConvBiasForward::check_winograd_param_valid( winograd_param, conv_bias_opr.input(0)->dtype()); megdnn::param::Winograd winograd_preprocess_param; @@ -110,8 +114,17 @@ void WinogradTransformReplacePass::apply(OptState& opt) const { megdnn::ConvBias::Param::Format::NCHW_WINOGRAD; } else { mgb_assert(new_inp[0]->shape().ndim == 5); - conv_bias_param.format = - megdnn::ConvBias::Param::Format::NCHW88_WINOGRAD; + size_t pack_size = new_inp[0]->shape()[4]; + if (pack_size == 8) { + conv_bias_param.format = + megdnn::ConvBias::Param::Format::NCHW88_WINOGRAD; + } else if (pack_size == 4) { + conv_bias_param.format = + megdnn::ConvBias::Param::Format::NCHW44_WINOGRAD; + } else { + mgb_assert(0, "Invalid pack size %zu in algo %s", pack_size, + algo_name.c_str()); + } } conv_bias_param.output_block_size = winograd_param.output_block_size;