megray_helper.cpp 1.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/**
 * \file src/opr-mm/impl/megray_helper.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/opr/megray_helper.h"

using namespace mgb;
using namespace opr;

17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
bool MegRayCommunicatorBuilder::find(uint64_t hash, std::shared_ptr<MegRay::Communicator>& comm) {
    std::unique_lock<std::mutex> lk(m_mtx);
    auto it = m_megray_comms.find(hash);
    if (it != m_megray_comms.end()) {
        comm = it->second;
        return true;
    }
    return false;
}

void MegRayCommunicatorBuilder::emplace(uint64_t hash,
        std::shared_ptr<MegRay::Communicator> comm) {
    std::unique_lock<std::mutex> lk(m_mtx);
    m_megray_comms.emplace(hash, comm);
}

33 34 35 36
std::shared_ptr<MegRay::Communicator> MegRayCommunicatorBuilder::get_megray_comm(
        uint64_t hash, std::string key, uint32_t size, uint32_t rank,
        MegRay::Backend backend,
        std::shared_ptr<mgb::opr::GroupClient> group_client) {
37 38 39
    std::shared_ptr<MegRay::Communicator> comm;
    if (!find(hash, comm)) {
        comm = MegRay::get_communicator(size, rank, backend);
40 41
        auto uid = comm->get_uid();
        auto uids = group_client->gather_uid(uid, key, size, rank);
42 43
        mgb_assert(comm->init(uids) == MegRay::Status::MEGRAY_OK);
        emplace(hash, comm);
44
    }
45
    return comm;
46 47 48 49 50
}

MGB_TYPEINFO_OBJ_IMPL(MegRayCommunicatorBuilder);

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}