提交 9f5bb42d 编写于 作者: M Megvii Engine Team

perf(fastrun): cache persistent cache

GitOrigin-RevId: 21f7d4c19d5e324555412d2e5c6ea4ac2908acfc
上级 6ab1c55d
......@@ -237,15 +237,67 @@ void init_utils(py::module m) {
mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str());
});
using mgb::PersistentCache;
class PyPersistentCache: public mgb::PersistentCache{
class PyPersistentCache: public mgb::PersistentCache {
private:
using KeyPair = std::pair<std::string, std::string>;
using BlobPtr = std::unique_ptr<Blob, void(*)(Blob*)>;
static size_t hash_key_pair(const KeyPair& kp) {
std::hash<std::string> hasher;
return hasher(kp.first) ^ hasher(kp.second);
}
std::string blob_to_str(const Blob& key) {
return std::string(reinterpret_cast<const char*>(key.ptr), key.size);
}
BlobPtr copy_blob(const Blob& blob) {
auto blob_deleter = [](Blob* blob){
if (blob) {
std::free(const_cast<void*>(blob->ptr));
delete blob;
}
};
auto blob_ptr = BlobPtr{ new Blob(), blob_deleter };
blob_ptr->ptr = std::malloc(blob.size);
std::memcpy(const_cast<void*>(blob_ptr->ptr), blob.ptr, blob.size);
blob_ptr->size = blob.size;
return blob_ptr;
}
BlobPtr str_to_blob(const std::string& str) {
auto blob = Blob{ str.data(), str.size() };
return copy_blob(blob);
}
std::unique_ptr<Blob, void(*)(Blob*)> empty_blob() {
return BlobPtr{ nullptr, [](Blob* blob){} };
}
public:
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
thread_local std::unordered_map<KeyPair, BlobPtr, mgb::pairhash> m_local_cache;
auto py_get = [this](const std::string& category, const Blob& key) -> mgb::Maybe<Blob> {
PYBIND11_OVERLOAD_PURE(mgb::Maybe<Blob>, PersistentCache, get, category, key);
};
KeyPair kp = { category, blob_to_str(key) };
auto iter = m_local_cache.find(kp);
if (iter == m_local_cache.end()) {
auto py_ret = py_get(category, key);
if (!py_ret.valid()) {
iter = m_local_cache.insert({kp, empty_blob()}).first;
} else {
iter = m_local_cache.insert({kp, copy_blob(py_ret.val())}).first;
}
}
if (iter->second) {
return *iter->second;
} else {
return {};
}
}
void put(const std::string& category, const Blob& key, const Blob& value) override {
PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value);
}
};
py::class_<PersistentCache, PyPersistentCache, std::shared_ptr<PersistentCache>>(m, "PersistentCache")
.def(py::init<>())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册