communicator.h 3.1 KB
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
/**
 * \file src/ucx/communicator.h
 * MegRay is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include <mutex>
#include <vector>

#include <ucp/api/ucp.h>

#include "../communicator.h"

namespace MegRay {

/*!
 * simple implementation of collective communications using ucp api
 * a ucx communicator corresponds to a ucp worker
 */
class UcxCommunicator : public Communicator {
    public:
        UcxCommunicator(int nranks, int rank);

        ~UcxCommunicator();

        // get a serialized string of ucp worker address
        std::string get_uid() override;

        Status init(const std::vector<std::string>& uids) override;

        Status send(const void* sendbuff, size_t len, uint32_t rank,
                std::shared_ptr<Context> ctx) override;

        Status recv(void* recvbuff, size_t len, uint32_t rank,
                std::shared_ptr<Context> ctx) override;

44 45 46 47 48 49 50 51 52
        Status scatter(const void* sendbuff, void* recvbuff, size_t recvlen,
                DType dtype, uint32_t root, std::shared_ptr<Context> ctx) override;

        Status gather(const void* sendbuff, void* recvbuff, size_t sendlen,
                DType dtype, uint32_t root, std::shared_ptr<Context> ctx) override;

        Status all_to_all(const void* sendbuff, void* recvbuff, size_t len,
                DType dtype, std::shared_ptr<Context> ctx) override;

M
Megvii Engine Team 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
        Status all_gather(const void* sendbuff, void* recvbuff, size_t sendlen,
                DType dtype, std::shared_ptr<Context> ctx) override;

        Status all_reduce(const void* sendbuff, void* recvbuff, size_t len,
                DType dtype, ReduceOp op, std::shared_ptr<Context> ctx) override;

        Status reduce_scatter(const void* sendbuff, void* recvbuff, size_t recvlen,
                DType dtype, ReduceOp op, std::shared_ptr<Context> ctx) override;

        Status broadcast(const void* sendbuff, void* recvbuff, size_t len,
                DType dtype, uint32_t root, std::shared_ptr<Context> ctx) override;

        Status reduce(const void* sendbuff, void* recvbuff, size_t len,
                DType dtype, ReduceOp op, uint32_t root, std::shared_ptr<Context> ctx) override;

    private:
        // internal non-blocking send method
        Status _send(const void* sendbuff, size_t len, uint32_t rank);

        // internal non-blocking receive method
        Status _recv(void* recvbuff, size_t len, uint32_t rank);

        // flush _send and _recv requests
        Status _flush();

        // launch cuda kernel for reduce operations
        void _reduce(void* i0, void* i1, void* o, size_t len, DType dtype,
                ReduceOp op, cudaStream_t stream);

        ucp_context_h m_context;
        ucp_worker_h m_worker;
        bool m_inited;
        std::vector<ucp_ep_h> m_eps;  // ucp endpoints
        std::vector<void*> m_requests;
        std::mutex m_requests_mtx;
};

} // namespace MegRay