提交 6f3c7145 编写于 作者: M Megvii Engine Team 提交者: liuqingyi

feat(server): add megray server client

GitOrigin-RevId: 4a4cfe708412c01dd3b5acd620288ed208f5a0d0
上级 e435ee1b
/**
* \file src/client.cpp
* MegRay is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "client.h"
#include <arpa/inet.h>
#include <netinet/in.h>
#include <string.h>
#include <sys/socket.h>
namespace MegRay {
Client::Client(uint32_t nranks, uint32_t rank) :
m_nranks(nranks), m_rank(rank), m_connected(false) {
}
Client::~Client() {
}
Status Client::connect(const char* master_ip, int port) {
std::unique_lock<std::mutex> lock(m_mutex);
if (m_connected) {
MEGRAY_ERROR("Client already connected");
return MEGRAY_INVALID_USAGE;
}
// create socket
SYS_CHECK_RET(socket(AF_INET, SOCK_STREAM, 0), -1, m_conn);
// set server_addr
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port);
SYS_CHECK(inet_pton(AF_INET, master_ip, &server_addr.sin_addr), -1);
// connect
SYS_CHECK(::connect(m_conn, (struct sockaddr*)&server_addr, sizeof(server_addr)), -1);
// send client rank
SYS_CHECK(send(m_conn, &m_rank, sizeof(uint32_t), 0), -1);
// recv ack from server
uint32_t ack;
SYS_CHECK(recv(m_conn, &ack, sizeof(uint32_t), MSG_WAITALL), -1);
m_connected = true;
return MEGRAY_OK;
}
Status Client::barrier() {
std::unique_lock<std::mutex> lock(m_mutex);
if (!m_connected) {
MEGRAY_ERROR("Client not connected");
return MEGRAY_INVALID_USAGE;
}
// send request_id
uint32_t request_id = 1;
SYS_CHECK(send(m_conn, &request_id, sizeof(uint32_t), 0), -1);
// recv ack
uint32_t ack;
SYS_CHECK(recv(m_conn, &ack, sizeof(uint32_t), MSG_WAITALL), -1);
return MEGRAY_OK;
}
Status Client::broadcast(const void* sendbuff, void* recvbuff, size_t len, uint32_t root) {
std::unique_lock<std::mutex> lock(m_mutex);
if (!m_connected) {
MEGRAY_ERROR("Client not connected");
return MEGRAY_INVALID_USAGE;
}
// send request_id
uint32_t request_id = 2;
SYS_CHECK(send(m_conn, &request_id, sizeof(uint32_t), 0), -1);
// send root
SYS_CHECK(send(m_conn, &root, sizeof(uint32_t), 0), -1);
// send len
uint64_t len64 = len;
SYS_CHECK(send(m_conn, &len64, sizeof(uint64_t), 0), -1);
// send data
if (m_rank == root) {
SYS_CHECK(send(m_conn, sendbuff, len, 0), -1);
}
// recv data
SYS_CHECK(recv(m_conn, recvbuff, len, MSG_WAITALL), -1);
return MEGRAY_OK;
}
Status Client::allgather(const void* sendbuff, void* recvbuff, size_t sendlen) {
std::unique_lock<std::mutex> lock(m_mutex);
if (!m_connected) {
MEGRAY_ERROR("Client not connected");
return MEGRAY_INVALID_USAGE;
}
// send request_id
uint32_t request_id = 3;
SYS_CHECK(send(m_conn, &request_id, sizeof(uint32_t), 0), -1);
// send sendlen
uint64_t sendlen64 = sendlen;
SYS_CHECK(send(m_conn, &sendlen64, sizeof(uint64_t), 0), -1);
// send data
SYS_CHECK(send(m_conn, sendbuff, sendlen, 0), -1);
// recv data
SYS_CHECK(recv(m_conn, recvbuff, sendlen * m_nranks, MSG_WAITALL), -1);
return MEGRAY_OK;
}
} // namespace MegRay
/**
* \file src/client.h
* MegRay is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <mutex>
#include "common.h"
namespace MegRay {
/*!
* synchronize meta information with megray server
*/
class Client {
public:
Client(uint32_t nranks, uint32_t rank);
~Client();
Status connect(const char* master_ip, int port);
// block until all ranks reach this barrier
Status barrier();
// the length of sendbuff = the length of recvbuff = len
Status broadcast(const void* sendbuff, void* recvbuff, size_t sendlen, uint32_t root);
// the length of sendbuff = sendlen
// the length of recvbuff = sendlen * m_nranks
Status allgather(const void* sendbuff, void* recvbuff, size_t sendlen);
private:
uint32_t m_nranks;
uint32_t m_rank;
bool m_connected = false;
int m_conn;
std::mutex m_mutex;
};
} // namespace MegRay
......@@ -11,6 +11,8 @@
#pragma once
#include <errno.h>
#include "cuda_runtime.h"
#include "debug.h"
......@@ -19,12 +21,15 @@ namespace MegRay {
typedef enum {
MEGRAY_OK = 0,
MEGRAY_CUDA_ERR = 1,
MEGRAY_NCCL_ERR = 2,
MEGRAY_UCX_ERR = 3,
MEGRAY_ENV_ERROR = 4,
MEGRAY_INVALID_ARGUMENT = 5,
MEGRAY_NOT_IMPLEMENTED = 6
MEGRAY_SYS_ERROR = 1,
MEGRAY_CUDA_ERR = 2,
MEGRAY_NCCL_ERR = 3,
MEGRAY_UCX_ERR = 4,
MEGRAY_ENV_ERROR = 5,
MEGRAY_INVALID_ARGUMENT = 6,
MEGRAY_INVALID_USAGE = 7,
MEGRAY_UNEXPECTED_ERR = 8,
MEGRAY_NOT_IMPLEMENTED = 9
} Status;
#define MEGRAY_CHECK(expr) \
......@@ -36,6 +41,38 @@ typedef enum {
} \
} while (0)
#define SYS_CHECK_RET(expr, errval, retval) \
do { \
retval = (expr); \
if (retval == errval) { \
MEGRAY_ERROR("system error [%d]: %s", \
errno, strerror(errno)); \
return MEGRAY_SYS_ERROR; \
} \
} while (0)
#define SYS_CHECK(expr, errval) \
do { \
int retval; \
SYS_CHECK_RET(expr, errval, retval); \
} while (0)
#define SYS_ASSERT_RET(expr, errval, retval) \
do { \
retval = (expr); \
if (retval == errval) { \
MEGRAY_ERROR("system error [%d]: %s", \
errno, strerror(errno)); \
MEGRAY_THROW("system error"); \
} \
} while (0)
#define SYS_ASSERT(expr, errval) \
do { \
int retval; \
SYS_ASSERT_RET(expr, errval, retval); \
} while (0)
#define CUDA_CHECK(expr) \
do { \
cudaError_t status = (expr); \
......@@ -58,7 +95,7 @@ typedef enum {
typedef enum {
MEGRAY_NCCL = 0,
MEGRAY_UCX = 1,
MEGRAY_UCX = 1
} Backend;
typedef enum {
......
......@@ -15,6 +15,12 @@
namespace MegRay {
Status Communicator::init(const char* master_ip, int port) {
m_client = std::make_shared<Client>(m_nranks, m_rank);
MEGRAY_CHECK(m_client->connect(master_ip, port));
return do_init();
}
std::shared_ptr<Communicator> get_communicator(uint32_t nranks, uint32_t rank, Backend backend) {
std::shared_ptr<Communicator> comm;
switch (backend) {
......
......@@ -17,6 +17,7 @@
#include "common.h"
#include "context.h"
#include "client.h"
namespace MegRay {
......@@ -37,11 +38,11 @@ class Communicator {
// get the rank of this process
uint32_t rank() { return m_rank; }
// get the unique id of this communicator
virtual std::string get_uid() = 0;
// establish connection with megray server
Status init(const char* master_ip, int port);
// build a group with unique ids of all communicators in the group
virtual Status init(const std::vector<std::string>& uids) = 0;
// implemented in the subclass and called in init()
virtual Status do_init() = 0;
// send data to another communicator in the group
virtual Status send(const void* sendbuff, size_t len, uint32_t rank,
......@@ -90,6 +91,7 @@ class Communicator {
protected:
uint32_t m_nranks;
uint32_t m_rank;
std::shared_ptr<Client> m_client;
};
/*!
......
......@@ -11,4 +11,5 @@
#pragma once
#include "server.h"
#include "communicator.h"
......@@ -28,7 +28,6 @@ namespace MegRay {
NcclCommunicator::NcclCommunicator(int nranks, int rank) :
Communicator(nranks, rank), m_inited(false) {
NCCL_ASSERT(ncclGetUniqueId(&m_uid));
}
NcclCommunicator::~NcclCommunicator() {
......@@ -37,19 +36,14 @@ NcclCommunicator::~NcclCommunicator() {
}
}
std::string NcclCommunicator::get_uid() {
// serialize ncclUniqueId into a string
return std::string(m_uid.internal, NCCL_UNIQUE_ID_BYTES);
}
Status NcclCommunicator::init(const std::vector<std::string>& uids) {
MEGRAY_ASSERT(uids.size() == m_nranks, "incorrect size of uids");
// only use unique id of rank 0 for initialization
const std::string uid = uids[0];
MEGRAY_ASSERT(uid.size() == NCCL_UNIQUE_ID_BYTES, "invalid uid");
memcpy(m_uid.internal, uid.data(), NCCL_UNIQUE_ID_BYTES);
// initialize nccl communicator
NCCL_CHECK(ncclCommInitRank(&m_comm, m_nranks, m_uid, m_rank));
Status NcclCommunicator::do_init() {
uint32_t root = 0;
ncclUniqueId uid;
if (m_rank == root) {
ncclGetUniqueId(&uid);
}
MEGRAY_CHECK(m_client->broadcast(&uid, &uid, NCCL_UNIQUE_ID_BYTES, root));
NCCL_CHECK(ncclCommInitRank(&m_comm, m_nranks, uid, m_rank));
m_inited = true;
return MEGRAY_OK;
}
......
......@@ -29,10 +29,7 @@ class NcclCommunicator : public Communicator {
~NcclCommunicator();
// get a serialized string of ncclUniqueId
std::string get_uid() override;
Status init(const std::vector<std::string>& uids) override;
Status do_init() override;
Status send(const void* sendbuff, size_t len, uint32_t rank,
std::shared_ptr<Context> ctx) override;
......@@ -65,7 +62,6 @@ class NcclCommunicator : public Communicator {
DType dtype, ReduceOp op, uint32_t root, std::shared_ptr<Context> ctx) override;
private:
ncclUniqueId m_uid;
ncclComm_t m_comm;
bool m_inited;
};
......
/**
* \file src/server.cpp
* MegRay is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "server.h"
#include <arpa/inet.h>
#include <ifaddrs.h>
#include <netinet/in.h>
#include <string.h>
#include <sys/socket.h>
#include <unistd.h>
#include <thread>
namespace MegRay {
/************************ get_host_ip ************************/
char* get_host_ip() {
const char* device = getenv("MEGRAY_NET_DEVICE");
if (device and strcmp(device, "lo") == 0) {
MEGRAY_ERROR("illegal net device: lo\n");
MEGRAY_THROW("invalid argument");
}
struct ifaddrs *ifa;
SYS_ASSERT(getifaddrs(&ifa), -1);
for (struct ifaddrs* p = ifa; p != NULL; p = p->ifa_next) {
if (p->ifa_addr and p->ifa_addr->sa_family == AF_INET and p->ifa_name) {
const char* name = p->ifa_name;
if (strcmp(name, "lo") != 0 and
(device == NULL or strcmp(name, device) == 0)) {
struct sockaddr_in* sin = (struct sockaddr_in*)p->ifa_addr;
const char* host_ip = inet_ntoa(sin->sin_addr);
MEGRAY_INFO("using net device %s (%s)", name, host_ip);
char* ret = new char(strlen(host_ip) + 1);
strcpy(ret, host_ip);
freeifaddrs(ifa);
return ret;
}
}
}
if (device) {
MEGRAY_ERROR("failed to get host ip for device %s", device);
} else {
MEGRAY_ERROR("failed to get host ip");
}
MEGRAY_THROW("system error");
return nullptr;
}
/************************ get_free_port ************************/
int get_free_port() {
// create socket
int sock;
SYS_ASSERT_RET(socket(AF_INET, SOCK_STREAM, 0), -1, sock);
// set address
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = htons(0);
// bind
SYS_ASSERT(bind(sock, (struct sockaddr*)&addr, sizeof(addr)), -1);
// get port
socklen_t len = sizeof(addr);
SYS_ASSERT(getsockname(sock, (struct sockaddr*)&addr, &len), -1);
int port = ntohs(addr.sin_port);
// close
SYS_ASSERT(close(sock), -1);
return port;
}
/************************ create_server ************************/
void serve_barrier(uint32_t nranks, int* conns);
void serve_broadcast(uint32_t nranks, int* conns);
void serve_allgather(uint32_t nranks, int* conns);
void server_thread(int listenfd, uint32_t nranks) {
int conns[nranks];
for (uint32_t i = 0; i < nranks; i++) {
// establish connection
int conn;
SYS_ASSERT_RET(accept(listenfd, (struct sockaddr*)NULL, NULL), -1, conn);
// recv rank and save into conns
uint32_t rank;
SYS_ASSERT(recv(conn, &rank, sizeof(uint32_t), MSG_WAITALL), -1);
conns[rank] = conn;
}
// send ack to clients
uint32_t ack = 0;
for (uint32_t i = 0; i < nranks; i++) {
SYS_ASSERT(send(conns[i], &ack, sizeof(uint32_t), 0), -1);
}
while (true) {
// receive a request from rank 0
uint32_t request_id;
SYS_ASSERT(recv(conns[0], &request_id, sizeof(uint32_t), MSG_WAITALL), -1);
if (request_id == 1) {
serve_barrier(nranks, conns);
} else if (request_id == 2) {
serve_broadcast(nranks, conns);
} else if (request_id == 3) {
serve_allgather(nranks, conns);
} else {
MEGRAY_ERROR("unexpected request id: %d", request_id);
MEGRAY_THROW("unexpected error");
}
}
}
Status create_server(uint32_t nranks, int port) {
// create socket
int listenfd;
SYS_CHECK_RET(socket(AF_INET, SOCK_STREAM, 0), -1, listenfd);
// set server_addr
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = htonl(INADDR_ANY);
server_addr.sin_port = htons(port);
// bind and listen
int opt = 1;
SYS_CHECK(setsockopt(listenfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(int)), -1);
SYS_CHECK(bind(listenfd, (struct sockaddr*)&server_addr, sizeof(server_addr)), -1);
SYS_CHECK(listen(listenfd, nranks), -1);
// start server thread
std::thread th(server_thread, listenfd, nranks);
th.detach();
return MEGRAY_OK;
}
/************************ barrier ************************/
void serve_barrier(uint32_t nranks, int* conns) {
uint32_t request_id;
// recv other requests
for (uint32_t rank = 1; rank < nranks; rank++) {
SYS_ASSERT(recv(conns[rank], &request_id, sizeof(uint32_t), MSG_WAITALL), -1);
MEGRAY_ASSERT(request_id == 1, "inconsistent request_id from rank %d", rank);
}
// send ack
uint32_t ack = 0;
for (uint32_t rank = 0; rank < nranks; rank++) {
SYS_ASSERT(send(conns[rank], &ack, sizeof(uint32_t), 0), -1);
}
}
/************************ broadcast ************************/
void serve_broadcast(uint32_t nranks, int* conns) {
uint32_t request_id, root, root0;
uint64_t len, len0;
// recv request 0
SYS_ASSERT(recv(conns[0], &root0, sizeof(uint32_t), MSG_WAITALL), -1);
SYS_ASSERT(recv(conns[0], &len0, sizeof(uint64_t), MSG_WAITALL), -1);
// recv other requests
for (uint32_t rank = 1; rank < nranks; rank++) {
SYS_ASSERT(recv(conns[rank], &request_id, sizeof(uint32_t), MSG_WAITALL), -1);
MEGRAY_ASSERT(request_id == 2, "inconsistent request_id from rank %d", rank);
SYS_ASSERT(recv(conns[rank], &root, sizeof(uint32_t), MSG_WAITALL), -1);
MEGRAY_ASSERT(root == root0, "inconsistent root from rank %d", rank);
SYS_ASSERT(recv(conns[rank], &len, sizeof(uint64_t), MSG_WAITALL), -1);
MEGRAY_ASSERT(len == len0, "inconsistent len from rank %d", rank);
}
// recv data from root
void* data = malloc(len);
SYS_ASSERT(recv(conns[root], data, len, MSG_WAITALL), -1);
// send data to clients
for (uint32_t rank = 0; rank < nranks; rank++) {
SYS_ASSERT(send(conns[rank], data, len, 0), -1);
}
free(data);
}
/************************ allgather ************************/
void serve_allgather(uint32_t nranks, int* conns) {
uint32_t request_id;
uint64_t len, len0;
// recv request 0
SYS_ASSERT(recv(conns[0], &len0, sizeof(uint64_t), MSG_WAITALL), -1);
// recv other requests
for (uint32_t rank = 1; rank < nranks; rank++) {
SYS_ASSERT(recv(conns[rank], &request_id, sizeof(uint32_t), MSG_WAITALL), -1);
MEGRAY_ASSERT(request_id == 3, "inconsistent request_id from rank %d", rank);
SYS_ASSERT(recv(conns[rank], &len, sizeof(uint64_t), MSG_WAITALL), -1);
MEGRAY_ASSERT(len == len0, "inconsistent len from rank %d", rank);
}
// recv data
void* data = malloc(len * nranks);
for (uint32_t rank = 0; rank < nranks; rank++) {
char* ptr = (char*)data + rank * len;
SYS_ASSERT(recv(conns[rank], ptr, len, MSG_WAITALL), -1);
}
// send data to clients
for (uint32_t rank = 0; rank < nranks; rank++) {
SYS_ASSERT(send(conns[rank], data, len * nranks, 0), -1);
}
free(data);
}
} // namespace MegRay
/**
* \file src/server.h
* MegRay is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <mutex>
#include "common.h"
namespace MegRay {
char* get_host_ip();
int get_free_port();
// create megray server
Status create_server(uint32_t nranks, int port);
} // namespace MegRay
......@@ -73,32 +73,38 @@ UcxCommunicator::~UcxCommunicator() {
ucp_cleanup(m_context);
}
std::string UcxCommunicator:: get_uid() {
size_t addr_len;
ucp_address_t* addr;
Status UcxCommunicator::do_init() {
// get ucp worker address
size_t addr_len, addr_lens[m_nranks];
ucp_address_t* addr;
ucs_status_t status = ucp_worker_get_address(m_worker, &addr, &addr_len);
MEGRAY_ASSERT(status == UCS_OK, "failed to get ucp worker address");
// copy bytes to a string
std::string uid((char*)addr, addr_len);
ucp_worker_release_address(m_worker, addr);
return uid;
}
Status UcxCommunicator::init(const std::vector<std::string>& uids) {
MEGRAY_ASSERT(uids.size() == m_nranks, "incorrect size of uids");
m_eps.resize(m_nranks);
// allgather addr_len
MEGRAY_CHECK(m_client->allgather(&addr_len, addr_lens, sizeof(size_t)));
// find max addr_len
size_t max_len = 0;
for (size_t i = 0; i < m_nranks; i++) {
if (addr_lens[i] > max_len) {
max_len = addr_lens[i];
}
}
// allgather addr
char addrs[max_len * m_nranks];
MEGRAY_CHECK(m_client->allgather(addr, addrs, max_len));
ucp_worker_release_address(m_worker, addr);
// set endpoint params
ucp_ep_params_t ep_params;
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
ucs_status_t status;
// create ucp endpoint
m_eps.resize(m_nranks);
for (size_t i = 0; i < m_nranks; i++) {
if (i == m_rank) continue;
// set endpoint address
ep_params.address = reinterpret_cast<const ucp_address_t*>(uids[i].data());
// create ucp endpoint
ep_params.address = reinterpret_cast<const ucp_address_t*>(addrs + i * max_len);
status = ucp_ep_create(m_worker, &ep_params, &m_eps[i]);
MEGRAY_ASSERT(status == UCS_OK, "failed to create ucp endpoint");
}
......
......@@ -30,10 +30,7 @@ class UcxCommunicator : public Communicator {
~UcxCommunicator();
// get a serialized string of ucp worker address
std::string get_uid() override;
Status init(const std::vector<std::string>& uids) override;
Status do_init() override;
Status send(const void* sendbuff, size_t len, uint32_t rank,
std::shared_ptr<Context> ctx) override;
......
......@@ -22,23 +22,24 @@ template <typename T>
void run_test(int nranks, MegRay::Backend backend,
std::vector<std::vector<T>>& inputs,
std::vector<std::vector<T>>& expect_outputs,
std::function<void(std::shared_ptr<MegRay::Communicator>,
std::vector<std::string>&, int,
std::function<void(std::shared_ptr<MegRay::Communicator>, int, int,
std::vector<T>&, std::vector<T>&)>
main_func) {
std::vector<std::shared_ptr<MegRay::Communicator>> comms(nranks);
std::vector<std::string> uids(nranks);
std::vector<std::vector<T>> outputs(nranks);
int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
for (int i = 0; i < nranks; i++) {
comms[i] = MegRay::get_communicator(nranks, i, backend);
uids[i] = comms[i]->get_uid();
outputs[i].resize(expect_outputs[i].size());
}
std::vector<std::thread> threads;
for (int i = 0; i < nranks; i++) {
threads.push_back(std::thread(main_func, comms[i], std::ref(uids), i,
threads.push_back(std::thread(main_func, comms[i], port, i,
std::ref(inputs[i]),
std::ref(outputs[i])));
}
......
......@@ -18,22 +18,18 @@
#include <gtest/gtest.h>
#include "../src/megray.h"
#include "test_base.h"
TEST(TestNcclCommunicator, Init) {
const int nranks = 3;
std::vector<std::shared_ptr<MegRay::Communicator>> comms(nranks);
std::vector<std::string> uids(nranks);
for (size_t i = 0; i < nranks; i++) {
comms[i] = MegRay::get_communicator(nranks, i, MegRay::MEGRAY_NCCL);
uids[i] = comms[i]->get_uid();
}
const int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
auto run = [&](int rank) {
cudaSetDevice(rank);
comms[rank]->init(uids);
auto comm = MegRay::get_communicator(nranks, rank, MegRay::MEGRAY_NCCL);
ASSERT_EQ(MegRay::MEGRAY_OK, comm->init("localhost", port));
};
std::vector<std::thread> threads;
......@@ -48,17 +44,14 @@ TEST(TestNcclCommunicator, Init) {
TEST(TestUcxCommunicator, Init) {
const int nranks = 3;
std::vector<std::shared_ptr<MegRay::Communicator>> comms(nranks);
std::vector<std::string> uids(nranks);
for (int i = 0; i < nranks; i++) {
comms[i] = MegRay::get_communicator(nranks, i, MegRay::MEGRAY_UCX);
uids[i] = comms[i]->get_uid();
}
const int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
auto run = [&](int rank) {
cudaSetDevice(rank);
comms[rank]->init(uids);
auto comm = MegRay::get_communicator(nranks, rank, MegRay::MEGRAY_UCX);
ASSERT_EQ(MegRay::MEGRAY_OK, comm->init("localhost", port));
};
std::vector<std::thread> threads;
......@@ -85,11 +78,11 @@ TEST(TestOpr, SendRecv) {
}
auto run = [len](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank,
int port, int rank,
std::vector<char>& input,
std::vector<char>& output) -> void {
CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids);
comm->init("localhost", port);
cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream));
......@@ -129,11 +122,11 @@ TEST(TestOpr, Scatter) {
}
auto run = [nranks, recvlen, root](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank,
int port, int rank,
std::vector<float>& input,
std::vector<float>& output) -> void {
CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids);
comm->init("localhost", port);
cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream));
......@@ -180,11 +173,11 @@ TEST(TestOpr, Gather) {
}
auto run = [nranks, sendlen, root](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank,
int port, int rank,
std::vector<float>& input,
std::vector<float>& output) -> void {
CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids);
comm->init("localhost", port);
cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream));
......@@ -235,11 +228,11 @@ TEST(TestOpr, AllToAll) {
}
auto run = [nranks, len](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank,
int port, int rank,
std::vector<float>& input,
std::vector<float>& output) -> void {
CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids);
comm->init("localhost", port);
cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream));
......@@ -283,11 +276,11 @@ TEST(TestOpr, AllGather) {
}
auto run = [nranks, sendlen](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank,
int port, int rank,
std::vector<float>& input,
std::vector<float>& output) -> void {
CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids);
comm->init("localhost", port);
cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream));
......@@ -322,11 +315,11 @@ TEST(TestOpr, AllReduce) {
auto reduce_func = [nranks, len](MegRay::ReduceOp op) {
auto run = [nranks, len, op](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank,
int port, int rank,
std::vector<float>& input,
std::vector<float>& output) {
CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids);
comm->init("localhost", port);
cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream));
......@@ -407,10 +400,10 @@ TEST(TestOpr, ReduceScatterSum) {
auto reduce_func = [nranks, recvlen](MegRay::ReduceOp op) {
auto run = [nranks, recvlen,
op](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank,
int port, int rank,
std::vector<float>& input, std::vector<float>& output) {
CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids);
comm->init("localhost", port);
cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream));
......@@ -501,11 +494,11 @@ TEST(TestOpr, Broadcast) {
}
auto run = [nranks, root, len](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank,
int port, int rank,
std::vector<float>& input,
std::vector<float>& output) {
CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids);
comm->init("localhost", port);
cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream));
......@@ -543,10 +536,10 @@ TEST(TestOpr, ReduceSum) {
auto reduce_func = [nranks, root, len](MegRay::ReduceOp op) {
auto run = [nranks, root, len,
op](std::shared_ptr<MegRay::Communicator> comm,
std::vector<std::string>& uids, int rank,
int port, int rank,
std::vector<float>& input, std::vector<float>& output) {
CUDA_ASSERT(cudaSetDevice(rank));
comm->init(uids);
comm->init("localhost", port);
cudaStream_t stream;
CUDA_ASSERT(cudaStreamCreate(&stream));
......
/**
* \file test/test_server_client.cpp
* MegRay is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <thread>
#include <vector>
#include <gtest/gtest.h>
#include "../src/server.h"
#include "../src/client.h"
TEST(TestServerClient, GetHostIP) {
char* ip = MegRay::get_host_ip();
ASSERT_TRUE(ip != NULL);
ASSERT_TRUE(strlen(ip) >= 8);
}
TEST(TestServerClient, GetFreePort) {
int port = MegRay::get_free_port();
ASSERT_TRUE(port > 0);
}
TEST(TestServerClient, Connect) {
const int nranks = 3;
const int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
auto run = [nranks, port](int rank) {
auto client = std::make_unique<MegRay::Client>(nranks, rank);
auto ret = client->connect("localhost", port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
};
std::vector<std::thread> threads;
for (size_t i = 0; i < nranks; i++) {
threads.push_back(std::thread(run, i));
}
for (size_t i = 0; i < nranks; i++) {
threads[i].join();
}
}
TEST(TestServerClient, Barrier) {
const int nranks = 3;
const int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
int counter = 0;
auto run = [nranks, port, &counter](int rank) {
auto client = std::make_unique<MegRay::Client>(nranks, rank);
auto ret = client->connect("localhost", port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
ret = client->barrier();
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
sleep(rank);
++counter;
ret = client->barrier();
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
// if the barrier is not working correctly, threads that sleep
// less seconds will arrive here earlier and counter might be
// less than nranks
ASSERT_EQ(nranks, counter);
};
std::vector<std::thread> threads;
for (size_t i = 0; i < nranks; i++) {
threads.push_back(std::thread(run, i));
}
for (size_t i = 0; i < nranks; i++) {
threads[i].join();
}
}
TEST(TestServerClient, Broadcast) {
const int nranks = 3;
const int root = 1;
const int chunk_size = 10;
const int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
std::string str(chunk_size * nranks, '\0');
for (size_t i = 0; i < str.size(); i++) {
str[i] = 'a' + i % 26;
}
auto expected = str.substr(root * chunk_size, chunk_size);
auto run = [nranks, port, &str, &expected](int rank) {
auto client = std::make_unique<MegRay::Client>(nranks, rank);
auto ret = client->connect("localhost", port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
const char* input = str.data() + rank * chunk_size;
char* output = (char*)malloc(chunk_size);
ret = client->broadcast(input, output, chunk_size, root);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
ASSERT_EQ(expected, std::string(output, chunk_size));
free(output);
};
std::vector<std::thread> threads;
for (size_t i = 0; i < nranks; i++) {
threads.push_back(std::thread(run, i));
}
for (size_t i = 0; i < nranks; i++) {
threads[i].join();
}
}
TEST(TestServerClient, AllGather) {
const int nranks = 3;
const int chunk_size = 10;
const int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
std::string str(chunk_size * nranks, '\0');
for (size_t i = 0; i < str.size(); i++) {
str[i] = 'a' + i % 26;
}
auto run = [nranks, port, &str](int rank) {
auto client = std::make_unique<MegRay::Client>(nranks, rank);
auto ret = client->connect("localhost", port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
const char* input = str.data() + rank * chunk_size;
char* output = (char*)malloc(str.size());
ret = client->allgather(input, output, chunk_size);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
ASSERT_EQ(str, std::string(output, str.size()));
free(output);
};
std::vector<std::thread> threads;
for (size_t i = 0; i < nranks; i++) {
threads.push_back(std::thread(run, i));
}
for (size_t i = 0; i < nranks; i++) {
threads[i].join();
}
}
TEST(TestServerClient, Sequence) {
const int nranks = 3;
const int chunk_size = 10;
const int port = MegRay::get_free_port();
auto ret = MegRay::create_server(nranks, port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
std::string str(chunk_size * nranks, '\0');
for (size_t i = 0; i < str.size(); i++) {
str[i] = 'a' + i % 26;
}
auto run = [nranks, port, &str](int rank) {
auto client = std::make_unique<MegRay::Client>(nranks, rank);
auto ret = client->connect("localhost", port);
ASSERT_EQ(MegRay::MEGRAY_OK, ret);
const char* input = str.data() + rank * chunk_size;
char* output = (char*)malloc(str.size());
// send a sequence of requets without checking output
ASSERT_EQ(MegRay::MEGRAY_OK, client->barrier());
ASSERT_EQ(MegRay::MEGRAY_OK, client->broadcast(input, output, chunk_size, 1));
ASSERT_EQ(MegRay::MEGRAY_OK, client->allgather(input, output, chunk_size));
ASSERT_EQ(MegRay::MEGRAY_OK, client->barrier());
ASSERT_EQ(MegRay::MEGRAY_OK, client->allgather(input, output, chunk_size));
ASSERT_EQ(MegRay::MEGRAY_OK, client->broadcast(input, output, chunk_size, 2));
ASSERT_EQ(MegRay::MEGRAY_OK, client->allgather(input, output, chunk_size));
ASSERT_EQ(MegRay::MEGRAY_OK, client->barrier());
free(output);
};
std::vector<std::thread> threads;
for (size_t i = 0; i < nranks; i++) {
threads.push_back(std::thread(run, i));
}
for (size_t i = 0; i < nranks; i++) {
threads[i].join();
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册