#pragma once #include #include #include "megbrain/comp_node.h" #include "megbrain/opr/group_manager.h" #include "megray.h" namespace mgb { namespace opr { MegRay::DType get_megray_dtype(megdnn::DType); MegRay::Backend get_megray_backend(const std::string& backend); std::shared_ptr get_megray_context(CompNode comp_node); /*! * gather MegRay unique ids and build communicator, use hash for deduplication */ class MegRayCommBuilder { private: bool find(uint64_t hash, std::shared_ptr& comm); void emplace(uint64_t hash, std::shared_ptr comm); void remove(uint64_t hash, std::shared_ptr comm); std::unordered_map> m_megray_comms; std::mutex m_map_mtx; static MegRayCommBuilder* sm_instance; static std::mutex sm_instance_mtx; public: static std::shared_ptr get_megray_comm( uint64_t hash, std::string key, uint32_t size, uint32_t rank, MegRay::Backend backend, std::shared_ptr group_client); }; } // namespace opr } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}