提交 0ac642b5 编写于 作者: M Megvii Engine Team

fix(imperative): persistent cache write through on put

GitOrigin-RevId: f9408ae5046a28b06b6de914f430a15c78ab82d9
上级 47dcdf3e
......@@ -18,6 +18,7 @@
#include <pybind11/operators.h>
#include <atomic>
#include <cstdint>
#include <shared_mutex>
#include "./imperative_rt.h"
#include "megbrain/common.h"
#include "megbrain/comp_node.h"
......@@ -236,12 +237,16 @@ void init_utils(py::module m) {
m.def("_timed_func_exec_cb", [](const std::string& user_data){
mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str());
});
using mgb::PersistentCache;
class PyPersistentCache: public mgb::PersistentCache {
private:
using KeyPair = std::pair<std::string, std::string>;
using BlobPtr = std::unique_ptr<Blob, void(*)(Blob*)>;
std::shared_mutex m_mutex;
std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache;
static size_t hash_key_pair(const KeyPair& kp) {
std::hash<std::string> hasher;
return hasher(kp.first) ^ hasher(kp.second);
......@@ -275,11 +280,11 @@ void init_utils(py::module m) {
}
public:
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
thread_local std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache;
auto py_get = [this](const std::string& category, const Blob& key) -> mgb::Maybe<Blob> {
PYBIND11_OVERLOAD_PURE(mgb::Maybe<Blob>, PersistentCache, get, category, key);
};
KeyPair kp = { category, blob_to_str(key) };
std::shared_lock<decltype(m_mutex)> rlock;
auto iter = m_local_cache.find(kp);
if (iter == m_local_cache.end()) {
auto py_ret = py_get(category, key);
......@@ -296,6 +301,9 @@ void init_utils(py::module m) {
}
}
void put(const std::string& category, const Blob& key, const Blob& value) override {
KeyPair kp = { category, blob_to_str(key) };
std::unique_lock<decltype(m_mutex)> wlock;
m_local_cache.insert_or_assign(kp, copy_blob(value));
PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册