#pragma once #include "megdnn/basic_types.h" #include "megdnn/oprs/base.h" #include #include #include namespace megdnn { class AlgorithmCache { private: AlgorithmCache() = default; public: MGE_WIN_DECLSPEC_FUC static AlgorithmCache& instance(); struct KeyStorage { size_t k1, k2; bool operator==(const KeyStorage& k) const { return k1 == k.k1 && k2 == k.k2; } }; struct Hash { size_t operator()(const KeyStorage& k) const { size_t h1 = k.k1; size_t h2 = k.k2; h1 ^= h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2); return h1; } }; struct Key { Handle* m_handle; uint32_t m_opr_type; const TensorLayout* m_inp_layouts_ptr; size_t m_inp_layouts_size; const void* m_param_ptr; size_t m_param_size; mutable SmallVector m_buf; public: Key(Handle* opr_handle, Algorithm::OprType opr_type, const TensorLayout* inp_layouts_ptr, size_t inp_layouts_size, const void* param_ptr = nullptr, size_t param_size = 0) : m_handle{opr_handle}, m_opr_type{static_cast(opr_type)}, m_inp_layouts_ptr{inp_layouts_ptr}, m_inp_layouts_size{inp_layouts_size}, m_param_ptr{param_ptr}, m_param_size{param_size} {} KeyStorage build_key_storage() const; }; struct Result { ExecutionPolicy policy; size_t workspace; // for cache collision SmallVector m_buf; SmallVector m_param_buf; }; MGE_WIN_DECLSPEC_FUC void put(const Key& key, Result& result); MGE_WIN_DECLSPEC_FUC Result get(const Key& key); MGE_WIN_DECLSPEC_FUC void clear(); private: std::unordered_map m_heuristic_cache; #if __DEPLOY_ON_XP_SP2__ size_t m_mtx; #else std::mutex m_mtx; #endif }; } // namespace megdnn