From d7bb62cfa145543f821cb57dc7d309630cbe14f2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 2 Jun 2020 15:52:33 +0800 Subject: [PATCH] refactor(mgb): move mm_handler from python module into opr-mm GitOrigin-RevId: f401ce86033da83a91ebea3c119fc7af54a66ba0 --- cmake/zmq.cmake | 1 + python_module/CMakeLists.txt | 22 +--- python_module/src/cpp/megbrain_config.cpp | 37 ++++++ python_module/src/cpp/opr_defs.cpp | 2 +- src/CMakeLists.txt | 9 ++ .../cpp => src/opr-mm/impl}/mm_handler.cpp | 118 ++++++------------ .../src/cpp => src/opr-mm/impl}/zmq_rpc.cpp | 4 +- .../opr-mm/include/megbrain/opr}/mm_handler.h | 35 ++---- .../opr-mm/include/megbrain/opr}/zmq_rpc.h | 2 +- .../src => src/opr-mm}/proto/mm_handler.proto | 0 10 files changed, 97 insertions(+), 133 deletions(-) rename {python_module/src/cpp => src/opr-mm/impl}/mm_handler.cpp (72%) rename {python_module/src/cpp => src/opr-mm/impl}/zmq_rpc.cpp (98%) rename {python_module/src/cpp => src/opr-mm/include/megbrain/opr}/mm_handler.h (58%) rename {python_module/src/cpp => src/opr-mm/include/megbrain/opr}/zmq_rpc.h (99%) rename {python_module/src => src/opr-mm}/proto/mm_handler.proto (100%) diff --git a/cmake/zmq.cmake b/cmake/zmq.cmake index 92a90bac5..d46775532 100644 --- a/cmake/zmq.cmake +++ b/cmake/zmq.cmake @@ -14,6 +14,7 @@ ExternalProject_add( ) set(ZMQ_INC ${ZMQ_BUILD_DIR}/include) +include_directories(${ZMQ_INC}) file(MAKE_DIRECTORY ${ZMQ_INC}) add_library(libzmq STATIC IMPORTED GLOBAL) diff --git a/python_module/CMakeLists.txt b/python_module/CMakeLists.txt index 26f573f64..a02f8282e 100644 --- a/python_module/CMakeLists.txt +++ b/python_module/CMakeLists.txt @@ -12,14 +12,6 @@ set(SWIG_SRC src/swig/mgb.i) set(CMAKE_SWIG_FLAGS -Wall -threads -py3 -modern -DSWIGWORDSIZE64) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") -if(MGE_WITH_DISTRIBUTED) - file(GLOB_RECURSE PROTO_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "src/proto/*.proto") - - PROTOBUF_GENERATE_CPP_WITH_ROOT(GRPC_SRCS GRPC_HDRS ${CMAKE_CURRENT_SOURCE_DIR} ${PROTO_FILES}) - add_custom_target(mgb_proto_target DEPENDS ${GRPC_SRCS} ${GRPC_HDRS} ${PROTOBUF_PROTOC_EXECUTABLE}) -endif() - - file(GLOB_RECURSE OPR_DECL_SRCS "${PROJECT_SOURCE_DIR}/src/**/*.oprdecl") file(GLOB_RECURSE PYTHON_SRCS setup.py src/python/*.py @@ -55,11 +47,7 @@ add_custom_command( add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py) -set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/mm_handler.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) - -if(MGE_WITH_DISTRIBUTED) - list(APPEND SRCS src/cpp/zmq_rpc.cpp) -endif() +set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/bfloat16.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) include(UseSWIG) set_property(SOURCE ${SWIG_SRC} PROPERTY CPLUSPLUS ON) @@ -70,7 +58,7 @@ set_property(SOURCE ${SWIG_SRC} PROPERTY SWIG_FLAGS -I${PROJECT_SOURCE_DIR}/src/ set(SWIG_OUTFILE_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(CMAKE_SWIG_OUTDIR ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal) -swig_add_library(mgb LANGUAGE python SOURCES ${SWIG_SRC} ${GRPC_SRCS} ${SRCS}) +swig_add_library(mgb LANGUAGE python SOURCES ${SWIG_SRC} ${SRCS}) set(VERSION_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/version.ld) add_custom_target(version_ld SOURCES ${VERSION_SCRIPT}) @@ -81,12 +69,6 @@ target_include_directories(_mgb PRIVATE ${PYTHON_INCLUDE_DIRS} src/cpp ${CMAKE_C target_link_libraries(_mgb ${PYTHON_LIBRARIES}) add_dependencies(_mgb mgb_opr_py version_ld) -if(MGE_WITH_DISTRIBUTED) - add_dependencies(_mgb mgb_proto_target) - target_link_libraries (_mgb libprotobuf libzmq) - set(CPPZMQ_INC ${PROJECT_SOURCE_DIR}/third_party/cppzmq) - target_include_directories(_mgb PRIVATE ${CPPZMQ_INC}) -endif() add_custom_command( TARGET _mgb POST_BUILD diff --git a/python_module/src/cpp/megbrain_config.cpp b/python_module/src/cpp/megbrain_config.cpp index ed2b40ac3..b6c70da1e 100644 --- a/python_module/src/cpp/megbrain_config.cpp +++ b/python_module/src/cpp/megbrain_config.cpp @@ -19,6 +19,10 @@ #include +#if MGB_ENABLE_OPR_MM +#include "megbrain/opr/mm_handler.h" +#endif + #if MGB_CUDA #include #endif @@ -276,4 +280,37 @@ std::vector> _config::dump_registered_oprs() { #endif } +#if MGB_ENABLE_OPR_MM +/*! see definition : src/cpp/megbrain_config.h. + * Create mm server. port 0 is permitted, leave zmqrpc to decide which port + * should be used. + */ +int _config::create_mm_server(const std::string& server_addr, int port) { + return create_zmqrpc_server(server_addr, port); +} + +void _config::group_barrier(const std::string& server_addr, + int port, uint32_t size, uint32_t rank) { + mgb_assert(rank < size, "invalid rank %d", rank); + auto group_mgr = std::make_shared( + ssprintf("%s:%d", server_addr.c_str(), port)); + uint32_t rsp = group_mgr->group_barrier(size, rank); + mgb_assert(rsp != 0, "rank already registered: %d", rank); + mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp); +} + +#else + +int _config::create_mm_server(const std::string& server_addr, int port) { + mgb_throw(mgb::MegBrainError, "OPR_MM suppport disable at compile time"); + return 0; +} + +void _config::group_barrier(const std::string& server_addr, + int port, uint32_t size, uint32_t rank) { + mgb_throw(mgb::MegBrainError, "OPR_MM suppport disable at compile time"); +} + +#endif + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/python_module/src/cpp/opr_defs.cpp b/python_module/src/cpp/opr_defs.cpp index 57414eafd..be74e8c93 100644 --- a/python_module/src/cpp/opr_defs.cpp +++ b/python_module/src/cpp/opr_defs.cpp @@ -12,7 +12,7 @@ #include "./python_helper.h" #if MGB_ENABLE_OPR_MM -#include "mm_handler.h" +#include "megbrain/opr/mm_handler.h" #endif #include "megbrain/opr/io.h" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 52c402dbc..6b68f2bc6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -10,6 +10,10 @@ endif() if(MGE_WITH_DISTRIBUTED) file(GLOB_RECURSE SOURCES_ opr-mm/impl/*.cpp opr-mm/impl/*.inl) list(APPEND SOURCES ${SOURCES_}) + file(GLOB_RECURSE PROTO_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../src/opr-mm/proto/*.proto") + PROTOBUF_GENERATE_CPP_WITH_ROOT(GRPC_SRCS GRPC_HDRS ${CMAKE_CURRENT_SOURCE_DIR} ${PROTO_FILES}) + add_custom_target(mgb_proto_target DEPENDS ${GRPC_SRCS} ${GRPC_HDRS} ${PROTOBUF_PROTOC_EXECUTABLE}) + list(APPEND SOURCES ${GRPC_SRCS}) endif() set(MGB_INC ${PROJECT_BINARY_DIR}/genfiles core/include gopt/include opr/include plugin/include serialization/include) @@ -52,6 +56,11 @@ if(CXX_SUPPORT_WCLASS_MEMACCESS) endif() target_link_libraries(megbrain megdnn) if(MGE_WITH_DISTRIBUTED) + add_dependencies(megbrain mgb_proto_target) + target_link_libraries (megbrain libprotobuf libzmq) + set(CPPZMQ_INC ${PROJECT_SOURCE_DIR}/third_party/cppzmq) + # FIXME: add CMAKE_CURRENT_BINARY_DIR for including mm_handler.pb.h + target_include_directories(megbrain PRIVATE ${CPPZMQ_INC} ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries (megbrain megray) endif() target_link_libraries(megbrain ${MGE_CUDA_LIBS}) diff --git a/python_module/src/cpp/mm_handler.cpp b/src/opr-mm/impl/mm_handler.cpp similarity index 72% rename from python_module/src/cpp/mm_handler.cpp rename to src/opr-mm/impl/mm_handler.cpp index 8f5562da7..4f8e3bbe8 100644 --- a/python_module/src/cpp/mm_handler.cpp +++ b/src/opr-mm/impl/mm_handler.cpp @@ -7,13 +7,14 @@ * */ -#include "mm_handler.h" +#include "megbrain/opr/mm_handler.h" #include "megbrain/exception.h" -#include "megbrain_config.h" +#include "megbrain_build_config.h" #if MGB_ENABLE_OPR_MM -#include "zmq_rpc.h" +#include "megbrain/opr/zmq_rpc.h" +#include "mm_handler.pb.h" #include /* ======================== GroupServerProxy ========================== */ @@ -128,17 +129,22 @@ void GroupServerProxy::group_barrier(void* input_ptr, size_t input_len, Request req; \ Response rsp; -#define SOLVE_REQUEST(name, req, rsp) \ - std::string req_str; \ - mgb_assert(req.SerializeToString(&req_str)); \ - zmq::message_t send(req_str.length() + name.length() + 1); \ - zmq::message_t recv; \ - memcpy(send.data(), name.data(), name.length() + 1); \ - memcpy((char*)send.data() + name.length() + 1, req_str.data(), \ - req_str.length()); \ - m_stub->request(send, recv); \ +#define SOLVE_REQUEST(name, req, rsp) \ + std::string req_str; \ + mgb_assert(req.SerializeToString(&req_str)); \ + zmq::message_t send(req_str.length() + name.length() + 1); \ + zmq::message_t recv; \ + memcpy(send.data(), name.data(), name.length() + 1); \ + memcpy((char*)send.data() + name.length() + 1, req_str.data(), \ + req_str.length()); \ + static_cast(m_stub)->request(send, recv); \ mgb_assert(rsp.ParseFromArray(recv.data(), recv.size())); +GroupClientProxy::GroupClientProxy(const std::string& server_addr) + : m_addr(server_addr), + m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} { + } + uint64_t GroupClientProxy::opr_register(const std::string& key, size_t nr_devices, uint32_t rank, uintptr_t stream) { INFO_INIT(mm_handler, opr_register, OprRegister) @@ -199,78 +205,26 @@ uint32_t GroupClientProxy::group_barrier(uint32_t size, uint32_t rank) { #undef INFO_INIT #undef SOLVE_REQUEST -/* ======================== ZmqRpcServerMgr ========================== */ - -class ZmqRpcServerMgr { - struct ServerInfo { - std::unique_ptr server; - }; - -public: - int create_zmqrpc_server(const std::string& server_addr, int port, - std::unique_ptr service) { - MGB_LOCK_GUARD(m_mtx); - auto server = - std::make_unique("tcp://" + server_addr, port, - std::move(service)); - port = server->port(); - if (port == -1) { - return -1; - } - - auto full_srv_addr = ssprintf("%s:%d", server_addr.c_str(), port); - server->run(); - auto ins = m_addr2server.emplace( - full_srv_addr, ServerInfo{std::move(server)}); - mgb_assert(ins.second); - - return port; - } - - static ZmqRpcServerMgr* get_zmqrpc_server_mgr() { - static ZmqRpcServerMgr mgr; - return &mgr; - } - -private: - std::unordered_map m_addr2server; - std::mutex m_mtx; +struct ServerInfo { + std::unique_ptr server; }; -/*! see definition : src/cpp/megbrain_config.h. - * Create mm server. port 0 is permitted, leave zmqrpc to decide which port - * should be used. - */ -int _config::create_mm_server(const std::string& server_addr, int port) { - return ZmqRpcServerMgr::get_zmqrpc_server_mgr()->create_zmqrpc_server( - server_addr, port, std::make_unique()); -} - -/* ======================== Group Barrier ========================== */ - -/*! see definition : src/cpp/megbrain_config.h. - * Block until all ranks in the group reach this barrier - */ -void _config::group_barrier(const std::string& server_addr, - int port, uint32_t size, uint32_t rank) { - mgb_assert(rank < size, "invalid rank %d", rank); - auto group_mgr = std::make_shared( - ssprintf("%s:%d", server_addr.c_str(), port)); - uint32_t rsp = group_mgr->group_barrier(size, rank); - mgb_assert(rsp != 0, "rank already registered: %d", rank); - mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp); -} - -#else - -int _config::create_mm_server(const std::string& server_addr, int port) { - mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time"); - return 0; -} - -void _config::group_barrier(const std::string& server_addr, - int port, uint32_t size, uint32_t rank) { - mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time"); +int create_zmqrpc_server(const std::string& server_addr, int port) { + static std::unordered_map addr2server; + static std::mutex mtx; + MGB_LOCK_GUARD(mtx); + auto service = std::make_unique(); + auto server = + std::make_unique("tcp://" + server_addr, port, + std::move(service)); + port = server->port(); + auto full_srv_addr = ssprintf("%s:%d", server_addr.c_str(), port); + server->run(); + auto ins = addr2server.emplace( + full_srv_addr, ServerInfo{std::move(server)}); + mgb_assert(ins.second); + + return port; } #endif diff --git a/python_module/src/cpp/zmq_rpc.cpp b/src/opr-mm/impl/zmq_rpc.cpp similarity index 98% rename from python_module/src/cpp/zmq_rpc.cpp rename to src/opr-mm/impl/zmq_rpc.cpp index 56c6a4411..c3fc4798b 100644 --- a/python_module/src/cpp/zmq_rpc.cpp +++ b/src/opr-mm/impl/zmq_rpc.cpp @@ -1,6 +1,6 @@ -#include "zmq_rpc.h" +#include "megbrain/opr/zmq_rpc.h" #include "megbrain/exception.h" -#include "megbrain_config.h" +#include "megbrain_build_config.h" #if MGB_CUDA #include diff --git a/python_module/src/cpp/mm_handler.h b/src/opr-mm/include/megbrain/opr/mm_handler.h similarity index 58% rename from python_module/src/cpp/mm_handler.h rename to src/opr-mm/include/megbrain/opr/mm_handler.h index 338ea36c1..fe80fb81f 100644 --- a/python_module/src/cpp/mm_handler.h +++ b/src/opr-mm/include/megbrain/opr/mm_handler.h @@ -13,10 +13,7 @@ #if MGB_ENABLE_OPR_MM -#include "zmq_rpc.h" - #include "megbrain/opr/collective_comm.h" -#include "mm_handler.pb.h" using namespace mgb; using namespace opr; @@ -31,10 +28,7 @@ class GroupClientProxy public: virtual ~GroupClientProxy() = default; - GroupClientProxy(const std::string& server_addr) - : m_addr(server_addr), - m_stub{ZmqRpc::ZmqRpcClient::get_client("tcp://" + server_addr)} { - } + GroupClientProxy(const std::string& server_addr); //! graph registration, assign graph_id to worker. uint64_t opr_register(const std::string& key, size_t nr_devices, uint32_t rank, @@ -50,33 +44,20 @@ public: uint32_t group_barrier(uint32_t size, uint32_t rank) override; - //! thread safe to create handler with address - static GroupClientProxy* get_handler(const std::string& addr) { - static std::unordered_map> - addr2handler; - static std::mutex mtx; - MGB_LOCK_GUARD(mtx); - auto it = addr2handler.emplace(addr, nullptr); - if (!it.second) { - mgb_assert(it.first->second->m_addr == addr); - return it.first->second.get(); - } else { - auto handler = std::make_unique(addr); - auto handler_ptr = handler.get(); - it.first->second = std::move(handler); - return handler_ptr; - } - } - const std::string& get_addr() const { return m_addr; } private: const std::string m_addr; - ZmqRpc::ZmqRpcClient* m_stub; + void* m_stub; }; + + +/* ======================== ZmqRpcServerMgr ========================== */ + +int create_zmqrpc_server(const std::string& server_addr, int port); + #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/python_module/src/cpp/zmq_rpc.h b/src/opr-mm/include/megbrain/opr/zmq_rpc.h similarity index 99% rename from python_module/src/cpp/zmq_rpc.h rename to src/opr-mm/include/megbrain/opr/zmq_rpc.h index 8b00cab0d..490485620 100644 --- a/python_module/src/cpp/zmq_rpc.h +++ b/src/opr-mm/include/megbrain/opr/zmq_rpc.h @@ -101,4 +101,4 @@ private: std::vector> m_own_sockets; }; } // namespace ZmqRpc -#endif \ No newline at end of file +#endif diff --git a/python_module/src/proto/mm_handler.proto b/src/opr-mm/proto/mm_handler.proto similarity index 100% rename from python_module/src/proto/mm_handler.proto rename to src/opr-mm/proto/mm_handler.proto -- GitLab