提交 d7bb62cf 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

refactor(mgb): move mm_handler from python module into opr-mm

GitOrigin-RevId: f401ce86033da83a91ebea3c119fc7af54a66ba0
上级 84068a6b
......@@ -14,6 +14,7 @@ ExternalProject_add(
)
set(ZMQ_INC ${ZMQ_BUILD_DIR}/include)
include_directories(${ZMQ_INC})
file(MAKE_DIRECTORY ${ZMQ_INC})
add_library(libzmq STATIC IMPORTED GLOBAL)
......
......@@ -12,14 +12,6 @@ set(SWIG_SRC src/swig/mgb.i)
set(CMAKE_SWIG_FLAGS -Wall -threads -py3 -modern -DSWIGWORDSIZE64)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter")
if(MGE_WITH_DISTRIBUTED)
file(GLOB_RECURSE PROTO_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "src/proto/*.proto")
PROTOBUF_GENERATE_CPP_WITH_ROOT(GRPC_SRCS GRPC_HDRS ${CMAKE_CURRENT_SOURCE_DIR} ${PROTO_FILES})
add_custom_target(mgb_proto_target DEPENDS ${GRPC_SRCS} ${GRPC_HDRS} ${PROTOBUF_PROTOC_EXECUTABLE})
endif()
file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl")
file(GLOB_RECURSE PYTHON_SRCS setup.py
src/python/*.py
......@@ -55,11 +47,7 @@ add_custom_command(
add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py)
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/mm_handler.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp)
if(MGE_WITH_DISTRIBUTED)
list(APPEND SRCS src/cpp/zmq_rpc.cpp)
endif()
set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp)
include(UseSWIG)
set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON)
......@@ -70,7 +58,7 @@ set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS -I${PROJECT_SOURCE_DIR}/src/
set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR})
set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal)
swig_add_library(mgb LANGUAGE python SOURCES ${SWIG_SRC} ${GRPC_SRCS} ${SRCS})
swig_add_library(mgb LANGUAGE python SOURCES ${SWIG_SRC} ${SRCS})
set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld)
add_custom_target(version_ld SOURCES ${VERSION_SCRIPT})
......@@ -81,12 +69,6 @@ target_include_directories(_mgb PRIVATE ${PYTHON_INCLUDE_DIRS} src/cpp ${CMAKE_C
target_link_libraries(_mgb ${PYTHON_LIBRARIES})
add_dependencies(_mgb mgb_opr_py version_ld)
if(MGE_WITH_DISTRIBUTED)
add_dependencies(_mgb mgb_proto_target)
target_link_libraries (_mgb libprotobuf libzmq)
set(CPPZMQ_INC ${PROJECT_SOURCE_DIR}/third_party/cppzmq)
target_include_directories(_mgb PRIVATE ${CPPZMQ_INC})
endif()
add_custom_command(
TARGET _mgb POST_BUILD
......
......@@ -19,6 +19,10 @@
#include <dlfcn.h>
#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/mm_handler.h"
#endif
#if MGB_CUDA
#include <cuda.h>
#endif
......@@ -276,4 +280,37 @@ std::vector<std::pair<uint64_t, std::string>> _config::dump_registered_oprs() {
#endif
}
#if MGB_ENABLE_OPR_MM
/*! see definition : src/cpp/megbrain_config.h.
* Create mm server. port 0 is permitted, leave zmqrpc to decide which port
* should be used.
*/
int _config::create_mm_server(const std::string& server_addr, int port) {
return create_zmqrpc_server(server_addr, port);
}
void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_assert(rank < size, "invalid rank %d", rank);
auto group_mgr = std::make_shared<GroupClientProxy>(
ssprintf("%s:%d", server_addr.c_str(), port));
uint32_t rsp = group_mgr->group_barrier(size, rank);
mgb_assert(rsp != 0, "rank already registered: %d", rank);
mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp);
}
#else
int _config::create_mm_server(const std::string& server_addr, int port) {
mgb_throw(mgb::MegBrainError, "OPR_MM suppport disable at compile time");
return 0;
}
void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_throw(mgb::MegBrainError, "OPR_MM suppport disable at compile time");
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -12,7 +12,7 @@
#include "./python_helper.h"
#if MGB_ENABLE_OPR_MM
#include "mm_handler.h"
#include "megbrain/opr/mm_handler.h"
#endif
#include "megbrain/opr/io.h"
......
......@@ -10,6 +10,10 @@ endif()
if(MGE_WITH_DISTRIBUTED)
file(GLOB_RECURSE SOURCES_ opr-mm/impl/*.cpp opr-mm/impl/*.inl)
list(APPEND SOURCES ${SOURCES_})
file(GLOB_RECURSE PROTO_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../src/opr-mm/proto/*.proto")
PROTOBUF_GENERATE_CPP_WITH_ROOT(GRPC_SRCS GRPC_HDRS ${CMAKE_CURRENT_SOURCE_DIR} ${PROTO_FILES})
add_custom_target(mgb_proto_target DEPENDS ${GRPC_SRCS} ${GRPC_HDRS} ${PROTOBUF_PROTOC_EXECUTABLE})
list(APPEND SOURCES ${GRPC_SRCS})
endif()
set(MGB_INC ${PROJECT_BINARY_DIR}/genfiles core/include gopt/include opr/include plugin/include serialization/include)
......@@ -52,6 +56,11 @@ if(CXX_SUPPORT_WCLASS_MEMACCESS)
endif()
target_link_libraries(megbrain megdnn)
if(MGE_WITH_DISTRIBUTED)
add_dependencies(megbrain mgb_proto_target)
target_link_libraries (megbrain libprotobuf libzmq)
set(CPPZMQ_INC ${PROJECT_SOURCE_DIR}/third_party/cppzmq)
# FIXME: add CMAKE_CURRENT_BINARY_DIR for including mm_handler.pb.h
target_include_directories(megbrain PRIVATE ${CPPZMQ_INC} ${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries (megbrain megray)
endif()
target_link_libraries(megbrain ${MGE_CUDA_LIBS})
......
......@@ -7,13 +7,14 @@
*
*/
#include "mm_handler.h"
#include "megbrain/opr/mm_handler.h"
#include "megbrain/exception.h"
#include "megbrain_config.h"
#include "megbrain_build_config.h"
#if MGB_ENABLE_OPR_MM
#include "zmq_rpc.h"
#include "megbrain/opr/zmq_rpc.h"
#include "mm_handler.pb.h"
#include <future>
/* ======================== GroupServerProxy ========================== */
......@@ -128,17 +129,22 @@ void GroupServerProxy::group_barrier(void* input_ptr, size_t input_len,
Request req; \
Response rsp;
#define SOLVE_REQUEST(name, req, rsp) \
std::string req_str; \
mgb_assert(req.SerializeToString(&req_str)); \
zmq::message_t send(req_str.length() + name.length() + 1); \
zmq::message_t recv; \
memcpy(send.data(), name.data(), name.length() + 1); \
memcpy((char*)send.data() + name.length() + 1, req_str.data(), \
req_str.length()); \
m_stub->request(send, recv); \
#define SOLVE_REQUEST(name, req, rsp) \
std::string req_str; \
mgb_assert(req.SerializeToString(&req_str)); \
zmq::message_t send(req_str.length() + name.length() + 1); \
zmq::message_t recv; \
memcpy(send.data(), name.data(), name.length() + 1); \
memcpy((char*)send.data() + name.length() + 1, req_str.data(), \
req_str.length()); \
static_cast<ZmqRpc::ZmqRpcClient*>(m_stub)->request(send, recv); \
mgb_assert(rsp.ParseFromArray(recv.data(), recv.size()));
GroupClientProxy::GroupClientProxy(const std::string& server_addr)
: m_addr(server_addr),
m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} {
}
uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices,
uint32_t rank, uintptr_t stream) {
INFO_INIT(mm_handler, opr_register, OprRegister)
......@@ -199,78 +205,26 @@ uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) {
#undef INFO_INIT
#undef SOLVE_REQUEST
/* ======================== ZmqRpcServerMgr ========================== */
class ZmqRpcServerMgr {
struct ServerInfo {
std::unique_ptr<ZmqRpc::ZmqRpcServer> server;
};
public:
int create_zmqrpc_server(const std::string& server_addr, int port,
std::unique_ptr<ZmqRpc::ZmqRpcServerImpl> service) {
MGB_LOCK_GUARD(m_mtx);
auto server =
std::make_unique<ZmqRpc::ZmqRpcServer>("tcp://" + server_addr, port,
std::move(service));
port = server->port();
if (port == -1) {
return -1;
}
auto full_srv_addr = ssprintf("%s:%d", server_addr.c_str(), port);
server->run();
auto ins = m_addr2server.emplace(
full_srv_addr, ServerInfo{std::move(server)});
mgb_assert(ins.second);
return port;
}
static ZmqRpcServerMgr* get_zmqrpc_server_mgr() {
static ZmqRpcServerMgr mgr;
return &mgr;
}
private:
std::unordered_map<std::string, ServerInfo> m_addr2server;
std::mutex m_mtx;
struct ServerInfo {
std::unique_ptr<ZmqRpc::ZmqRpcServer> server;
};
/*! see definition : src/cpp/megbrain_config.h.
* Create mm server. port 0 is permitted, leave zmqrpc to decide which port
* should be used.
*/
int _config::create_mm_server(const std::string& server_addr, int port) {
return ZmqRpcServerMgr::get_zmqrpc_server_mgr()->create_zmqrpc_server(
server_addr, port, std::make_unique<GroupServerProxy>());
}
/* ======================== Group Barrier ========================== */
/*! see definition : src/cpp/megbrain_config.h.
* Block until all ranks in the group reach this barrier
*/
void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_assert(rank < size, "invalid rank %d", rank);
auto group_mgr = std::make_shared<GroupClientProxy>(
ssprintf("%s:%d", server_addr.c_str(), port));
uint32_t rsp = group_mgr->group_barrier(size, rank);
mgb_assert(rsp != 0, "rank already registered: %d", rank);
mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp);
}
#else
int _config::create_mm_server(const std::string& server_addr, int port) {
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time");
return 0;
}
void _config::group_barrier(const std::string& server_addr,
int port, uint32_t size, uint32_t rank) {
mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time");
int create_zmqrpc_server(const std::string& server_addr, int port) {
static std::unordered_map<std::string, ServerInfo> addr2server;
static std::mutex mtx;
MGB_LOCK_GUARD(mtx);
auto service = std::make_unique<GroupServerProxy>();
auto server =
std::make_unique<ZmqRpc::ZmqRpcServer>("tcp://" + server_addr, port,
std::move(service));
port = server->port();
auto full_srv_addr = ssprintf("%s:%d", server_addr.c_str(), port);
server->run();
auto ins = addr2server.emplace(
full_srv_addr, ServerInfo{std::move(server)});
mgb_assert(ins.second);
return port;
}
#endif
......
#include "zmq_rpc.h"
#include "megbrain/opr/zmq_rpc.h"
#include "megbrain/exception.h"
#include "megbrain_config.h"
#include "megbrain_build_config.h"
#if MGB_CUDA
#include <unistd.h>
......
......@@ -13,10 +13,7 @@
#if MGB_ENABLE_OPR_MM
#include "zmq_rpc.h"
#include "megbrain/opr/collective_comm.h"
#include "mm_handler.pb.h"
using namespace mgb;
using namespace opr;
......@@ -31,10 +28,7 @@ class GroupClientProxy
public:
virtual ~GroupClientProxy() = default;
GroupClientProxy(const std::string& server_addr)
: m_addr(server_addr),
m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} {
}
GroupClientProxy(const std::string& server_addr);
//! graph registration, assign graph_id to worker.
uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank,
......@@ -50,33 +44,20 @@ public:
uint32_t group_barrier(uint32_t size, uint32_t rank) override;
//! thread safe to create handler with address
static GroupClientProxy* get_handler(const std::string& addr) {
static std::unordered_map<std::string,
std::unique_ptr<GroupClientProxy>>
addr2handler;
static std::mutex mtx;
MGB_LOCK_GUARD(mtx);
auto it = addr2handler.emplace(addr, nullptr);
if (!it.second) {
mgb_assert(it.first->second->m_addr == addr);
return it.first->second.get();
} else {
auto handler = std::make_unique<GroupClientProxy>(addr);
auto handler_ptr = handler.get();
it.first->second = std::move(handler);
return handler_ptr;
}
}
const std::string& get_addr() const {
return m_addr;
}
private:
const std::string m_addr;
ZmqRpc::ZmqRpcClient* m_stub;
void* m_stub;
};
/* ======================== ZmqRpcServerMgr ========================== */
int create_zmqrpc_server(const std::string& server_addr, int port);
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -101,4 +101,4 @@ private:
std::vector<std::shared_ptr<zmq::socket_t>> m_own_sockets;
};
} // namespace ZmqRpc
#endif
\ No newline at end of file
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册