diff --git a/imperative/CMakeLists.txt b/imperative/CMakeLists.txt index d24193cd307afb1ce4b2864f50b845e7feb21418..8561187d53963f3e0da8cceb81f4754b6fb2c087 100644 --- a/imperative/CMakeLists.txt +++ b/imperative/CMakeLists.txt @@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) set(MODULE_NAME _imperative_rt) set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) +set(SRCS ${SRCS} ${CPP_REDIS_SRCS}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") @@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) -target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) +target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) if(CXX_SUPPORT_WCLASS_MEMACCESS) diff --git a/imperative/python/megengine/utils/persistent_cache.py b/imperative/python/megengine/utils/persistent_cache.py index dfe82159524fc09501a5c91578c7f3908907bcd8..3b0f7ae2c07423e98ecc7141334286ea1041dda9 100644 --- a/imperative/python/megengine/utils/persistent_cache.py +++ b/imperative/python/megengine/utils/persistent_cache.py @@ -11,6 +11,7 @@ import argparse import getpass import os import sys +import urllib.parse from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager from ..logger import get_logger @@ -23,8 +24,10 @@ class PersistentCacheManager(_PersistentCacheManager): if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": get_logger().info("fastrun use in-memory cache") self.open_memory() - else: + elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE": self.open_file() + else: + self.open_redis() def open_memory(self): pass @@ -51,6 +54,28 @@ class PersistentCacheManager(_PersistentCacheManager): ) self.open_memory() + def open_redis(self): + prefix = "mgbcache:{}:MGB{}:GIT:{}".format( + getpass.getuser(), __version__, git_version + ) + url = os.getenv("MGE_FASTRUN_CACHE_URL") + if url is None: + self.open_file() + try: + assert sys.platform != "win32", "redis cache on windows not tested" + parse_result = urllib.parse.urlparse(url, scheme="redis") + assert parse_result.scheme == "redis", "unsupported scheme" + assert not parse_result.username, "redis conn with username unsupported" + assert self.try_open_redis( + parse_result.hostname, parse_result.port, parse_result.password, prefix + ), "connect failed" + except Exception as exc: + get_logger().error( + "failed to connect to cache server {!r}; try fallback to " + "in-file cache".format(exc) + ) + self.open_file() + _manager = None @@ -60,3 +85,23 @@ def get_manager(): if _manager is None: _manager = PersistentCacheManager() return _manager + + +def _clean(): + nr_del = get_manager().clean() + if nr_del is not None: + print("{} cache entries deleted".format(nr_del)) + + +def main(): + parser = argparse.ArgumentParser(description="manage persistent cache") + subp = parser.add_subparsers(description="action to be performed", dest="cmd") + subp.required = True + subp_clean = subp.add_parser("clean", help="clean all the cache of current user") + subp_clean.set_defaults(action=_clean) + args = parser.parse_args() + args.action() + + +if __name__ == "__main__": + main() diff --git a/imperative/python/src/utils.cpp b/imperative/python/src/utils.cpp index 872a71bb7005b796d7f9631f14d3a7b699615aca..5162f488f31044c4c4fc5b435d63ee70b39b7c13 100644 --- a/imperative/python/src/utils.cpp +++ b/imperative/python/src/utils.cpp @@ -245,6 +245,11 @@ void init_utils(py::module m) { } return false; } + bool open_redis( + std::string ip, size_t port, std::string password, std::string prefix) { + return try_reg(mgb::imperative::persistent_cache::make_redis( + ip, port, password, prefix)); + } bool open_file(std::string path) { return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); } @@ -271,6 +276,7 @@ void init_utils(py::module m) { py::class_(m, "PersistentCacheManager") .def(py::init<>()) + .def("try_open_redis", &PersistentCacheManager::open_redis) .def("try_open_file", &PersistentCacheManager::open_file) .def("clean", &PersistentCacheManager::clean) .def("put", &PersistentCacheManager::put) diff --git a/imperative/src/impl/persistent_cache.cpp b/imperative/src/impl/persistent_cache.cpp index e7f31b6260820876e036fb08f8bfd045209c8c01..ba3809abb04cb6962e2291ccdd04da398dbf49a3 100644 --- a/imperative/src/impl/persistent_cache.cpp +++ b/imperative/src/impl/persistent_cache.cpp @@ -13,12 +13,109 @@ #include #include +#include "cpp_redis/cpp_redis" + #include "megbrain/imperative/persistent_cache.h" #include "megbrain/imperative/utils/base64.h" #include "megbrain/utils/infile_persistent_cache.h" namespace mgb::imperative::persistent_cache { +class RedisCache final : public ExtendedPersistentCache { +public: + RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) { + m_local = std::make_shared(); + } + + bool connect(std::string ip, size_t port, std::string password) { + m_client.auth(password); + m_client.connect( + ip, port, + [](const std::string& host, std::size_t port, + cpp_redis::connect_state status) { + if (status == cpp_redis::connect_state::dropped) { + mgb_log("client disconnected from %s.", host.c_str()); + mgb_log("Redis server connect to %s :%zu failed.", host.c_str(), + port); + } + }, + std::uint32_t(200)); + if (!m_client.is_connected()) { + return false; + } + auto flag = m_client.get("mgb-cache-flag"); + sync(); + return flag.get().ok(); + } + + bool valid() const override { return m_client.is_connected(); } + + mgb::Maybe get(const std::string& category, const Blob& key) override { + MGB_LOCK_GUARD(m_mtx); + auto mem_result = m_local->get(category, key); + if (mem_result.valid()) + return mem_result; + + std::string key_str(static_cast(key.ptr), key.size); + std::string redis_key_str; + encode(category + '@' + key_str, redis_key_str, 24); + auto result = m_client.get(redis_key_str); + sync(); + auto content = result.get(); + if (content.is_null()) + return mgb::None; + std::string decode_content; + decode(content.as_string(), decode_content); + m_local->put(category, key, {decode_content.data(), decode_content.length()}); + + return m_local->get(category, key); + } + + void put(const std::string& category, const Blob& key, const Blob& value) override { + MGB_LOCK_GUARD(m_mtx); + std::string key_str(static_cast(key.ptr), key.size); + std::string redis_key_str; + encode(category + '@' + key_str, redis_key_str); + std::string value_str(static_cast(value.ptr), value.size); + std::string redis_value_str; + encode(value_str, redis_value_str); + + auto result = m_client.set(redis_key_str, redis_value_str); + m_local->put(category, key, value); + sync(); + } + + std::optional clear() override { + size_t cursor = 0, nr_deleted = 0; + std::string pattern = m_prefix + "@*"; + do { + auto reply = m_client.scan(cursor, pattern).share(); + sync(); + auto keys = reply.get().as_array(); + std::vector string_keys; + for (auto&& key : keys) { + string_keys.push_back(key.as_string()); + } + m_client.del(string_keys); + nr_deleted += string_keys.size(); + cursor = reply.get().as_array()[0].as_integer(); + } while (cursor != 0); + return nr_deleted; + } + +private: + std::shared_ptr m_local; + std::mutex m_mtx; + cpp_redis::client m_client; + std::string m_prefix; + uint64_t m_timeout; + + void sync() { + m_client.sync_commit(std::chrono::milliseconds(m_timeout)); + mgb_assert(valid()); + } +}; + class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { private: std::string m_path; @@ -68,6 +165,15 @@ public: bool valid() const override { return m_impl != nullptr; } }; +std::shared_ptr make_redis( + std::string ip, size_t port, std::string password, std::string prefix) { + auto cache = std::make_shared(prefix, 100); + if (!cache->connect(ip, port, password)) { + return nullptr; + } + return cache; +} + std::shared_ptr make_in_file(std::string path) { auto cache = std::make_shared(); if (!cache->open(path)) { diff --git a/imperative/src/include/megbrain/imperative/persistent_cache.h b/imperative/src/include/megbrain/imperative/persistent_cache.h index f1331597ff31f226ad1661c2eb552b064b83d89e..59326bbdedfddcc4ca6eaf2f4ba6caf375cd3d8f 100644 --- a/imperative/src/include/megbrain/imperative/persistent_cache.h +++ b/imperative/src/include/megbrain/imperative/persistent_cache.h @@ -22,6 +22,9 @@ public: virtual std::optional clear() = 0; }; +std::shared_ptr make_redis( + std::string ip, size_t port, std::string password, std::string prefix); + std::shared_ptr make_in_file(std::string path); } // namespace mgb::imperative::persistent_cache diff --git a/imperative/test/CMakeLists.txt b/imperative/test/CMakeLists.txt index 1d472ac6d8129f31307987a48c8128cf57915132..debaa5c93e54958592d6470cfdffe5d675a47e07 100644 --- a/imperative/test/CMakeLists.txt +++ b/imperative/test/CMakeLists.txt @@ -12,7 +12,7 @@ endif() # TODO: turn python binding into a static/object library add_executable(imperative_test ${SOURCES} ${SRCS}) add_dependencies(imperative_test mgb_opdef) -target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) +target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) # Python binding target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt index 7210506de963330abca72889a74d4819041b1217..da465388387d19522edad0f02007def1ad131ac3 100644 --- a/lite/CMakeLists.txt +++ b/lite/CMakeLists.txt @@ -35,6 +35,10 @@ configure_file(src/lite_build_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/genfiles/l install(FILES ${CMAKE_CURRENT_BINARY_DIR}/genfiles/lite_build_config.h DESTINATION ${CMAKE_INSTALL_PREFIX}/lite/include) # begin config lite +if(LITE_BUILD_WITH_MGE AND LITE_WITH_CUDA AND NOT WIN32) + # FXIME third_party cpp redis do not support build with clang-cl + list(APPEND SOURCES_LITE ${CPP_REDIS_SRCS}) +endif() add_library(lite_static STATIC ${SOURCES_LITE}) add_dependencies(lite_static lite_fbs_generate) include_directories($) @@ -106,6 +110,14 @@ endif() if(LITE_BUILD_WITH_MGE) target_link_libraries(lite_static_all_in_one PRIVATE megbrain megdnn ${MGE_CUDA_LIBS}) endif() + +if(LITE_BUILD_WITH_MGE AND LITE_WITH_CUDA AND NOT WIN32) + # FXIME third_party cpp redis do not support build with clang-cl + target_include_directories(lite_static PRIVATE ${CPP_REDIS_INCLUDES}) + target_include_directories(lite_shared PRIVATE ${CPP_REDIS_INCLUDES}) + target_include_directories(lite_shared_whl PRIVATE ${CPP_REDIS_INCLUDES}) + target_include_directories(lite_static_all_in_one PRIVATE ${CPP_REDIS_INCLUDES}) +endif() set(LITE_VERSION_SCRIPT ${PROJECT_SOURCE_DIR}/lite/src/version_lite.ld CACHE INTERNAL "Path to linker version script") add_custom_target(_lite_version_ld SOURCES ${LITE_VERSION_SCRIPT}) if(NOT MSVC AND NOT WIN32)