提交 ab6328c5 编写于 作者: M Megvii Engine Team

feat(imperative): port persistent cache

GitOrigin-RevId: 8ca24a37cc28a0be3f659e0e8863fee1beac3a38
上级 60c6d59f
......@@ -78,13 +78,18 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save
from .tensor import Parameter, Tensor, tensor
from .version import __version__
from .utils import persistent_cache, comp_graph_tools as cgtools
_set_fork_exec_path_for_timed_func(
sys.executable,
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"),
)
_persistent_cache_impl_ins = persistent_cache.PersistentCacheOnServer()
_persistent_cache_impl_ins.reg()
atexit.register(sync)
del sync
del _set_fork_exec_path_for_timed_func
del _persistent_cache_impl_ins
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import getpass
import json
import os
import shelve
from ..core._imperative_rt import PersistentCache as _PersistentCache
from ..logger import get_logger
from ..version import __version__
class _FakeRedisConn:
def __init__(self):
try:
from ..hub.hub import _get_megengine_home
cache_dir = os.path.expanduser(
os.path.join(_get_megengine_home(), "persistent_cache")
)
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "cache")
self._dict = shelve.open(cache_file)
self._is_shelve = True
except:
self._dict = {}
self._is_shelve = False
def get(self, key):
if self._is_shelve and isinstance(key, bytes):
key = key.decode("utf-8")
return self._dict.get(key)
def set(self, key, val):
if self._is_shelve and isinstance(key, bytes):
key = key.decode("utf-8")
self._dict[key] = val
def __del__(self):
if self._is_shelve:
self._dict.close()
class PersistentCacheOnServer(_PersistentCache):
_cached_conn = None
_prefix = None
_prev_get_refkeep = None
@property
def _conn(self):
"""get redis connection"""
if self._cached_conn is None:
self._cached_conn = _FakeRedisConn()
self._prefix = self.make_user_prefix()
return self._cached_conn
@classmethod
def make_user_prefix(cls):
return "mgbcache:{}".format(getpass.getuser())
def _make_key(self, category, key):
prefix_with_version = "{}:MGB{}".format(self._prefix, __version__)
return b"@".join(
(prefix_with_version.encode("ascii"), category.encode("ascii"), key)
)
def put(self, category, key, value):
conn = self._conn
key = self._make_key(category, key)
conn.set(key, value)
def get(self, category, key):
conn = self._conn
key = self._make_key(category, key)
self._prev_get_refkeep = conn.get(key)
return self._prev_get_refkeep
......@@ -12,6 +12,7 @@
#pragma once
#include "megbrain/graph.h"
#include "megbrain/utils/persistent_cache.h"
#include <Python.h>
#include <string>
......@@ -328,6 +329,49 @@ namespace detail {
template<> struct type_caster<mgb::CompNode> : public from_none_caster<mgb::CompNode> {};
template <> struct type_caster<mgb::PersistentCache::Blob> {
PYBIND11_TYPE_CASTER(mgb::PersistentCache::Blob, _("Blob"));
public:
bool load(handle src, bool convert) {
if (!isinstance<bytes>(src)) {
return false;
}
value.ptr = PYBIND11_BYTES_AS_STRING(src.ptr());
value.size = PYBIND11_BYTES_SIZE(src.ptr());
return true;
}
static handle cast(mgb::PersistentCache::Blob blob, return_value_policy /* policy */, handle /* parent */) {
return bytes((const char*)blob.ptr, blob.size);
}
};
template <typename T> struct type_caster<mgb::Maybe<T>> {
using value_conv = make_caster<T>;
PYBIND11_TYPE_CASTER(mgb::Maybe<T>, _("Optional[") + value_conv::name + _("]"));
public:
bool load(handle src, bool convert) {
if(!src) {
return false;
}
if (src.is_none()) {
return true;
}
value_conv inner_caster;
if (!inner_caster.load(src, convert)) {
return false;
}
value.emplace(cast_op<T&&>(std::move(inner_caster)));
return true;
}
static handle cast(mgb::Maybe<T> src, return_value_policy policy, handle parent) {
if(!src.valid()) {
return none().inc_ref();
}
return pybind11::cast(src.val(), policy, parent);
}
};
} // detail
} // PYBIND11_NAMESPACE
......
......@@ -25,6 +25,7 @@
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/tensor_sanity_check.h"
#include "megbrain/serialization/helper.h"
#include "megbrain/utils/persistent_cache.h"
#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/mm_handler.h"
......@@ -262,4 +263,20 @@ void init_utils(py::module m) {
m.def("_timed_func_exec_cb", [](const std::string& user_data){
mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str());
});
using mgb::PersistentCache;
class PyPersistentCache: public mgb::PersistentCache{
public:
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
PYBIND11_OVERLOAD_PURE(mgb::Maybe<Blob>, PersistentCache, get, category, key);
}
void put(const std::string& category, const Blob& key, const Blob& value) override {
PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value);
}
};
py::class_<PersistentCache, PyPersistentCache, std::shared_ptr<PersistentCache>>(m, "PersistentCache")
.def(py::init<>())
.def("get", &PersistentCache::get)
.def("put", &PersistentCache::put)
.def("reg", &PersistentCache::set_impl);
}
import pytest
import megengine
from megengine.utils.persistent_cache import PersistentCacheOnServer
def test_persistent_cache():
pc = PersistentCacheOnServer()
k0 = b"\x00\x00"
k1 = b"\x00\x01"
cat = "test"
pc.put(cat, k0, k1)
pc.put(cat, k1, k0)
assert k1 == pc.get(cat, k0)
assert k0 == pc.get(cat, k1)
assert pc.get("test1", k0) == None
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册