From 6b28a2f9ecdd42edc7666b1959cde53914a3e786 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 14 Dec 2021 15:07:01 +0800 Subject: [PATCH] fix(copybara): re-open redis cache GitOrigin-RevId: 055bf6aa4f703943d2bdfc0640cb2ba14488b347 --- imperative/CMakeLists.txt | 3 +- .../megengine/utils/persistent_cache.py | 47 +++++++- imperative/python/src/utils.cpp | 6 + imperative/src/impl/persistent_cache.cpp | 106 ++++++++++++++++++ .../megbrain/imperative/persistent_cache.h | 3 + imperative/test/CMakeLists.txt | 2 +- lite/CMakeLists.txt | 12 ++ 7 files changed, 176 insertions(+), 3 deletions(-) diff --git a/imperative/CMakeLists.txt b/imperative/CMakeLists.txt index d24193cd3..8561187d5 100644 --- a/imperative/CMakeLists.txt +++ b/imperative/CMakeLists.txt @@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) set(MODULE_NAME _imperative_rt) set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) +set(SRCS ${SRCS} ${CPP_REDIS_SRCS}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") @@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) -target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) +target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) if(CXX_SUPPORT_WCLASS_MEMACCESS) diff --git a/imperative/python/megengine/utils/persistent_cache.py b/imperative/python/megengine/utils/persistent_cache.py index dfe821595..3b0f7ae2c 100644 --- a/imperative/python/megengine/utils/persistent_cache.py +++ b/imperative/python/megengine/utils/persistent_cache.py @@ -11,6 +11,7 @@ import argparse import getpass import os import sys +import urllib.parse from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager from ..logger import get_logger @@ -23,8 +24,10 @@ class PersistentCacheManager(_PersistentCacheManager): if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": get_logger().info("fastrun use in-memory cache") self.open_memory() - else: + elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE": self.open_file() + else: + self.open_redis() def open_memory(self): pass @@ -51,6 +54,28 @@ class PersistentCacheManager(_PersistentCacheManager): ) self.open_memory() + def open_redis(self): + prefix = "mgbcache:{}:MGB{}:GIT:{}".format( + getpass.getuser(), __version__, git_version + ) + url = os.getenv("MGE_FASTRUN_CACHE_URL") + if url is None: + self.open_file() + try: + assert sys.platform != "win32", "redis cache on windows not tested" + parse_result = urllib.parse.urlparse(url, scheme="redis") + assert parse_result.scheme == "redis", "unsupported scheme" + assert not parse_result.username, "redis conn with username unsupported" + assert self.try_open_redis( + parse_result.hostname, parse_result.port, parse_result.password, prefix + ), "connect failed" + except Exception as exc: + get_logger().error( + "failed to connect to cache server {!r}; try fallback to " + "in-file cache".format(exc) + ) + self.open_file() + _manager = None @@ -60,3 +85,23 @@ def get_manager(): if _manager is None: _manager = PersistentCacheManager() return _manager + + +def _clean(): + nr_del = get_manager().clean() + if nr_del is not None: + print("{} cache entries deleted".format(nr_del)) + + +def main(): + parser = argparse.ArgumentParser(description="manage persistent cache") + subp = parser.add_subparsers(description="action to be performed", dest="cmd") + subp.required = True + subp_clean = subp.add_parser("clean", help="clean all the cache of current user") + subp_clean.set_defaults(action=_clean) + args = parser.parse_args() + args.action() + + +if __name__ == "__main__": + main() diff --git a/imperative/python/src/utils.cpp b/imperative/python/src/utils.cpp index 872a71bb7..5162f488f 100644 --- a/imperative/python/src/utils.cpp +++ b/imperative/python/src/utils.cpp @@ -245,6 +245,11 @@ void init_utils(py::module m) { } return false; } + bool open_redis( + std::string ip, size_t port, std::string password, std::string prefix) { + return try_reg(mgb::imperative::persistent_cache::make_redis( + ip, port, password, prefix)); + } bool open_file(std::string path) { return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); } @@ -271,6 +276,7 @@ void init_utils(py::module m) { py::class_(m, "PersistentCacheManager") .def(py::init<>()) + .def("try_open_redis", &PersistentCacheManager::open_redis) .def("try_open_file", &PersistentCacheManager::open_file) .def("clean", &PersistentCacheManager::clean) .def("put", &PersistentCacheManager::put) diff --git a/imperative/src/impl/persistent_cache.cpp b/imperative/src/impl/persistent_cache.cpp index e7f31b626..ba3809abb 100644 --- a/imperative/src/impl/persistent_cache.cpp +++ b/imperative/src/impl/persistent_cache.cpp @@ -13,12 +13,109 @@ #include #include +#include "cpp_redis/cpp_redis" + #include "megbrain/imperative/persistent_cache.h" #include "megbrain/imperative/utils/base64.h" #include "megbrain/utils/infile_persistent_cache.h" namespace mgb::imperative::persistent_cache { +class RedisCache final : public ExtendedPersistentCache { +public: + RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) { + m_local = std::make_shared(); + } + + bool connect(std::string ip, size_t port, std::string password) { + m_client.auth(password); + m_client.connect( + ip, port, + [](const std::string& host, std::size_t port, + cpp_redis::connect_state status) { + if (status == cpp_redis::connect_state::dropped) { + mgb_log("client disconnected from %s.", host.c_str()); + mgb_log("Redis server connect to %s :%zu failed.", host.c_str(), + port); + } + }, + std::uint32_t(200)); + if (!m_client.is_connected()) { + return false; + } + auto flag = m_client.get("mgb-cache-flag"); + sync(); + return flag.get().ok(); + } + + bool valid() const override { return m_client.is_connected(); } + + mgb::Maybe get(const std::string& category, const Blob& key) override { + MGB_LOCK_GUARD(m_mtx); + auto mem_result = m_local->get(category, key); + if (mem_result.valid()) + return mem_result; + + std::string key_str(static_cast(key.ptr), key.size); + std::string redis_key_str; + encode(category + '@' + key_str, redis_key_str, 24); + auto result = m_client.get(redis_key_str); + sync(); + auto content = result.get(); + if (content.is_null()) + return mgb::None; + std::string decode_content; + decode(content.as_string(), decode_content); + m_local->put(category, key, {decode_content.data(), decode_content.length()}); + + return m_local->get(category, key); + } + + void put(const std::string& category, const Blob& key, const Blob& value) override { + MGB_LOCK_GUARD(m_mtx); + std::string key_str(static_cast(key.ptr), key.size); + std::string redis_key_str; + encode(category + '@' + key_str, redis_key_str); + std::string value_str(static_cast(value.ptr), value.size); + std::string redis_value_str; + encode(value_str, redis_value_str); + + auto result = m_client.set(redis_key_str, redis_value_str); + m_local->put(category, key, value); + sync(); + } + + std::optional clear() override { + size_t cursor = 0, nr_deleted = 0; + std::string pattern = m_prefix + "@*"; + do { + auto reply = m_client.scan(cursor, pattern).share(); + sync(); + auto keys = reply.get().as_array(); + std::vector string_keys; + for (auto&& key : keys) { + string_keys.push_back(key.as_string()); + } + m_client.del(string_keys); + nr_deleted += string_keys.size(); + cursor = reply.get().as_array()[0].as_integer(); + } while (cursor != 0); + return nr_deleted; + } + +private: + std::shared_ptr m_local; + std::mutex m_mtx; + cpp_redis::client m_client; + std::string m_prefix; + uint64_t m_timeout; + + void sync() { + m_client.sync_commit(std::chrono::milliseconds(m_timeout)); + mgb_assert(valid()); + } +}; + class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { private: std::string m_path; @@ -68,6 +165,15 @@ public: bool valid() const override { return m_impl != nullptr; } }; +std::shared_ptr make_redis( + std::string ip, size_t port, std::string password, std::string prefix) { + auto cache = std::make_shared(prefix, 100); + if (!cache->connect(ip, port, password)) { + return nullptr; + } + return cache; +} + std::shared_ptr make_in_file(std::string path) { auto cache = std::make_shared(); if (!cache->open(path)) { diff --git a/imperative/src/include/megbrain/imperative/persistent_cache.h b/imperative/src/include/megbrain/imperative/persistent_cache.h index f1331597f..59326bbde 100644 --- a/imperative/src/include/megbrain/imperative/persistent_cache.h +++ b/imperative/src/include/megbrain/imperative/persistent_cache.h @@ -22,6 +22,9 @@ public: virtual std::optional clear() = 0; }; +std::shared_ptr make_redis( + std::string ip, size_t port, std::string password, std::string prefix); + std::shared_ptr make_in_file(std::string path); } // namespace mgb::imperative::persistent_cache diff --git a/imperative/test/CMakeLists.txt b/imperative/test/CMakeLists.txt index 1d472ac6d..debaa5c93 100644 --- a/imperative/test/CMakeLists.txt +++ b/imperative/test/CMakeLists.txt @@ -12,7 +12,7 @@ endif() # TODO: turn python binding into a static/object library add_executable(imperative_test ${SOURCES} ${SRCS}) add_dependencies(imperative_test mgb_opdef) -target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) +target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) # Python binding target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) diff --git a/lite/CMakeLists.txt b/lite/CMakeLists.txt index 7210506de..da4653883 100644 --- a/lite/CMakeLists.txt +++ b/lite/CMakeLists.txt @@ -35,6 +35,10 @@ configure_file(src/lite_build_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/genfiles/l install(FILES ${CMAKE_CURRENT_BINARY_DIR}/genfiles/lite_build_config.h DESTINATION ${CMAKE_INSTALL_PREFIX}/lite/include) # begin config lite +if(LITE_BUILD_WITH_MGE AND LITE_WITH_CUDA AND NOT WIN32) + # FXIME third_party cpp redis do not support build with clang-cl + list(APPEND SOURCES_LITE ${CPP_REDIS_SRCS}) +endif() add_library(lite_static STATIC ${SOURCES_LITE}) add_dependencies(lite_static lite_fbs_generate) include_directories($) @@ -106,6 +110,14 @@ endif() if(LITE_BUILD_WITH_MGE) target_link_libraries(lite_static_all_in_one PRIVATE megbrain megdnn ${MGE_CUDA_LIBS}) endif() + +if(LITE_BUILD_WITH_MGE AND LITE_WITH_CUDA AND NOT WIN32) + # FXIME third_party cpp redis do not support build with clang-cl + target_include_directories(lite_static PRIVATE ${CPP_REDIS_INCLUDES}) + target_include_directories(lite_shared PRIVATE ${CPP_REDIS_INCLUDES}) + target_include_directories(lite_shared_whl PRIVATE ${CPP_REDIS_INCLUDES}) + target_include_directories(lite_static_all_in_one PRIVATE ${CPP_REDIS_INCLUDES}) +endif() set(LITE_VERSION_SCRIPT ${PROJECT_SOURCE_DIR}/lite/src/version_lite.ld CACHE INTERNAL "Path to linker version script") add_custom_target(_lite_version_ld SOURCES ${LITE_VERSION_SCRIPT}) if(NOT MSVC AND NOT WIN32) -- GitLab