communicator.cpp 4.8 KB
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
/**
 * \file src/nccl/communicator.cpp
 * MegRay is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "communicator.h"

#include <string.h>

#include "utils.h"

namespace MegRay {

NcclCommunicator::NcclCommunicator(int nranks, int rank) :
        Communicator(nranks, rank), m_inited(false) {
    NCCL_ASSERT(ncclGetUniqueId(&m_uid));
}

NcclCommunicator::~NcclCommunicator() {
    if (m_inited) {
        ncclCommDestroy(m_comm);
    }
}

std::string NcclCommunicator::get_uid() {
    // serialize ncclUniqueId into a string
    return std::string(m_uid.internal, NCCL_UNIQUE_ID_BYTES);
}

Status NcclCommunicator::init(const std::vector<std::string>& uids) {
    MEGRAY_ASSERT(uids.size() == m_nranks, "incorrect size of uids");
    // only use unique id of rank 0 for initialization
    const std::string uid = uids[0];
    MEGRAY_ASSERT(uid.size() == NCCL_UNIQUE_ID_BYTES, "invalid uid");
    memcpy(m_uid.internal, uid.data(), NCCL_UNIQUE_ID_BYTES);
    // initialize nccl communicator
    NCCL_CHECK(ncclCommInitRank(&m_comm, m_nranks, m_uid, m_rank));
    m_inited = true;
    return MEGRAY_OK;
}

Status NcclCommunicator::send(const void* sendbuff, size_t len, uint32_t rank,
        std::shared_ptr<Context> ctx) {
    // derived from base class, not implemented
    MEGRAY_THROW("not implemented");
    return MEGRAY_NOT_IMPLEMENTED;
}

Status NcclCommunicator::recv(void* recvbuf, size_t len, uint32_t rank,
        std::shared_ptr<Context> ctx) {
    // derived from base class, not implemented
    MEGRAY_THROW("not implemented");
    return MEGRAY_NOT_IMPLEMENTED;
}

Status NcclCommunicator::all_gather(const void* sendbuff, void* recvbuff, size_t sendlen,
        DType dtype, std::shared_ptr<Context> ctx) {
    // check context type and get cuda stream
    MEGRAY_ASSERT(ctx->type() == MEGRAY_CTX_CUDA, "only cuda context supported");
    cudaStream_t stream = static_cast<CudaContext*>(ctx.get())->get_stream();
    // perform all gather synchronously
    NCCL_CHECK(ncclAllGather(sendbuff, recvbuff, sendlen, get_nccl_dtype(dtype),
            m_comm, stream));
    CUDA_CHECK(cudaStreamSynchronize(stream));
    return MEGRAY_OK;
}

Status NcclCommunicator::all_reduce(const void* sendbuff, void* recvbuff, size_t len,
        DType dtype, ReduceOp op, std::shared_ptr<Context> ctx) {
    // check context type and get cuda stream
    MEGRAY_ASSERT(ctx->type() == MEGRAY_CTX_CUDA, "only cuda context supported");
    cudaStream_t stream = static_cast<CudaContext*>(ctx.get())->get_stream();
    // perform all reduce synchronously
    NCCL_CHECK(ncclAllReduce(sendbuff, recvbuff, len, get_nccl_dtype(dtype),
            get_nccl_reduce_op(op), m_comm, stream));
    CUDA_CHECK(cudaStreamSynchronize(stream));
    return MEGRAY_OK;
}

Status NcclCommunicator::reduce_scatter(const void* sendbuff, void* recvbuff, size_t recvlen,
        DType dtype, ReduceOp op, std::shared_ptr<Context> ctx) {
    // check context type and get cuda stream
    MEGRAY_ASSERT(ctx->type() == MEGRAY_CTX_CUDA, "only cuda context supported");
    cudaStream_t stream = static_cast<CudaContext*>(ctx.get())->get_stream();
    // perform reduce scatter synchronously
    NCCL_CHECK(ncclReduceScatter(sendbuff, recvbuff, recvlen, get_nccl_dtype(dtype),
            get_nccl_reduce_op(op), m_comm, stream));
    CUDA_CHECK(cudaStreamSynchronize(stream));
    return MEGRAY_OK;
}

Status NcclCommunicator::broadcast(const void* sendbuff, void* recvbuff, size_t len,
        DType dtype, uint32_t root, std::shared_ptr<Context> ctx) {
    // check context type and get cuda stream
    MEGRAY_ASSERT(ctx->type() == MEGRAY_CTX_CUDA, "only cuda context supported");
    cudaStream_t stream = static_cast<CudaContext*>(ctx.get())->get_stream();
    // perform broadcast synchronously
    NCCL_CHECK(ncclBroadcast(sendbuff, recvbuff, len, get_nccl_dtype(dtype), root,
            m_comm, stream));
    CUDA_CHECK(cudaStreamSynchronize(stream));
    return MEGRAY_OK;
}

Status NcclCommunicator::reduce(const void* sendbuff, void* recvbuff, size_t len,
        DType dtype, ReduceOp op, uint32_t root, std::shared_ptr<Context> ctx) {
    // check context type and get cuda stream
    MEGRAY_ASSERT(ctx->type() == MEGRAY_CTX_CUDA, "only cuda context supported");
    cudaStream_t stream = static_cast<CudaContext*>(ctx.get())->get_stream();
    // perform reduce synchronously
    NCCL_CHECK(ncclReduce(sendbuff, recvbuff, len, get_nccl_dtype(dtype),
            get_nccl_reduce_op(op), root, m_comm, stream));
    CUDA_CHECK(cudaStreamSynchronize(stream));
    return MEGRAY_OK;
}

} // namespace MegRay