From a33c3b73bd9fd09dcc9d35a0caa5d940c20b2e5d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 11 Jul 2021 22:52:32 +0800 Subject: [PATCH] refactor(mgb/dnn): arm pooling rebase algochooser GitOrigin-RevId: 21d17e647afdc349929ebc668639406e088e3c68 --- dnn/src/arm_common/pooling/algo.h | 23 +++++ dnn/src/arm_common/pooling/opr_impl.cpp | 120 ++++++++++++++++-------- dnn/src/arm_common/pooling/opr_impl.h | 89 +++++++++++++++--- 3 files changed, 176 insertions(+), 56 deletions(-) diff --git a/dnn/src/arm_common/pooling/algo.h b/dnn/src/arm_common/pooling/algo.h index d7c0f00f8..71506f7db 100644 --- a/dnn/src/arm_common/pooling/algo.h +++ b/dnn/src/arm_common/pooling/algo.h @@ -28,6 +28,7 @@ public: const char* name() const override { return "ARM_POOLING_STRIDE1"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_FilterxModexStride1) }; class PoolingImpl::AlgoFilter2ModexStride2 final : public AlgoBase { @@ -38,6 +39,7 @@ public: const char* name() const override { return "ARM_POOLING_STRIDE2"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Filter2ModexStride2) }; class PoolingImpl::AlgoFilter3MaxStride2 final : public AlgoBase { public: @@ -47,6 +49,7 @@ public: const char* name() const override { return "ARM_POOLING_FILTER3_MAX"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Filter3MaxStride2) }; class PoolingImpl::AlgoFilter3AverageStride2 final : public AlgoBase { @@ -57,6 +60,7 @@ public: const char* name() const override { return "ARM_POOLING_FILTER3_AVERAGE"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Filter3AverageStride2) }; class PoolingImpl::AlgoFilter4MaxStride2 final : public AlgoBase { @@ -67,6 +71,7 @@ public: const char* name() const override { return "ARM_POOLING_FILTER4_MAX"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Filter4MaxStride2) }; class PoolingImpl::AlgoFilter5MaxStride2 final : public AlgoBase { @@ -77,6 +82,7 @@ public: const char* name() const override { return "ARM_POOLING_FILTER5_MAX"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Filter5MaxStride2) }; class PoolingImpl::AlgoInt8Filter2MaxStride2 final : public AlgoBase { @@ -87,6 +93,7 @@ public: const char* name() const override { return "ARM_POOLING_INT8_FILTER2X2"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Int8Filter2MaxStride2) }; class PoolingImpl::AlgoInt8Filter3MaxStride2 final : public AlgoBase { @@ -97,6 +104,7 @@ public: const char* name() const override { return "ARM_POOLING_INT8_FILTER3X3"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Int8Filter3MaxStride2) }; class PoolingImpl::AlgoFilter3ModexStridexNCHW44 final : public AlgoBase { @@ -107,6 +115,7 @@ public: const char* name() const override { return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Filter3ModexStridexNCHW44) }; class PoolingImpl::AlgoFilter2ModexStridexNCHW44 final : public AlgoBase { @@ -117,6 +126,7 @@ public: const char* name() const override { return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Filter2ModexStridexNCHW44) }; class PoolingImpl::AlgoFilter4ModexStridexNCHW44 final : public AlgoBase { @@ -127,6 +137,7 @@ public: const char* name() const override { return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Filter4ModexStridexNCHW44) }; class PoolingImpl::AlgoFilter5ModexStridexNCHW44 final : public AlgoBase { @@ -137,6 +148,7 @@ public: const char* name() const override { return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Filter5ModexStridexNCHW44) }; class PoolingImpl::AlgoFp32ModexStridexNCHW44 final : public AlgoBase { public: @@ -146,6 +158,17 @@ public: const char* name() const override { return "ARM_POOLING_FP32_MODEX_STRIDEX_NCHW44"; } bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; + MEGDNN_DECL_ALGO_TYPE(ARM_Fp32ModexStridexNCHW44) +}; +class PoolingImpl::AlgoFallback final : public AlgoBase { +public: + AlgoAttribute attribute() const override { + return AlgoAttribute::REPRODUCIBLE; + }; + const char* name() const override { return "FALLBACK_POOLING"; } + bool usable(const PoolingKernSizeParam&) const override { return true; } + void exec(const PoolingKernParam&) const override {} + MEGDNN_DECL_ALGO_TYPE(ARM_Fallback) }; WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); diff --git a/dnn/src/arm_common/pooling/opr_impl.cpp b/dnn/src/arm_common/pooling/opr_impl.cpp index 4acf320bc..43edb7ac0 100644 --- a/dnn/src/arm_common/pooling/opr_impl.cpp +++ b/dnn/src/arm_common/pooling/opr_impl.cpp @@ -12,11 +12,14 @@ #include "src/arm_common/pooling/opr_impl.h" #include "src/arm_common/pooling/algo.h" #include "src/common/metahelper.h" +#include "src/common/algo_chooser.h" using namespace megdnn; using namespace arm_common; class PoolingImpl::AlgoPack : NonCopyableObj { +private: + AlgoBase::Mapper m_all_algos_map; AlgoFilterxModexStride1 algo_filterx_modex_stride1; AlgoFilter2ModexStride2 algo_filter2_modex_stride2; AlgoFilter3MaxStride2 algo_filter3_max_stride2; @@ -30,6 +33,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj { AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4; AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4; AlgoFp32ModexStridexNCHW44 algo_fp32_modex_stridex_nchw44; + AlgoFallback algo_fallback; public: AlgoPack() { @@ -46,10 +50,18 @@ public: all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4); all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4); all_algos.emplace_back(&algo_fp32_modex_stridex_nchw44); + all_algos.emplace_back(&algo_fallback); + + for (auto&& algo : all_algos) { + m_all_algos_map.emplace(algo->info().desc, algo); + } } SmallVector all_algos; + const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } }; +PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack; + PoolingImpl::PoolingKernSizeParam PoolingImpl::make_pooling_kern_szie_param( fallback::PoolingImpl* opr, const TensorLayout& src, const TensorLayout& dst) { @@ -89,44 +101,36 @@ PoolingImpl::PoolingKernParam PoolingImpl::make_pooling_kern_param( size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst) { - bool find_algo = false; - static AlgoPack m_algo_pack; auto param = make_pooling_kern_szie_param(this, src, dst); - for (auto& m_algo : m_algo_pack.all_algos) { - if (m_algo->usable(param)) { - find_algo = true; - break; - } - } - size_t arm_common_workspace = 0; - - //! When multi-thread, every thread has its own workspace - size_t nr_threads = static_cast(handle()) - ->megcore_dispatcher() - ->nr_threads(); - if ((param.src_type.category() == DTypeCategory::FLOAT || - param.src_type == dtype::Int8{} || - param.src_type.enumv() == DTypeEnum::QuantizedS8 || - param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && - param.filter[0] == param.filter[1] && - (param.filter[0] == 3 || param.filter[0] == 5) && - param.format == Param::Format::NCHW && - (param.mode == Mode::MAX || - (param.mode == Mode::AVERAGE && param.filter[0] == 3)) && - param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 && - param.isz[1] >= 2) { - WorkspaceBundle ws = get_bundle(param); - arm_common_workspace = ws.total_size_in_bytes() * nr_threads; - } + auto algo = get_algorithm(this, src, dst); + if (!is_fallback_algo(algo)) { + size_t arm_common_workspace = 0; - if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || - param.src_type.enumv() == DTypeEnum::Int8) && - (param.format == param::Pooling::Format::NCHW44)) { - WorkspaceBundle ws = get_bundle_nchw44(param); - arm_common_workspace = ws.total_size_in_bytes() * nr_threads; - } + //! When multi-thread, every thread has its own workspace + size_t nr_threads = static_cast(handle()) + ->megcore_dispatcher() + ->nr_threads(); + if ((param.src_type.category() == DTypeCategory::FLOAT || + param.src_type == dtype::Int8{} || + param.src_type.enumv() == DTypeEnum::QuantizedS8 || + param.src_type.enumv() == DTypeEnum::Quantized8Asymm) && + param.filter[0] == param.filter[1] && + (param.filter[0] == 3 || param.filter[0] == 5) && + param.format == Param::Format::NCHW && + (param.mode == Mode::MAX || + (param.mode == Mode::AVERAGE && param.filter[0] == 3)) && + param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 && + param.isz[1] >= 2) { + WorkspaceBundle ws = get_bundle(param); + arm_common_workspace = ws.total_size_in_bytes() * nr_threads; + } - if (find_algo) { + if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 || + param.src_type.enumv() == DTypeEnum::Int8) && + (param.format == param::Pooling::Format::NCHW44)) { + WorkspaceBundle ws = get_bundle_nchw44(param); + arm_common_workspace = ws.total_size_in_bytes() * nr_threads; + } return arm_common_workspace; } else { auto fallback_worksapce = @@ -139,14 +143,48 @@ void PoolingImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); auto param = make_pooling_kern_param(this, src, dst, workspace); - static AlgoPack m_algo_pack; - for (auto& m_algo : m_algo_pack.all_algos) { - if (m_algo->usable(param)) { - m_algo->exec(param); - return; + auto algo = get_algorithm(this, src.layout, dst.layout); + if (!is_fallback_algo(algo)) { + algo->exec(param); + } else { + fallback::PoolingImpl::exec(src, dst, workspace); + } +} + +MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingImpl); + +std::vector PoolingImpl::get_all_algorithms( + const TensorLayout& src, const TensorLayout& dst) { + auto param = make_pooling_kern_szie_param(this, src, dst); + std::vector ret; + ret.reserve(algo_pack().all_algos.size()); + for (auto i : algo_pack().all_algos) { + if (i->usable(param)) { + ret.push_back(i); + } + } + megdnn_assert(!ret.empty(), "no usable pooling fwd algorithm"); + return ret; +} + +Algorithm* PoolingImpl::get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { + MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes); + + auto param = make_pooling_kern_szie_param(this, src, dst); + for (auto&& iter : sm_algo_pack.all_algos) { + if (iter->is_available_attribute(param, positive_attr, negative_attr)) { + return iter; } } - fallback::PoolingImpl::exec(src, dst, workspace); + megdnn_throw( + ssprintf("require algorithm with attribute(%s) and without " + "attribute(%s), but can't get suitable algo.\n", + Algorithm::attribute_str(positive_attr).c_str(), + Algorithm::attribute_str(negative_attr).c_str())); + return nullptr; } // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/pooling/opr_impl.h b/dnn/src/arm_common/pooling/opr_impl.h index fca7d282a..04ab72e51 100644 --- a/dnn/src/arm_common/pooling/opr_impl.h +++ b/dnn/src/arm_common/pooling/opr_impl.h @@ -12,11 +12,30 @@ #pragma once #include "megdnn/oprs/base.h" #include "src/fallback/pooling/opr_impl.h" +#include namespace megdnn { namespace arm_common { class PoolingImpl final : public fallback::PoolingImpl { +private: + class AlgoFilterxModexStride1; + class AlgoFilter2ModexStride2; + class AlgoFilter3MaxStride2; + class AlgoFilter3AverageStride2; + class AlgoFilter4MaxStride2; + class AlgoFilter5MaxStride2; + class AlgoInt8Filter2MaxStride2; + class AlgoInt8Filter3MaxStride2; + class AlgoFilter2ModexStridexNCHW44; + class AlgoFilter3ModexStridexNCHW44; + class AlgoFilter4ModexStridexNCHW44; + class AlgoFilter5ModexStridexNCHW44; + class AlgoFp32ModexStridexNCHW44; + class AlgoFallback; + class AlgoPack; + static AlgoPack sm_algo_pack; + public: using fallback::PoolingImpl::PoolingImpl; void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, @@ -70,28 +89,68 @@ public: _megdnn_workspace workspace); class AlgoBase : public detail::Algorithm { public: + enum class AlgoType : uint32_t { + ARM_FilterxModexStride1, + ARM_Filter2ModexStride2, + ARM_Filter3MaxStride2, + ARM_Filter3AverageStride2, + ARM_Filter4MaxStride2, + ARM_Filter5MaxStride2, + ARM_Int8Filter2MaxStride2, + ARM_Int8Filter3MaxStride2, + ARM_Filter2ModexStridexNCHW44, + ARM_Filter3ModexStridexNCHW44, + ARM_Filter4ModexStridexNCHW44, + ARM_Filter5ModexStridexNCHW44, + ARM_Fp32ModexStridexNCHW44, + ARM_Fallback + }; + + using Mapper = std::unordered_map; + AlgoBase() : Algorithm() { m_handle_type = Handle::HandleType::ARM_COMMON; } virtual ~AlgoBase() = default; virtual bool usable(const PoolingKernSizeParam& param) const = 0; virtual void exec(const PoolingKernParam& param) const = 0; uint32_t type() const override { return INVALID_ALGO_TYPE; }; + bool is_available_attribute( + const PoolingKernSizeParam& param, + const AlgoAttribute& positive_attr = + AlgoAttribute::REPRODUCIBLE, + const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { + return contain_attribute_all(positive_attr) && + !contain_attribute_any(negative_attr) && usable(param); + } }; -private: - class AlgoFilterxModexStride1; - class AlgoFilter2ModexStride2; - class AlgoFilter3MaxStride2; - class AlgoFilter3AverageStride2; - class AlgoFilter4MaxStride2; - class AlgoFilter5MaxStride2; - class AlgoInt8Filter2MaxStride2; - class AlgoInt8Filter3MaxStride2; - class AlgoFilter2ModexStridexNCHW44; - class AlgoFilter3ModexStridexNCHW44; - class AlgoFilter4ModexStridexNCHW44; - class AlgoFilter5ModexStridexNCHW44; - class AlgoFp32ModexStridexNCHW44; - class AlgoPack; + const char* get_algorithm_set_name() const override { + return "ARM_POOLING_FORWARD"; + } + + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; + + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& dst) override; + + Algorithm* get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) override; + + AlgorithmInfo get_algorithm_info_heuristic( + const TensorLayout& src, const TensorLayout& dst, + size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr, + const AlgoAttribute& negative_attr) { + return get_algorithm_heuristic(src, dst, workspace_limit_in_bytes, + positive_attr, negative_attr) + ->info(); + } + + static const AlgoPack& algo_pack() { return sm_algo_pack; } + + bool is_fallback_algo(Algorithm* algo) { + return strcmp(algo->name(), "FALLBACK_POOLING") == 0; + } }; } // namespace arm_common } // namespace megdnn -- GitLab