diff --git a/imperative/python/src/utils.cpp b/imperative/python/src/utils.cpp index 3852a95cddc99860ca1e58beec68131d1eae104e..e43a8ac57c27be67e182f7c45f09e54198bd995d 100644 --- a/imperative/python/src/utils.cpp +++ b/imperative/python/src/utils.cpp @@ -237,15 +237,67 @@ void init_utils(py::module m) { mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); }); using mgb::PersistentCache; - class PyPersistentCache: public mgb::PersistentCache{ + class PyPersistentCache: public mgb::PersistentCache { + private: + using KeyPair = std::pair; + using BlobPtr = std::unique_ptr; + + static size_t hash_key_pair(const KeyPair& kp) { + std::hash hasher; + return hasher(kp.first) ^ hasher(kp.second); + } + + std::string blob_to_str(const Blob& key) { + return std::string(reinterpret_cast(key.ptr), key.size); + } + + BlobPtr copy_blob(const Blob& blob) { + auto blob_deleter = [](Blob* blob){ + if (blob) { + std::free(const_cast(blob->ptr)); + delete blob; + } + }; + auto blob_ptr = BlobPtr{ new Blob(), blob_deleter }; + blob_ptr->ptr = std::malloc(blob.size); + std::memcpy(const_cast(blob_ptr->ptr), blob.ptr, blob.size); + blob_ptr->size = blob.size; + return blob_ptr; + } + + BlobPtr str_to_blob(const std::string& str) { + auto blob = Blob{ str.data(), str.size() }; + return copy_blob(blob); + } + + std::unique_ptr empty_blob() { + return BlobPtr{ nullptr, [](Blob* blob){} }; + } public: mgb::Maybe get(const std::string& category, const Blob& key) override { - PYBIND11_OVERLOAD_PURE(mgb::Maybe, PersistentCache, get, category, key); + thread_local std::unordered_map m_local_cache; + auto py_get = [this](const std::string& category, const Blob& key) -> mgb::Maybe { + PYBIND11_OVERLOAD_PURE(mgb::Maybe, PersistentCache, get, category, key); + }; + KeyPair kp = { category, blob_to_str(key) }; + auto iter = m_local_cache.find(kp); + if (iter == m_local_cache.end()) { + auto py_ret = py_get(category, key); + if (!py_ret.valid()) { + iter = m_local_cache.insert({kp, empty_blob()}).first; + } else { + iter = m_local_cache.insert({kp, copy_blob(py_ret.val())}).first; + } + } + if (iter->second) { + return *iter->second; + } else { + return {}; + } } void put(const std::string& category, const Blob& key, const Blob& value) override { PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); } - }; py::class_>(m, "PersistentCache") .def(py::init<>())