提交 cf1db261 编写于 作者: M Megvii Engine Team

fix(fastrun): replace py_redis with cpp_redis to avoid deadlock

GitOrigin-RevId: 9af7fa5c97d401c51bfd02490887222ec8341098
上级 b75d1009
......@@ -877,6 +877,8 @@ if(MGE_WITH_JIT AND MGE_WITH_HALIDE)
include(cmake/Halide.cmake)
endif()
include(cmake/cpp_redis.cmake)
# Thread
IF(APPLE)
set(CMAKE_THREAD_LIBS_INIT "-lpthread")
......
file(GLOB_RECURSE CPP_REDIS_SRCS ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/sources/*.cpp ${PROJECT_SOURCE_DIR}/third_party/tacopie/sources/*.cpp)
set(CPP_REDIS_INCLUDES ${PROJECT_SOURCE_DIR}/third_party/cpp_redis/includes ${PROJECT_SOURCE_DIR}/third_party/tacopie/includes)
\ No newline at end of file
......@@ -5,6 +5,7 @@ set(PACKAGE_NAME ${PACKAGE_NAME} PARENT_SCOPE)
set(MODULE_NAME _imperative_rt)
set(MODULE_NAME ${MODULE_NAME} PARENT_SCOPE)
file(GLOB_RECURSE SRCS src/impl/*.cpp src/include/*.h python/src/*.cpp python/src/*.h)
set(SRCS ${SRCS} ${CPP_REDIS_SRCS})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_WITH_IMPERATIVE=1")
......@@ -42,7 +43,7 @@ target_link_libraries(${MODULE_NAME} PRIVATE range-v3)
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/Json ${PROJECT_BINARY_DIR}/third_party/Json)
target_link_libraries(${MODULE_NAME} PRIVATE nlohmann_json::nlohmann_json)
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR})
target_include_directories(${MODULE_NAME} PUBLIC src/include PRIVATE ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR} ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES})
target_compile_definitions(${MODULE_NAME} PRIVATE MODULE_NAME=${MODULE_NAME})
target_compile_options(${MODULE_NAME} PRIVATE -Wno-unused-parameter)
if(CXX_SUPPORT_WCLASS_MEMACCESS)
......
......@@ -92,9 +92,6 @@ _set_fork_exec_path_for_timed_func(
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"),
)
_persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins.reg()
atexit.register(_close)
del _set_fork_exec_path_for_timed_func
......@@ -135,3 +132,5 @@ import megengine.quantization
import megengine.random
import megengine.utils
import megengine.traced_module
persistent_cache.get_manager()
......@@ -9,108 +9,54 @@
import argparse
import getpass
import json
import os
import shelve
import sys
from ..core._imperative_rt import PersistentCache as _PersistentCache
from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager
from ..logger import get_logger
from ..version import __version__, git_version
class _FakeRedisConn:
_cache_dir = None
_is_shelve = False
_dict = {}
class PersistentCacheManager(_PersistentCacheManager):
def __init__(self):
super().__init__()
if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY":
self._dict = {}
self._is_shelve = False
get_logger().info("fastrun use in-memory cache")
self.open_memory()
else:
try:
self._cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR")
if not self._cache_dir:
from ..hub.hub import _get_megengine_home
self._cache_dir = os.path.expanduser(
os.path.join(_get_megengine_home(), "persistent_cache")
)
os.makedirs(self._cache_dir, exist_ok=True)
cache_file = os.path.join(self._cache_dir, "cache")
self._dict = shelve.open(cache_file)
self._is_shelve = True
get_logger().info(
"fastrun use in-file cache in {}".format(self._cache_dir)
)
except Exception as exc:
self._dict = {}
self._is_shelve = False
get_logger().error(
"failed to create cache file in {} {!r}; fallback to "
"in-memory cache".format(self._cache_dir, exc)
)
def get(self, key):
if self._is_shelve and isinstance(key, bytes):
key = key.decode("utf-8")
return self._dict.get(key)
def set(self, key, val):
if self._is_shelve and isinstance(key, bytes):
key = key.decode("utf-8")
self._dict[key] = val
def clear(self):
print("{} cache item deleted in {}".format(len(self._dict), self._cache_dir))
self._dict.clear()
self.open_file()
def __del__(self):
if self._is_shelve:
self._dict.close()
def open_memory(self):
pass
def open_file(self):
cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR")
try:
if not cache_dir:
from ..hub.hub import _get_megengine_home
class PersistentCacheOnServer(_PersistentCache):
_cached_conn = None
_prefix = None
_prev_get_refkeep = None
@property
def _conn(self):
"""get redis connection"""
if self._cached_conn is None:
self._cached_conn = _FakeRedisConn()
self._prefix = self.make_user_prefix()
return self._cached_conn
@classmethod
def make_user_prefix(cls):
return "mgbcache:{}".format(getpass.getuser())
def _make_key(self, category, key):
prefix_with_version = "{}:MGB{}:GIT:{}".format(
self._prefix, __version__, git_version
)
return b"@".join(
(prefix_with_version.encode("ascii"), category.encode("ascii"), key)
)
def put(self, category, key, value):
conn = self._conn
key = self._make_key(category, key)
conn.set(key, value)
def get(self, category, key):
conn = self._conn
key = self._make_key(category, key)
self._prev_get_refkeep = conn.get(key)
return self._prev_get_refkeep
def clean(self):
conn = self._conn
if isinstance(conn, _FakeRedisConn):
conn.clear()
cache_dir = os.path.expanduser(
os.path.join(_get_megengine_home(), "persistent_cache.bin")
)
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "cache")
with open(cache_file, "a"):
pass
assert self.try_open_file(cache_file), "cannot create file"
get_logger().info("fastrun use in-file cache in {}".format(cache_dir))
except Exception as exc:
get_logger().error(
"failed to create cache file in {} {!r}; fallback to "
"in-memory cache".format(cache_dir, exc)
)
self.open_memory()
_manager = None
def get_manager():
global _manager
if _manager is None:
_manager = PersistentCacheManager()
return _manager
......@@ -23,6 +23,7 @@
#include "megbrain/common.h"
#include "megbrain/comp_node.h"
#include "megbrain/imperative/blob_manager.h"
#include "megbrain/imperative/persistent_cache.h"
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/tensor_sanity_check.h"
#include "megbrain/serialization/helper.h"
......@@ -229,83 +230,55 @@ void init_utils(py::module m) {
mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str());
});
using mgb::PersistentCache;
class PyPersistentCache : public mgb::PersistentCache {
private:
using KeyPair = std::pair<std::string, std::string>;
using BlobPtr = std::unique_ptr<Blob, void (*)(Blob*)>;
using PersistentCache = mgb::PersistentCache;
using ExtendedPersistentCache =
mgb::imperative::persistent_cache::ExtendedPersistentCache;
std::shared_mutex m_mutex;
std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache;
struct PersistentCacheManager {
std::shared_ptr<ExtendedPersistentCache> instance;
static size_t hash_key_pair(const KeyPair& kp) {
std::hash<std::string> hasher;
return hasher(kp.first) ^ hasher(kp.second);
bool try_reg(std::shared_ptr<ExtendedPersistentCache> cache) {
if (cache) {
instance = cache;
PersistentCache::set_impl(cache);
return true;
}
return false;
}
std::string blob_to_str(const Blob& key) {
return std::string(reinterpret_cast<const char*>(key.ptr), key.size);
bool open_redis(
std::string ip, size_t port, std::string password, std::string prefix) {
return try_reg(mgb::imperative::persistent_cache::make_redis(
ip, port, password, prefix));
}
BlobPtr copy_blob(const Blob& blob) {
auto blob_deleter = [](Blob* blob) {
if (blob) {
std::free(const_cast<void*>(blob->ptr));
delete blob;
}
};
auto blob_ptr = BlobPtr{new Blob(), blob_deleter};
blob_ptr->ptr = std::malloc(blob.size);
std::memcpy(const_cast<void*>(blob_ptr->ptr), blob.ptr, blob.size);
blob_ptr->size = blob.size;
return blob_ptr;
bool open_file(std::string path) {
return try_reg(mgb::imperative::persistent_cache::make_in_file(path));
}
BlobPtr str_to_blob(const std::string& str) {
auto blob = Blob{str.data(), str.size()};
return copy_blob(blob);
std::optional<size_t> clean() {
if (instance) {
return instance->clear();
}
return {};
}
std::unique_ptr<Blob, void (*)(Blob*)> empty_blob() {
return BlobPtr{nullptr, [](Blob* blob) {}};
void put(std::string category, std::string key, std::string value) {
PersistentCache::inst().put(
category, {key.data(), key.size()}, {value.data(), value.size()});
}
public:
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
auto py_get = [this](const std::string& category,
const Blob& key) -> mgb::Maybe<Blob> {
PYBIND11_OVERLOAD_PURE(
mgb::Maybe<Blob>, PersistentCache, get, category, key);
};
KeyPair kp = {category, blob_to_str(key)};
std::shared_lock<decltype(m_mutex)> rlock;
auto iter = m_local_cache.find(kp);
if (iter == m_local_cache.end()) {
auto py_ret = py_get(category, key);
if (!py_ret.valid()) {
iter = m_local_cache.insert({kp, empty_blob()}).first;
} else {
iter = m_local_cache.insert({kp, copy_blob(py_ret.val())}).first;
}
}
if (iter->second) {
return *iter->second;
py::object get(std::string category, std::string key) {
auto value =
PersistentCache::inst().get(category, {key.data(), key.size()});
if (value.valid()) {
return py::bytes(std::string((const char*)value->ptr, value->size));
} else {
return {};
return py::none();
}
}
void put(const std::string& category, const Blob& key, const Blob& value)
override {
KeyPair kp = {category, blob_to_str(key)};
std::unique_lock<decltype(m_mutex)> wlock;
m_local_cache.insert_or_assign(kp, copy_blob(value));
PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value);
}
};
py::class_<PersistentCache, PyPersistentCache, std::shared_ptr<PersistentCache>>(
m, "PersistentCache")
py::class_<PersistentCacheManager>(m, "PersistentCacheManager")
.def(py::init<>())
.def("get", &PersistentCache::get)
.def("put", &PersistentCache::put)
.def("reg", &PersistentCache::set_impl);
.def("try_open_redis", &PersistentCacheManager::open_redis)
.def("try_open_file", &PersistentCacheManager::open_file)
.def("clean", &PersistentCacheManager::clean)
.def("put", &PersistentCacheManager::put)
.def("get", &PersistentCacheManager::get);
}
import pytest
import megengine
from megengine.utils.persistent_cache import PersistentCacheOnServer
from megengine.utils.persistent_cache import _manager
@pytest.mark.skip(reason="fixme: github ci failed")
def test_persistent_cache():
pc = PersistentCacheOnServer()
pc = _manager
k0 = b"\x00\x00"
k1 = b"\x00\x01"
cat = "test"
......
/**
* \file imperative/src/impl/persistent_cache.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include <fstream>
#include <string>
#include <vector>
#include "cpp_redis/cpp_redis"
#include "megbrain/imperative/persistent_cache.h"
#include "megbrain/imperative/utils/base64.h"
#include "megbrain/utils/infile_persistent_cache.h"
namespace mgb::imperative::persistent_cache {
class RedisCache final : public ExtendedPersistentCache {
public:
RedisCache(std::string prefix, uint64_t timeout) : m_prefix(prefix) {
m_local = std::make_shared<mgb::InMemoryPersistentCache>();
}
bool connect(std::string ip, size_t port, std::string password) {
m_client.auth(password);
m_client.connect(
ip, port,
[](const std::string& host, std::size_t port,
cpp_redis::connect_state status) {
if (status == cpp_redis::connect_state::dropped) {
mgb_log("client disconnected from %s.", host.c_str());
mgb_log("Redis server connect to %s :%zu failed.", host.c_str(),
port);
}
},
std::uint32_t(200));
if (!m_client.is_connected()) {
return false;
}
auto flag = m_client.get("mgb-cache-flag");
sync();
return flag.get().ok();
}
bool valid() const override { return m_client.is_connected(); }
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
MGB_LOCK_GUARD(m_mtx);
auto mem_result = m_local->get(category, key);
if (mem_result.valid())
return mem_result;
std::string key_str(static_cast<const char*>(key.ptr), key.size);
std::string redis_key_str;
encode(category + '@' + key_str, redis_key_str, 24);
auto result = m_client.get(redis_key_str);
sync();
auto content = result.get();
if (content.is_null())
return mgb::None;
std::string decode_content;
decode(content.as_string(), decode_content);
m_local->put(category, key, {decode_content.data(), decode_content.length()});
return m_local->get(category, key);
}
void put(const std::string& category, const Blob& key, const Blob& value) override {
MGB_LOCK_GUARD(m_mtx);
std::string key_str(static_cast<const char*>(key.ptr), key.size);
std::string redis_key_str;
encode(category + '@' + key_str, redis_key_str);
std::string value_str(static_cast<const char*>(value.ptr), value.size);
std::string redis_value_str;
encode(value_str, redis_value_str);
auto result = m_client.set(redis_key_str, redis_value_str);
m_local->put(category, key, value);
sync();
}
std::optional<size_t> clear() override {
size_t cursor = 0, nr_deleted = 0;
std::string pattern = m_prefix + "@*";
do {
auto reply = m_client.scan(cursor, pattern).share();
sync();
auto keys = reply.get().as_array();
std::vector<std::string> string_keys;
for (auto&& key : keys) {
string_keys.push_back(key.as_string());
}
m_client.del(string_keys);
nr_deleted += string_keys.size();
cursor = reply.get().as_array()[0].as_integer();
} while (cursor != 0);
return nr_deleted;
}
private:
std::shared_ptr<mgb::PersistentCache> m_local;
std::mutex m_mtx;
cpp_redis::client m_client;
std::string m_prefix;
uint64_t m_timeout;
void sync() {
m_client.sync_commit<double, std::milli>(std::chrono::milliseconds(m_timeout));
mgb_assert(valid());
}
};
class ExtendedInFilePersistentCache final : public ExtendedPersistentCache {
private:
std::string m_path;
std::unique_ptr<mgb::InFilePersistentCache> m_impl;
public:
ExtendedInFilePersistentCache() = default;
bool open(std::string path) {
std::fstream file;
file.open(path, std::ios::in | std::ios::binary);
if (!file.is_open()) {
return false;
}
std::vector<char> bytes = {
std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()};
if (bytes.size()) {
m_impl = std::make_unique<mgb::InFilePersistentCache>(
(const uint8_t*)bytes.data(), bytes.size());
} else {
m_impl = std::make_unique<mgb::InFilePersistentCache>();
}
m_path = path;
return true;
}
~ExtendedInFilePersistentCache() {
if (m_impl) {
m_impl->dump_cache(m_path.c_str());
}
}
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
return m_impl->get(category, key);
}
void put(const std::string& category, const Blob& key, const Blob& value) override {
return m_impl->put(category, key, value);
}
std::optional<size_t> clear() override {
m_impl = std::make_unique<mgb::InFilePersistentCache>();
m_impl->dump_cache(m_path.c_str());
return {};
}
bool valid() const override { return m_impl != nullptr; }
};
std::shared_ptr<ExtendedPersistentCache> make_redis(
std::string ip, size_t port, std::string password, std::string prefix) {
auto cache = std::make_shared<RedisCache>(prefix, 100);
if (!cache->connect(ip, port, password)) {
return nullptr;
}
return cache;
}
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) {
auto cache = std::make_shared<ExtendedInFilePersistentCache>();
if (!cache->open(path)) {
return nullptr;
}
return cache;
}
} // namespace mgb::imperative::persistent_cache
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file imperative/src/impl/base64.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/utils/base64.h"
namespace mgb::imperative {
namespace {
/*
** Translation Table as described in RFC1113
*/
const char cb64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
/*
** Translation Table to decode:
*https://github.com/dgiardini/imgcalkap/blob/master/base64.c
*/
const char cd64[] =
"|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`"
"abcdefghijklmnopq";
/*
** encodeblock
**
** encode 3 8-bit binary bytes as 4 '6-bit' characters
*/
void encodeblock(unsigned char in[3], unsigned char out[4], int len) {
out[0] = cb64[in[0] >> 2];
out[1] = cb64[((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4)];
out[2] =
(unsigned char)(len > 1 ? cb64[((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6)] : '=');
out[3] = (unsigned char)(len > 2 ? cb64[in[2] & 0x3f] : '=');
}
/*
** decodeblock
**
** decode 4 '6-bit' characters into 3 8-bit binary bytes
*/
void decodeblock(unsigned char in[4], unsigned char out[3]) {
out[0] = (unsigned char)(in[0] << 2 | in[1] >> 4);
out[1] = (unsigned char)(in[1] << 4 | in[2] >> 2);
out[2] = (unsigned char)(((in[2] << 6) & 0xc0) | in[3]);
}
} // namespace
/**
* Encode string to base64 string
* @param input - source string
* @param outdata - target base64 string
* @param linesize - max size of line
*/
void encode(
const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata,
int linesize) {
outdata.clear();
unsigned char in[3], out[4];
int i, len, blocksout = 0;
size_t j = 0;
auto* indata = reinterpret_cast<const unsigned char*>(input.data());
unsigned int insize = input.size();
while (j <= insize) {
len = 0;
for (i = 0; i < 3; i++) {
in[i] = (unsigned char)indata[j];
j++;
if (j <= insize) {
len++;
} else {
in[i] = 0;
}
}
if (len) {
encodeblock(in, out, len);
for (i = 0; i < 4; i++) {
outdata.push_back(out[i]);
}
blocksout++;
}
if (blocksout >= (linesize / 4) || (j == insize)) {
if (blocksout) {
outdata.push_back('\r');
outdata.push_back('\n');
}
blocksout = 0;
}
}
}
/**
* Decode base64 string ot source
* @param input - base64 string
* @param outdata - source string
*/
void decode(
const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata) {
outdata.clear();
unsigned char in[4], out[3], v;
int i, len;
size_t j = 0;
auto* indata = reinterpret_cast<const unsigned char*>(input.data());
unsigned int insize = input.size();
while (j <= insize) {
for (len = 0, i = 0; i < 4 && (j <= insize); i++) {
v = 0;
while ((j <= insize) && v == 0) {
v = (unsigned char)indata[j++];
v = (unsigned char)((v < 43 || v > 122) ? 0 : cd64[v - 43]);
if (v) {
v = (unsigned char)((v == '$') ? 0 : v - 61);
}
}
if (j <= insize) {
len++;
if (v) {
in[i] = (unsigned char)(v - 1);
}
} else {
in[i] = 0;
}
}
if (len) {
decodeblock(in, out);
for (i = 0; i < len - 1; i++) {
outdata.push_back(out[i]);
}
}
}
}
/**
* Encode binary data to base64 buffer
* @param input - source data
* @param outdata - target base64 buffer
* @param linesize
*/
void encode(const std::string& input, std::string& outdata, int linesize) {
std::vector<std::uint8_t> out;
std::vector<std::uint8_t> in(input.begin(), input.end());
encode(in, out, linesize);
outdata = std::string(out.begin(), out.end());
}
/**
* Decode base64 buffer to source binary data
* @param input - base64 buffer
* @param outdata - source binary data
*/
void decode(const std::string& input, std::string& outdata) {
std::vector<std::uint8_t> in(input.begin(), input.end());
std::vector<std::uint8_t> out;
decode(in, out);
outdata = std::string(out.begin(), out.end());
}
} // namespace mgb::imperative
/**
* \file imperative/src/include/megbrain/imperative/persistent_cache.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <memory>
#include "megbrain/utils/persistent_cache.h"
namespace mgb::imperative::persistent_cache {
class ExtendedPersistentCache : public mgb::PersistentCache {
public:
virtual bool valid() const = 0;
virtual std::optional<size_t> clear() = 0;
};
std::shared_ptr<ExtendedPersistentCache> make_redis(
std::string ip, size_t port, std::string password, std::string prefix);
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path);
} // namespace mgb::imperative::persistent_cache
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file imperative/src/include/megbrain/imperative/utils/base64.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/common.h"
namespace mgb::imperative {
/**
* Encode string to base64 string
* @param input - source string
* @param outdata - target base64 string
* @param linesize - max size of line
*/
void encode(
const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata,
int linesize = 76);
/**
* Decode base64 string ot source
* @param input - base64 string
* @param outdata - source string
*/
void decode(const std::vector<std::uint8_t>& input, std::vector<std::uint8_t>& outdata);
/**
* Encode binary data to base64 buffer
* @param input - source data
* @param outdata - target base64 buffer
* @param linesize
*/
void encode(const std::string& input, std::string& outdata, int linesize = 76);
/**
* Decode base64 buffer to source binary data
* @param input - base64 buffer
* @param outdata - source binary data
*/
void decode(const std::string& input, std::string& outdata);
} // namespace mgb::imperative
......@@ -12,7 +12,7 @@ endif()
# TODO: turn python binding into a static/object library
add_executable(imperative_test ${SOURCES} ${SRCS})
add_dependencies(imperative_test mgb_opdef)
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR})
target_include_directories(imperative_test PRIVATE ${MGB_TEST_DIR}/include ../src/include ${MGB_OPDEF_OUT_DIR} ${CPP_REDIS_INCLUDES})
# Python binding
target_include_directories(imperative_test PRIVATE ${MODULE_SRC_INCLUDE} ${PYTHON_INCLUDE_DIRS} ${NUMPY_INCLUDE_DIR})
......
......@@ -74,14 +74,19 @@ class InMemoryPersistentCache final : public PersistentCache {
};
};
Maybe<Blob> get(const std::string& category, const Blob& key) override;
void put(const std::string& category, const Blob& key, const Blob& value) override;
MGE_WIN_DECLSPEC_FUC Maybe<Blob> get(
const std::string& category, const Blob& key) override;
MGE_WIN_DECLSPEC_FUC void put(
const std::string& category, const Blob& key, const Blob& value) override;
std::unordered_map<
std::string,
std::unordered_map<BlobStorage, BlobStorage, BlobStorage::Hash>>
m_cache;
MGB_MUTEX m_mtx;
public:
MGE_WIN_DECLSPEC_FUC InMemoryPersistentCache() = default;
};
/*!
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册