From 85fa9883486f4ccb6771991cf88486ba1ee70121 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 8 Dec 2020 16:05:09 +0800 Subject: [PATCH] refactor(dnn): add get_algorithm_from_desc interface GitOrigin-RevId: 6d211ca1676d43b8b4eeed751d30468097a08b5c --- dnn/include/megdnn/oprs/base.h | 3 + dnn/src/common/algo_base.h | 11 +- dnn/src/common/algo_chooser.h | 4 +- dnn/src/cuda/batch_conv_bias/opr_impl.h | 2 +- dnn/src/cuda/batched_matrix_mul/opr_impl.h | 2 +- dnn/src/cuda/conv_bias/opr_impl.h | 2 +- dnn/src/cuda/convolution/opr_impl.cpp | 22 ++ dnn/src/cuda/convolution/opr_impl.h | 6 +- dnn/src/cuda/convolution3d/opr_impl.h | 6 +- dnn/src/cuda/deformable_conv/opr_impl.h | 6 +- dnn/src/cuda/local_share/opr_impl.h | 6 +- dnn/src/cuda/matrix_mul/opr_impl.h | 2 +- .../fallback/batched_matrix_mul/opr_impl.h | 3 +- dnn/src/fallback/conv_bias/opr_impl.cpp | 6 +- dnn/src/fallback/conv_bias/opr_impl.h | 2 +- dnn/src/fallback/convolution/opr_impl.cpp | 12 +- dnn/src/fallback/convolution/opr_impl.h | 4 +- dnn/src/fallback/matrix_mul/opr_impl.cpp | 5 +- dnn/src/fallback/matrix_mul/opr_impl.h | 3 +- dnn/src/naive/batch_conv_bias/opr_impl.cpp | 8 + dnn/src/naive/batch_conv_bias/opr_impl.h | 2 + dnn/src/naive/batched_matrix_mul/opr_impl.cpp | 9 + dnn/src/naive/batched_matrix_mul/opr_impl.h | 2 + dnn/src/naive/conv_bias/opr_impl.cpp | 9 + dnn/src/naive/conv_bias/opr_impl.h | 2 + dnn/src/naive/convolution/convolution.cpp | 26 +++ dnn/src/naive/convolution/opr_impl.h | 6 + dnn/src/naive/convolution3d/convolution3d.cpp | 198 ++++++++++-------- dnn/src/naive/convolution3d/opr_impl.h | 128 ++++++----- dnn/src/naive/deformable_conv/opr_impl.h | 12 ++ dnn/src/naive/local_share/opr_impl.cpp | 27 +++ dnn/src/naive/local_share/opr_impl.h | 6 +- dnn/src/naive/matrix_mul/opr_impl.cpp | 8 + dnn/src/naive/matrix_mul/opr_impl.h | 2 + dnn/src/rocm/batched_matrix_mul/opr_impl.h | 2 +- dnn/src/rocm/convolution/opr_impl.h | 6 +- dnn/src/rocm/matrix_mul/opr_impl.h | 3 +- src/opr/test/dnn/convolution.cpp | 4 + 38 files changed, 373 insertions(+), 194 deletions(-) diff --git a/dnn/include/megdnn/oprs/base.h b/dnn/include/megdnn/oprs/base.h index e2c106348..734109501 100644 --- a/dnn/include/megdnn/oprs/base.h +++ b/dnn/include/megdnn/oprs/base.h @@ -188,6 +188,7 @@ public: using AlgorithmInfo = detail::Algorithm::Info; using AlgorithmDesc = detail::Algorithm::Info::Desc; using Algorithm = detail::Algorithm; + /*! * \brief get a string representation for current algorithm set; * @@ -209,6 +210,8 @@ public: return m_execution_policy; } + virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0; + protected: ~MultiAlgoOpr() = default; diff --git a/dnn/src/common/algo_base.h b/dnn/src/common/algo_base.h index 39056f570..11724a96a 100644 --- a/dnn/src/common/algo_base.h +++ b/dnn/src/common/algo_base.h @@ -38,11 +38,12 @@ namespace megdnn { return algo_pack().all_algos_map().at(desc); \ } -#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ - _opr::AlgoBase* _opr::get_algo_from_desc(const AlgorithmDesc& desc) { \ - megdnn_assert(algo_pack().all_algos_map().find(desc) != \ - algo_pack().all_algos_map().end()); \ - return algo_pack().all_algos_map().at(desc); \ +#define MEGDNN_DEF_GET_ALGO_FROM_DESC(_opr) \ + _opr::Algorithm* _opr::get_algorithm_from_desc( \ + const AlgorithmDesc& desc) { \ + megdnn_assert(algo_pack().all_algos_map().find(desc) != \ + algo_pack().all_algos_map().end()); \ + return algo_pack().all_algos_map().at(desc); \ } /** diff --git a/dnn/src/common/algo_chooser.h b/dnn/src/common/algo_chooser.h index 211171dc3..f7486c2ad 100644 --- a/dnn/src/common/algo_chooser.h +++ b/dnn/src/common/algo_chooser.h @@ -34,7 +34,8 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { std::forward(args)..., std::numeric_limits::max(), false); } - return opr->get_algo_from_desc(ret.desc); + return static_cast( + opr->get_algorithm_from_desc(ret.desc)); } /*! @@ -43,7 +44,6 @@ typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { */ template typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { - typename Opr::AlgorithmInfo ret; auto set = opr->execution_policy().algo; if (set.valid()) { return opr->algo_pack().construct_and_get_algo(set.desc); diff --git a/dnn/src/cuda/batch_conv_bias/opr_impl.h b/dnn/src/cuda/batch_conv_bias/opr_impl.h index 3ddc0a1f5..996bb71dd 100644 --- a/dnn/src/cuda/batch_conv_bias/opr_impl.h +++ b/dnn/src/cuda/batch_conv_bias/opr_impl.h @@ -35,7 +35,7 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms( diff --git a/dnn/src/cuda/batched_matrix_mul/opr_impl.h b/dnn/src/cuda/batched_matrix_mul/opr_impl.h index 62ab45007..5686e148e 100644 --- a/dnn/src/cuda/batched_matrix_mul/opr_impl.h +++ b/dnn/src/cuda/batched_matrix_mul/opr_impl.h @@ -39,7 +39,7 @@ public: bool is_thread_safe() const override { return true; } static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms(const TensorLayout& A, diff --git a/dnn/src/cuda/conv_bias/opr_impl.h b/dnn/src/cuda/conv_bias/opr_impl.h index 4c36a8933..916e61c37 100644 --- a/dnn/src/cuda/conv_bias/opr_impl.h +++ b/dnn/src/cuda/conv_bias/opr_impl.h @@ -69,7 +69,7 @@ public: static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; std::vector get_all_algorithms( const TensorLayout& src, const TensorLayout& filter, diff --git a/dnn/src/cuda/convolution/opr_impl.cpp b/dnn/src/cuda/convolution/opr_impl.cpp index 6e4b0c3da..1d76b9d7f 100644 --- a/dnn/src/cuda/convolution/opr_impl.cpp +++ b/dnn/src/cuda/convolution/opr_impl.cpp @@ -86,6 +86,28 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, workspace_limit_in_bytes, reproducible); } +ConvolutionForwardImpl::Algorithm* +ConvolutionForwardImpl::get_algorithm_from_desc( + const ConvolutionForward::AlgorithmDesc& desc) { + auto conv_param = param(); + auto convbias_opr = this->handle()->create_operator(); + convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY, + conv_param.mode, + conv_param.sparse, + conv_param.format, + conv_param.pad_h, + conv_param.pad_w, + conv_param.stride_h, + conv_param.stride_w, + conv_param.dilate_h, + conv_param.dilate_w, + conv_param.compute_mode}; + convbias_opr->execution_policy() = {this->execution_policy().algo}; + + return static_cast(convbias_opr.get()) + ->get_algorithm_from_desc(desc); +} + std::vector ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, const TensorLayout& filter, diff --git a/dnn/src/cuda/convolution/opr_impl.h b/dnn/src/cuda/convolution/opr_impl.h index 58ce5785e..f10a2739f 100644 --- a/dnn/src/cuda/convolution/opr_impl.h +++ b/dnn/src/cuda/convolution/opr_impl.h @@ -46,6 +46,8 @@ class ConvolutionForwardImpl: public ConvolutionForward { megdnn_throw("cuda exec_preprocess has not implemeted yet"); } + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; + protected: struct ConvBiasExtraData{ std::unique_ptr convbias_opr; @@ -98,7 +100,7 @@ public: static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms( @@ -152,7 +154,7 @@ public: static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms( diff --git a/dnn/src/cuda/convolution3d/opr_impl.h b/dnn/src/cuda/convolution3d/opr_impl.h index 8d419b495..4b2f21e93 100644 --- a/dnn/src/cuda/convolution3d/opr_impl.h +++ b/dnn/src/cuda/convolution3d/opr_impl.h @@ -42,7 +42,7 @@ public: class AlgoGroupConvGeneral; class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms( @@ -92,7 +92,7 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms( @@ -143,7 +143,7 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms( diff --git a/dnn/src/cuda/deformable_conv/opr_impl.h b/dnn/src/cuda/deformable_conv/opr_impl.h index 6f85a4e15..04d19efbf 100644 --- a/dnn/src/cuda/deformable_conv/opr_impl.h +++ b/dnn/src/cuda/deformable_conv/opr_impl.h @@ -46,7 +46,7 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms( @@ -97,7 +97,7 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms( @@ -151,7 +151,7 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms( diff --git a/dnn/src/cuda/local_share/opr_impl.h b/dnn/src/cuda/local_share/opr_impl.h index 28bc12bde..f877e2f31 100644 --- a/dnn/src/cuda/local_share/opr_impl.h +++ b/dnn/src/cuda/local_share/opr_impl.h @@ -33,7 +33,7 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms( @@ -65,7 +65,7 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms( @@ -98,7 +98,7 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms( diff --git a/dnn/src/cuda/matrix_mul/opr_impl.h b/dnn/src/cuda/matrix_mul/opr_impl.h index 0f4997b09..6adf44eb5 100644 --- a/dnn/src/cuda/matrix_mul/opr_impl.h +++ b/dnn/src/cuda/matrix_mul/opr_impl.h @@ -46,7 +46,7 @@ public: static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; protected: std::vector get_all_algorithms(const TensorLayout& A, diff --git a/dnn/src/fallback/batched_matrix_mul/opr_impl.h b/dnn/src/fallback/batched_matrix_mul/opr_impl.h index 801c2c837..6a26fad39 100644 --- a/dnn/src/fallback/batched_matrix_mul/opr_impl.h +++ b/dnn/src/fallback/batched_matrix_mul/opr_impl.h @@ -29,8 +29,7 @@ public: class AlgoDefault; class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; private: std::vector get_all_algorithms( diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index eb88d30ce..7471541cb 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -454,8 +454,8 @@ std::vector ConvBiasImpl::get_all_algorithms_with_ncb( return algos; } -ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc( - const AlgorithmDesc& desc) const { +ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { if (!desc.valid()) { return nullptr; } else { @@ -495,7 +495,7 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algo_from_desc( ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( const NCBKernSizeParam& param, size_t workspace_size) { - if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { + if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) { return algo; } if (!m_prev_selected_algo || diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index 850d44902..fb3aff9d5 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -381,7 +381,7 @@ private: bool is_naive_algo(ConvBiasImpl::Algorithm* algo); - Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; //! get algorithm set by user or by heuristic Algorithm* get_algorithm( diff --git a/dnn/src/fallback/convolution/opr_impl.cpp b/dnn/src/fallback/convolution/opr_impl.cpp index 7417af0d9..d15c0af92 100644 --- a/dnn/src/fallback/convolution/opr_impl.cpp +++ b/dnn/src/fallback/convolution/opr_impl.cpp @@ -361,8 +361,8 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) { return ret; } -ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc( - const AlgorithmDesc& desc) const { +ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { if (!desc.valid()) { return nullptr; } else { @@ -387,7 +387,7 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algo_from_desc( ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( const NCBKernSizeParam& param, size_t workspace_size) { - if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { + if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) { return algo; } if (!m_prev_selected_algo || @@ -749,8 +749,8 @@ ConvolutionBackwardDataImpl::ncb_1g_get_algorithm_heuristic( } ConvolutionBackwardDataImpl::Algorithm* -ConvolutionBackwardDataImpl::get_algo_from_desc( - const AlgorithmDesc& desc) const { +ConvolutionBackwardDataImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { if (!desc.valid()) { return nullptr; } else { @@ -783,7 +783,7 @@ ConvolutionBackwardDataImpl::get_algo_from_desc( ConvolutionBackwardDataImpl::Algorithm* ConvolutionBackwardDataImpl::get_algorithm(const NCBKernSizeParam& param) { - if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { + if (auto algo = get_algorithm_from_desc(execution_policy().algo.desc)) { return algo; } if (!m_prev_selected_algo || diff --git a/dnn/src/fallback/convolution/opr_impl.h b/dnn/src/fallback/convolution/opr_impl.h index 868839e14..c0d5d928b 100644 --- a/dnn/src/fallback/convolution/opr_impl.h +++ b/dnn/src/fallback/convolution/opr_impl.h @@ -284,7 +284,7 @@ private: NCBKernSizeParam m_prev_selected_algo_sizep; Algorithm* m_prev_selected_algo = nullptr; - Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; bool is_naive_algo(ConvolutionImpl::Algorithm* algo); Algorithm* get_algorithm( const NCBKernSizeParam& param, @@ -493,7 +493,7 @@ private: class AlgoDirect; class AlgoMatrixMul; class AlgoPack; - Algorithm* get_algo_from_desc(const AlgorithmDesc& desc) const; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; public: //! maintain all the algos of in the opr of fallback diff --git a/dnn/src/fallback/matrix_mul/opr_impl.cpp b/dnn/src/fallback/matrix_mul/opr_impl.cpp index 09f1a1645..c1df44e90 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.cpp +++ b/dnn/src/fallback/matrix_mul/opr_impl.cpp @@ -96,7 +96,7 @@ std::vector MatrixMulImpl::get_all_algorithms( return gemv_algos; } -MatrixMulImpl::AlgoBase* MatrixMulImpl::get_algo_from_desc( +MatrixMulImpl::Algorithm* MatrixMulImpl::get_algorithm_from_desc( const AlgorithmDesc& desc) { if (!desc.valid()) { return nullptr; @@ -133,7 +133,8 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, size_t workspace_limit_in_bytes, bool reproducible) { auto kern_size_param = make_kern_size_param(A, B, C); - if (auto algo = get_algo_from_desc(execution_policy().algo.desc)) { + if (auto algo = static_cast( + get_algorithm_from_desc(execution_policy().algo.desc))) { megdnn_assert(algo->get_workspace(kern_size_param) < workspace_limit_in_bytes); auto cur = megdnn::get_reproducible_algo(algo, diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index eeedda3ff..48d23b9fb 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -238,7 +238,8 @@ private: class AlgoPack; //! maintain all the algos of in the opr of fallback static const AlgoPack& algo_pack(); - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; + public: /** diff --git a/dnn/src/naive/batch_conv_bias/opr_impl.cpp b/dnn/src/naive/batch_conv_bias/opr_impl.cpp index 51efe2bf2..a4ddbbd7a 100644 --- a/dnn/src/naive/batch_conv_bias/opr_impl.cpp +++ b/dnn/src/naive/batch_conv_bias/opr_impl.cpp @@ -138,4 +138,12 @@ BatchConvBiasForwardImpl::get_algorithm_heuristic( return algo; } +BatchConvBiasForward::Algorithm* +BatchConvBiasForwardImpl::get_algorithm_from_desc(const AlgorithmDesc& desc) { + Algorithm* ret = static_cast(handle()) + ->default_batch_conv_bias_fwd_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/batch_conv_bias/opr_impl.h b/dnn/src/naive/batch_conv_bias/opr_impl.h index 1e82696d4..555e683d4 100644 --- a/dnn/src/naive/batch_conv_bias/opr_impl.h +++ b/dnn/src/naive/batch_conv_bias/opr_impl.h @@ -39,6 +39,8 @@ public: size_t workspace_limit_in_bytes, bool reproducible) override; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; + const char* get_algorithm_set_name() const override { return "DEFAULT"; } private: WorkspaceBundle get_workspace_bundle(dt_byte* raw_ptr, diff --git a/dnn/src/naive/batched_matrix_mul/opr_impl.cpp b/dnn/src/naive/batched_matrix_mul/opr_impl.cpp index bbd616607..a62184766 100644 --- a/dnn/src/naive/batched_matrix_mul/opr_impl.cpp +++ b/dnn/src/naive/batched_matrix_mul/opr_impl.cpp @@ -81,6 +81,15 @@ BatchedMatrixMulForwardImpl::get_algorithm_heuristic( ->default_batched_matmul_fwd_algo(); } +BatchedMatrixMulForward::Algorithm* +BatchedMatrixMulForwardImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { + Algorithm* ret = static_cast(handle()) + ->default_batched_matmul_fwd_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + } // namespace naive } // namespace megdnn diff --git a/dnn/src/naive/batched_matrix_mul/opr_impl.h b/dnn/src/naive/batched_matrix_mul/opr_impl.h index 3b1a4411d..03a702189 100644 --- a/dnn/src/naive/batched_matrix_mul/opr_impl.h +++ b/dnn/src/naive/batched_matrix_mul/opr_impl.h @@ -34,6 +34,8 @@ public: size_t /*workspace_limit_in_bytes*/, bool /* reproducible */) override; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; + const char* get_algorithm_set_name() const override { return "DEFAULT"; } private: diff --git a/dnn/src/naive/conv_bias/opr_impl.cpp b/dnn/src/naive/conv_bias/opr_impl.cpp index 9c9a12e38..0eb52c419 100644 --- a/dnn/src/naive/conv_bias/opr_impl.cpp +++ b/dnn/src/naive/conv_bias/opr_impl.cpp @@ -256,6 +256,15 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( return algo; } +ConvBiasForward::Algorithm* +ConvBiasForwardImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { + Algorithm* ret = + static_cast(handle())->default_conv_bias_fwd_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + const char* ConvBiasForwardImpl::get_algorithm_set_name() const { return "DEFAULT"; } diff --git a/dnn/src/naive/conv_bias/opr_impl.h b/dnn/src/naive/conv_bias/opr_impl.h index 982a2a7d4..fedfa1e70 100644 --- a/dnn/src/naive/conv_bias/opr_impl.h +++ b/dnn/src/naive/conv_bias/opr_impl.h @@ -64,6 +64,8 @@ public: _megdnn_workspace) override {} const char* get_algorithm_set_name() const override; + + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; }; void handle_z_inp_and_activation_naive( diff --git a/dnn/src/naive/convolution/convolution.cpp b/dnn/src/naive/convolution/convolution.cpp index aa8ec34d5..1bbb821c9 100644 --- a/dnn/src/naive/convolution/convolution.cpp +++ b/dnn/src/naive/convolution/convolution.cpp @@ -285,6 +285,14 @@ ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic( return algo; } +ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { + Algorithm* ret = + static_cast(handle())->default_conv_fwd_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + std::vector ConvolutionBackwardDataImpl:: get_all_algorithms(const TensorLayout &, const TensorLayout &, const TensorLayout &) @@ -309,6 +317,15 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( return algo; } +ConvolutionBackwardData::Algorithm* +ConvolutionBackwardDataImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { + Algorithm* ret = + static_cast(handle())->default_conv_bwd_data_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + std::vector ConvolutionBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, const TensorLayout &, const TensorLayout &) @@ -333,6 +350,15 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( return algo; } +ConvolutionBackwardFilter::Algorithm* +ConvolutionBackwardFilterImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { + Algorithm* ret = + static_cast(handle())->default_conv_bwd_filter_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + const char* ConvolutionForwardImpl::get_algorithm_set_name() const { return "DEFAULT"; } diff --git a/dnn/src/naive/convolution/opr_impl.h b/dnn/src/naive/convolution/opr_impl.h index a7427f012..fa74dae82 100644 --- a/dnn/src/naive/convolution/opr_impl.h +++ b/dnn/src/naive/convolution/opr_impl.h @@ -52,6 +52,8 @@ class ConvolutionForwardImpl: public ConvolutionForward { return {}; } + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; + const char* get_algorithm_set_name() const override; }; @@ -74,6 +76,8 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { const TensorLayout&) override; const char* get_algorithm_set_name() const override; + + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; }; class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { @@ -95,6 +99,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { const TensorLayout&) override; const char* get_algorithm_set_name() const override; + + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; }; } // namespace naive diff --git a/dnn/src/naive/convolution3d/convolution3d.cpp b/dnn/src/naive/convolution3d/convolution3d.cpp index 81aca98c8..501b3e3bc 100644 --- a/dnn/src/naive/convolution3d/convolution3d.cpp +++ b/dnn/src/naive/convolution3d/convolution3d.cpp @@ -6,15 +6,15 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ -#include "./opr_impl.h" #include "./helper.h" +#include "./opr_impl.h" -#include "src/naive/handle.h" -#include "src/naive/handle.h" -#include "src/common/utils.h" #include "megdnn/dtype.h" +#include "src/common/utils.h" +#include "src/naive/handle.h" #include @@ -25,93 +25,95 @@ using namespace megdnn; using namespace naive; void Convolution3DForwardImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) -{ + _megdnn_tensor_in filter, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) { MIDOUT_BEGIN(megdnn_naive_conv3d_fwd) { - - auto filter_meta = check_exec( - src.layout, filter.layout, dst.layout, workspace.size); - switch (param().data_type) { - case Param::DataType::FLOAT: -#define cb(dt) do { \ - if (src.layout.dtype == dt()) { \ - using ctype = DTypeTrait
::ctype; \ - MEGDNN_DISPATCH_CPU_KERN(static_cast(handle()), \ - convolution3d::forward< \ - ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ - src, filter, dst, filter_meta); \ - ); \ - return; \ - } \ -} while(0); - MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); + auto filter_meta = check_exec(src.layout, filter.layout, dst.layout, + workspace.size); + switch (param().data_type) { + case Param::DataType::FLOAT: +#define cb(dt) \ + do { \ + if (src.layout.dtype == dt()) { \ + using ctype = DTypeTrait
::ctype; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(handle()), \ + convolution3d::forward< \ + ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ + src, filter, dst, filter_meta);); \ + return; \ + } \ + } while (0); + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); #undef cb - break; - case Param::DataType::FLOAT_IO16xC32: - MEGDNN_INC_FLOAT16( - MEGDNN_DISPATCH_CPU_KERN(static_cast(handle()), - convolution3d::forward< - dt_float16 MEGDNN_COMMA dt_float16 MEGDNN_COMMA dt_float32>( - src, filter, dst, filter_meta);)); - return; + break; + case Param::DataType::FLOAT_IO16xC32: + MEGDNN_INC_FLOAT16(MEGDNN_DISPATCH_CPU_KERN( + static_cast(handle()), + convolution3d::forward< + dt_float16 MEGDNN_COMMA dt_float16 MEGDNN_COMMA + dt_float32>(src, filter, dst, + filter_meta);)); + return; + } + megdnn_assert_internal(0); } - megdnn_assert_internal(0); - - } MIDOUT_END(); + MIDOUT_END(); } void Convolution3DBackwardDataImpl::exec(_megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) -{ - auto filter_meta = check_exec( - filter.layout, diff.layout, grad.layout, workspace.size); -#define cb(dt) do { \ - if (filter.layout.dtype == dt()) { \ - using ctype = DTypeTrait
::ctype; \ - MEGDNN_DISPATCH_CPU_KERN(static_cast(handle()), \ - convolution3d::backward_data< \ - ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ - filter, diff, grad, filter_meta);); \ - return; \ - } \ -} while(0); + _megdnn_tensor_in diff, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + auto filter_meta = + check_exec(filter.layout, diff.layout, grad.layout, workspace.size); +#define cb(dt) \ + do { \ + if (filter.layout.dtype == dt()) { \ + using ctype = DTypeTrait
::ctype; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(handle()), \ + convolution3d::backward_data< \ + ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ + filter, diff, grad, filter_meta);); \ + return; \ + } \ + } while (0); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); #undef cb megdnn_assert_internal(0); } void Convolution3DBackwardFilterImpl::exec(_megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) -{ - auto filter_meta = check_exec( - src.layout, diff.layout, grad.layout, workspace.size); -#define cb(dt) do { \ - if (src.layout.dtype == dt()) { \ - using ctype = DTypeTrait
::ctype; \ - MEGDNN_DISPATCH_CPU_KERN(static_cast(handle()), \ - convolution3d::backward_filter< \ - ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ - src, diff, grad, filter_meta);); \ - return; \ - } \ -} while(0); + _megdnn_tensor_in diff, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + auto filter_meta = + check_exec(src.layout, diff.layout, grad.layout, workspace.size); +#define cb(dt) \ + do { \ + if (src.layout.dtype == dt()) { \ + using ctype = DTypeTrait
::ctype; \ + MEGDNN_DISPATCH_CPU_KERN( \ + static_cast(handle()), \ + convolution3d::backward_filter< \ + ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \ + src, diff, grad, filter_meta);); \ + return; \ + } \ + } while (0); MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); #undef cb megdnn_assert_internal(0); } -std::vector -Convolution3DForwardImpl:: get_all_algorithms(const TensorLayout &, - const TensorLayout &, const TensorLayout &) -{ - return {static_cast(handle())->default_conv3d_fwd_algo()}; +std::vector +Convolution3DForwardImpl::get_all_algorithms(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle())->default_conv3d_fwd_algo()}; } Convolution3DForward::Algorithm* @@ -130,11 +132,20 @@ Convolution3DForwardImpl::get_algorithm_heuristic( return algo; } -std::vector -Convolution3DBackwardDataImpl:: get_all_algorithms(const TensorLayout &, - const TensorLayout &, const TensorLayout &) -{ - return {static_cast(handle())->default_conv3d_bwd_data_algo()}; +Convolution3DForward::Algorithm* +Convolution3DForwardImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { + Algorithm* ret = + static_cast(handle())->default_conv3d_fwd_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + +std::vector +Convolution3DBackwardDataImpl::get_all_algorithms(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle())->default_conv3d_bwd_data_algo()}; } Convolution3DBackwardData::Algorithm* @@ -154,11 +165,21 @@ Convolution3DBackwardDataImpl::get_algorithm_heuristic( return algo; } -std::vector -Convolution3DBackwardFilterImpl:: get_all_algorithms(const TensorLayout &, - const TensorLayout &, const TensorLayout &) -{ - return {static_cast(handle())->default_conv3d_bwd_filter_algo()}; +Convolution3DBackwardData::Algorithm* +Convolution3DBackwardDataImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { + Algorithm* ret = + static_cast(handle())->default_conv3d_bwd_data_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + +std::vector +Convolution3DBackwardFilterImpl::get_all_algorithms(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) { + return {static_cast(handle()) + ->default_conv3d_bwd_filter_algo()}; } Convolution3DBackwardFilter::Algorithm* @@ -179,6 +200,15 @@ Convolution3DBackwardFilterImpl::get_algorithm_heuristic( return algo; } +Convolution3DBackwardFilter::Algorithm* +Convolution3DBackwardFilterImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { + Algorithm* ret = static_cast(handle()) + ->default_conv3d_bwd_filter_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + const char* Convolution3DForwardImpl::get_algorithm_set_name() const { return "DEFAULT"; } diff --git a/dnn/src/naive/convolution3d/opr_impl.h b/dnn/src/naive/convolution3d/opr_impl.h index da1830dce..992d8f5ad 100644 --- a/dnn/src/naive/convolution3d/opr_impl.h +++ b/dnn/src/naive/convolution3d/opr_impl.h @@ -6,81 +6,79 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" namespace megdnn { namespace naive { -class Convolution3DForwardImpl: public Convolution3DForward { - public: - using Convolution3DForward::Convolution3DForward; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in filter, - _megdnn_tensor_out dst, - _megdnn_workspace workspace) override; - std::vector get_all_algorithms(const TensorLayout &src, - const TensorLayout &filter, - const TensorLayout &dst) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const TensorLayout& filter, - const TensorLayout& dst, - size_t workspace_limit_in_bytes, - bool reproducible) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&) override { - return 0; - } - const char* get_algorithm_set_name() const override; +class Convolution3DForwardImpl : public Convolution3DForward { +public: + using Convolution3DForward::Convolution3DForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst, + size_t workspace_limit_in_bytes, + bool reproducible) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } + + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; + const char* get_algorithm_set_name() const override; }; -class Convolution3DBackwardDataImpl: public Convolution3DBackwardData { - public: - using Convolution3DBackwardData::Convolution3DBackwardData; - void exec(_megdnn_tensor_in filter, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - std::vector get_all_algorithms(const TensorLayout &filter, - const TensorLayout &diff, - const TensorLayout &grad) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& filter, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_limit_in_bytes, - bool reproducible) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&) override { - return 0; - } +class Convolution3DBackwardDataImpl : public Convolution3DBackwardData { +public: + using Convolution3DBackwardData::Convolution3DBackwardData; + void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + std::vector get_all_algorithms( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } - const char* get_algorithm_set_name() const override; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; + const char* get_algorithm_set_name() const override; }; -class Convolution3DBackwardFilterImpl: public Convolution3DBackwardFilter { - public: - using Convolution3DBackwardFilter::Convolution3DBackwardFilter; - void exec(_megdnn_tensor_in src, - _megdnn_tensor_in diff, - _megdnn_tensor_out grad, - _megdnn_workspace workspace) override; - std::vector get_all_algorithms(const TensorLayout &src, - const TensorLayout &diff, - const TensorLayout &grad) override; - Algorithm* get_algorithm_heuristic(const TensorLayout& src, - const TensorLayout& diff, - const TensorLayout& grad, - size_t workspace_limit_in_bytes, - bool reproducible) override; - size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, - const TensorLayout&) override { - return 0; - } - const char* get_algorithm_set_name() const override; +class Convolution3DBackwardFilterImpl : public Convolution3DBackwardFilter { +public: + using Convolution3DBackwardFilter::Convolution3DBackwardFilter; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; + const char* get_algorithm_set_name() const override; }; - -} // namespace naive -} // namespace megdnn -// vim: syntax=cpp.doxygen +} // namespace naive +} // namespace megdnn + // vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/deformable_conv/opr_impl.h b/dnn/src/naive/deformable_conv/opr_impl.h index 77d3786a9..b15a8e231 100644 --- a/dnn/src/naive/deformable_conv/opr_impl.h +++ b/dnn/src/naive/deformable_conv/opr_impl.h @@ -48,6 +48,10 @@ public: return "DEFORMABLE_CONV2_NAIVE"; }; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { + return {}; + } + void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in offset, _megdnn_tensor_in mask, _megdnn_tensor_out dst, _megdnn_workspace workspace) override; @@ -84,6 +88,10 @@ public: return "DEFORMABLE_CONV2_BWD_FILTER_NAIVE"; }; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { + return {}; + } + void exec(_megdnn_tensor_in im, _megdnn_tensor_in offset, _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, _megdnn_tensor_out filter_grad, @@ -130,6 +138,10 @@ public: return "DEFORMABLE_CONV2_BWD_DATA_NAIVE"; }; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override { + return {}; + } + void exec(_megdnn_tensor_in im, _megdnn_tensor_in filter, _megdnn_tensor_in offset, _megdnn_tensor_in mask, _megdnn_tensor_in out_grad, _megdnn_tensor_out im_grad, diff --git a/dnn/src/naive/local_share/opr_impl.cpp b/dnn/src/naive/local_share/opr_impl.cpp index ba3413dc9..409fc468e 100644 --- a/dnn/src/naive/local_share/opr_impl.cpp +++ b/dnn/src/naive/local_share/opr_impl.cpp @@ -175,6 +175,15 @@ LocalShareForward::Algorithm* LocalShareForwardImpl::get_algorithm_heuristic( return algo; } +LocalShareForward::Algorithm* +LocalShareForwardImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { + Algorithm* ret = + static_cast(handle())->default_local_share_fwd_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + std::vector LocalShareBackwardDataImpl::get_all_algorithms(const TensorLayout&, const TensorLayout&, @@ -200,6 +209,15 @@ LocalShareBackwardDataImpl::get_algorithm_heuristic( return algo; } +LocalShareBackwardData::Algorithm* +LocalShareBackwardDataImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { + Algorithm* ret = static_cast(handle()) + ->default_local_share_bwd_data_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + std::vector LocalShareBackwardFilterImpl::get_all_algorithms(const TensorLayout&, const TensorLayout&, @@ -225,4 +243,13 @@ LocalShareBackwardFilterImpl::get_algorithm_heuristic( return algo; } +LocalShareBackwardFilter::Algorithm* +LocalShareBackwardFilterImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { + Algorithm* ret = static_cast(handle()) + ->default_local_share_bwd_filter_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/local_share/opr_impl.h b/dnn/src/naive/local_share/opr_impl.h index 591ee9be5..42ba1d26e 100644 --- a/dnn/src/naive/local_share/opr_impl.h +++ b/dnn/src/naive/local_share/opr_impl.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once #include "megdnn/oprs.h" @@ -35,6 +36,7 @@ public: size_t /*workspace_limit_in_bytes*/, bool /*reproducible*/) override; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; const char* get_algorithm_set_name() const override { return "DEFAULT"; } }; @@ -59,6 +61,7 @@ public: size_t /*workspace_limit_in_bytes*/, bool /*reproducible*/) override; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; const char* get_algorithm_set_name() const override { return "DEFAULT"; } }; @@ -83,6 +86,7 @@ public: size_t /*workspace_limit_in_bytes*/, bool /*reproducible*/) override; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; const char* get_algorithm_set_name() const override { return "DEFAULT"; } }; diff --git a/dnn/src/naive/matrix_mul/opr_impl.cpp b/dnn/src/naive/matrix_mul/opr_impl.cpp index 0b01c654f..52b2c61eb 100644 --- a/dnn/src/naive/matrix_mul/opr_impl.cpp +++ b/dnn/src/naive/matrix_mul/opr_impl.cpp @@ -95,6 +95,14 @@ MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_heuristic( return static_cast(handle())->default_matmul_fwd_algo(); } +MatrixMulForward::Algorithm* MatrixMulForwardImpl::get_algorithm_from_desc( + const AlgorithmDesc& desc) { + Algorithm* ret = + static_cast(handle())->default_matmul_fwd_algo(); + megdnn_assert(desc == ret->info().desc); + return ret; +} + } // namespace naive } // namespace megdnn diff --git a/dnn/src/naive/matrix_mul/opr_impl.h b/dnn/src/naive/matrix_mul/opr_impl.h index 6c0416e21..ae9748edd 100644 --- a/dnn/src/naive/matrix_mul/opr_impl.h +++ b/dnn/src/naive/matrix_mul/opr_impl.h @@ -35,6 +35,8 @@ public: size_t /*workspace_limit_in_bytes*/, bool /* reproducible */) override; + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; + const char* get_algorithm_set_name() const override { return "DEFAULT"; } private: diff --git a/dnn/src/rocm/batched_matrix_mul/opr_impl.h b/dnn/src/rocm/batched_matrix_mul/opr_impl.h index 4311f3d67..9fffad07b 100644 --- a/dnn/src/rocm/batched_matrix_mul/opr_impl.h +++ b/dnn/src/rocm/batched_matrix_mul/opr_impl.h @@ -29,8 +29,8 @@ public: class AlgoBlas; class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); private: std::vector get_all_algorithms( const TensorLayout& /*A*/, const TensorLayout& /*B*/, diff --git a/dnn/src/rocm/convolution/opr_impl.h b/dnn/src/rocm/convolution/opr_impl.h index d672a086d..4ddf1bef9 100644 --- a/dnn/src/rocm/convolution/opr_impl.h +++ b/dnn/src/rocm/convolution/opr_impl.h @@ -66,7 +66,7 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; private: std::vector get_all_algorithms( @@ -112,7 +112,7 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; private: std::vector get_all_algorithms( @@ -158,7 +158,7 @@ public: class AlgoPack; - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; static const AlgoPack& algo_pack() { return sm_algo_pack; } private: diff --git a/dnn/src/rocm/matrix_mul/opr_impl.h b/dnn/src/rocm/matrix_mul/opr_impl.h index 52bb3e90c..fe6502b24 100644 --- a/dnn/src/rocm/matrix_mul/opr_impl.h +++ b/dnn/src/rocm/matrix_mul/opr_impl.h @@ -29,7 +29,7 @@ public: class AlgoPack; static const AlgoPack& algo_pack() { return sm_algo_pack; } - static AlgoBase* get_algo_from_desc(const AlgorithmDesc& desc); + Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) override; private: std::vector get_all_algorithms( @@ -41,6 +41,7 @@ private: const TensorLayout& /*C*/, size_t /*workspace_limit_in_bytes*/, bool /*reproducible*/) override; + const char* get_algorithm_set_name() const override { return "ROCM MATMUL"; } diff --git a/src/opr/test/dnn/convolution.cpp b/src/opr/test/dnn/convolution.cpp index d7daaf44b..9d8ad5cd2 100644 --- a/src/opr/test/dnn/convolution.cpp +++ b/src/opr/test/dnn/convolution.cpp @@ -2204,6 +2204,10 @@ public: const TensorLayout& p2, size_t workspace_limit_in_bytes, bool reproducible)); + + MOCK_METHOD1(get_algorithm_from_desc, + Algorithm*(const AlgorithmDesc&)); + protected: const char* get_algorithm_set_name() const override { return m_algorithm_set_name; -- GitLab