提交 6f3c7145 编写于 作者: M Megvii Engine Team 提交者: liuqingyi

feat(server): add megray server client

GitOrigin-RevId: 4a4cfe708412c01dd3b5acd620288ed208f5a0d0
上级 e435ee1b
/**
* \file src/client.cpp
* MegRay is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "client.h"
#include <arpa/inet.h>
#include <netinet/in.h>
#include <string.h>
#include <sys/socket.h>
namespace MegRay {
Client::Client(uint32_t nranks, uint32_t rank) :
m_nranks(nranks), m_rank(rank), m_connected(false) {
}
Client::~Client() {
}
Status Client::connect(const char* master_ip, int port) {
std::unique_lock<std::mutex> lock(m_mutex);
if (m_connected) {
MEGRAY_ERROR("Client already connected");
return MEGRAY_INVALID_USAGE;
}
// create socket
SYS_CHECK_RET(socket(AF_INET, SOCK_STREAM, 0), -1, m_conn);
// set server_addr
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port);
SYS_CHECK(inet_pton(AF_INET, master_ip, &server_addr.sin_addr), -1);
// connect
SYS_CHECK(::connect(m_conn, (struct sockaddr*)&server_addr, sizeof(server_addr)), -1);
// send client rank
SYS_CHECK(send(m_conn, &m_rank, sizeof(uint32_t), 0), -1);
// recv ack from server
uint32_t ack;
SYS_CHECK(recv(m_conn, &ack, sizeof(uint32_t), MSG_WAITALL), -1);
m_connected = true;
return MEGRAY_OK;
}
Status Client::barrier() {
std::unique_lock<std::mutex> lock(m_mutex);
if (!m_connected) {
MEGRAY_ERROR("Client not connected");
return MEGRAY_INVALID_USAGE;
}
// send request_id
uint32_t request_id = 1;
SYS_CHECK(send(m_conn, &request_id, sizeof(uint32_t), 0), -1);
// recv ack
uint32_t ack;
SYS_CHECK(recv(m_conn, &ack, sizeof(uint32_t), MSG_WAITALL), -1);
return MEGRAY_OK;
}
Status Client::broadcast(const void* sendbuff, void* recvbuff, size_t len, uint32_t root) {
std::unique_lock<std::mutex> lock(m_mutex);
if (!m_connected) {
MEGRAY_ERROR("Client not connected");
return MEGRAY_INVALID_USAGE;
}
// send request_id
uint32_t request_id = 2;
SYS_CHECK(send(m_conn, &request_id, sizeof(uint32_t), 0), -1);
// send root
SYS_CHECK(send(m_conn, &root, sizeof(uint32_t), 0), -1);
// send len
uint64_t len64 = len;
SYS_CHECK(send(m_conn, &len64, sizeof(uint64_t), 0), -1);
// send data
if (m_rank == root) {
SYS_CHECK(send(m_conn, sendbuff, len, 0), -1);
}
// recv data
SYS_CHECK(recv(m_conn, recvbuff, len, MSG_WAITALL), -1);
return MEGRAY_OK;
}
Status Client::allgather(const void* sendbuff, void* recvbuff, size_t sendlen) {
std::unique_lock<std::mutex> lock(m_mutex);
if (!m_connected) {
MEGRAY_ERROR("Client not connected");
return MEGRAY_INVALID_USAGE;
}
// send request_id
uint32_t request_id = 3;
SYS_CHECK(send(m_conn, &request_id, sizeof(uint32_t), 0), -1);
// send sendlen
uint64_t sendlen64 = sendlen;
SYS_CHECK(send(m_conn, &sendlen64, sizeof(uint64_t), 0), -1);
// send data
SYS_CHECK(send(m_conn, sendbuff, sendlen, 0), -1);
// recv data
SYS_CHECK(recv(m_conn, recvbuff, sendlen * m_nranks, MSG_WAITALL), -1);
return MEGRAY_OK;
}
} // namespace MegRay
/**
* \file src/client.h
* MegRay is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <mutex>
#include "common.h"
namespace MegRay {
/*!
* synchronize meta information with megray server
*/
class Client {
public:
Client(uint32_t nranks, uint32_t rank);
~Client();
Status connect(const char* master_ip, int port);
// block until all ranks reach this barrier
Status barrier();
// the length of sendbuff = the length of recvbuff = len
Status broadcast(const void* sendbuff, void* recvbuff, size_t sendlen, uint32_t root);
// the length of sendbuff = sendlen
// the length of recvbuff = sendlen * m_nranks
Status allgather(const void* sendbuff, void* recvbuff, size_t sendlen);
private:
uint32_t m_nranks;
uint32_t m_rank;
bool m_connected = false;
int m_conn;
std::mutex m_mutex;
};
} // namespace MegRay
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#pragma once #pragma once
#include <errno.h>
#include "cuda_runtime.h" #include "cuda_runtime.h"
#include "debug.h" #include "debug.h"
...@@ -19,12 +21,15 @@ namespace MegRay { ...@@ -19,12 +21,15 @@ namespace MegRay {
typedef enum { typedef enum {
MEGRAY_OK = 0, MEGRAY_OK = 0,
MEGRAY_CUDA_ERR = 1, MEGRAY_SYS_ERROR = 1,
MEGRAY_NCCL_ERR = 2, MEGRAY_CUDA_ERR = 2,
MEGRAY_UCX_ERR = 3, MEGRAY_NCCL_ERR = 3,
MEGRAY_ENV_ERROR = 4, MEGRAY_UCX_ERR = 4,
MEGRAY_INVALID_ARGUMENT = 5, MEGRAY_ENV_ERROR = 5,
MEGRAY_NOT_IMPLEMENTED = 6 MEGRAY_INVALID_ARGUMENT = 6,
MEGRAY_INVALID_USAGE = 7,
MEGRAY_UNEXPECTED_ERR = 8,
MEGRAY_NOT_IMPLEMENTED = 9
} Status; } Status;
#define MEGRAY_CHECK(expr) \ #define MEGRAY_CHECK(expr) \
...@@ -36,6 +41,38 @@ typedef enum { ...@@ -36,6 +41,38 @@ typedef enum {
} \ } \
} while (0) } while (0)
#define SYS_CHECK_RET(expr, errval, retval) \
do { \
retval = (expr); \
if (retval == errval) { \
MEGRAY_ERROR("system error [%d]: %s", \
errno, strerror(errno)); \
return MEGRAY_SYS_ERROR; \
} \
} while (0)
#define SYS_CHECK(expr, errval) \
do { \
int retval; \
SYS_CHECK_RET(expr, errval, retval); \
} while (0)
#define SYS_ASSERT_RET(expr, errval, retval) \
do { \
retval = (expr); \
if (retval == errval) { \
MEGRAY_ERROR("system error [%d]: %s", \
errno, strerror(errno)); \
MEGRAY_THROW("system error"); \
} \
} while (0)
#define SYS_ASSERT(expr, errval) \
do { \
int retval; \
SYS_ASSERT_RET(expr, errval, retval); \
} while (0)
#define CUDA_CHECK(expr) \ #define CUDA_CHECK(expr) \
do { \ do { \
cudaError_t status = (expr); \ cudaError_t status = (expr); \
...@@ -58,7 +95,7 @@ typedef enum { ...@@ -58,7 +95,7 @@ typedef enum {
typedef enum { typedef enum {
MEGRAY_NCCL = 0, MEGRAY_NCCL = 0,
MEGRAY_UCX = 1, MEGRAY_UCX = 1
} Backend; } Backend;
typedef enum { typedef enum {
......
...@@ -15,6 +15,12 @@ ...@@ -15,6 +15,12 @@
namespace MegRay { namespace MegRay {
Status Communicator::init(const char* master_ip, int port) {
m_client = std::make_shared<Client>(m_nranks, m_rank);
MEGRAY_CHECK(m_client->connect(master_ip, port));
return do_init();
}
std::shared_ptr<Communicator> get_communicator(uint32_t nranks, uint32_t rank, Backend backend) { std::shared_ptr<Communicator> get_communicator(uint32_t nranks, uint32_t rank, Backend backend) {
std::shared_ptr<Communicator> comm; std::shared_ptr<Communicator> comm;
switch (backend) { switch (backend) {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "common.h" #include "common.h"
#include "context.h" #include "context.h"
#include "client.h"
namespace MegRay { namespace MegRay {
...@@ -37,11 +38,11 @@ class Communicator { ...@@ -37,11 +38,11 @@ class Communicator {
// get the rank of this process // get the rank of this process
uint32_t rank() { return m_rank; } uint32_t rank() { return m_rank; }
// get the unique id of this communicator // establish connection with megray server
virtual std::string get_uid() = 0; Status init(const char* master_ip, int port);
// build a group with unique ids of all communicators in the group // implemented in the subclass and called in init()
virtual Status init(const std::vector<std::string>& uids) = 0; virtual Status do_init() = 0;
// send data to another communicator in the group // send data to another communicator in the group
virtual Status send(const void* sendbuff, size_t len, uint32_t rank, virtual Status send(const void* sendbuff, size_t len, uint32_t rank,
...@@ -90,6 +91,7 @@ class Communicator { ...@@ -90,6 +91,7 @@ class Communicator {
protected: protected:
uint32_t m_nranks; uint32_t m_nranks;
uint32_t m_rank; uint32_t m_rank;
std::shared_ptr<Client> m_client;
}; };
/*! /*!
......
...@@ -11,4 +11,5 @@ ...@@ -11,4 +11,5 @@
#pragma once #pragma once
#include "server.h"
#include "communicator.h" #include "communicator.h"
...@@ -28,7 +28,6 @@ namespace MegRay { ...@@ -28,7 +28,6 @@ namespace MegRay {
NcclCommunicator::NcclCommunicator(int nranks, int rank) : NcclCommunicator::NcclCommunicator(int nranks, int rank) :
Communicator(nranks, rank), m_inited(false) { Communicator(nranks, rank), m_inited(false) {
NCCL_ASSERT(ncclGetUniqueId(&m_uid));
} }
NcclCommunicator::~NcclCommunicator() { NcclCommunicator::~NcclCommunicator() {
...@@ -37,19 +36,14 @@ NcclCommunicator::~NcclCommunicator() { ...@@ -37,19 +36,14 @@ NcclCommunicator::~NcclCommunicator() {
} }
} }
std::string NcclCommunicator::get_uid() { Status NcclCommunicator::do_init() {
// serialize ncclUniqueId into a string uint32_t root = 0;
return std::string(m_uid.internal, NCCL_UNIQUE_ID_BYTES); ncclUniqueId uid;
} if (m_rank == root) {
ncclGetUniqueId(&uid);
Status NcclCommunicator::init(const std::vector<std::string>& uids) { }
MEGRAY_ASSERT(uids.size() == m_nranks, "incorrect size of uids"); MEGRAY_CHECK(m_client->broadcast(&uid, &uid, NCCL_UNIQUE_ID_BYTES, root));
// only use unique id of rank 0 for initialization NCCL_CHECK(ncclCommInitRank(&m_comm, m_nranks, uid, m_rank));
const std::string uid = uids[0];
MEGRAY_ASSERT(uid.size() == NCCL_UNIQUE_ID_BYTES, "invalid uid");
memcpy(m_uid.internal, uid.data(), NCCL_UNIQUE_ID_BYTES);
// initialize nccl communicator
NCCL_CHECK(ncclCommInitRank(&m_comm, m_nranks, m_uid, m_rank));
m_inited = true; m_inited = true;
return MEGRAY_OK; return MEGRAY_OK;
} }
......
...@@ -29,10 +29,7 @@ class NcclCommunicator : public Communicator { ...@@ -29,10 +29,7 @@ class NcclCommunicator : public Communicator {
~NcclCommunicator(); ~NcclCommunicator();
// get a serialized string of ncclUniqueId Status do_init() override;
std::string get_uid() override;
Status init(const std::vector<std::string>& uids) override;
Status send(const void* sendbuff, size_t len, uint32_t rank, Status send(const void* sendbuff, size_t len, uint32_t rank,
std::shared_ptr<Context> ctx) override; std::shared_ptr<Context> ctx) override;
...@@ -65,7 +62,6 @@ class NcclCommunicator : public Communicator { ...@@ -65,7 +62,6 @@ class NcclCommunicator : public Communicator {
DType dtype, ReduceOp op, uint32_t root, std::shared_ptr<Context> ctx) override; DType dtype, ReduceOp op, uint32_t root, std::shared_ptr<Context> ctx) override;
private: private:
ncclUniqueId m_uid;
ncclComm_t m_comm; ncclComm_t m_comm;
bool m_inited; bool m_inited;
}; };
......
/**
* \file src/server.cpp
* MegRay is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "server.h"
#include <arpa/inet.h>
#include <ifaddrs.h>
#include <netinet/in.h>
#include <string.h>
#include <sys/socket.h>
#include <unistd.h>
#include <thread>
namespace MegRay {
/************************ get_host_ip ************************/
char* get_host_ip() {
const char* device = getenv("MEGRAY_NET_DEVICE");
if (device and strcmp(device, "lo") == 0) {
MEGRAY_ERROR("illegal net device: lo\n");
MEGRAY_THROW("invalid argument");
}
struct ifaddrs *ifa;
SYS_ASSERT(getifaddrs(&ifa), -1);
for (struct ifaddrs* p = ifa; p != NULL; p = p->ifa_next) {
if (p->ifa_addr and p->ifa_addr->sa_family == AF_INET and p->ifa_name) {
const char* name = p->ifa_name;
if (strcmp(name, "lo") != 0 and
(device == NULL or strcmp(name, device) == 0)) {
struct sockaddr_in* sin = (struct sockaddr_in*)p->ifa_addr;
const char* host_ip = inet_ntoa(sin->sin_addr);
MEGRAY_INFO("using net device %s (%s)", name, host_ip);
char* ret = new char(strlen(host_ip) + 1);
strcpy(ret, host_ip);
freeifaddrs(ifa);
return ret;
}
}
}
if (device) {
MEGRAY_ERROR("failed to get host ip for device %s", device);
} else {
MEGRAY_ERROR("failed to get host ip");
}
MEGRAY_THROW("system error");
return nullptr;
}
/************************ get_free_port ************************/
int get_free_port() {
// create socket
int sock;
SYS_ASSERT_RET(socket(AF_INET, SOCK_STREAM, 0), -1, sock);
// set address
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = htons(0);
// bind
SYS_ASSERT(bind(sock, (struct sockaddr*)&addr, sizeof(addr)), -1);
// get port
socklen_t len = sizeof(addr);
SYS_ASSERT(getsockname(sock, (struct sockaddr*)&addr, &len), -1);
int port = ntohs(addr.sin_port);
// close
SYS_ASSERT(close(sock), -1);
return port;
}
/************************ create_server ************************/
void serve_barrier(uint32_t nranks, int* conns);
void serve_broadcast(uint32_t nranks, int* conns);
void serve_allgather(uint32_t nranks, int* conns);
void server_thread(int listenfd, uint32_t nranks) {
int conns[nranks];
for (uint32_t i = 0; i < nranks; i++) {
// establish connection
int conn;
SYS_ASSERT_RET(accept(listenfd, (struct sockaddr*)NULL, NULL), -1, conn);
// recv rank and save into conns
uint32_t rank;
SYS_ASSERT(recv(conn, &rank, sizeof(uint32_t), MSG_WAITALL), -1);
conns[rank] = conn;
}
// send ack to clients
uint32_t ack = 0;
for (uint32_t i = 0; i < nranks; i++) {
SYS_ASSERT(send(conns[i], &ack, sizeof(uint32_t), 0), -1);
}
while (true) {
// receive a request from rank 0
uint32_t request_id;
SYS_ASSERT(recv(conns[0], &request_id, sizeof(uint32_t), MSG_WAITALL), -1);
if (request_id == 1) {
serve_barrier(nranks, conns);
} else if (request_id == 2) {
serve_broadcast(nranks, conns);
} else if (request_id == 3) {
serve_allgather(nranks, conns);
} else {
MEGRAY_ERROR("unexpected request id: %d", request_id);
MEGRAY_THROW("unexpected error");
}
}
}
Status create_server(uint32_t nranks, int port) {
// create socket
int listenfd;
SYS_CHECK_RET(socket(AF_INET, SOCK_STREAM, 0), -1, listenfd);
// set server_addr
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
server_addr.sin_port = htons(port);
// bind and listen
int opt = 1;
SYS_CHECK(setsockopt(listenfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(int)), -1);
SYS_CHECK(bind(listenfd, (struct sockaddr*)&server_addr, sizeof(server_addr)), -1);
SYS_CHECK(listen(listenfd, nranks), -1);
// start server thread
std::thread th(server_thread, listenfd, nranks);
th.detach();
return MEGRAY_OK;
}
/************************ barrier ************************/
void serve_barrier(uint32_t nranks, int* conns) {
uint32_t request_id;
// recv other requests
for (uint32_t rank = 1; rank < nranks; rank++) {
SYS_ASSERT(recv(conns[rank], &request_id, sizeof(uint32_t), MSG_WAITALL), -1);
MEGRAY_ASSERT(request_id == 1, "inconsistent request_id from rank %d", rank);
}
// send ack
uint32_t ack = 0;
for (uint32_t rank = 0; rank < nranks; rank++) {
SYS_ASSERT(send(conns[rank], &ack, sizeof(uint32_t), 0), -1);
}
}
/************************ broadcast ************************/
void serve_broadcast(uint32_t nranks, int* conns) {
uint32_t request_id, root, root0;
uint64_t len, len0;
// recv request 0
SYS_ASSERT(recv(conns[0], &root0, sizeof(uint32_t), MSG_WAITALL), -1);
SYS_ASSERT(recv(conns[0], &len0, sizeof(uint64_t), MSG_WAITALL), -1);
// recv other requests
for (uint32_t rank = 1; rank < nranks; rank++) {
SYS_ASSERT(recv(conns[rank], &request_id, sizeof(uint32_t), MSG_WAITALL), -1);
MEGRAY_ASSERT(request_id == 2, "inconsistent request_id from rank %d", rank);
SYS_ASSERT(recv(conns[rank], &root, sizeof(uint32_t), MSG_WAITALL), -1);
MEGRAY_ASSERT(root == root0, "inconsistent root from rank %d", rank);
SYS_ASSERT(recv(conns[rank], &len, sizeof(uint64_t), MSG_WAITALL), -1);
MEGRAY_ASSERT(len == len0, "inconsistent len from rank %d", rank);
}
// recv data from root
void* data = malloc(len);
SYS_ASSERT(recv(conns[root], data, len, MSG_WAITALL), -1);
// send data to clients
for (uint32_t rank = 0; rank < nranks; rank++) {
SYS_ASSERT(send(conns[rank], data, len, 0), -1);
}
free(data);
}
/************************ allgather ************************/
void serve_allgather(uint32_t nranks, int* conns) {
uint32_t request_id;
uint64_t len, len0;
// recv request 0
SYS_ASSERT(recv(conns[0], &len0, sizeof(uint64_t), MSG_WAITALL), -1);
// recv other requests
for (uint32_t rank = 1; rank < nranks; rank++) {
SYS_ASSERT(recv(conns[rank], &request_id, sizeof(uint32_t), MSG_WAITALL), -1);
MEGRAY_ASSERT(request_id == 3, "inconsistent request_id from rank %d", rank);
SYS_ASSERT(recv(conns[rank], &len, sizeof(uint64_t), MSG_WAITALL), -1);
MEGRAY_ASSERT(len == len0, "inconsistent len from rank %d", rank);
}
// recv data
void* data = malloc(len * nranks);
for (uint32_t rank = 0; rank < nranks; rank++) {
char* ptr = (char*)data + rank * len;
SYS_ASSERT(recv(conns[rank], ptr, len, MSG_WAITALL), -1);
}
// send data to clients
for (uint32_t rank = 0; rank < nranks; rank++) {
SYS_ASSERT(send(conns[rank], data, len * nranks, 0), -1);
}
free(data);
}
} // namespace MegRay
/**
* \file src/server.h
* MegRay is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <mutex>
#include "common.h"
namespace MegRay {
char* get_host_ip();
int get_free_port();
// create megray server
Status create_server(uint32_t nranks, int port);
} // namespace MegRay
...@@ -73,32 +73,38 @@ UcxCommunicator::~UcxCommunicator() { ...@@ -73,32 +73,38 @@ UcxCommunicator::~UcxCommunicator() {
ucp_cleanup(m_context); ucp_cleanup(m_context);
} }
std::string UcxCommunicator:: get_uid() { Status UcxCommunicator::do_init() {
size_t addr_len;
ucp_address_t* addr;
// get ucp worker address // get ucp worker address
size_t addr_len, addr_lens[m_nranks];
ucp_address_t* addr;
ucs_status_t status = ucp_worker_get_address(m_worker, &addr, &addr_len); ucs_status_t status = ucp_worker_get_address(m_worker, &addr, &addr_len);
MEGRAY_ASSERT(status == UCS_OK, "failed to get ucp worker address"); MEGRAY_ASSERT(status == UCS_OK, "failed to get ucp worker address");
// copy bytes to a string
std::string uid((char*)addr, addr_len);
ucp_worker_release_address(m_worker, addr);
return uid;
}
Status UcxCommunicator::init(const std::vector<std::string>& uids) { // allgather addr_len
MEGRAY_ASSERT(uids.size() == m_nranks, "incorrect size of uids"); MEGRAY_CHECK(m_client->allgather(&addr_len, addr_lens, sizeof(size_t)));
m_eps.resize(m_nranks);
// find max addr_len
size_t max_len = 0;
for (size_t i = 0; i < m_nranks; i++) {
if (addr_lens[i] > max_len) {
max_len = addr_lens[i];
}
}
// allgather addr
char addrs[max_len * m_nranks];
MEGRAY_CHECK(m_client->allgather(addr, addrs, max_len));
ucp_worker_release_address(m_worker, addr);
// set endpoint params // set endpoint params
ucp_ep_params_t ep_params; ucp_ep_params_t ep_params;
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
ucs_status_t status;
// create ucp endpoint
m_eps.resize(m_nranks);
for (size_t i = 0; i < m_nranks; i++) { for (size_t i = 0; i < m_nranks; i++) {
if (i == m_rank) continue; if (i == m_rank) continue;
// set endpoint address ep_params.address = reinterpret_cast<const ucp_address_t*>(addrs + i * max_len);
ep_params.address = reinterpret_cast<const ucp_address_t*>(uids[i].data());
// create ucp endpoint
status = ucp_ep_create(m_worker, &ep_params, &m_eps[i]); status = ucp_ep_create(m_worker, &ep_params, &m_eps[i]);
MEGRAY_ASSERT(status == UCS_OK, "failed to create ucp endpoint"); MEGRAY_ASSERT(status == UCS_OK, "failed to create ucp endpoint");
} }
......
...@@ -30,10 +30,7 @@ class UcxCommunicator : public Communicator { ...@@ -30,10 +30,7 @@ class UcxCommunicator : public Communicator {
~UcxCommunicator(); ~UcxCommunicator();
// get a serialized string of ucp worker address Status do_init() override;
std::string get_uid() override;
Status init(const std::vector<std::string>& uids) override;
Status send(const void* sendbuff, size_t len, uint32_t rank, Status send(const void* sendbuff, size_t len, uint32_t rank,
std::shared_ptr<Context> ctx) override; std::shared_ptr<Context> ctx) override;
......
...@@ -22,23 +22,24 @@ template <typename T> ...@@ -22,23 +22,24 @@ template <typename T>
void run_test(int nranks, MegRay::Backend backend, void run_test(int nranks, MegRay::Backend backend,
std::vector<std::vector<T>>& inputs, std::vector<std::vector<T>>& inputs,
std::vector<std::vector<T>>& expect_outputs, std::vector<std::vector<T>>& expect_outputs,
std::function<void(std::shared_ptr<MegRay::Communicator>, std::function<void(std::shared_ptr<MegRay::Communicator>, int, int,
std::vector<std::string>&, int,
std::vector<T>&, std::vector<T>&)> std::vector<T>&, std::vector<T>&)>
main_func) { main_func) {
std::vector<std::shared_ptr<MegRay::Communicator>> comms(nranks); std::vector<std::shared_ptr<MegRay::Communicator>> comms(nranks);
std::vector<std::string> uids(nranks);
std::vector<std::vector<T>> outputs(nranks); std::vector<std::vector<T>> outputs(nranks);
int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
for (int i = 0; i < nranks; i++) { for (int i = 0; i < nranks; i++) {
comms[i] = MegRay::get_communicator(nranks, i, backend); comms[i] = MegRay::get_communicator(nranks, i, backend);
uids[i] = comms[i]->get_uid();
outputs[i].resize(expect_outputs[i].size()); outputs[i].resize(expect_outputs[i].size());
} }
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (int i = 0; i < nranks; i++) { for (int i = 0; i < nranks; i++) {
threads.push_back(std::thread(main_func, comms[i], std::ref(uids), i, threads.push_back(std::thread(main_func, comms[i], port, i,
std::ref(inputs[i]), std::ref(inputs[i]),
std::ref(outputs[i]))); std::ref(outputs[i])));
} }
......
...@@ -18,22 +18,18 @@ ...@@ -18,22 +18,18 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "../src/megray.h"
#include "test_base.h" #include "test_base.h"
TEST(TestNcclCommunicator, Init) { TEST(TestNcclCommunicator, Init) {
const int nranks = 3; const int nranks = 3;
const int port = MegRay::get_free_port();
std::vector<std::shared_ptr<MegRay::Communicator>> comms(nranks); auto ret = MegRay::create_server(nranks, port);
std::vector<std::string> uids(nranks); ASSERT_EQ(MegRay::MEGRAY_OK, ret);
for (size_t i = 0; i < nranks; i++) {
comms[i] = MegRay::get_communicator(nranks, i, MegRay::MEGRAY_NCCL);
uids[i] = comms[i]->get_uid();
}
auto run = [&](int rank) { auto run = [&](int rank) {
cudaSetDevice(rank); cudaSetDevice(rank);
comms[rank]->init(uids); auto comm = MegRay::get_communicator(nranks, rank, MegRay::MEGRAY_NCCL);
ASSERT_EQ(MegRay::MEGRAY_OK, comm->init("localhost", port));
}; };
std::vector<std::thread> threads; std::vector<std::thread> threads;
...@@ -48,17 +44,14 @@ TEST(TestNcclCommunicator, Init) { ...@@ -48,17 +44,14 @@ TEST(TestNcclCommunicator, Init) {
TEST(TestUcxCommunicator, Init) { TEST(TestUcxCommunicator, Init) {
const int nranks = 3; const int nranks = 3;
const int port = MegRay::get_free_port();
std::vector<std::shared_ptr<MegRay::Communicator>> comms(nranks); auto ret = MegRay::create_server(nranks, port);
std::vector<std::string> uids(nranks); ASSERT_EQ(MegRay::MEGRAY_OK, ret);
for (int i = 0; i < nranks; i++) {
comms[i] = MegRay::get_communicator(nranks, i, MegRay::MEGRAY_UCX);
uids[i] = comms[i]->get_uid();
}
auto run = [&](int rank) { auto run = [&](int rank) {
cudaSetDevice(rank); cudaSetDevice(rank);
comms[rank]->init(uids); auto comm = MegRay::get_communicator(nranks, rank, MegRay::MEGRAY_UCX);
ASSERT_EQ(MegRay::MEGRAY_OK, comm->init("localhost", port));
}; };
std::vector<std::thread> threads; std::vector<std::thread> threads;
...@@ -85,11 +78,11 @@ TEST(TestOpr, SendRecv) { ...@@ -85,11 +78,11 @@ TEST(TestOpr, SendRecv) {
} }
auto run = [len](std::shared_ptr<MegRay::Communicator> comm, auto run = [len](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank, int port, int rank,
std::vector<char>& input, std::vector<char>& input,
std::vector<char>& output) -> void { std::vector<char>& output) -> void {
CUDA_ASSERT(cudaSetDevice(rank)); CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids); comm->init("localhost", port);
cudaStream_t stream; cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream)); CUDA_ASSERT(cudaStreamCreate(&stream));
...@@ -129,11 +122,11 @@ TEST(TestOpr, Scatter) { ...@@ -129,11 +122,11 @@ TEST(TestOpr, Scatter) {
} }
auto run = [nranks, recvlen, root](std::shared_ptr<MegRay::Communicator> comm, auto run = [nranks, recvlen, root](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank, int port, int rank,
std::vector<float>& input, std::vector<float>& input,
std::vector<float>& output) -> void { std::vector<float>& output) -> void {
CUDA_ASSERT(cudaSetDevice(rank)); CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids); comm->init("localhost", port);
cudaStream_t stream; cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream)); CUDA_ASSERT(cudaStreamCreate(&stream));
...@@ -180,11 +173,11 @@ TEST(TestOpr, Gather) { ...@@ -180,11 +173,11 @@ TEST(TestOpr, Gather) {
} }
auto run = [nranks, sendlen, root](std::shared_ptr<MegRay::Communicator> comm, auto run = [nranks, sendlen, root](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank, int port, int rank,
std::vector<float>& input, std::vector<float>& input,
std::vector<float>& output) -> void { std::vector<float>& output) -> void {
CUDA_ASSERT(cudaSetDevice(rank)); CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids); comm->init("localhost", port);
cudaStream_t stream; cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream)); CUDA_ASSERT(cudaStreamCreate(&stream));
...@@ -235,11 +228,11 @@ TEST(TestOpr, AllToAll) { ...@@ -235,11 +228,11 @@ TEST(TestOpr, AllToAll) {
} }
auto run = [nranks, len](std::shared_ptr<MegRay::Communicator> comm, auto run = [nranks, len](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank, int port, int rank,
std::vector<float>& input, std::vector<float>& input,
std::vector<float>& output) -> void { std::vector<float>& output) -> void {
CUDA_ASSERT(cudaSetDevice(rank)); CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids); comm->init("localhost", port);
cudaStream_t stream; cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream)); CUDA_ASSERT(cudaStreamCreate(&stream));
...@@ -283,11 +276,11 @@ TEST(TestOpr, AllGather) { ...@@ -283,11 +276,11 @@ TEST(TestOpr, AllGather) {
} }
auto run = [nranks, sendlen](std::shared_ptr<MegRay::Communicator> comm, auto run = [nranks, sendlen](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank, int port, int rank,
std::vector<float>& input, std::vector<float>& input,
std::vector<float>& output) -> void { std::vector<float>& output) -> void {
CUDA_ASSERT(cudaSetDevice(rank)); CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids); comm->init("localhost", port);
cudaStream_t stream; cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream)); CUDA_ASSERT(cudaStreamCreate(&stream));
...@@ -322,11 +315,11 @@ TEST(TestOpr, AllReduce) { ...@@ -322,11 +315,11 @@ TEST(TestOpr, AllReduce) {
auto reduce_func = [nranks, len](MegRay::ReduceOp op) { auto reduce_func = [nranks, len](MegRay::ReduceOp op) {
auto run = [nranks, len, op](std::shared_ptr<MegRay::Communicator> comm, auto run = [nranks, len, op](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank, int port, int rank,
std::vector<float>& input, std::vector<float>& input,
std::vector<float>& output) { std::vector<float>& output) {
CUDA_ASSERT(cudaSetDevice(rank)); CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids); comm->init("localhost", port);
cudaStream_t stream; cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream)); CUDA_ASSERT(cudaStreamCreate(&stream));
...@@ -407,10 +400,10 @@ TEST(TestOpr, ReduceScatterSum) { ...@@ -407,10 +400,10 @@ TEST(TestOpr, ReduceScatterSum) {
auto reduce_func = [nranks, recvlen](MegRay::ReduceOp op) { auto reduce_func = [nranks, recvlen](MegRay::ReduceOp op) {
auto run = [nranks, recvlen, auto run = [nranks, recvlen,
op](std::shared_ptr<MegRay::Communicator> comm, op](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank, int port, int rank,
std::vector<float>& input, std::vector<float>& output) { std::vector<float>& input, std::vector<float>& output) {
CUDA_ASSERT(cudaSetDevice(rank)); CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids); comm->init("localhost", port);
cudaStream_t stream; cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream)); CUDA_ASSERT(cudaStreamCreate(&stream));
...@@ -501,11 +494,11 @@ TEST(TestOpr, Broadcast) { ...@@ -501,11 +494,11 @@ TEST(TestOpr, Broadcast) {
} }
auto run = [nranks, root, len](std::shared_ptr<MegRay::Communicator> comm, auto run = [nranks, root, len](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank, int port, int rank,
std::vector<float>& input, std::vector<float>& input,
std::vector<float>& output) { std::vector<float>& output) {
CUDA_ASSERT(cudaSetDevice(rank)); CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids); comm->init("localhost", port);
cudaStream_t stream; cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream)); CUDA_ASSERT(cudaStreamCreate(&stream));
...@@ -543,10 +536,10 @@ TEST(TestOpr, ReduceSum) { ...@@ -543,10 +536,10 @@ TEST(TestOpr, ReduceSum) {
auto reduce_func = [nranks, root, len](MegRay::ReduceOp op) { auto reduce_func = [nranks, root, len](MegRay::ReduceOp op) {
auto run = [nranks, root, len, auto run = [nranks, root, len,
op](std::shared_ptr<MegRay::Communicator> comm, op](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank, int port, int rank,
std::vector<float>& input, std::vector<float>& output) { std::vector<float>& input, std::vector<float>& output) {
CUDA_ASSERT(cudaSetDevice(rank)); CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids); comm->init("localhost", port);
cudaStream_t stream; cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream)); CUDA_ASSERT(cudaStreamCreate(&stream));
......
/**
* \file test/test_server_client.cpp
* MegRay is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <thread>
#include <vector>
#include <gtest/gtest.h>
#include "../src/server.h"
#include "../src/client.h"
TEST(TestServerClient, GetHostIP) {
char* ip = MegRay::get_host_ip();
ASSERT_TRUE(ip != NULL);
ASSERT_TRUE(strlen(ip) >= 8);
}
TEST(TestServerClient, GetFreePort) {
int port = MegRay::get_free_port();
ASSERT_TRUE(port > 0);
}
TEST(TestServerClient, Connect) {
const int nranks = 3;
const int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
auto run = [nranks, port](int rank) {
auto client = std::make_unique<MegRay::Client>(nranks, rank);
auto ret = client->connect("localhost", port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
};
std::vector<std::thread> threads;
for (size_t i = 0; i < nranks; i++) {
threads.push_back(std::thread(run, i));
}
for (size_t i = 0; i < nranks; i++) {
threads[i].join();
}
}
TEST(TestServerClient, Barrier) {
const int nranks = 3;
const int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
int counter = 0;
auto run = [nranks, port, &counter](int rank) {
auto client = std::make_unique<MegRay::Client>(nranks, rank);
auto ret = client->connect("localhost", port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
ret = client->barrier();
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
sleep(rank);
++counter;
ret = client->barrier();
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
// if the barrier is not working correctly, threads that sleep
// less seconds will arrive here earlier and counter might be
// less than nranks
ASSERT_EQ(nranks, counter);
};
std::vector<std::thread> threads;
for (size_t i = 0; i < nranks; i++) {
threads.push_back(std::thread(run, i));
}
for (size_t i = 0; i < nranks; i++) {
threads[i].join();
}
}
TEST(TestServerClient, Broadcast) {
const int nranks = 3;
const int root = 1;
const int chunk_size = 10;
const int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
std::string str(chunk_size * nranks, '\0');
for (size_t i = 0; i < str.size(); i++) {
str[i] = 'a' + i % 26;
}
auto expected = str.substr(root * chunk_size, chunk_size);
auto run = [nranks, port, &str, &expected](int rank) {
auto client = std::make_unique<MegRay::Client>(nranks, rank);
auto ret = client->connect("localhost", port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
const char* input = str.data() + rank * chunk_size;
char* output = (char*)malloc(chunk_size);
ret = client->broadcast(input, output, chunk_size, root);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
ASSERT_EQ(expected, std::string(output, chunk_size));
free(output);
};
std::vector<std::thread> threads;
for (size_t i = 0; i < nranks; i++) {
threads.push_back(std::thread(run, i));
}
for (size_t i = 0; i < nranks; i++) {
threads[i].join();
}
}
TEST(TestServerClient, AllGather) {
const int nranks = 3;
const int chunk_size = 10;
const int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
std::string str(chunk_size * nranks, '\0');
for (size_t i = 0; i < str.size(); i++) {
str[i] = 'a' + i % 26;
}
auto run = [nranks, port, &str](int rank) {
auto client = std::make_unique<MegRay::Client>(nranks, rank);
auto ret = client->connect("localhost", port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
const char* input = str.data() + rank * chunk_size;
char* output = (char*)malloc(str.size());
ret = client->allgather(input, output, chunk_size);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
ASSERT_EQ(str, std::string(output, str.size()));
free(output);
};
std::vector<std::thread> threads;
for (size_t i = 0; i < nranks; i++) {
threads.push_back(std::thread(run, i));
}
for (size_t i = 0; i < nranks; i++) {
threads[i].join();
}
}
TEST(TestServerClient, Sequence) {
const int nranks = 3;
const int chunk_size = 10;
const int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
std::string str(chunk_size * nranks, '\0');
for (size_t i = 0; i < str.size(); i++) {
str[i] = 'a' + i % 26;
}
auto run = [nranks, port, &str](int rank) {
auto client = std::make_unique<MegRay::Client>(nranks, rank);
auto ret = client->connect("localhost", port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
const char* input = str.data() + rank * chunk_size;
char* output = (char*)malloc(str.size());
// send a sequence of requets without checking output
ASSERT_EQ(MegRay::MEGRAY_OK, client->barrier());
ASSERT_EQ(MegRay::MEGRAY_OK, client->broadcast(input, output, chunk_size, 1));
ASSERT_EQ(MegRay::MEGRAY_OK, client->allgather(input, output, chunk_size));
ASSERT_EQ(MegRay::MEGRAY_OK, client->barrier());
ASSERT_EQ(MegRay::MEGRAY_OK, client->allgather(input, output, chunk_size));
ASSERT_EQ(MegRay::MEGRAY_OK, client->broadcast(input, output, chunk_size, 2));
ASSERT_EQ(MegRay::MEGRAY_OK, client->allgather(input, output, chunk_size));
ASSERT_EQ(MegRay::MEGRAY_OK, client->barrier());
free(output);
};
std::vector<std::thread> threads;
for (size_t i = 0; i < nranks; i++) {
threads.push_back(std::thread(run, i));
}
for (size_t i = 0; i < nranks; i++) {
threads[i].join();
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册