From cf1db2616e0208e17b72d2c2bc24886954e86241 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 2 Dec 2021 21:11:08 +0800 Subject: [PATCH] fix(fastrun): replace py_redis with cpp_redis to avoid deadlock GitOrigin-RevId: 9af7fa5c97d401c51bfd02490887222ec8341098 --- CMakeLists.txt | 2 + cmake/cpp_redis.cmake | 2 + imperative/CMakeLists.txt | 3 +- imperative/python/megengine/__init__.py | 5 +- .../megengine/utils/persistent_cache.py | 130 ++++-------- imperative/python/src/utils.cpp | 107 ++++------ .../python/test/unit/utils/test_utils.py | 5 +- imperative/src/impl/persistent_cache.cpp | 186 ++++++++++++++++++ imperative/src/impl/utils/base64.cpp | 172 ++++++++++++++++ .../megbrain/imperative/persistent_cache.h | 31 +++ .../megbrain/imperative/utils/base64.h | 50 +++++ imperative/test/CMakeLists.txt | 2 +- .../include/megbrain/utils/persistent_cache.h | 9 +- 13 files changed, 535 insertions(+), 169 deletions(-) create mode 100644 cmake/cpp_redis.cmake create mode 100644 imperative/src/impl/persistent_cache.cpp create mode 100644 imperative/src/impl/utils/base64.cpp create mode 100644 imperative/src/include/megbrain/imperative/persistent_cache.h create mode 100644 imperative/src/include/megbrain/imperative/utils/base64.h diff --git a/CMakeLists.txt b/CMakeLists.txt index d39d6dce7..2dbd6d281 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -877,6 +877,8 @@ if(MGE_WITH_JIT AND MGE_WITH_HALIDE) include(cmake/Halide.cmake) endif() +include(cmake/cpp_redis.cmake) + # Thread IF(APPLE) set(CMAKE_THREAD_LIBS_INIT "-lpthread") diff --git a/cmake/cpp_redis.cmake b/cmake/cpp_redis.cmake new file mode 100644 index 000000000..d7b642e49 --- /dev/null +++ b/cmake/cpp_redis.cmake @@ -0,0 +1,2 @@ +file(GLOB_RECURSE CPP_REDIS_SRCS ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/sources/*.cpp ${PROJECT_SOURCE_DIR}/third_party/tacopie/sources/*.cpp) +set(CPP_REDIS_INCLUDES ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/includes ${PROJECT_SOURCE_DIR}/third_party/tacopie/includes) \ No newline at end of file diff --git a/imperative/CMakeLists.txt b/imperative/CMakeLists.txt index d24193cd3..8561187d5 100644 --- a/imperative/CMakeLists.txt +++ b/imperative/CMakeLists.txt @@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) set(MODULE_NAME _imperative_rt) set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) +set(SRCS ${SRCS} ${CPP_REDIS_SRCS}) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") @@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) -target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) +target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) if(CXX_SUPPORT_WCLASS_MEMACCESS) diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index c525d4c0e..38291bfde 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -92,9 +92,6 @@ _set_fork_exec_path_for_timed_func( os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), ) -_persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer() -_persistent_cache_impl_ins.reg() - atexit.register(_close) del _set_fork_exec_path_for_timed_func @@ -135,3 +132,5 @@ import megengine.quantization import megengine.random import megengine.utils import megengine.traced_module + +persistent_cache.get_manager() diff --git a/imperative/python/megengine/utils/persistent_cache.py b/imperative/python/megengine/utils/persistent_cache.py index f0c409cfb..dfe821595 100644 --- a/imperative/python/megengine/utils/persistent_cache.py +++ b/imperative/python/megengine/utils/persistent_cache.py @@ -9,108 +9,54 @@ import argparse import getpass -import json import os -import shelve +import sys -from ..core._imperative_rt import PersistentCache as _PersistentCache +from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager from ..logger import get_logger from ..version import __version__, git_version -class _FakeRedisConn: - _cache_dir = None - _is_shelve = False - _dict = {} - +class PersistentCacheManager(_PersistentCacheManager): def __init__(self): + super().__init__() if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": - self._dict = {} - self._is_shelve = False get_logger().info("fastrun use in-memory cache") + self.open_memory() else: - try: - self._cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") - if not self._cache_dir: - from ..hub.hub import _get_megengine_home - - self._cache_dir = os.path.expanduser( - os.path.join(_get_megengine_home(), "persistent_cache") - ) - os.makedirs(self._cache_dir, exist_ok=True) - cache_file = os.path.join(self._cache_dir, "cache") - self._dict = shelve.open(cache_file) - self._is_shelve = True - get_logger().info( - "fastrun use in-file cache in {}".format(self._cache_dir) - ) - except Exception as exc: - self._dict = {} - self._is_shelve = False - get_logger().error( - "failed to create cache file in {} {!r}; fallback to " - "in-memory cache".format(self._cache_dir, exc) - ) - - def get(self, key): - if self._is_shelve and isinstance(key, bytes): - key = key.decode("utf-8") - - return self._dict.get(key) - - def set(self, key, val): - if self._is_shelve and isinstance(key, bytes): - key = key.decode("utf-8") - - self._dict[key] = val - - def clear(self): - print("{} cache item deleted in {}".format(len(self._dict), self._cache_dir)) - self._dict.clear() + self.open_file() - def __del__(self): - if self._is_shelve: - self._dict.close() + def open_memory(self): + pass + def open_file(self): + cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") + try: + if not cache_dir: + from ..hub.hub import _get_megengine_home -class PersistentCacheOnServer(_PersistentCache): - _cached_conn = None - _prefix = None - _prev_get_refkeep = None - - @property - def _conn(self): - """get redis connection""" - if self._cached_conn is None: - self._cached_conn = _FakeRedisConn() - self._prefix = self.make_user_prefix() - - return self._cached_conn - - @classmethod - def make_user_prefix(cls): - return "mgbcache:{}".format(getpass.getuser()) - - def _make_key(self, category, key): - prefix_with_version = "{}:MGB{}:GIT:{}".format( - self._prefix, __version__, git_version - ) - return b"@".join( - (prefix_with_version.encode("ascii"), category.encode("ascii"), key) - ) - - def put(self, category, key, value): - conn = self._conn - key = self._make_key(category, key) - conn.set(key, value) - - def get(self, category, key): - conn = self._conn - key = self._make_key(category, key) - self._prev_get_refkeep = conn.get(key) - return self._prev_get_refkeep - - def clean(self): - conn = self._conn - if isinstance(conn, _FakeRedisConn): - conn.clear() + cache_dir = os.path.expanduser( + os.path.join(_get_megengine_home(), "persistent_cache.bin") + ) + os.makedirs(cache_dir, exist_ok=True) + cache_file = os.path.join(cache_dir, "cache") + with open(cache_file, "a"): + pass + assert self.try_open_file(cache_file), "cannot create file" + get_logger().info("fastrun use in-file cache in {}".format(cache_dir)) + except Exception as exc: + get_logger().error( + "failed to create cache file in {} {!r}; fallback to " + "in-memory cache".format(cache_dir, exc) + ) + self.open_memory() + + +_manager = None + + +def get_manager(): + global _manager + if _manager is None: + _manager = PersistentCacheManager() + return _manager diff --git a/imperative/python/src/utils.cpp b/imperative/python/src/utils.cpp index 0bf3cab03..5162f488f 100644 --- a/imperative/python/src/utils.cpp +++ b/imperative/python/src/utils.cpp @@ -23,6 +23,7 @@ #include "megbrain/common.h" #include "megbrain/comp_node.h" #include "megbrain/imperative/blob_manager.h" +#include "megbrain/imperative/persistent_cache.h" #include "megbrain/imperative/profiler.h" #include "megbrain/imperative/tensor_sanity_check.h" #include "megbrain/serialization/helper.h" @@ -229,83 +230,55 @@ void init_utils(py::module m) { mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); }); - using mgb::PersistentCache; - class PyPersistentCache : public mgb::PersistentCache { - private: - using KeyPair = std::pair; - using BlobPtr = std::unique_ptr; + using PersistentCache = mgb::PersistentCache; + using ExtendedPersistentCache = + mgb::imperative::persistent_cache::ExtendedPersistentCache; - std::shared_mutex m_mutex; - std::unordered_map m_local_cache; + struct PersistentCacheManager { + std::shared_ptr instance; - static size_t hash_key_pair(const KeyPair& kp) { - std::hash hasher; - return hasher(kp.first) ^ hasher(kp.second); + bool try_reg(std::shared_ptr cache) { + if (cache) { + instance = cache; + PersistentCache::set_impl(cache); + return true; + } + return false; } - - std::string blob_to_str(const Blob& key) { - return std::string(reinterpret_cast(key.ptr), key.size); + bool open_redis( + std::string ip, size_t port, std::string password, std::string prefix) { + return try_reg(mgb::imperative::persistent_cache::make_redis( + ip, port, password, prefix)); } - - BlobPtr copy_blob(const Blob& blob) { - auto blob_deleter = [](Blob* blob) { - if (blob) { - std::free(const_cast(blob->ptr)); - delete blob; - } - }; - auto blob_ptr = BlobPtr{new Blob(), blob_deleter}; - blob_ptr->ptr = std::malloc(blob.size); - std::memcpy(const_cast(blob_ptr->ptr), blob.ptr, blob.size); - blob_ptr->size = blob.size; - return blob_ptr; + bool open_file(std::string path) { + return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); } - - BlobPtr str_to_blob(const std::string& str) { - auto blob = Blob{str.data(), str.size()}; - return copy_blob(blob); + std::optional clean() { + if (instance) { + return instance->clear(); + } + return {}; } - - std::unique_ptr empty_blob() { - return BlobPtr{nullptr, [](Blob* blob) {}}; + void put(std::string category, std::string key, std::string value) { + PersistentCache::inst().put( + category, {key.data(), key.size()}, {value.data(), value.size()}); } - - public: - mgb::Maybe get(const std::string& category, const Blob& key) override { - auto py_get = [this](const std::string& category, - const Blob& key) -> mgb::Maybe { - PYBIND11_OVERLOAD_PURE( - mgb::Maybe, PersistentCache, get, category, key); - }; - KeyPair kp = {category, blob_to_str(key)}; - std::shared_lock rlock; - auto iter = m_local_cache.find(kp); - if (iter == m_local_cache.end()) { - auto py_ret = py_get(category, key); - if (!py_ret.valid()) { - iter = m_local_cache.insert({kp, empty_blob()}).first; - } else { - iter = m_local_cache.insert({kp, copy_blob(py_ret.val())}).first; - } - } - if (iter->second) { - return *iter->second; + py::object get(std::string category, std::string key) { + auto value = + PersistentCache::inst().get(category, {key.data(), key.size()}); + if (value.valid()) { + return py::bytes(std::string((const char*)value->ptr, value->size)); } else { - return {}; + return py::none(); } } - void put(const std::string& category, const Blob& key, const Blob& value) - override { - KeyPair kp = {category, blob_to_str(key)}; - std::unique_lock wlock; - m_local_cache.insert_or_assign(kp, copy_blob(value)); - PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); - } }; - py::class_>( - m, "PersistentCache") + + py::class_(m, "PersistentCacheManager") .def(py::init<>()) - .def("get", &PersistentCache::get) - .def("put", &PersistentCache::put) - .def("reg", &PersistentCache::set_impl); + .def("try_open_redis", &PersistentCacheManager::open_redis) + .def("try_open_file", &PersistentCacheManager::open_file) + .def("clean", &PersistentCacheManager::clean) + .def("put", &PersistentCacheManager::put) + .def("get", &PersistentCacheManager::get); } diff --git a/imperative/python/test/unit/utils/test_utils.py b/imperative/python/test/unit/utils/test_utils.py index 10b462344..91f517806 100644 --- a/imperative/python/test/unit/utils/test_utils.py +++ b/imperative/python/test/unit/utils/test_utils.py @@ -1,12 +1,11 @@ import pytest -import megengine -from megengine.utils.persistent_cache import PersistentCacheOnServer +from megengine.utils.persistent_cache import _manager @pytest.mark.skip(reason="fixme: github ci failed") def test_persistent_cache(): - pc = PersistentCacheOnServer() + pc = _manager k0 = b"\x00\x00" k1 = b"\x00\x01" cat = "test" diff --git a/imperative/src/impl/persistent_cache.cpp b/imperative/src/impl/persistent_cache.cpp new file mode 100644 index 000000000..dd6d65853 --- /dev/null +++ b/imperative/src/impl/persistent_cache.cpp @@ -0,0 +1,186 @@ +/** + * \file imperative/src/impl/persistent_cache.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include +#include +#include + +#include "cpp_redis/cpp_redis" + +#include "megbrain/imperative/persistent_cache.h" +#include "megbrain/imperative/utils/base64.h" +#include "megbrain/utils/infile_persistent_cache.h" + +namespace mgb::imperative::persistent_cache { + +class RedisCache final : public ExtendedPersistentCache { +public: + RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) { + m_local = std::make_shared(); + } + + bool connect(std::string ip, size_t port, std::string password) { + m_client.auth(password); + m_client.connect( + ip, port, + [](const std::string& host, std::size_t port, + cpp_redis::connect_state status) { + if (status == cpp_redis::connect_state::dropped) { + mgb_log("client disconnected from %s.", host.c_str()); + mgb_log("Redis server connect to %s :%zu failed.", host.c_str(), + port); + } + }, + std::uint32_t(200)); + if (!m_client.is_connected()) { + return false; + } + auto flag = m_client.get("mgb-cache-flag"); + sync(); + return flag.get().ok(); + } + + bool valid() const override { return m_client.is_connected(); } + + mgb::Maybe get(const std::string& category, const Blob& key) override { + MGB_LOCK_GUARD(m_mtx); + auto mem_result = m_local->get(category, key); + if (mem_result.valid()) + return mem_result; + + std::string key_str(static_cast(key.ptr), key.size); + std::string redis_key_str; + encode(category + '@' + key_str, redis_key_str, 24); + auto result = m_client.get(redis_key_str); + sync(); + auto content = result.get(); + if (content.is_null()) + return mgb::None; + std::string decode_content; + decode(content.as_string(), decode_content); + m_local->put(category, key, {decode_content.data(), decode_content.length()}); + + return m_local->get(category, key); + } + + void put(const std::string& category, const Blob& key, const Blob& value) override { + MGB_LOCK_GUARD(m_mtx); + std::string key_str(static_cast(key.ptr), key.size); + std::string redis_key_str; + encode(category + '@' + key_str, redis_key_str); + std::string value_str(static_cast(value.ptr), value.size); + std::string redis_value_str; + encode(value_str, redis_value_str); + + auto result = m_client.set(redis_key_str, redis_value_str); + m_local->put(category, key, value); + sync(); + } + + std::optional clear() override { + size_t cursor = 0, nr_deleted = 0; + std::string pattern = m_prefix + "@*"; + do { + auto reply = m_client.scan(cursor, pattern).share(); + sync(); + auto keys = reply.get().as_array(); + std::vector string_keys; + for (auto&& key : keys) { + string_keys.push_back(key.as_string()); + } + m_client.del(string_keys); + nr_deleted += string_keys.size(); + cursor = reply.get().as_array()[0].as_integer(); + } while (cursor != 0); + return nr_deleted; + } + +private: + std::shared_ptr m_local; + std::mutex m_mtx; + cpp_redis::client m_client; + std::string m_prefix; + uint64_t m_timeout; + + void sync() { + m_client.sync_commit(std::chrono::milliseconds(m_timeout)); + mgb_assert(valid()); + } +}; + +class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { +private: + std::string m_path; + std::unique_ptr m_impl; + +public: + ExtendedInFilePersistentCache() = default; + + bool open(std::string path) { + std::fstream file; + file.open(path, std::ios::in | std::ios::binary); + if (!file.is_open()) { + return false; + } + std::vector bytes = { + std::istreambuf_iterator(file), std::istreambuf_iterator()}; + if (bytes.size()) { + m_impl = std::make_unique( + (const uint8_t*)bytes.data(), bytes.size()); + } else { + m_impl = std::make_unique(); + } + m_path = path; + return true; + } + + ~ExtendedInFilePersistentCache() { + if (m_impl) { + m_impl->dump_cache(m_path.c_str()); + } + } + + mgb::Maybe get(const std::string& category, const Blob& key) override { + return m_impl->get(category, key); + } + + void put(const std::string& category, const Blob& key, const Blob& value) override { + return m_impl->put(category, key, value); + } + + std::optional clear() override { + m_impl = std::make_unique(); + m_impl->dump_cache(m_path.c_str()); + return {}; + } + + bool valid() const override { return m_impl != nullptr; } +}; + +std::shared_ptr make_redis( + std::string ip, size_t port, std::string password, std::string prefix) { + auto cache = std::make_shared(prefix, 100); + if (!cache->connect(ip, port, password)) { + return nullptr; + } + return cache; +} +std::shared_ptr make_in_file(std::string path) { + auto cache = std::make_shared(); + if (!cache->open(path)) { + return nullptr; + } + return cache; +} + +} // namespace mgb::imperative::persistent_cache + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/impl/utils/base64.cpp b/imperative/src/impl/utils/base64.cpp new file mode 100644 index 000000000..b09a8eeaa --- /dev/null +++ b/imperative/src/impl/utils/base64.cpp @@ -0,0 +1,172 @@ +/** + * \file imperative/src/impl/base64.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/imperative/utils/base64.h" + +namespace mgb::imperative { + +namespace { + +/* +** Translation Table as described in RFC1113 +*/ +const char cb64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +/* +** Translation Table to decode: +*https://github.com/dgiardini/imgcalkap/blob/master/base64.c +*/ +const char cd64[] = + "|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`" + "abcdefghijklmnopq"; + +/* +** encodeblock +** +** encode 3 8-bit binary bytes as 4 '6-bit' characters +*/ +void encodeblock(unsigned char in[3], unsigned char out[4], int len) { + out[0] = cb64[in[0] >> 2]; + out[1] = cb64[((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4)]; + out[2] = + (unsigned char)(len > 1 ? cb64[((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6)] : '='); + out[3] = (unsigned char)(len > 2 ? cb64[in[2] & 0x3f] : '='); +} + +/* +** decodeblock +** +** decode 4 '6-bit' characters into 3 8-bit binary bytes +*/ +void decodeblock(unsigned char in[4], unsigned char out[3]) { + out[0] = (unsigned char)(in[0] << 2 | in[1] >> 4); + out[1] = (unsigned char)(in[1] << 4 | in[2] >> 2); + out[2] = (unsigned char)(((in[2] << 6) & 0xc0) | in[3]); +} + +} // namespace + +/** + * Encode string to base64 string + * @param input - source string + * @param outdata - target base64 string + * @param linesize - max size of line + */ +void encode( + const std::vector& input, std::vector& outdata, + int linesize) { + outdata.clear(); + + unsigned char in[3], out[4]; + int i, len, blocksout = 0; + size_t j = 0; + + auto* indata = reinterpret_cast(input.data()); + unsigned int insize = input.size(); + + while (j <= insize) { + len = 0; + for (i = 0; i < 3; i++) { + in[i] = (unsigned char)indata[j]; + j++; + if (j <= insize) { + len++; + } else { + in[i] = 0; + } + } + if (len) { + encodeblock(in, out, len); + for (i = 0; i < 4; i++) { + outdata.push_back(out[i]); + } + blocksout++; + } + if (blocksout >= (linesize / 4) || (j == insize)) { + if (blocksout) { + outdata.push_back('\r'); + outdata.push_back('\n'); + } + blocksout = 0; + } + } +} + +/** + * Decode base64 string ot source + * @param input - base64 string + * @param outdata - source string + */ +void decode( + const std::vector& input, std::vector& outdata) { + outdata.clear(); + + unsigned char in[4], out[3], v; + int i, len; + size_t j = 0; + + auto* indata = reinterpret_cast(input.data()); + unsigned int insize = input.size(); + + while (j <= insize) { + for (len = 0, i = 0; i < 4 && (j <= insize); i++) { + v = 0; + while ((j <= insize) && v == 0) { + v = (unsigned char)indata[j++]; + v = (unsigned char)((v < 43 || v > 122) ? 0 : cd64[v - 43]); + if (v) { + v = (unsigned char)((v == '$') ? 0 : v - 61); + } + } + if (j <= insize) { + len++; + if (v) { + in[i] = (unsigned char)(v - 1); + } + } else { + in[i] = 0; + } + } + if (len) { + decodeblock(in, out); + for (i = 0; i < len - 1; i++) { + outdata.push_back(out[i]); + } + } + } +} + +/** + * Encode binary data to base64 buffer + * @param input - source data + * @param outdata - target base64 buffer + * @param linesize + */ +void encode(const std::string& input, std::string& outdata, int linesize) { + std::vector out; + std::vector in(input.begin(), input.end()); + encode(in, out, linesize); + outdata = std::string(out.begin(), out.end()); +} + +/** + * Decode base64 buffer to source binary data + * @param input - base64 buffer + * @param outdata - source binary data + */ +void decode(const std::string& input, std::string& outdata) { + std::vector in(input.begin(), input.end()); + std::vector out; + decode(in, out); + outdata = std::string(out.begin(), out.end()); +} + +} // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/persistent_cache.h b/imperative/src/include/megbrain/imperative/persistent_cache.h new file mode 100644 index 000000000..59326bbde --- /dev/null +++ b/imperative/src/include/megbrain/imperative/persistent_cache.h @@ -0,0 +1,31 @@ +/** + * \file imperative/src/include/megbrain/imperative/persistent_cache.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include "megbrain/utils/persistent_cache.h" + +namespace mgb::imperative::persistent_cache { + +class ExtendedPersistentCache : public mgb::PersistentCache { +public: + virtual bool valid() const = 0; + virtual std::optional clear() = 0; +}; + +std::shared_ptr make_redis( + std::string ip, size_t port, std::string password, std::string prefix); + +std::shared_ptr make_in_file(std::string path); + +} // namespace mgb::imperative::persistent_cache +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/include/megbrain/imperative/utils/base64.h b/imperative/src/include/megbrain/imperative/utils/base64.h new file mode 100644 index 000000000..540131512 --- /dev/null +++ b/imperative/src/include/megbrain/imperative/utils/base64.h @@ -0,0 +1,50 @@ +/** + * \file imperative/src/include/megbrain/imperative/utils/base64.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/common.h" + +namespace mgb::imperative { + +/** + * Encode string to base64 string + * @param input - source string + * @param outdata - target base64 string + * @param linesize - max size of line + */ +void encode( + const std::vector& input, std::vector& outdata, + int linesize = 76); + +/** + * Decode base64 string ot source + * @param input - base64 string + * @param outdata - source string + */ +void decode(const std::vector& input, std::vector& outdata); + +/** + * Encode binary data to base64 buffer + * @param input - source data + * @param outdata - target base64 buffer + * @param linesize + */ +void encode(const std::string& input, std::string& outdata, int linesize = 76); + +/** + * Decode base64 buffer to source binary data + * @param input - base64 buffer + * @param outdata - source binary data + */ +void decode(const std::string& input, std::string& outdata); + +} // namespace mgb::imperative diff --git a/imperative/test/CMakeLists.txt b/imperative/test/CMakeLists.txt index 32a0dc698..1d472ac6d 100644 --- a/imperative/test/CMakeLists.txt +++ b/imperative/test/CMakeLists.txt @@ -12,7 +12,7 @@ endif() # TODO: turn python binding into a static/object library add_executable(imperative_test ${SOURCES} ${SRCS}) add_dependencies(imperative_test mgb_opdef) -target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR}) +target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES}) # Python binding target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) diff --git a/src/core/include/megbrain/utils/persistent_cache.h b/src/core/include/megbrain/utils/persistent_cache.h index 2a9bab611..384deae22 100644 --- a/src/core/include/megbrain/utils/persistent_cache.h +++ b/src/core/include/megbrain/utils/persistent_cache.h @@ -74,14 +74,19 @@ class InMemoryPersistentCache final : public PersistentCache { }; }; - Maybe get(const std::string& category, const Blob& key) override; - void put(const std::string& category, const Blob& key, const Blob& value) override; + MGE_WIN_DECLSPEC_FUC Maybe get( + const std::string& category, const Blob& key) override; + MGE_WIN_DECLSPEC_FUC void put( + const std::string& category, const Blob& key, const Blob& value) override; std::unordered_map< std::string, std::unordered_map> m_cache; MGB_MUTEX m_mtx; + +public: + MGE_WIN_DECLSPEC_FUC InMemoryPersistentCache() = default; }; /*! -- GitLab