/** * \file src/ucx/communicator.cpp * MegRay is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "communicator.h" #include namespace MegRay { // returned handler of ucp requests // flag "completed" is set in callback functions struct Request { int completed; }; // init request handler static void request_init(void* request) { static_cast(request)->completed = 0; } // send callback, set flag "completed" to 1 static void send_cb(void* request, ucs_status_t status) { ((Request*)request)->completed = 1; } // receive callback, set flag "completed" to 1 static void recv_cb(void* request, ucs_status_t status, ucp_tag_recv_info_t* info) { ((Request*)request)->completed = 1; } UcxCommunicator::UcxCommunicator(int nranks, int rank) : Communicator(nranks, rank), m_inited(false) { const char* env = "UCX_MEMTYPE_CACHE=n"; putenv((char*)env); // set ucp context params ucp_params_t ucp_params; memset(&ucp_params, 0, sizeof(ucp_params)); ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_REQUEST_SIZE | UCP_PARAM_FIELD_REQUEST_INIT; ucp_params.features = UCP_FEATURE_TAG | UCP_FEATURE_RMA | UCP_FEATURE_WAKEUP; ucp_params.request_size = sizeof(Request); ucp_params.request_init = request_init; // init ucp context ucs_status_t status; status = ucp_init(&ucp_params, nullptr, &m_context); MEGRAY_ASSERT(status == UCS_OK, "failed to init ucp context"); // set ucp worker params ucp_worker_params_t worker_params; memset(&worker_params, 0, sizeof(worker_params)); worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; // create ucp worker status = ucp_worker_create(m_context, &worker_params, &m_worker); MEGRAY_ASSERT(status == UCS_OK, "failed to create ucp worker"); } UcxCommunicator::~UcxCommunicator() { // destroy ucp worker and cleanup ucp context ucp_worker_destroy(m_worker); ucp_cleanup(m_context); } std::string UcxCommunicator:: get_uid() { size_t addr_len; ucp_address_t* addr; // get ucp worker address ucs_status_t status = ucp_worker_get_address(m_worker, &addr, &addr_len); MEGRAY_ASSERT(status == UCS_OK, "failed to get ucp worker address"); // copy bytes to a string std::string uid((char*)addr, addr_len); ucp_worker_release_address(m_worker, addr); return uid; } Status UcxCommunicator::init(const std::vector& uids) { MEGRAY_ASSERT(uids.size() == m_nranks, "incorrect size of uids"); m_eps.resize(m_nranks); // set endpoint params ucp_ep_params_t ep_params; ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; ucs_status_t status; for (size_t i = 0; i < m_nranks; i++) { if (i == m_rank) continue; // set endpoint address ep_params.address = reinterpret_cast(uids[i].data()); // create ucp endpoint status = ucp_ep_create(m_worker, &ep_params, &m_eps[i]); MEGRAY_ASSERT(status == UCS_OK, "failed to create ucp endpoint"); } return MEGRAY_OK; } Status UcxCommunicator::send(const void* sendbuff, size_t len, uint32_t rank, std::shared_ptr ctx) { char sync; MEGRAY_CHECK(_send(sendbuff, len, rank)); MEGRAY_CHECK(_recv(&sync, sizeof(char), rank)); // synchronize MEGRAY_CHECK(_flush()); return MEGRAY_OK; } Status UcxCommunicator::recv(void* recvbuff, size_t len, uint32_t rank, std::shared_ptr ctx) { char sync; MEGRAY_CHECK(_recv(recvbuff, len, rank)); MEGRAY_CHECK(_send(&sync, sizeof(char), rank)); // synchronize MEGRAY_CHECK(_flush()); return MEGRAY_OK; } Status UcxCommunicator::_send(const void* sendbuff, size_t len, uint32_t rank) { MEGRAY_ASSERT(rank != m_rank, "invalid send rank"); std::lock_guard lock(m_requests_mtx); // submit non-blocking send request to ucp void* ptr = ucp_tag_send_nb(m_eps[rank], sendbuff, len, ucp_dt_make_contig(1), m_rank, send_cb); if (UCS_PTR_IS_PTR(ptr)) { m_requests.push_back(ptr); // send request is pending } else if (UCS_PTR_STATUS(ptr) != UCS_OK) { return MEGRAY_UCX_ERR; } return MEGRAY_OK; } Status UcxCommunicator::_recv(void* recvbuff, size_t len, uint32_t rank) { MEGRAY_ASSERT(rank != m_rank, "invalid recv rank"); std::lock_guard lock(m_requests_mtx); // submit non-blocking receive request to ucp // mask 0xffffffff means matching every bit of uint32 void* ptr = ucp_tag_recv_nb(m_worker, recvbuff, len, ucp_dt_make_contig(1), rank, 0xffffffff, recv_cb); if (UCS_PTR_IS_PTR(ptr)) { m_requests.push_back(ptr); } else if (UCS_PTR_STATUS(ptr) != UCS_OK) { // receive request is pending return MEGRAY_UCX_ERR; } return MEGRAY_OK; } Status UcxCommunicator::_flush() { std::lock_guard lock(m_requests_mtx); for (size_t i = 0; i < m_requests.size(); i++) { Request* req = (Request*)(m_requests[i]); // check flag "completed" of request handler while (req->completed == 0) { ucp_worker_progress(m_worker); } // release request handler req->completed = 0; ucp_request_release(req); } m_requests.clear(); return MEGRAY_OK; } } // namespace MegRay