提交 cf1db261 编写于 作者: M Megvii Engine Team

fix(fastrun): replace py_redis with cpp_redis to avoid deadlock

GitOrigin-RevId: 9af7fa5c97d401c51bfd02490887222ec8341098
上级 b75d1009
...@@ -877,6 +877,8 @@ if(MGE_WITH_JIT AND MGE_WITH_HALIDE) ...@@ -877,6 +877,8 @@ if(MGE_WITH_JIT AND MGE_WITH_HALIDE)
include(cmake/Halide.cmake) include(cmake/Halide.cmake)
endif() endif()
include(cmake/cpp_redis.cmake)
# Thread # Thread
IF(APPLE) IF(APPLE)
set(CMAKE_THREAD_LIBS_INIT "-lpthread") set(CMAKE_THREAD_LIBS_INIT "-lpthread")
......
file(GLOB_RECURSE CPP_REDIS_SRCS ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/sources/*.cpp ${PROJECT_SOURCE_DIR}/third_party/tacopie/sources/*.cpp)
set(CPP_REDIS_INCLUDES ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/includes ${PROJECT_SOURCE_DIR}/third_party/tacopie/includes)
\ No newline at end of file
...@@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE) ...@@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE)
set(MODULE_NAME _imperative_rt) set(MODULE_NAME _imperative_rt)
set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE) set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE)
file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h) file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h)
set(SRCS ${SRCS} ${CPP_REDIS_SRCS})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1")
...@@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3) ...@@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json) add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json)
target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json) target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json)
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR}) target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES})
target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME}) target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME})
target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter) target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter)
if(CXX_SUPPORT_WCLASS_MEMACCESS) if(CXX_SUPPORT_WCLASS_MEMACCESS)
......
...@@ -92,9 +92,6 @@ _set_fork_exec_path_for_timed_func( ...@@ -92,9 +92,6 @@ _set_fork_exec_path_for_timed_func(
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"),
) )
_persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins.reg()
atexit.register(_close) atexit.register(_close)
del _set_fork_exec_path_for_timed_func del _set_fork_exec_path_for_timed_func
...@@ -135,3 +132,5 @@ import megengine.quantization ...@@ -135,3 +132,5 @@ import megengine.quantization
import megengine.random import megengine.random
import megengine.utils import megengine.utils
import megengine.traced_module import megengine.traced_module
persistent_cache.get_manager()
...@@ -9,108 +9,54 @@ ...@@ -9,108 +9,54 @@
import argparse import argparse
import getpass import getpass
import json
import os import os
import shelve import sys
from ..core._imperative_rt import PersistentCache as _PersistentCache from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager
from ..logger import get_logger from ..logger import get_logger
from ..version import __version__, git_version from ..version import __version__, git_version
class _FakeRedisConn: class PersistentCacheManager(_PersistentCacheManager):
_cache_dir = None
_is_shelve = False
_dict = {}
def __init__(self): def __init__(self):
super().__init__()
if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY":
self._dict = {}
self._is_shelve = False
get_logger().info("fastrun use in-memory cache") get_logger().info("fastrun use in-memory cache")
self.open_memory()
else: else:
try: self.open_file()
self._cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR")
if not self._cache_dir:
from ..hub.hub import _get_megengine_home
self._cache_dir = os.path.expanduser(
os.path.join(_get_megengine_home(), "persistent_cache")
)
os.makedirs(self._cache_dir, exist_ok=True)
cache_file = os.path.join(self._cache_dir, "cache")
self._dict = shelve.open(cache_file)
self._is_shelve = True
get_logger().info(
"fastrun use in-file cache in {}".format(self._cache_dir)
)
except Exception as exc:
self._dict = {}
self._is_shelve = False
get_logger().error(
"failed to create cache file in {} {!r}; fallback to "
"in-memory cache".format(self._cache_dir, exc)
)
def get(self, key):
if self._is_shelve and isinstance(key, bytes):
key = key.decode("utf-8")
return self._dict.get(key)
def set(self, key, val):
if self._is_shelve and isinstance(key, bytes):
key = key.decode("utf-8")
self._dict[key] = val
def clear(self):
print("{} cache item deleted in {}".format(len(self._dict), self._cache_dir))
self._dict.clear()
def __del__(self): def open_memory(self):
if self._is_shelve: pass
self._dict.close()
def open_file(self):
cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR")
try:
if not cache_dir:
from ..hub.hub import _get_megengine_home
class PersistentCacheOnServer(_PersistentCache): cache_dir = os.path.expanduser(
_cached_conn = None os.path.join(_get_megengine_home(), "persistent_cache.bin")
_prefix = None )
_prev_get_refkeep = None os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "cache")
@property with open(cache_file, "a"):
def _conn(self): pass
"""get redis connection""" assert self.try_open_file(cache_file), "cannot create file"
if self._cached_conn is None: get_logger().info("fastrun use in-file cache in {}".format(cache_dir))
self._cached_conn = _FakeRedisConn() except Exception as exc:
self._prefix = self.make_user_prefix() get_logger().error(
"failed to create cache file in {} {!r}; fallback to "
return self._cached_conn "in-memory cache".format(cache_dir, exc)
)
@classmethod self.open_memory()
def make_user_prefix(cls):
return "mgbcache:{}".format(getpass.getuser())
_manager = None
def _make_key(self, category, key):
prefix_with_version = "{}:MGB{}:GIT:{}".format(
self._prefix, __version__, git_version def get_manager():
) global _manager
return b"@".join( if _manager is None:
(prefix_with_version.encode("ascii"), category.encode("ascii"), key) _manager = PersistentCacheManager()
) return _manager
def put(self, category, key, value):
conn = self._conn
key = self._make_key(category, key)
conn.set(key, value)
def get(self, category, key):
conn = self._conn
key = self._make_key(category, key)
self._prev_get_refkeep = conn.get(key)
return self._prev_get_refkeep
def clean(self):
conn = self._conn
if isinstance(conn, _FakeRedisConn):
conn.clear()
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megbrain/imperative/blob_manager.h" #include "megbrain/imperative/blob_manager.h"
#include "megbrain/imperative/persistent_cache.h"
#include "megbrain/imperative/profiler.h" #include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/tensor_sanity_check.h" #include "megbrain/imperative/tensor_sanity_check.h"
#include "megbrain/serialization/helper.h" #include "megbrain/serialization/helper.h"
...@@ -229,83 +230,55 @@ void init_utils(py::module m) { ...@@ -229,83 +230,55 @@ void init_utils(py::module m) {
mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str());
}); });
using mgb::PersistentCache; using PersistentCache = mgb::PersistentCache;
class PyPersistentCache : public mgb::PersistentCache { using ExtendedPersistentCache =
private: mgb::imperative::persistent_cache::ExtendedPersistentCache;
using KeyPair = std::pair<std::string, std::string>;
using BlobPtr = std::unique_ptr<Blob, void (*)(Blob*)>;
std::shared_mutex m_mutex; struct PersistentCacheManager {
std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache; std::shared_ptr<ExtendedPersistentCache> instance;
static size_t hash_key_pair(const KeyPair& kp) { bool try_reg(std::shared_ptr<ExtendedPersistentCache> cache) {
std::hash<std::string> hasher; if (cache) {
return hasher(kp.first) ^ hasher(kp.second); instance = cache;
PersistentCache::set_impl(cache);
return true;
}
return false;
} }
bool open_redis(
std::string blob_to_str(const Blob& key) { std::string ip, size_t port, std::string password, std::string prefix) {
return std::string(reinterpret_cast<const char*>(key.ptr), key.size); return try_reg(mgb::imperative::persistent_cache::make_redis(
ip, port, password, prefix));
} }
bool open_file(std::string path) {
BlobPtr copy_blob(const Blob& blob) { return try_reg(mgb::imperative::persistent_cache::make_in_file(path));
auto blob_deleter = [](Blob* blob) {
if (blob) {
std::free(const_cast<void*>(blob->ptr));
delete blob;
}
};
auto blob_ptr = BlobPtr{new Blob(), blob_deleter};
blob_ptr->ptr = std::malloc(blob.size);
std::memcpy(const_cast<void*>(blob_ptr->ptr), blob.ptr, blob.size);
blob_ptr->size = blob.size;
return blob_ptr;
} }
std::optional<size_t> clean() {
BlobPtr str_to_blob(const std::string& str) { if (instance) {
auto blob = Blob{str.data(), str.size()}; return instance->clear();
return copy_blob(blob); }
return {};
} }
void put(std::string category, std::string key, std::string value) {
std::unique_ptr<Blob, void (*)(Blob*)> empty_blob() { PersistentCache::inst().put(
return BlobPtr{nullptr, [](Blob* blob) {}}; category, {key.data(), key.size()}, {value.data(), value.size()});
} }
py::object get(std::string category, std::string key) {
public: auto value =
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override { PersistentCache::inst().get(category, {key.data(), key.size()});
auto py_get = [this](const std::string& category, if (value.valid()) {
const Blob& key) -> mgb::Maybe<Blob> { return py::bytes(std::string((const char*)value->ptr, value->size));
PYBIND11_OVERLOAD_PURE(
mgb::Maybe<Blob>, PersistentCache, get, category, key);
};
KeyPair kp = {category, blob_to_str(key)};
std::shared_lock<decltype(m_mutex)> rlock;
auto iter = m_local_cache.find(kp);
if (iter == m_local_cache.end()) {
auto py_ret = py_get(category, key);
if (!py_ret.valid()) {
iter = m_local_cache.insert({kp, empty_blob()}).first;
} else {
iter = m_local_cache.insert({kp, copy_blob(py_ret.val())}).first;
}
}
if (iter->second) {
return *iter->second;
} else { } else {
return {}; return py::none();
} }
} }
void put(const std::string& category, const Blob& key, const Blob& value)
override {
KeyPair kp = {category, blob_to_str(key)};
std::unique_lock<decltype(m_mutex)> wlock;
m_local_cache.insert_or_assign(kp, copy_blob(value));
PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value);
}
}; };
py::class_<PersistentCache, PyPersistentCache, std::shared_ptr<PersistentCache>>(
m, "PersistentCache") py::class_<PersistentCacheManager>(m, "PersistentCacheManager")
.def(py::init<>()) .def(py::init<>())
.def("get", &PersistentCache::get) .def("try_open_redis", &PersistentCacheManager::open_redis)
.def("put", &PersistentCache::put) .def("try_open_file", &PersistentCacheManager::open_file)
.def("reg", &PersistentCache::set_impl); .def("clean", &PersistentCacheManager::clean)
.def("put", &PersistentCacheManager::put)
.def("get", &PersistentCacheManager::get);
} }
import pytest import pytest
import megengine from megengine.utils.persistent_cache import _manager
from megengine.utils.persistent_cache import PersistentCacheOnServer
@pytest.mark.skip(reason="fixme: github ci failed") @pytest.mark.skip(reason="fixme: github ci failed")
def test_persistent_cache(): def test_persistent_cache():
pc = PersistentCacheOnServer() pc = _manager
k0 = b"\x00\x00" k0 = b"\x00\x00"
k1 = b"\x00\x01" k1 = b"\x00\x01"
cat = "test" cat = "test"
......
/**
* \file imperative/src/impl/persistent_cache.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <fstream>
#include <string>
#include <vector>
#include "cpp_redis/cpp_redis"
#include "megbrain/imperative/persistent_cache.h"
#include "megbrain/imperative/utils/base64.h"
#include "megbrain/utils/infile_persistent_cache.h"
namespace mgb::imperative::persistent_cache {
class RedisCache final : public ExtendedPersistentCache {
public:
RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) {
m_local = std::make_shared<mgb::InMemoryPersistentCache>();
}
bool connect(std::string ip, size_t port, std::string password) {
m_client.auth(password);
m_client.connect(
ip, port,
[](const std::string& host, std::size_t port,
cpp_redis::connect_state status) {
if (status == cpp_redis::connect_state::dropped) {
mgb_log("client disconnected from %s.", host.c_str());
mgb_log("Redis server connect to %s :%zu failed.", host.c_str(),
port);
}
},
std::uint32_t(200));
if (!m_client.is_connected()) {
return false;
}
auto flag = m_client.get("mgb-cache-flag");
sync();
return flag.get().ok();
}
bool valid() const override { return m_client.is_connected(); }
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
MGB_LOCK_GUARD(m_mtx);
auto mem_result = m_local->get(category, key);
if (mem_result.valid())
return mem_result;
std::string key_str(static_cast<const char*>(key.ptr), key.size);
std::string redis_key_str;
encode(category + '@' + key_str, redis_key_str, 24);
auto result = m_client.get(redis_key_str);
sync();
auto content = result.get();
if (content.is_null())
return mgb::None;
std::string decode_content;
decode(content.as_string(), decode_content);
m_local->put(category, key, {decode_content.data(), decode_content.length()});
return m_local->get(category, key);
}
void put(const std::string& category, const Blob& key, const Blob& value) override {
MGB_LOCK_GUARD(m_mtx);
std::string key_str(static_cast<const char*>(key.ptr), key.size);
std::string redis_key_str;
encode(category + '@' + key_str, redis_key_str);
std::string value_str(static_cast<const char*>(value.ptr), value.size);
std::string redis_value_str;
encode(value_str, redis_value_str);
auto result = m_client.set(redis_key_str, redis_value_str);
m_local->put(category, key, value);
sync();
}
std::optional<size_t> clear() override {
size_t cursor = 0, nr_deleted = 0;
std::string pattern = m_prefix + "@*";
do {
auto reply = m_client.scan(cursor, pattern).share();
sync();
auto keys = reply.get().as_array();
std::vector<std::string> string_keys;
for (auto&& key : keys) {
string_keys.push_back(key.as_string());
}
m_client.del(string_keys);
nr_deleted += string_keys.size();
cursor = reply.get().as_array()[0].as_integer();
} while (cursor != 0);
return nr_deleted;
}
private:
std::shared_ptr<mgb::PersistentCache> m_local;
std::mutex m_mtx;
cpp_redis::client m_client;
std::string m_prefix;
uint64_t m_timeout;
void sync() {
m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout));
mgb_assert(valid());
}
};
class ExtendedInFilePersistentCache final : public ExtendedPersistentCache {
private:
std::string m_path;
std::unique_ptr<mgb::InFilePersistentCache> m_impl;
public:
ExtendedInFilePersistentCache() = default;
bool open(std::string path) {
std::fstream file;
file.open(path, std::ios::in | std::ios::binary);
if (!file.is_open()) {
return false;
}
std::vector<char> bytes = {
std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()};
if (bytes.size()) {
m_impl = std::make_unique<mgb::InFilePersistentCache>(
(const uint8_t*)bytes.data(), bytes.size());
} else {
m_impl = std::make_unique<mgb::InFilePersistentCache>();
}
m_path = path;
return true;
}
~ExtendedInFilePersistentCache() {
if (m_impl) {
m_impl->dump_cache(m_path.c_str());
}
}
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
return m_impl->get(category, key);
}
void put(const std::string& category, const Blob& key, const Blob& value) override {
return m_impl->put(category, key, value);
}
std::optional<size_t> clear() override {
m_impl = std::make_unique<mgb::InFilePersistentCache>();
m_impl->dump_cache(m_path.c_str());
return {};
}
bool valid() const override { return m_impl != nullptr; }
};
std::shared_ptr<ExtendedPersistentCache> make_redis(
std::string ip, size_t port, std::string password, std::string prefix) {
auto cache = std::make_shared<RedisCache>(prefix, 100);
if (!cache->connect(ip, port, password)) {
return nullptr;
}
return cache;
}
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) {
auto cache = std::make_shared<ExtendedInFilePersistentCache>();
if (!cache->open(path)) {
return nullptr;
}
return cache;
}
} // namespace mgb::imperative::persistent_cache
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file imperative/src/impl/base64.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/utils/base64.h"
namespace mgb::imperative {
namespace {
/*
** Translation Table as described in RFC1113
*/
const char cb64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
/*
** Translation Table to decode:
*https://github.com/dgiardini/imgcalkap/blob/master/base64.c
*/
const char cd64[] =
"|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`"
"abcdefghijklmnopq";
/*
** encodeblock
**
** encode 3 8-bit binary bytes as 4 '6-bit' characters
*/
void encodeblock(unsigned char in[3], unsigned char out[4], int len) {
out[0] = cb64[in[0] >> 2];
out[1] = cb64[((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4)];
out[2] =
(unsigned char)(len > 1 ? cb64[((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6)] : '=');
out[3] = (unsigned char)(len > 2 ? cb64[in[2] & 0x3f] : '=');
}
/*
** decodeblock
**
** decode 4 '6-bit' characters into 3 8-bit binary bytes
*/
void decodeblock(unsigned char in[4], unsigned char out[3]) {
out[0] = (unsigned char)(in[0] << 2 | in[1] >> 4);
out[1] = (unsigned char)(in[1] << 4 | in[2] >> 2);
out[2] = (unsigned char)(((in[2] << 6) & 0xc0) | in[3]);
}
} // namespace
/**
* Encode string to base64 string
* @param input - source string
* @param outdata - target base64 string
* @param linesize - max size of line
*/
void encode(
const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata,
int linesize) {
outdata.clear();
unsigned char in[3], out[4];
int i, len, blocksout = 0;
size_t j = 0;
auto* indata = reinterpret_cast<const unsigned char*>(input.data());
unsigned int insize = input.size();
while (j <= insize) {
len = 0;
for (i = 0; i < 3; i++) {
in[i] = (unsigned char)indata[j];
j++;
if (j <= insize) {
len++;
} else {
in[i] = 0;
}
}
if (len) {
encodeblock(in, out, len);
for (i = 0; i < 4; i++) {
outdata.push_back(out[i]);
}
blocksout++;
}
if (blocksout >= (linesize / 4) || (j == insize)) {
if (blocksout) {
outdata.push_back('\r');
outdata.push_back('\n');
}
blocksout = 0;
}
}
}
/**
* Decode base64 string ot source
* @param input - base64 string
* @param outdata - source string
*/
void decode(
const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata) {
outdata.clear();
unsigned char in[4], out[3], v;
int i, len;
size_t j = 0;
auto* indata = reinterpret_cast<const unsigned char*>(input.data());
unsigned int insize = input.size();
while (j <= insize) {
for (len = 0, i = 0; i < 4 && (j <= insize); i++) {
v = 0;
while ((j <= insize) && v == 0) {
v = (unsigned char)indata[j++];
v = (unsigned char)((v < 43 || v > 122) ? 0 : cd64[v - 43]);
if (v) {
v = (unsigned char)((v == '$') ? 0 : v - 61);
}
}
if (j <= insize) {
len++;
if (v) {
in[i] = (unsigned char)(v - 1);
}
} else {
in[i] = 0;
}
}
if (len) {
decodeblock(in, out);
for (i = 0; i < len - 1; i++) {
outdata.push_back(out[i]);
}
}
}
}
/**
* Encode binary data to base64 buffer
* @param input - source data
* @param outdata - target base64 buffer
* @param linesize
*/
void encode(const std::string& input, std::string& outdata, int linesize) {
std::vector<std::uint8_t> out;
std::vector<std::uint8_t> in(input.begin(), input.end());
encode(in, out, linesize);
outdata = std::string(out.begin(), out.end());
}
/**
* Decode base64 buffer to source binary data
* @param input - base64 buffer
* @param outdata - source binary data
*/
void decode(const std::string& input, std::string& outdata) {
std::vector<std::uint8_t> in(input.begin(), input.end());
std::vector<std::uint8_t> out;
decode(in, out);
outdata = std::string(out.begin(), out.end());
}
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/persistent_cache.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <memory>
#include "megbrain/utils/persistent_cache.h"
namespace mgb::imperative::persistent_cache {
class ExtendedPersistentCache : public mgb::PersistentCache {
public:
virtual bool valid() const = 0;
virtual std::optional<size_t> clear() = 0;
};
std::shared_ptr<ExtendedPersistentCache> make_redis(
std::string ip, size_t port, std::string password, std::string prefix);
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path);
} // namespace mgb::imperative::persistent_cache
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file imperative/src/include/megbrain/imperative/utils/base64.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/common.h"
namespace mgb::imperative {
/**
* Encode string to base64 string
* @param input - source string
* @param outdata - target base64 string
* @param linesize - max size of line
*/
void encode(
const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata,
int linesize = 76);
/**
* Decode base64 string ot source
* @param input - base64 string
* @param outdata - source string
*/
void decode(const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata);
/**
* Encode binary data to base64 buffer
* @param input - source data
* @param outdata - target base64 buffer
* @param linesize
*/
void encode(const std::string& input, std::string& outdata, int linesize = 76);
/**
* Decode base64 buffer to source binary data
* @param input - base64 buffer
* @param outdata - source binary data
*/
void decode(const std::string& input, std::string& outdata);
} // namespace mgb::imperative
...@@ -12,7 +12,7 @@ endif() ...@@ -12,7 +12,7 @@ endif()
# TODO: turn python binding into a static/object library # TODO: turn python binding into a static/object library
add_executable(imperative_test ${SOURCES} ${SRCS}) add_executable(imperative_test ${SOURCES} ${SRCS})
add_dependencies(imperative_test mgb_opdef) add_dependencies(imperative_test mgb_opdef)
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR}) target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES})
# Python binding # Python binding
target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR}) target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR})
......
...@@ -74,14 +74,19 @@ class InMemoryPersistentCache final : public PersistentCache { ...@@ -74,14 +74,19 @@ class InMemoryPersistentCache final : public PersistentCache {
}; };
}; };
Maybe<Blob> get(const std::string& category, const Blob& key) override; MGE_WIN_DECLSPEC_FUC Maybe<Blob> get(
void put(const std::string& category, const Blob& key, const Blob& value) override; const std::string& category, const Blob& key) override;
MGE_WIN_DECLSPEC_FUC void put(
const std::string& category, const Blob& key, const Blob& value) override;
std::unordered_map< std::unordered_map<
std::string, std::string,
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>> std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>>
m_cache; m_cache;
MGB_MUTEX m_mtx; MGB_MUTEX m_mtx;
public:
MGE_WIN_DECLSPEC_FUC InMemoryPersistentCache() = default;
}; };
/*! /*!
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册