#pragma once #include #include "megbrain/tensor.h" namespace mgb { namespace opr { /*! * GroupInfo: stream and shape information from all ranks of a group */ class GroupInfo { public: struct OprInfo { uint64_t comp_node_hash; bool is_root; int rank; }; void add_opr( const std::string& key, size_t nr_expected_devices, bool is_root, int rank, uint64_t comp_node_hash); void set_output_shape(const std::string& key, const TensorShape& shape); TensorShape get_output_shape(const std::string& key); void clear(); const std::vector& opr_infos() const { return m_opr_infos; } int get_root_rank() const { return m_root_rank; } int get_rank(uint64_t hash) const { return m_rank_map.at(hash); } uint64_t get_group_hash() const { return m_hash; } private: void sort_opr_infos(); void gen_infos_from_opr_infos(); std::vector m_opr_infos; std::unordered_map m_rank_map; uint64_t m_hash; uint32_t m_nr_registered_devs; uint32_t m_nr_expected_devs; Maybe m_output_shape; uint32_t m_count = 0; int m_root_rank = -1; std::mutex m_group_mtx; std::condition_variable m_register_cv; std::condition_variable m_clear_cv; std::mutex m_output_shape_mtx; std::condition_variable m_output_shape_cv; }; /*! * GroupManager: build groups and exchange meta information */ class GroupManager { public: ~GroupManager() = default; struct RegisterInfo { uint64_t hash; int rank, root_rank; }; //! register oprs' info to server, return deduplicated hash RegisterInfo opr_register( const std::string& key, size_t nr_devices, bool is_root, int rank, uint64_t comp_node_hash); //! broadcast master_ip and port void bcast_addr( std::string& master_ip, int& port, const std::string& key, uint32_t size, uint32_t rank, uint32_t root); //! bcast uid void bcast_nccluniqueid( const std::string& key, std::string& id, uint32_t size, uint32_t rank, uint32_t root); //! Set output shape of this key void set_output_shape(const std::string& key, const TensorShape& shape); //! Get output shape of this key, blocks until output shape is set TensorShape get_output_shape(const std::string& key); //! Block clients until all ranks reach this barrier uint32_t group_barrier(uint32_t size, uint32_t rank); private: GroupInfo& get_group(const std::string& key); //! key -> group info. std::unordered_map m_key2group_info; std::mutex m_key2group_info_mtx; //! key -> addr std::unordered_map m_key2master_ip; std::unordered_map m_key2port; std::unordered_map m_key2addr_size; std::unordered_map m_key2addr_flag; std::mutex m_key2addr_mtx; std::condition_variable m_bcast_cv; //! key -> ncclid std::unordered_map m_key2nccl_id; std::unordered_map m_key2nccl_id_size; std::unordered_map m_key2nccl_id_flag; std::mutex m_key2nccl_id_mtx; //! barrier uint32_t m_barrier_size; std::set m_barrier_set; std::mutex m_barrier_mtx; std::condition_variable m_barrier_cv; }; /*! * Client interface to interact with GroupManager. * All the methods below should be overrided by subclasses * Test cases mock the interface to directly interact with GroupManager */ class GroupClient { protected: virtual ~GroupClient() = default; public: virtual const std::string& get_addr() const = 0; virtual GroupManager::RegisterInfo opr_register( const std::string& key, size_t nr_devices, bool is_root, int rank, uint64_t comp_node_hash) = 0; virtual void bcast_addr( std::string& master_ip, int& port, const std::string& key, uint32_t size, uint32_t rank, uint32_t root) = 0; virtual void bcast_nccluniqueid( const std::string& key, std::string& id, uint32_t size, uint32_t rank, uint32_t root) = 0; virtual void set_output_shape(const std::string& key, const TensorShape& shape) = 0; virtual TensorShape get_output_shape(const std::string& key) = 0; virtual uint32_t group_barrier(uint32_t size, uint32_t rank) = 0; }; /*! * Cache RegisterInfo returned from GroupManager. This feature is only enabled * in imperative runtime mode, so that multi-machine operators do not have to * call opr_register repeatedly in each iter */ namespace RegInfoCache { static std::mutex mtx; static std::unordered_map key2info; void set_info(const std::string& key, const GroupManager::RegisterInfo& info); bool has_info(const std::string& key); GroupManager::RegisterInfo get_info(const std::string& key); } // namespace RegInfoCache } // namespace opr } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}