提交 6b28a2f9 编写于 作者: M Megvii Engine Team

fix(copybara): re-open redis cache

GitOrigin-RevId: 055bf6aa4f703943d2bdfc0640cb2ba14488b347
上级 60c14b68
......@@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE)
set(MODULE_NAME _imperative_rt)
set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE)
file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h)
set(SRCS ${SRCS} ${CPP_REDIS_SRCS})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1")
......@@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json)
target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json)
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR})
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES})
target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME})
target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter)
if(CXX_SUPPORT_WCLASS_MEMACCESS)
......
......@@ -11,6 +11,7 @@ import argparse
import getpass
import os
import sys
import urllib.parse
from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager
from ..logger import get_logger
......@@ -23,8 +24,10 @@ class PersistentCacheManager(_PersistentCacheManager):
if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY":
get_logger().info("fastrun use in-memory cache")
self.open_memory()
else:
elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE":
self.open_file()
else:
self.open_redis()
def open_memory(self):
pass
......@@ -51,6 +54,28 @@ class PersistentCacheManager(_PersistentCacheManager):
)
self.open_memory()
def open_redis(self):
prefix = "mgbcache:{}:MGB{}:GIT:{}".format(
getpass.getuser(), __version__, git_version
)
url = os.getenv("MGE_FASTRUN_CACHE_URL")
if url is None:
self.open_file()
try:
assert sys.platform != "win32", "redis cache on windows not tested"
parse_result = urllib.parse.urlparse(url, scheme="redis")
assert parse_result.scheme == "redis", "unsupported scheme"
assert not parse_result.username, "redis conn with username unsupported"
assert self.try_open_redis(
parse_result.hostname, parse_result.port, parse_result.password, prefix
), "connect failed"
except Exception as exc:
get_logger().error(
"failed to connect to cache server {!r}; try fallback to "
"in-file cache".format(exc)
)
self.open_file()
_manager = None
......@@ -60,3 +85,23 @@ def get_manager():
if _manager is None:
_manager = PersistentCacheManager()
return _manager
def _clean():
nr_del = get_manager().clean()
if nr_del is not None:
print("{} cache entries deleted".format(nr_del))
def main():
parser = argparse.ArgumentParser(description="manage persistent cache")
subp = parser.add_subparsers(description="action to be performed", dest="cmd")
subp.required = True
subp_clean = subp.add_parser("clean", help="clean all the cache of current user")
subp_clean.set_defaults(action=_clean)
args = parser.parse_args()
args.action()
if __name__ == "__main__":
main()
......@@ -245,6 +245,11 @@ void init_utils(py::module m) {
}
return false;
}
bool open_redis(
std::string ip, size_t port, std::string password, std::string prefix) {
return try_reg(mgb::imperative::persistent_cache::make_redis(
ip, port, password, prefix));
}
bool open_file(std::string path) {
return try_reg(mgb::imperative::persistent_cache::make_in_file(path));
}
......@@ -271,6 +276,7 @@ void init_utils(py::module m) {
py::class_<PersistentCacheManager>(m, "PersistentCacheManager")
.def(py::init<>())
.def("try_open_redis", &PersistentCacheManager::open_redis)
.def("try_open_file", &PersistentCacheManager::open_file)
.def("clean", &PersistentCacheManager::clean)
.def("put", &PersistentCacheManager::put)
......
......@@ -13,12 +13,109 @@
#include <string>
#include <vector>
#include "cpp_redis/cpp_redis"
#include "megbrain/imperative/persistent_cache.h"
#include "megbrain/imperative/utils/base64.h"
#include "megbrain/utils/infile_persistent_cache.h"
namespace mgb::imperative::persistent_cache {
class RedisCache final : public ExtendedPersistentCache {
public:
RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) {
m_local = std::make_shared<mgb::InMemoryPersistentCache>();
}
bool connect(std::string ip, size_t port, std::string password) {
m_client.auth(password);
m_client.connect(
ip, port,
[](const std::string& host, std::size_t port,
cpp_redis::connect_state status) {
if (status == cpp_redis::connect_state::dropped) {
mgb_log("client disconnected from %s.", host.c_str());
mgb_log("Redis server connect to %s :%zu failed.", host.c_str(),
port);
}
},
std::uint32_t(200));
if (!m_client.is_connected()) {
return false;
}
auto flag = m_client.get("mgb-cache-flag");
sync();
return flag.get().ok();
}
bool valid() const override { return m_client.is_connected(); }
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
MGB_LOCK_GUARD(m_mtx);
auto mem_result = m_local->get(category, key);
if (mem_result.valid())
return mem_result;
std::string key_str(static_cast<const char*>(key.ptr), key.size);
std::string redis_key_str;
encode(category + '@' + key_str, redis_key_str, 24);
auto result = m_client.get(redis_key_str);
sync();
auto content = result.get();
if (content.is_null())
return mgb::None;
std::string decode_content;
decode(content.as_string(), decode_content);
m_local->put(category, key, {decode_content.data(), decode_content.length()});
return m_local->get(category, key);
}
void put(const std::string& category, const Blob& key, const Blob& value) override {
MGB_LOCK_GUARD(m_mtx);
std::string key_str(static_cast<const char*>(key.ptr), key.size);
std::string redis_key_str;
encode(category + '@' + key_str, redis_key_str);
std::string value_str(static_cast<const char*>(value.ptr), value.size);
std::string redis_value_str;
encode(value_str, redis_value_str);
auto result = m_client.set(redis_key_str, redis_value_str);
m_local->put(category, key, value);
sync();
}
std::optional<size_t> clear() override {
size_t cursor = 0, nr_deleted = 0;
std::string pattern = m_prefix + "@*";
do {
auto reply = m_client.scan(cursor, pattern).share();
sync();
auto keys = reply.get().as_array();
std::vector<std::string> string_keys;
for (auto&& key : keys) {
string_keys.push_back(key.as_string());
}
m_client.del(string_keys);
nr_deleted += string_keys.size();
cursor = reply.get().as_array()[0].as_integer();
} while (cursor != 0);
return nr_deleted;
}
private:
std::shared_ptr<mgb::PersistentCache> m_local;
std::mutex m_mtx;
cpp_redis::client m_client;
std::string m_prefix;
uint64_t m_timeout;
void sync() {
m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout));
mgb_assert(valid());
}
};
class ExtendedInFilePersistentCache final : public ExtendedPersistentCache {
private:
std::string m_path;
......@@ -68,6 +165,15 @@ public:
bool valid() const override { return m_impl != nullptr; }
};
std::shared_ptr<ExtendedPersistentCache> make_redis(
std::string ip, size_t port, std::string password, std::string prefix) {
auto cache = std::make_shared<RedisCache>(prefix, 100);
if (!cache->connect(ip, port, password)) {
return nullptr;
}
return cache;
}
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) {
auto cache = std::make_shared<ExtendedInFilePersistentCache>();
if (!cache->open(path)) {
......
......@@ -22,6 +22,9 @@ public:
virtual std::optional<size_t> clear() = 0;
};
std::shared_ptr<ExtendedPersistentCache> make_redis(
std::string ip, size_t port, std::string password, std::string prefix);
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path);
} // namespace mgb::imperative::persistent_cache
......
......@@ -12,7 +12,7 @@ endif()
# TODO: turn python binding into a static/object library
add_executable(imperative_test ${SOURCES} ${SRCS})
add_dependencies(imperative_test mgb_opdef)
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES})
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES})
# Python binding
target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR})
......
......@@ -35,6 +35,10 @@ configure_file(src/lite_build_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/genfiles/l
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/genfiles/lite_build_config.h DESTINATION ${CMAKE_INSTALL_PREFIX}/lite/include)
# begin config lite
if(LITE_BUILD_WITH_MGE AND LITE_WITH_CUDA AND NOT WIN32)
# FXIME third_party cpp redis do not support build with clang-cl
list(APPEND SOURCES_LITE ${CPP_REDIS_SRCS})
endif()
add_library(lite_static STATIC ${SOURCES_LITE})
add_dependencies(lite_static lite_fbs_generate)
include_directories($<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/genfiles>)
......@@ -106,6 +110,14 @@ endif()
if(LITE_BUILD_WITH_MGE)
target_link_libraries(lite_static_all_in_one PRIVATE megbrain megdnn ${MGE_CUDA_LIBS})
endif()
if(LITE_BUILD_WITH_MGE AND LITE_WITH_CUDA AND NOT WIN32)
# FXIME third_party cpp redis do not support build with clang-cl
target_include_directories(lite_static PRIVATE ${CPP_REDIS_INCLUDES})
target_include_directories(lite_shared PRIVATE ${CPP_REDIS_INCLUDES})
target_include_directories(lite_shared_whl PRIVATE ${CPP_REDIS_INCLUDES})
target_include_directories(lite_static_all_in_one PRIVATE ${CPP_REDIS_INCLUDES})
endif()
set(LITE_VERSION_SCRIPT ${PROJECT_SOURCE_DIR}/lite/src/version_lite.ld CACHE INTERNAL "Path to linker version script")
add_custom_target(_lite_version_ld SOURCES ${LITE_VERSION_SCRIPT})
if(NOT MSVC AND NOT WIN32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册