提交 1657b8e8 编写于 作者: M Megvii Engine Team

fix(fastrun): fix persistent_cache in redis

GitOrigin-RevId: ada5862b057dd7310e63a535874b00da882b21ba
上级 a404cd7d
......@@ -84,7 +84,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level
from .serialization import load, save
from .tensor import Parameter, Tensor, tensor
from .utils import comp_graph_tools as cgtools
from .utils import persistent_cache
from .utils.persistent_cache import PersistentCacheOnServer as _PersistentCacheOnServer
from .version import __version__
_set_fork_exec_path_for_timed_func(
......@@ -92,15 +92,13 @@ _set_fork_exec_path_for_timed_func(
os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"),
)
atexit.register(_close)
del _set_fork_exec_path_for_timed_func
_exit_handlers = []
def _run_exit_handlers():
for handler in _exit_handlers:
for handler in reversed(_exit_handlers):
handler()
_exit_handlers.clear()
......@@ -117,6 +115,13 @@ def _atexit(handler):
_exit_handlers.append(handler)
_atexit(_close)
_persistent_cache = _PersistentCacheOnServer()
_persistent_cache.reg()
_atexit(_persistent_cache.flush)
# subpackages
import megengine.amp
import megengine.autodiff
......@@ -132,5 +137,3 @@ import megengine.quantization
import megengine.random
import megengine.utils
import megengine.traced_module
persistent_cache.get_manager()
......@@ -8,87 +8,114 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import contextlib
import getpass
import os
import sys
import urllib.parse
from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager
import filelock
from ..core._imperative_rt import PersistentCache as _PersistentCache
from ..logger import get_logger
from ..version import __version__, git_version
class PersistentCacheManager(_PersistentCacheManager):
class PersistentCacheOnServer(_PersistentCache):
def __init__(self):
super().__init__()
if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY":
get_logger().info("fastrun use in-memory cache")
self.open_memory()
elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE":
self.open_file()
else:
self.open_redis()
def open_memory(self):
pass
cache_type = os.getenv("MGE_FASTRUN_CACHE_TYPE")
if cache_type not in ("FILE", "MEMORY"):
try:
redis_config = self.get_redis_config()
except Exception as exc:
get_logger().error(
"failed to connect to cache server {!r}; try fallback to "
"in-file cache".format(exc)
)
else:
self.add_config(
"redis",
redis_config,
"fastrun use redis cache",
"failed to connect to cache server",
)
if cache_type != "MEMORY":
path = self.get_cache_file(self.get_cache_dir())
self.add_config(
"in-file",
{"path": path},
"fastrun use in-file cache in {}".format(path),
"failed to create cache file in {}".format(path),
)
self.add_config(
"in-memory",
{},
"fastrun use in-memory cache",
"failed to create in-memory cache",
)
def open_file(self):
def get_cache_dir(self):
cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR")
try:
if not cache_dir:
from ..hub.hub import _get_megengine_home
if not cache_dir:
from ..hub.hub import _get_megengine_home
cache_dir = os.path.expanduser(
os.path.join(_get_megengine_home(), "persistent_cache.bin")
)
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "cache")
with open(cache_file, "a"):
pass
assert self.try_open_file(cache_file), "cannot create file"
get_logger().info("fastrun use in-file cache in {}".format(cache_dir))
except Exception as exc:
get_logger().error(
"failed to create cache file in {} {!r}; fallback to "
"in-memory cache".format(cache_dir, exc)
cache_dir = os.path.expanduser(
os.path.join(_get_megengine_home(), "persistent_cache")
)
self.open_memory()
def open_redis(self):
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
def get_cache_file(self, cache_dir):
cache_file = os.path.join(cache_dir, "cache.bin")
with open(cache_file, "a"):
pass
return cache_file
@contextlib.contextmanager
def lock_cache_file(self, cache_dir):
lock_file = os.path.join(cache_dir, "cache.lock")
with filelock.FileLock(lock_file):
yield
def get_redis_config(self):
url = os.getenv("MGE_FASTRUN_CACHE_URL")
if url is None:
return None
assert sys.platform != "win32", "redis cache on windows not tested"
prefix = "mgbcache:{}:MGB{}:GIT:{}".format(
getpass.getuser(), __version__, git_version
)
url = os.getenv("MGE_FASTRUN_CACHE_URL")
if url is None:
self.open_file()
try:
assert sys.platform != "win32", "redis cache on windows not tested"
parse_result = urllib.parse.urlparse(url, scheme="redis")
assert parse_result.scheme == "redis", "unsupported scheme"
assert not parse_result.username, "redis conn with username unsupported"
assert self.try_open_redis(
parse_result.hostname, parse_result.port, parse_result.password, prefix
), "connect failed"
except Exception as exc:
get_logger().error(
"failed to connect to cache server {!r}; try fallback to "
"in-file cache".format(exc)
)
self.open_file()
_manager = None
parse_result = urllib.parse.urlparse(url)
assert not parse_result.username, "redis conn with username unsupported"
if parse_result.scheme == "redis":
assert parse_result.hostname and parse_result.port, "invalid url"
assert not parse_result.path
config = {
"hostname": parse_result.hostname,
"port": str(parse_result.port),
}
elif parse_result.scheme == "redis+socket":
assert not (parse_result.hostname or parse_result.port)
assert parse_result.path
config = {
"unixsocket": parse_result.path,
}
else:
assert False, "unsupported scheme"
if parse_result.password is not None:
config["password"] = parse_result.password
config["prefix"] = prefix
return config
def get_manager():
global _manager
if _manager is None:
_manager = PersistentCacheManager()
return _manager
def flush(self):
if self.config is not None and self.config.type == "in-file":
with self.lock_cache_file(self.get_cache_dir()):
super().flush()
def _clean():
nr_del = get_manager().clean()
nr_del = PersistentCacheOnServer().clean()
if nr_del is not None:
print("{} cache entries deleted".format(nr_del))
......
......@@ -4,8 +4,8 @@ pyarrow
requests
tabulate
tqdm
redispy
deprecated
mprop
wheel
megfile>=0.0.10
\ No newline at end of file
megfile>=0.0.10
filelock
......@@ -210,7 +210,7 @@ void init_utils(py::module m) {
.def("disable", [](TensorSanityCheck& checker) { checker.disable(); });
#if MGB_ENABLE_OPR_MM
m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"),
m.def("create_mm_server", &mgb::opr::create_zmqrpc_server, py::arg("addr"),
py::arg("port") = 0);
#else
m.def("create_mm_server", []() {});
......@@ -234,51 +234,108 @@ void init_utils(py::module m) {
using ExtendedPersistentCache =
mgb::imperative::persistent_cache::ExtendedPersistentCache;
struct PersistentCacheManager {
std::shared_ptr<ExtendedPersistentCache> instance;
struct ConfigurablePersistentCache : mgb::PersistentCache {
struct Config {
std::string type;
std::unordered_map<std::string, std::string> args;
std::string on_success;
std::string on_fail;
};
bool try_reg(std::shared_ptr<ExtendedPersistentCache> cache) {
if (cache) {
instance = cache;
PersistentCache::set_impl(cache);
return true;
}
return false;
}
bool open_redis(
std::string ip, size_t port, std::string password, std::string prefix) {
return try_reg(mgb::imperative::persistent_cache::make_redis(
ip, port, password, prefix));
std::shared_ptr<ExtendedPersistentCache> impl;
std::optional<Config> impl_config;
std::vector<Config> configs;
void add_config(
std::string type, std::unordered_map<std::string, std::string> args,
std::string on_success, std::string on_fail) {
configs.push_back({type, args, on_success, on_fail});
}
bool open_file(std::string path) {
return try_reg(mgb::imperative::persistent_cache::make_in_file(path));
std::optional<size_t> clean() { return get_impl()->clear(); }
void load_config() {
std::optional<std::string> err_msg;
for (size_t i = 0; i < configs.size(); ++i) {
auto& config = configs[i];
if (err_msg) {
mgb_log_warn("try fallback to %s cache", config.type.c_str());
} else {
err_msg.emplace();
}
auto cache = ExtendedPersistentCache::make_from_config(
config.type, config.args, *err_msg);
if (!cache) {
mgb_log_warn("%s %s", config.on_fail.c_str(), err_msg->c_str());
} else {
impl = cache;
impl_config = config;
break;
}
}
mgb_assert(impl_config.has_value(), "not valid config");
}
std::optional<size_t> clean() {
if (instance) {
return instance->clear();
std::shared_ptr<ExtendedPersistentCache> get_impl() {
if (!impl) {
load_config();
}
return {};
return impl;
}
void put(std::string category, std::string key, std::string value) {
PersistentCache::inst().put(
category, {key.data(), key.size()}, {value.data(), value.size()});
virtual mgb::Maybe<Blob> get(const std::string& category, const Blob& key) {
return get_impl()->get(category, key);
}
virtual void put(
const std::string& category, const Blob& key, const Blob& value) {
return get_impl()->put(category, key, value);
}
py::object get(std::string category, std::string key) {
auto value =
PersistentCache::inst().get(category, {key.data(), key.size()});
virtual bool support_dump_cache() { return get_impl()->support_dump_cache(); }
py::object py_get(std::string category, std::string key) {
auto value = get_impl()->get(category, {key.data(), key.size()});
if (value.valid()) {
return py::bytes(std::string((const char*)value->ptr, value->size));
} else {
return py::none();
}
}
void py_put(std::string category, std::string key, std::string value) {
get_impl()->put(
category, {key.data(), key.size()}, {value.data(), value.size()});
}
void flush() {
if (impl) {
impl->flush();
}
}
};
py::class_<PersistentCacheManager>(m, "PersistentCacheManager")
.def(py::init<>())
.def("try_open_redis", &PersistentCacheManager::open_redis)
.def("try_open_file", &PersistentCacheManager::open_file)
.def("clean", &PersistentCacheManager::clean)
.def("put", &PersistentCacheManager::put)
.def("get", &PersistentCacheManager::get);
auto PyConfigurablePersistentCache =
py::class_<
ConfigurablePersistentCache,
std::shared_ptr<ConfigurablePersistentCache>>(m, "PersistentCache")
.def(py::init<>())
.def("add_config", &ConfigurablePersistentCache::add_config)
.def("reg",
[](std::shared_ptr<ConfigurablePersistentCache> inst) {
PersistentCache::set_impl(inst);
})
.def("clean", &ConfigurablePersistentCache::clean)
.def("get", &ConfigurablePersistentCache::py_get)
.def("put", &ConfigurablePersistentCache::py_put)
.def_readonly("config", &ConfigurablePersistentCache::impl_config)
.def("flush", &ConfigurablePersistentCache::flush);
py::class_<ConfigurablePersistentCache::Config>(
PyConfigurablePersistentCache, "Config")
.def_readwrite("type", &ConfigurablePersistentCache::Config::type)
.def_readwrite("args", &ConfigurablePersistentCache::Config::args)
.def_readwrite("on_fail", &ConfigurablePersistentCache::Config::on_fail)
.def_readwrite(
"on_success", &ConfigurablePersistentCache::Config::on_success);
}
......@@ -27,7 +27,7 @@ namespace imperative {
namespace {
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& comm = def.cast_final_safe<CollectiveComm>();
auto group_client = std::make_shared<GroupClientProxy>(
auto group_client = std::make_shared<opr::GroupClientProxy>(
ssprintf("%s:%d", comm.addr.data(), comm.port));
SmallVector<std::shared_ptr<mgb::DeviceTensorND>> dev_buffer_arr(1, nullptr);
auto disable = std::make_shared<DTypeScalar>();
......
......@@ -28,7 +28,7 @@ namespace {
cg::OperatorNodeBase* apply_on_var_node_remote_send(
const OpDef& def, const VarNodeArray& inputs) {
auto&& send = def.cast_final_safe<RemoteSend>();
auto group_client = std::make_shared<GroupClientProxy>(
auto group_client = std::make_shared<opr::GroupClientProxy>(
ssprintf("%s:%d", send.addr.data(), send.port));
auto&& graph = inputs[0]->owner_graph();
......@@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv(
auto&& recv = def.cast_final_safe<RemoteRecv>();
OperatorNodeConfig config{recv.cn};
config.name(recv.make_name());
auto group_client = std::make_shared<GroupClientProxy>(
auto group_client = std::make_shared<opr::GroupClientProxy>(
ssprintf("%s:%d", recv.addr.data(), recv.port));
auto&& graph = inputs[0]->owner_graph();
return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>(
......
......@@ -27,8 +27,10 @@ public:
m_local = std::make_shared<mgb::InMemoryPersistentCache>();
}
bool connect(std::string ip, size_t port, std::string password) {
m_client.auth(password);
void connect(std::string ip, size_t port, std::optional<std::string> password) {
if (password) {
m_client.auth(*password);
}
m_client.connect(
ip, port,
[](const std::string& host, std::size_t port,
......@@ -40,16 +42,32 @@ public:
}
},
std::uint32_t(200));
if (!m_client.is_connected()) {
return false;
}
mgb_assert(m_client.is_connected(), "connect failed");
auto flag = m_client.get("mgb-cache-flag");
sync();
return flag.get().ok();
auto is_valid = [](const cpp_redis::reply& reply) {
switch (reply.get_type()) {
case cpp_redis::reply::type::error:
case cpp_redis::reply::type::null:
return false;
case cpp_redis::reply::type::integer:
return reply.as_integer() != 0;
case cpp_redis::reply::type::simple_string:
case cpp_redis::reply::type::bulk_string:
return !reply.as_string().empty();
case cpp_redis::reply::type::array:
return !reply.as_array().empty();
default:
mgb_assert(false, "unknown reply type %d", (int)reply.get_type());
}
};
mgb_assert(is_valid(flag.get()), "invalid mgb-cache-flag");
}
bool valid() const override { return m_client.is_connected(); }
void flush() override {}
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
MGB_LOCK_GUARD(m_mtx);
auto mem_result = m_local->get(category, key);
......@@ -75,7 +93,7 @@ public:
MGB_LOCK_GUARD(m_mtx);
std::string key_str(static_cast<const char*>(key.ptr), key.size);
std::string redis_key_str;
encode(category + '@' + key_str, redis_key_str);
encode(category + '@' + key_str, redis_key_str, 24);
std::string value_str(static_cast<const char*>(value.ptr), value.size);
std::string redis_value_str;
encode(value_str, redis_value_str);
......@@ -118,18 +136,16 @@ private:
class ExtendedInFilePersistentCache final : public ExtendedPersistentCache {
private:
std::string m_path;
std::optional<std::string> m_path;
std::unique_ptr<mgb::InFilePersistentCache> m_impl;
public:
ExtendedInFilePersistentCache() = default;
bool open(std::string path) {
void open(std::string path) {
std::fstream file;
file.open(path, std::ios::in | std::ios::binary);
if (!file.is_open()) {
return false;
}
mgb_assert(file.is_open(), "can't open file in %s", path.c_str());
std::vector<char> bytes = {
std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>()};
if (bytes.size()) {
......@@ -139,14 +155,11 @@ public:
m_impl = std::make_unique<mgb::InFilePersistentCache>();
}
m_path = path;
return true;
}
~ExtendedInFilePersistentCache() {
if (m_impl) {
m_impl->dump_cache(m_path.c_str());
}
}
void open() { m_impl = std::make_unique<mgb::InFilePersistentCache>(); }
~ExtendedInFilePersistentCache() { flush(); }
mgb::Maybe<Blob> get(const std::string& category, const Blob& key) override {
return m_impl->get(category, key);
......@@ -157,29 +170,64 @@ public:
}
std::optional<size_t> clear() override {
m_impl = std::make_unique<mgb::InFilePersistentCache>();
m_impl->dump_cache(m_path.c_str());
if (m_impl) {
m_impl = std::make_unique<mgb::InFilePersistentCache>();
if (m_path) {
m_impl->dump_cache(m_path->c_str());
}
}
return {};
}
bool valid() const override { return m_impl != nullptr; }
};
std::shared_ptr<ExtendedPersistentCache> make_redis(
std::string ip, size_t port, std::string password, std::string prefix) {
auto cache = std::make_shared<RedisCache>(prefix, 100);
if (!cache->connect(ip, port, password)) {
return nullptr;
void flush() override {
if (m_impl && m_path) {
m_impl->dump_cache(m_path->c_str());
}
}
return cache;
}
};
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path) {
auto cache = std::make_shared<ExtendedInFilePersistentCache>();
if (!cache->open(path)) {
return nullptr;
std::shared_ptr<ExtendedPersistentCache> ExtendedPersistentCache::make_from_config(
std::string type, std::unordered_map<std::string, std::string> args,
std::string& err_msg) {
try {
if (type == "redis") {
std::string prefix = args.at("prefix");
std::optional<std::string> password = args.count("password")
? args.at("password")
: std::optional<std::string>();
auto cache = std::make_shared<RedisCache>(prefix, 100);
if (args.count("unixsocket")) {
std::string unixsocket = args.at("unixsocket");
cache->connect(unixsocket, 0, password);
} else {
std::string ip = args.at("hostname");
int port = atoi(args.at("port").c_str());
std::optional<std::string> password =
args.count("password") ? args.at("password")
: std::optional<std::string>();
cache->connect(ip, port, password);
}
return cache;
} else if (type == "in-file") {
std::string path = args.at("path");
auto cache = std::make_shared<ExtendedInFilePersistentCache>();
cache->open(path);
return cache;
} else if (type == "in-memory") {
auto cache = std::make_shared<ExtendedInFilePersistentCache>();
cache->open();
return cache;
} else {
mgb_assert(false, "persistent cache type %s unsupported", type.c_str());
}
} catch (const std::exception& exc) {
err_msg = exc.what();
} catch (...) {
err_msg = "unknown exception";
}
return cache;
return nullptr;
}
} // namespace mgb::imperative::persistent_cache
......
......@@ -20,12 +20,12 @@ class ExtendedPersistentCache : public mgb::PersistentCache {
public:
virtual bool valid() const = 0;
virtual std::optional<size_t> clear() = 0;
};
std::shared_ptr<ExtendedPersistentCache> make_redis(
std::string ip, size_t port, std::string password, std::string prefix);
virtual void flush() = 0;
std::shared_ptr<ExtendedPersistentCache> make_in_file(std::string path);
static std::shared_ptr<ExtendedPersistentCache> make_from_config(
std::string type, std::unordered_map<std::string, std::string> args,
std::string& err_msg);
};
} // namespace mgb::imperative::persistent_cache
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -20,7 +20,7 @@ TEST(TestImperative, AllReduceBasic) {
REQUIRE_GPU(2);
const char* server_addr = "127.0.0.1";
uint32_t port = 3456;
mgb_assert(create_zmqrpc_server(server_addr, port) > 0);
mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0);
HostTensorGenerator<> gen;
CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1");
......
......@@ -20,7 +20,7 @@ TEST(TestImperative, IORemote) {
REQUIRE_GPU(2);
const char* server_addr = "127.0.0.1";
uint32_t port = 4567;
mgb_assert(create_zmqrpc_server(server_addr, port) > 0);
mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0);
HostTensorGenerator<> gen;
CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1");
......
......@@ -17,6 +17,9 @@
#include "megbrain/opr/zmq_rpc.h"
#include "mm_handler.pb.h"
using namespace mgb;
using namespace opr;
/* ======================== GroupServerProxy ========================== */
/*!
* A proxy that receives zmqrpc call, direct call to NCCL Manager
......@@ -213,7 +216,7 @@ struct ServerInfo {
std::unique_ptr<ZmqRpc::ZmqRpcServer> server;
};
int create_zmqrpc_server(const std::string& server_addr, int port) {
int mgb::opr::create_zmqrpc_server(const std::string& server_addr, int port) {
static std::unordered_map<std::string, ServerInfo> addr2server;
static std::mutex mtx;
MGB_LOCK_GUARD(mtx);
......
......@@ -16,8 +16,8 @@
#include "megbrain/opr/collective_comm.h"
#include "megbrain/opr/group_manager.h"
using namespace mgb;
using namespace opr;
namespace mgb {
namespace opr {
/*!
* Comm MM Client Proxy.
......@@ -56,6 +56,9 @@ private:
int create_zmqrpc_server(const std::string& server_addr, int port);
} // namespace opr
} // namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -13,8 +13,6 @@ global:
base_exceptions*;
};
megcore*;
*GroupClientProxy*;
*create_zmqrpc_server*;
*custom*;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册