diff --git a/src/client.cpp b/src/client.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5bf567570a21cc3ea529eff5f6f7a46df6c113f5 --- /dev/null +++ b/src/client.cpp @@ -0,0 +1,134 @@ +/** + * \file src/client.cpp + * MegRay is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "client.h" + +#include +#include +#include +#include + +namespace MegRay { + +Client::Client(uint32_t nranks, uint32_t rank) : + m_nranks(nranks), m_rank(rank), m_connected(false) { +} + +Client::~Client() { +} + +Status Client::connect(const char* master_ip, int port) { + std::unique_lock lock(m_mutex); + + if (m_connected) { + MEGRAY_ERROR("Client already connected"); + return MEGRAY_INVALID_USAGE; + } + + // create socket + SYS_CHECK_RET(socket(AF_INET, SOCK_STREAM, 0), -1, m_conn); + + // set server_addr + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(port); + SYS_CHECK(inet_pton(AF_INET, master_ip, &server_addr.sin_addr), -1); + + // connect + SYS_CHECK(::connect(m_conn, (struct sockaddr*)&server_addr, sizeof(server_addr)), -1); + + // send client rank + SYS_CHECK(send(m_conn, &m_rank, sizeof(uint32_t), 0), -1); + + // recv ack from server + uint32_t ack; + SYS_CHECK(recv(m_conn, &ack, sizeof(uint32_t), MSG_WAITALL), -1); + + m_connected = true; + return MEGRAY_OK; +} + +Status Client::barrier() { + std::unique_lock lock(m_mutex); + + if (!m_connected) { + MEGRAY_ERROR("Client not connected"); + return MEGRAY_INVALID_USAGE; + } + + // send request_id + uint32_t request_id = 1; + SYS_CHECK(send(m_conn, &request_id, sizeof(uint32_t), 0), -1); + + // recv ack + uint32_t ack; + SYS_CHECK(recv(m_conn, &ack, sizeof(uint32_t), MSG_WAITALL), -1); + + return MEGRAY_OK; +} + +Status Client::broadcast(const void* sendbuff, void* recvbuff, size_t len, uint32_t root) { + std::unique_lock lock(m_mutex); + + if (!m_connected) { + MEGRAY_ERROR("Client not connected"); + return MEGRAY_INVALID_USAGE; + } + + // send request_id + uint32_t request_id = 2; + SYS_CHECK(send(m_conn, &request_id, sizeof(uint32_t), 0), -1); + + // send root + SYS_CHECK(send(m_conn, &root, sizeof(uint32_t), 0), -1); + + // send len + uint64_t len64 = len; + SYS_CHECK(send(m_conn, &len64, sizeof(uint64_t), 0), -1); + + // send data + if (m_rank == root) { + SYS_CHECK(send(m_conn, sendbuff, len, 0), -1); + } + + // recv data + SYS_CHECK(recv(m_conn, recvbuff, len, MSG_WAITALL), -1); + + return MEGRAY_OK; +} + +Status Client::allgather(const void* sendbuff, void* recvbuff, size_t sendlen) { + std::unique_lock lock(m_mutex); + + if (!m_connected) { + MEGRAY_ERROR("Client not connected"); + return MEGRAY_INVALID_USAGE; + } + + // send request_id + uint32_t request_id = 3; + SYS_CHECK(send(m_conn, &request_id, sizeof(uint32_t), 0), -1); + + // send sendlen + uint64_t sendlen64 = sendlen; + SYS_CHECK(send(m_conn, &sendlen64, sizeof(uint64_t), 0), -1); + + // send data + SYS_CHECK(send(m_conn, sendbuff, sendlen, 0), -1); + + // recv data + SYS_CHECK(recv(m_conn, recvbuff, sendlen * m_nranks, MSG_WAITALL), -1); + + return MEGRAY_OK; +} + +} // namespace MegRay diff --git a/src/client.h b/src/client.h new file mode 100644 index 0000000000000000000000000000000000000000..af09af916c2f8bcd0d3b1690bd2389b12779791e --- /dev/null +++ b/src/client.h @@ -0,0 +1,49 @@ +/** + * \file src/client.h + * MegRay is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include + +#include "common.h" + +namespace MegRay { + +/*! + * synchronize meta information with megray server + */ +class Client { + public: + Client(uint32_t nranks, uint32_t rank); + + ~Client(); + + Status connect(const char* master_ip, int port); + + // block until all ranks reach this barrier + Status barrier(); + + // the length of sendbuff = the length of recvbuff = len + Status broadcast(const void* sendbuff, void* recvbuff, size_t sendlen, uint32_t root); + + // the length of sendbuff = sendlen + // the length of recvbuff = sendlen * m_nranks + Status allgather(const void* sendbuff, void* recvbuff, size_t sendlen); + + private: + uint32_t m_nranks; + uint32_t m_rank; + bool m_connected = false; + int m_conn; + std::mutex m_mutex; +}; + +} // namespace MegRay diff --git a/src/common.h b/src/common.h index dfed736a4a4f0f2fd97426e83c70f31486a0539d..63592ad019dc35bcb0f127850b72d3ca81e75fa3 100644 --- a/src/common.h +++ b/src/common.h @@ -11,6 +11,8 @@ #pragma once +#include + #include "cuda_runtime.h" #include "debug.h" @@ -19,12 +21,15 @@ namespace MegRay { typedef enum { MEGRAY_OK = 0, - MEGRAY_CUDA_ERR = 1, - MEGRAY_NCCL_ERR = 2, - MEGRAY_UCX_ERR = 3, - MEGRAY_ENV_ERROR = 4, - MEGRAY_INVALID_ARGUMENT = 5, - MEGRAY_NOT_IMPLEMENTED = 6 + MEGRAY_SYS_ERROR = 1, + MEGRAY_CUDA_ERR = 2, + MEGRAY_NCCL_ERR = 3, + MEGRAY_UCX_ERR = 4, + MEGRAY_ENV_ERROR = 5, + MEGRAY_INVALID_ARGUMENT = 6, + MEGRAY_INVALID_USAGE = 7, + MEGRAY_UNEXPECTED_ERR = 8, + MEGRAY_NOT_IMPLEMENTED = 9 } Status; #define MEGRAY_CHECK(expr) \ @@ -36,6 +41,38 @@ typedef enum { } \ } while (0) +#define SYS_CHECK_RET(expr, errval, retval) \ + do { \ + retval = (expr); \ + if (retval == errval) { \ + MEGRAY_ERROR("system error [%d]: %s", \ + errno, strerror(errno)); \ + return MEGRAY_SYS_ERROR; \ + } \ + } while (0) + +#define SYS_CHECK(expr, errval) \ + do { \ + int retval; \ + SYS_CHECK_RET(expr, errval, retval); \ + } while (0) + +#define SYS_ASSERT_RET(expr, errval, retval) \ + do { \ + retval = (expr); \ + if (retval == errval) { \ + MEGRAY_ERROR("system error [%d]: %s", \ + errno, strerror(errno)); \ + MEGRAY_THROW("system error"); \ + } \ + } while (0) + +#define SYS_ASSERT(expr, errval) \ + do { \ + int retval; \ + SYS_ASSERT_RET(expr, errval, retval); \ + } while (0) + #define CUDA_CHECK(expr) \ do { \ cudaError_t status = (expr); \ @@ -58,7 +95,7 @@ typedef enum { typedef enum { MEGRAY_NCCL = 0, - MEGRAY_UCX = 1, + MEGRAY_UCX = 1 } Backend; typedef enum { diff --git a/src/communicator.cpp b/src/communicator.cpp index ff287a9968a15ed16eb2348688a01bc91d68bd50..97da7200ee07b27ae494d9340f3c49b98b7a91fb 100644 --- a/src/communicator.cpp +++ b/src/communicator.cpp @@ -15,6 +15,12 @@ namespace MegRay { +Status Communicator::init(const char* master_ip, int port) { + m_client = std::make_shared(m_nranks, m_rank); + MEGRAY_CHECK(m_client->connect(master_ip, port)); + return do_init(); +} + std::shared_ptr get_communicator(uint32_t nranks, uint32_t rank, Backend backend) { std::shared_ptr comm; switch (backend) { diff --git a/src/communicator.h b/src/communicator.h index 8d23c1cb1cad9ad7ccc6ca90a8b666b2f44a1475..fc19b42e114ff2512cfe67526eaf369f401b5629 100644 --- a/src/communicator.h +++ b/src/communicator.h @@ -17,6 +17,7 @@ #include "common.h" #include "context.h" +#include "client.h" namespace MegRay { @@ -37,11 +38,11 @@ class Communicator { // get the rank of this process uint32_t rank() { return m_rank; } - // get the unique id of this communicator - virtual std::string get_uid() = 0; + // establish connection with megray server + Status init(const char* master_ip, int port); - // build a group with unique ids of all communicators in the group - virtual Status init(const std::vector& uids) = 0; + // implemented in the subclass and called in init() + virtual Status do_init() = 0; // send data to another communicator in the group virtual Status send(const void* sendbuff, size_t len, uint32_t rank, @@ -90,6 +91,7 @@ class Communicator { protected: uint32_t m_nranks; uint32_t m_rank; + std::shared_ptr m_client; }; /*! diff --git a/src/megray.h b/src/megray.h index cf01c6516340bd83d7f6e97505941028601db1b0..ce7d8773d57a00195dd0d9802c673076be37594e 100644 --- a/src/megray.h +++ b/src/megray.h @@ -11,4 +11,5 @@ #pragma once +#include "server.h" #include "communicator.h" diff --git a/src/nccl/communicator.cpp b/src/nccl/communicator.cpp index da3679fd94de58d8e75270839e2c4cf6e7342a9b..4a1b70d997387a47643611282c8e996bfe7ae347 100644 --- a/src/nccl/communicator.cpp +++ b/src/nccl/communicator.cpp @@ -28,7 +28,6 @@ namespace MegRay { NcclCommunicator::NcclCommunicator(int nranks, int rank) : Communicator(nranks, rank), m_inited(false) { - NCCL_ASSERT(ncclGetUniqueId(&m_uid)); } NcclCommunicator::~NcclCommunicator() { @@ -37,19 +36,14 @@ NcclCommunicator::~NcclCommunicator() { } } -std::string NcclCommunicator::get_uid() { - // serialize ncclUniqueId into a string - return std::string(m_uid.internal, NCCL_UNIQUE_ID_BYTES); -} - -Status NcclCommunicator::init(const std::vector& uids) { - MEGRAY_ASSERT(uids.size() == m_nranks, "incorrect size of uids"); - // only use unique id of rank 0 for initialization - const std::string uid = uids[0]; - MEGRAY_ASSERT(uid.size() == NCCL_UNIQUE_ID_BYTES, "invalid uid"); - memcpy(m_uid.internal, uid.data(), NCCL_UNIQUE_ID_BYTES); - // initialize nccl communicator - NCCL_CHECK(ncclCommInitRank(&m_comm, m_nranks, m_uid, m_rank)); +Status NcclCommunicator::do_init() { + uint32_t root = 0; + ncclUniqueId uid; + if (m_rank == root) { + ncclGetUniqueId(&uid); + } + MEGRAY_CHECK(m_client->broadcast(&uid, &uid, NCCL_UNIQUE_ID_BYTES, root)); + NCCL_CHECK(ncclCommInitRank(&m_comm, m_nranks, uid, m_rank)); m_inited = true; return MEGRAY_OK; } diff --git a/src/nccl/communicator.h b/src/nccl/communicator.h index 935a9c0eeb363524eded7ad816116acb7628075b..130d5c248f9244fb0a6ab6013c419f389a3da62d 100644 --- a/src/nccl/communicator.h +++ b/src/nccl/communicator.h @@ -29,10 +29,7 @@ class NcclCommunicator : public Communicator { ~NcclCommunicator(); - // get a serialized string of ncclUniqueId - std::string get_uid() override; - - Status init(const std::vector& uids) override; + Status do_init() override; Status send(const void* sendbuff, size_t len, uint32_t rank, std::shared_ptr ctx) override; @@ -65,7 +62,6 @@ class NcclCommunicator : public Communicator { DType dtype, ReduceOp op, uint32_t root, std::shared_ptr ctx) override; private: - ncclUniqueId m_uid; ncclComm_t m_comm; bool m_inited; }; diff --git a/src/server.cpp b/src/server.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e8bf51071510225805636cacb043859014f86369 --- /dev/null +++ b/src/server.cpp @@ -0,0 +1,246 @@ +/** + * \file src/server.cpp + * MegRay is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "server.h" + +#include +#include +#include +#include +#include +#include + +#include + +namespace MegRay { + +/************************ get_host_ip ************************/ + +char* get_host_ip() { + const char* device = getenv("MEGRAY_NET_DEVICE"); + if (device and strcmp(device, "lo") == 0) { + MEGRAY_ERROR("illegal net device: lo\n"); + MEGRAY_THROW("invalid argument"); + } + + struct ifaddrs *ifa; + SYS_ASSERT(getifaddrs(&ifa), -1); + + for (struct ifaddrs* p = ifa; p != NULL; p = p->ifa_next) { + if (p->ifa_addr and p->ifa_addr->sa_family == AF_INET and p->ifa_name) { + const char* name = p->ifa_name; + if (strcmp(name, "lo") != 0 and + (device == NULL or strcmp(name, device) == 0)) { + struct sockaddr_in* sin = (struct sockaddr_in*)p->ifa_addr; + const char* host_ip = inet_ntoa(sin->sin_addr); + MEGRAY_INFO("using net device %s (%s)", name, host_ip); + char* ret = new char(strlen(host_ip) + 1); + strcpy(ret, host_ip); + freeifaddrs(ifa); + return ret; + } + } + } + + if (device) { + MEGRAY_ERROR("failed to get host ip for device %s", device); + } else { + MEGRAY_ERROR("failed to get host ip"); + } + MEGRAY_THROW("system error"); + return nullptr; +} + +/************************ get_free_port ************************/ + +int get_free_port() { + // create socket + int sock; + SYS_ASSERT_RET(socket(AF_INET, SOCK_STREAM, 0), -1, sock); + + // set address + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(INADDR_ANY); + addr.sin_port = htons(0); + + // bind + SYS_ASSERT(bind(sock, (struct sockaddr*)&addr, sizeof(addr)), -1); + + // get port + socklen_t len = sizeof(addr); + SYS_ASSERT(getsockname(sock, (struct sockaddr*)&addr, &len), -1); + int port = ntohs(addr.sin_port); + + // close + SYS_ASSERT(close(sock), -1); + + return port; +} + +/************************ create_server ************************/ + +void serve_barrier(uint32_t nranks, int* conns); + +void serve_broadcast(uint32_t nranks, int* conns); + +void serve_allgather(uint32_t nranks, int* conns); + +void server_thread(int listenfd, uint32_t nranks) { + int conns[nranks]; + + for (uint32_t i = 0; i < nranks; i++) { + // establish connection + int conn; + SYS_ASSERT_RET(accept(listenfd, (struct sockaddr*)NULL, NULL), -1, conn); + + // recv rank and save into conns + uint32_t rank; + SYS_ASSERT(recv(conn, &rank, sizeof(uint32_t), MSG_WAITALL), -1); + conns[rank] = conn; + } + + // send ack to clients + uint32_t ack = 0; + for (uint32_t i = 0; i < nranks; i++) { + SYS_ASSERT(send(conns[i], &ack, sizeof(uint32_t), 0), -1); + } + + while (true) { + // receive a request from rank 0 + uint32_t request_id; + SYS_ASSERT(recv(conns[0], &request_id, sizeof(uint32_t), MSG_WAITALL), -1); + + if (request_id == 1) { + serve_barrier(nranks, conns); + } else if (request_id == 2) { + serve_broadcast(nranks, conns); + } else if (request_id == 3) { + serve_allgather(nranks, conns); + } else { + MEGRAY_ERROR("unexpected request id: %d", request_id); + MEGRAY_THROW("unexpected error"); + } + } +} + +Status create_server(uint32_t nranks, int port) { + // create socket + int listenfd; + SYS_CHECK_RET(socket(AF_INET, SOCK_STREAM, 0), -1, listenfd); + + // set server_addr + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = htonl(INADDR_ANY); + server_addr.sin_port = htons(port); + + // bind and listen + int opt = 1; + SYS_CHECK(setsockopt(listenfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(int)), -1); + SYS_CHECK(bind(listenfd, (struct sockaddr*)&server_addr, sizeof(server_addr)), -1); + SYS_CHECK(listen(listenfd, nranks), -1); + + // start server thread + std::thread th(server_thread, listenfd, nranks); + th.detach(); + + return MEGRAY_OK; +} + +/************************ barrier ************************/ + +void serve_barrier(uint32_t nranks, int* conns) { + uint32_t request_id; + + // recv other requests + for (uint32_t rank = 1; rank < nranks; rank++) { + SYS_ASSERT(recv(conns[rank], &request_id, sizeof(uint32_t), MSG_WAITALL), -1); + MEGRAY_ASSERT(request_id == 1, "inconsistent request_id from rank %d", rank); + } + + // send ack + uint32_t ack = 0; + for (uint32_t rank = 0; rank < nranks; rank++) { + SYS_ASSERT(send(conns[rank], &ack, sizeof(uint32_t), 0), -1); + } +} + +/************************ broadcast ************************/ + +void serve_broadcast(uint32_t nranks, int* conns) { + uint32_t request_id, root, root0; + uint64_t len, len0; + + // recv request 0 + SYS_ASSERT(recv(conns[0], &root0, sizeof(uint32_t), MSG_WAITALL), -1); + SYS_ASSERT(recv(conns[0], &len0, sizeof(uint64_t), MSG_WAITALL), -1); + + // recv other requests + for (uint32_t rank = 1; rank < nranks; rank++) { + SYS_ASSERT(recv(conns[rank], &request_id, sizeof(uint32_t), MSG_WAITALL), -1); + MEGRAY_ASSERT(request_id == 2, "inconsistent request_id from rank %d", rank); + + SYS_ASSERT(recv(conns[rank], &root, sizeof(uint32_t), MSG_WAITALL), -1); + MEGRAY_ASSERT(root == root0, "inconsistent root from rank %d", rank); + + SYS_ASSERT(recv(conns[rank], &len, sizeof(uint64_t), MSG_WAITALL), -1); + MEGRAY_ASSERT(len == len0, "inconsistent len from rank %d", rank); + } + + // recv data from root + void* data = malloc(len); + SYS_ASSERT(recv(conns[root], data, len, MSG_WAITALL), -1); + + // send data to clients + for (uint32_t rank = 0; rank < nranks; rank++) { + SYS_ASSERT(send(conns[rank], data, len, 0), -1); + } + + free(data); +} + +/************************ allgather ************************/ + +void serve_allgather(uint32_t nranks, int* conns) { + uint32_t request_id; + uint64_t len, len0; + + // recv request 0 + SYS_ASSERT(recv(conns[0], &len0, sizeof(uint64_t), MSG_WAITALL), -1); + + // recv other requests + for (uint32_t rank = 1; rank < nranks; rank++) { + SYS_ASSERT(recv(conns[rank], &request_id, sizeof(uint32_t), MSG_WAITALL), -1); + MEGRAY_ASSERT(request_id == 3, "inconsistent request_id from rank %d", rank); + + SYS_ASSERT(recv(conns[rank], &len, sizeof(uint64_t), MSG_WAITALL), -1); + MEGRAY_ASSERT(len == len0, "inconsistent len from rank %d", rank); + } + + // recv data + void* data = malloc(len * nranks); + for (uint32_t rank = 0; rank < nranks; rank++) { + char* ptr = (char*)data + rank * len; + SYS_ASSERT(recv(conns[rank], ptr, len, MSG_WAITALL), -1); + } + + // send data to clients + for (uint32_t rank = 0; rank < nranks; rank++) { + SYS_ASSERT(send(conns[rank], data, len * nranks, 0), -1); + } + + free(data); +} + +} // namespace MegRay diff --git a/src/server.h b/src/server.h new file mode 100644 index 0000000000000000000000000000000000000000..8db15f63ebaec4283290249d9377ae99235cbca0 --- /dev/null +++ b/src/server.h @@ -0,0 +1,27 @@ +/** + * \file src/server.h + * MegRay is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include + +#include "common.h" + +namespace MegRay { + +char* get_host_ip(); + +int get_free_port(); + +// create megray server +Status create_server(uint32_t nranks, int port); + +} // namespace MegRay diff --git a/src/ucx/communicator.cpp b/src/ucx/communicator.cpp index e25d970cf95523385d97b24be74ef28551abe04a..0c678d9e1c9e8eaab09cc7ca630f0e6a2e2bc51c 100644 --- a/src/ucx/communicator.cpp +++ b/src/ucx/communicator.cpp @@ -73,32 +73,38 @@ UcxCommunicator::~UcxCommunicator() { ucp_cleanup(m_context); } -std::string UcxCommunicator:: get_uid() { - size_t addr_len; - ucp_address_t* addr; +Status UcxCommunicator::do_init() { // get ucp worker address + size_t addr_len, addr_lens[m_nranks]; + ucp_address_t* addr; ucs_status_t status = ucp_worker_get_address(m_worker, &addr, &addr_len); MEGRAY_ASSERT(status == UCS_OK, "failed to get ucp worker address"); - // copy bytes to a string - std::string uid((char*)addr, addr_len); - ucp_worker_release_address(m_worker, addr); - return uid; -} -Status UcxCommunicator::init(const std::vector& uids) { - MEGRAY_ASSERT(uids.size() == m_nranks, "incorrect size of uids"); - m_eps.resize(m_nranks); + // allgather addr_len + MEGRAY_CHECK(m_client->allgather(&addr_len, addr_lens, sizeof(size_t))); + + // find max addr_len + size_t max_len = 0; + for (size_t i = 0; i < m_nranks; i++) { + if (addr_lens[i] > max_len) { + max_len = addr_lens[i]; + } + } + + // allgather addr + char addrs[max_len * m_nranks]; + MEGRAY_CHECK(m_client->allgather(addr, addrs, max_len)); + ucp_worker_release_address(m_worker, addr); // set endpoint params ucp_ep_params_t ep_params; ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; - ucs_status_t status; + // create ucp endpoint + m_eps.resize(m_nranks); for (size_t i = 0; i < m_nranks; i++) { if (i == m_rank) continue; - // set endpoint address - ep_params.address = reinterpret_cast(uids[i].data()); - // create ucp endpoint + ep_params.address = reinterpret_cast(addrs + i * max_len); status = ucp_ep_create(m_worker, &ep_params, &m_eps[i]); MEGRAY_ASSERT(status == UCS_OK, "failed to create ucp endpoint"); } diff --git a/src/ucx/communicator.h b/src/ucx/communicator.h index a0e6251ea33d11bbf794aa0fc9bf0a8ff7869912..48b1f72f5e9efe3ef72f702d8ce3bb626bc255dc 100644 --- a/src/ucx/communicator.h +++ b/src/ucx/communicator.h @@ -30,10 +30,7 @@ class UcxCommunicator : public Communicator { ~UcxCommunicator(); - // get a serialized string of ucp worker address - std::string get_uid() override; - - Status init(const std::vector& uids) override; + Status do_init() override; Status send(const void* sendbuff, size_t len, uint32_t rank, std::shared_ptr ctx) override; diff --git a/test/test_base.h b/test/test_base.h index 236ae6ab8c81ac4540d88f38f8fd02845d7094d0..3c73a372e8a5c011a6e1ada637eaa04b0ce1b2f1 100644 --- a/test/test_base.h +++ b/test/test_base.h @@ -22,23 +22,24 @@ template void run_test(int nranks, MegRay::Backend backend, std::vector>& inputs, std::vector>& expect_outputs, - std::function, - std::vector&, int, + std::function, int, int, std::vector&, std::vector&)> main_func) { std::vector> comms(nranks); - std::vector uids(nranks); std::vector> outputs(nranks); + int port = MegRay::get_free_port(); + auto ret = MegRay::create_server(nranks, port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + for (int i = 0; i < nranks; i++) { comms[i] = MegRay::get_communicator(nranks, i, backend); - uids[i] = comms[i]->get_uid(); outputs[i].resize(expect_outputs[i].size()); } std::vector threads; for (int i = 0; i < nranks; i++) { - threads.push_back(std::thread(main_func, comms[i], std::ref(uids), i, + threads.push_back(std::thread(main_func, comms[i], port, i, std::ref(inputs[i]), std::ref(outputs[i]))); } diff --git a/test/test_opr.cpp b/test/test_opr.cpp index 1facb0dea3784deac2152ef1e72a6363a0bafeb3..ef0fce0ff5d2165f257e08920855f61ba8494194 100644 --- a/test/test_opr.cpp +++ b/test/test_opr.cpp @@ -18,22 +18,18 @@ #include -#include "../src/megray.h" #include "test_base.h" TEST(TestNcclCommunicator, Init) { const int nranks = 3; - - std::vector> comms(nranks); - std::vector uids(nranks); - for (size_t i = 0; i < nranks; i++) { - comms[i] = MegRay::get_communicator(nranks, i, MegRay::MEGRAY_NCCL); - uids[i] = comms[i]->get_uid(); - } + const int port = MegRay::get_free_port(); + auto ret = MegRay::create_server(nranks, port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); auto run = [&](int rank) { cudaSetDevice(rank); - comms[rank]->init(uids); + auto comm = MegRay::get_communicator(nranks, rank, MegRay::MEGRAY_NCCL); + ASSERT_EQ(MegRay::MEGRAY_OK, comm->init("localhost", port)); }; std::vector threads; @@ -48,17 +44,14 @@ TEST(TestNcclCommunicator, Init) { TEST(TestUcxCommunicator, Init) { const int nranks = 3; - - std::vector> comms(nranks); - std::vector uids(nranks); - for (int i = 0; i < nranks; i++) { - comms[i] = MegRay::get_communicator(nranks, i, MegRay::MEGRAY_UCX); - uids[i] = comms[i]->get_uid(); - } + const int port = MegRay::get_free_port(); + auto ret = MegRay::create_server(nranks, port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); auto run = [&](int rank) { cudaSetDevice(rank); - comms[rank]->init(uids); + auto comm = MegRay::get_communicator(nranks, rank, MegRay::MEGRAY_UCX); + ASSERT_EQ(MegRay::MEGRAY_OK, comm->init("localhost", port)); }; std::vector threads; @@ -85,11 +78,11 @@ TEST(TestOpr, SendRecv) { } auto run = [len](std::shared_ptr comm, - std::vector& uids, int rank, + int port, int rank, std::vector& input, std::vector& output) -> void { CUDA_ASSERT(cudaSetDevice(rank)); - comm->init(uids); + comm->init("localhost", port); cudaStream_t stream; CUDA_ASSERT(cudaStreamCreate(&stream)); @@ -129,11 +122,11 @@ TEST(TestOpr, Scatter) { } auto run = [nranks, recvlen, root](std::shared_ptr comm, - std::vector& uids, int rank, + int port, int rank, std::vector& input, std::vector& output) -> void { CUDA_ASSERT(cudaSetDevice(rank)); - comm->init(uids); + comm->init("localhost", port); cudaStream_t stream; CUDA_ASSERT(cudaStreamCreate(&stream)); @@ -180,11 +173,11 @@ TEST(TestOpr, Gather) { } auto run = [nranks, sendlen, root](std::shared_ptr comm, - std::vector& uids, int rank, + int port, int rank, std::vector& input, std::vector& output) -> void { CUDA_ASSERT(cudaSetDevice(rank)); - comm->init(uids); + comm->init("localhost", port); cudaStream_t stream; CUDA_ASSERT(cudaStreamCreate(&stream)); @@ -235,11 +228,11 @@ TEST(TestOpr, AllToAll) { } auto run = [nranks, len](std::shared_ptr comm, - std::vector& uids, int rank, + int port, int rank, std::vector& input, std::vector& output) -> void { CUDA_ASSERT(cudaSetDevice(rank)); - comm->init(uids); + comm->init("localhost", port); cudaStream_t stream; CUDA_ASSERT(cudaStreamCreate(&stream)); @@ -283,11 +276,11 @@ TEST(TestOpr, AllGather) { } auto run = [nranks, sendlen](std::shared_ptr comm, - std::vector& uids, int rank, + int port, int rank, std::vector& input, std::vector& output) -> void { CUDA_ASSERT(cudaSetDevice(rank)); - comm->init(uids); + comm->init("localhost", port); cudaStream_t stream; CUDA_ASSERT(cudaStreamCreate(&stream)); @@ -322,11 +315,11 @@ TEST(TestOpr, AllReduce) { auto reduce_func = [nranks, len](MegRay::ReduceOp op) { auto run = [nranks, len, op](std::shared_ptr comm, - std::vector& uids, int rank, + int port, int rank, std::vector& input, std::vector& output) { CUDA_ASSERT(cudaSetDevice(rank)); - comm->init(uids); + comm->init("localhost", port); cudaStream_t stream; CUDA_ASSERT(cudaStreamCreate(&stream)); @@ -407,10 +400,10 @@ TEST(TestOpr, ReduceScatterSum) { auto reduce_func = [nranks, recvlen](MegRay::ReduceOp op) { auto run = [nranks, recvlen, op](std::shared_ptr comm, - std::vector& uids, int rank, + int port, int rank, std::vector& input, std::vector& output) { CUDA_ASSERT(cudaSetDevice(rank)); - comm->init(uids); + comm->init("localhost", port); cudaStream_t stream; CUDA_ASSERT(cudaStreamCreate(&stream)); @@ -501,11 +494,11 @@ TEST(TestOpr, Broadcast) { } auto run = [nranks, root, len](std::shared_ptr comm, - std::vector& uids, int rank, + int port, int rank, std::vector& input, std::vector& output) { CUDA_ASSERT(cudaSetDevice(rank)); - comm->init(uids); + comm->init("localhost", port); cudaStream_t stream; CUDA_ASSERT(cudaStreamCreate(&stream)); @@ -543,10 +536,10 @@ TEST(TestOpr, ReduceSum) { auto reduce_func = [nranks, root, len](MegRay::ReduceOp op) { auto run = [nranks, root, len, op](std::shared_ptr comm, - std::vector& uids, int rank, + int port, int rank, std::vector& input, std::vector& output) { CUDA_ASSERT(cudaSetDevice(rank)); - comm->init(uids); + comm->init("localhost", port); cudaStream_t stream; CUDA_ASSERT(cudaStreamCreate(&stream)); diff --git a/test/test_server_client.cpp b/test/test_server_client.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8febc0b08f331a1f34713ade90aaa6703ffffa73 --- /dev/null +++ b/test/test_server_client.cpp @@ -0,0 +1,211 @@ + /** + * \file test/test_server_client.cpp + * MegRay is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include + +#include + +#include "../src/server.h" +#include "../src/client.h" + +TEST(TestServerClient, GetHostIP) { + char* ip = MegRay::get_host_ip(); + ASSERT_TRUE(ip != NULL); + ASSERT_TRUE(strlen(ip) >= 8); +} + +TEST(TestServerClient, GetFreePort) { + int port = MegRay::get_free_port(); + ASSERT_TRUE(port > 0); +} + +TEST(TestServerClient, Connect) { + const int nranks = 3; + + const int port = MegRay::get_free_port(); + auto ret = MegRay::create_server(nranks, port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + auto run = [nranks, port](int rank) { + auto client = std::make_unique(nranks, rank); + auto ret = client->connect("localhost", port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} + +TEST(TestServerClient, Barrier) { + const int nranks = 3; + + const int port = MegRay::get_free_port(); + auto ret = MegRay::create_server(nranks, port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + int counter = 0; + + auto run = [nranks, port, &counter](int rank) { + auto client = std::make_unique(nranks, rank); + auto ret = client->connect("localhost", port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + ret = client->barrier(); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + sleep(rank); + ++counter; + + ret = client->barrier(); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + // if the barrier is not working correctly, threads that sleep + // less seconds will arrive here earlier and counter might be + // less than nranks + ASSERT_EQ(nranks, counter); + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} + +TEST(TestServerClient, Broadcast) { + const int nranks = 3; + const int root = 1; + const int chunk_size = 10; + + const int port = MegRay::get_free_port(); + auto ret = MegRay::create_server(nranks, port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + std::string str(chunk_size * nranks, '\0'); + for (size_t i = 0; i < str.size(); i++) { + str[i] = 'a' + i % 26; + } + auto expected = str.substr(root * chunk_size, chunk_size); + + auto run = [nranks, port, &str, &expected](int rank) { + auto client = std::make_unique(nranks, rank); + auto ret = client->connect("localhost", port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + const char* input = str.data() + rank * chunk_size; + char* output = (char*)malloc(chunk_size); + ret = client->broadcast(input, output, chunk_size, root); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + ASSERT_EQ(expected, std::string(output, chunk_size)); + free(output); + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} + +TEST(TestServerClient, AllGather) { + const int nranks = 3; + const int chunk_size = 10; + + const int port = MegRay::get_free_port(); + auto ret = MegRay::create_server(nranks, port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + std::string str(chunk_size * nranks, '\0'); + for (size_t i = 0; i < str.size(); i++) { + str[i] = 'a' + i % 26; + } + + auto run = [nranks, port, &str](int rank) { + auto client = std::make_unique(nranks, rank); + auto ret = client->connect("localhost", port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + const char* input = str.data() + rank * chunk_size; + char* output = (char*)malloc(str.size()); + ret = client->allgather(input, output, chunk_size); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + ASSERT_EQ(str, std::string(output, str.size())); + free(output); + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +} + +TEST(TestServerClient, Sequence) { + const int nranks = 3; + const int chunk_size = 10; + + const int port = MegRay::get_free_port(); + auto ret = MegRay::create_server(nranks, port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + std::string str(chunk_size * nranks, '\0'); + for (size_t i = 0; i < str.size(); i++) { + str[i] = 'a' + i % 26; + } + + auto run = [nranks, port, &str](int rank) { + auto client = std::make_unique(nranks, rank); + auto ret = client->connect("localhost", port); + ASSERT_EQ(MegRay::MEGRAY_OK, ret); + + const char* input = str.data() + rank * chunk_size; + char* output = (char*)malloc(str.size()); + + // send a sequence of requets without checking output + ASSERT_EQ(MegRay::MEGRAY_OK, client->barrier()); + ASSERT_EQ(MegRay::MEGRAY_OK, client->broadcast(input, output, chunk_size, 1)); + ASSERT_EQ(MegRay::MEGRAY_OK, client->allgather(input, output, chunk_size)); + ASSERT_EQ(MegRay::MEGRAY_OK, client->barrier()); + ASSERT_EQ(MegRay::MEGRAY_OK, client->allgather(input, output, chunk_size)); + ASSERT_EQ(MegRay::MEGRAY_OK, client->broadcast(input, output, chunk_size, 2)); + ASSERT_EQ(MegRay::MEGRAY_OK, client->allgather(input, output, chunk_size)); + ASSERT_EQ(MegRay::MEGRAY_OK, client->barrier()); + + free(output); + }; + + std::vector threads; + for (size_t i = 0; i < nranks; i++) { + threads.push_back(std::thread(run, i)); + } + + for (size_t i = 0; i < nranks; i++) { + threads[i].join(); + } +}