reduce_scatter.cpp 2.2 KB
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
/**
 * \file src/ucx/reduce_scatter.cpp
 * MegRay is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "communicator.h"

#include "utils.h"

namespace MegRay {

Status UcxCommunicator::reduce_scatter(const void* sendbuff, void* recvbuff, size_t recvlen,
        DType dtype, ReduceOp op, std::shared_ptr<Context> ctx) {
    // get cuda stream
    MEGRAY_ASSERT(ctx->type() == MEGRAY_CTX_CUDA, "only cuda context supported");
    cudaStream_t stream = static_cast<CudaContext*>(ctx.get())->get_stream();
    CUDA_CHECK(cudaStreamSynchronize(stream));

    // allocate lbuffer and rbuffer
    size_t size = get_dtype_size(dtype);
    char* lbuffer;
    char* rbuffer;
    CUDA_CHECK(cudaMalloc(&lbuffer, recvlen * size));
    CUDA_CHECK(cudaMalloc(&rbuffer, recvlen * m_nranks * size));

    CUDA_CHECK(cudaMemcpy(rbuffer, sendbuff, recvlen * m_nranks * size, cudaMemcpyDeviceToDevice));

    // pass and add (i-1)-th part from i-th node to i+1 to i+2 finally to (i-1)-th node
    // at last i-th node has the sum of i-th part data
    size_t lrank = ring_sub(m_rank, 1, m_nranks);
    size_t rrank = ring_add(m_rank, 1, m_nranks);
    for (size_t i = 1; i < m_nranks; ++i) {
        size_t send_offset = recvlen * size * ring_sub(m_rank, i, m_nranks);
        size_t recv_offset =
                recvlen * size * ring_sub(m_rank, i + 1, m_nranks);
        MEGRAY_CHECK(_send(rbuffer + send_offset, recvlen * size, rrank));
        MEGRAY_CHECK(_recv(lbuffer, recvlen * size, lrank));
        MEGRAY_CHECK(_flush());
        _reduce(lbuffer, rbuffer + recv_offset, rbuffer + recv_offset,
                recvlen, dtype, op, stream);
        CUDA_CHECK(cudaStreamSynchronize(stream));
    }
    size_t offset = recvlen * size * m_rank;
    CUDA_CHECK(cudaMemcpy(recvbuff, rbuffer + offset, recvlen * size, cudaMemcpyDeviceToDevice));
    CUDA_CHECK(cudaFree(lbuffer));
    CUDA_CHECK(cudaFree(rbuffer));

    return MEGRAY_OK;
}

} // namespace MegRay