communicator.cpp 1001 字节
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
/**
 * \file src/communicator.cpp
 * MegRay is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "communicator.h"
#include "nccl/communicator.h"
#include "ucx/communicator.h"

namespace MegRay {

std::shared_ptr<Communicator> get_communicator(uint32_t nranks, uint32_t rank, Backend backend) {
    std::shared_ptr<Communicator> comm;
    switch (backend) {
        case MEGRAY_NCCL:
            comm = std::make_shared<NcclCommunicator>(nranks, rank);
            break;
        case MEGRAY_UCX:
            comm = std::make_shared<UcxCommunicator>(nranks, rank);
            break;
        default:
            MEGRAY_THROW("unknown backend");
    }
    return comm;
}

} // namespace MegRay