From 1657b8e88102be79f2309b3a5d1455ab528cf3e7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 30 Dec 2021 17:20:06 +0800 Subject: [PATCH] fix(fastrun): fix persistent_cache in redis GitOrigin-RevId: ada5862b057dd7310e63a535874b00da882b21ba --- imperative/python/megengine/__init__.py | 15 +- .../megengine/utils/persistent_cache.py | 145 +++++++++++------- imperative/python/requires.txt | 4 +- imperative/python/src/utils.cpp | 125 +++++++++++---- imperative/src/impl/ops/collective_comm.cpp | 2 +- imperative/src/impl/ops/io_remote.cpp | 4 +- imperative/src/impl/persistent_cache.cpp | 114 ++++++++++---- .../megbrain/imperative/persistent_cache.h | 10 +- imperative/src/test/collective_comm.cpp | 2 +- imperative/src/test/io_remote.cpp | 2 +- src/opr-mm/impl/mm_handler.cpp | 5 +- src/opr-mm/include/megbrain/opr/mm_handler.h | 7 +- src/version.ld | 2 - 13 files changed, 288 insertions(+), 149 deletions(-) diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index 38291bfde..e248be570 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -84,7 +84,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .serialization import load, save from .tensor import Parameter, Tensor, tensor from .utils import comp_graph_tools as cgtools -from .utils import persistent_cache +from .utils.persistent_cache import PersistentCacheOnServer as _PersistentCacheOnServer from .version import __version__ _set_fork_exec_path_for_timed_func( @@ -92,15 +92,13 @@ _set_fork_exec_path_for_timed_func( os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), ) -atexit.register(_close) - del _set_fork_exec_path_for_timed_func _exit_handlers = [] def _run_exit_handlers(): - for handler in _exit_handlers: + for handler in reversed(_exit_handlers): handler() _exit_handlers.clear() @@ -117,6 +115,13 @@ def _atexit(handler): _exit_handlers.append(handler) +_atexit(_close) + +_persistent_cache = _PersistentCacheOnServer() +_persistent_cache.reg() + +_atexit(_persistent_cache.flush) + # subpackages import megengine.amp import megengine.autodiff @@ -132,5 +137,3 @@ import megengine.quantization import megengine.random import megengine.utils import megengine.traced_module - -persistent_cache.get_manager() diff --git a/imperative/python/megengine/utils/persistent_cache.py b/imperative/python/megengine/utils/persistent_cache.py index 3b0f7ae2c..d517b9226 100644 --- a/imperative/python/megengine/utils/persistent_cache.py +++ b/imperative/python/megengine/utils/persistent_cache.py @@ -8,87 +8,114 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import argparse +import contextlib import getpass import os import sys import urllib.parse -from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager +import filelock + +from ..core._imperative_rt import PersistentCache as _PersistentCache from ..logger import get_logger from ..version import __version__, git_version -class PersistentCacheManager(_PersistentCacheManager): +class PersistentCacheOnServer(_PersistentCache): def __init__(self): super().__init__() - if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": - get_logger().info("fastrun use in-memory cache") - self.open_memory() - elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE": - self.open_file() - else: - self.open_redis() - - def open_memory(self): - pass + cache_type = os.getenv("MGE_FASTRUN_CACHE_TYPE") + if cache_type not in ("FILE", "MEMORY"): + try: + redis_config = self.get_redis_config() + except Exception as exc: + get_logger().error( + "failed to connect to cache server {!r}; try fallback to " + "in-file cache".format(exc) + ) + else: + self.add_config( + "redis", + redis_config, + "fastrun use redis cache", + "failed to connect to cache server", + ) + if cache_type != "MEMORY": + path = self.get_cache_file(self.get_cache_dir()) + self.add_config( + "in-file", + {"path": path}, + "fastrun use in-file cache in {}".format(path), + "failed to create cache file in {}".format(path), + ) + self.add_config( + "in-memory", + {}, + "fastrun use in-memory cache", + "failed to create in-memory cache", + ) - def open_file(self): + def get_cache_dir(self): cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") - try: - if not cache_dir: - from ..hub.hub import _get_megengine_home + if not cache_dir: + from ..hub.hub import _get_megengine_home - cache_dir = os.path.expanduser( - os.path.join(_get_megengine_home(), "persistent_cache.bin") - ) - os.makedirs(cache_dir, exist_ok=True) - cache_file = os.path.join(cache_dir, "cache") - with open(cache_file, "a"): - pass - assert self.try_open_file(cache_file), "cannot create file" - get_logger().info("fastrun use in-file cache in {}".format(cache_dir)) - except Exception as exc: - get_logger().error( - "failed to create cache file in {} {!r}; fallback to " - "in-memory cache".format(cache_dir, exc) + cache_dir = os.path.expanduser( + os.path.join(_get_megengine_home(), "persistent_cache") ) - self.open_memory() - - def open_redis(self): + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + def get_cache_file(self, cache_dir): + cache_file = os.path.join(cache_dir, "cache.bin") + with open(cache_file, "a"): + pass + return cache_file + + @contextlib.contextmanager + def lock_cache_file(self, cache_dir): + lock_file = os.path.join(cache_dir, "cache.lock") + with filelock.FileLock(lock_file): + yield + + def get_redis_config(self): + url = os.getenv("MGE_FASTRUN_CACHE_URL") + if url is None: + return None + assert sys.platform != "win32", "redis cache on windows not tested" prefix = "mgbcache:{}:MGB{}:GIT:{}".format( getpass.getuser(), __version__, git_version ) - url = os.getenv("MGE_FASTRUN_CACHE_URL") - if url is None: - self.open_file() - try: - assert sys.platform != "win32", "redis cache on windows not tested" - parse_result = urllib.parse.urlparse(url, scheme="redis") - assert parse_result.scheme == "redis", "unsupported scheme" - assert not parse_result.username, "redis conn with username unsupported" - assert self.try_open_redis( - parse_result.hostname, parse_result.port, parse_result.password, prefix - ), "connect failed" - except Exception as exc: - get_logger().error( - "failed to connect to cache server {!r}; try fallback to " - "in-file cache".format(exc) - ) - self.open_file() - - -_manager = None - + parse_result = urllib.parse.urlparse(url) + assert not parse_result.username, "redis conn with username unsupported" + if parse_result.scheme == "redis": + assert parse_result.hostname and parse_result.port, "invalid url" + assert not parse_result.path + config = { + "hostname": parse_result.hostname, + "port": str(parse_result.port), + } + elif parse_result.scheme == "redis+socket": + assert not (parse_result.hostname or parse_result.port) + assert parse_result.path + config = { + "unixsocket": parse_result.path, + } + else: + assert False, "unsupported scheme" + if parse_result.password is not None: + config["password"] = parse_result.password + config["prefix"] = prefix + return config -def get_manager(): - global _manager - if _manager is None: - _manager = PersistentCacheManager() - return _manager + def flush(self): + if self.config is not None and self.config.type == "in-file": + with self.lock_cache_file(self.get_cache_dir()): + super().flush() def _clean(): - nr_del = get_manager().clean() + nr_del = PersistentCacheOnServer().clean() if nr_del is not None: print("{} cache entries deleted".format(nr_del)) diff --git a/imperative/python/requires.txt b/imperative/python/requires.txt index 58a806c05..894b332a0 100644 --- a/imperative/python/requires.txt +++ b/imperative/python/requires.txt @@ -4,8 +4,8 @@ pyarrow requests tabulate tqdm -redispy deprecated mprop wheel -megfile>=0.0.10 \ No newline at end of file +megfile>=0.0.10 +filelock diff --git a/imperative/python/src/utils.cpp b/imperative/python/src/utils.cpp index 5162f488f..260c7aa26 100644 --- a/imperative/python/src/utils.cpp +++ b/imperative/python/src/utils.cpp @@ -210,7 +210,7 @@ void init_utils(py::module m) { .def("disable", [](TensorSanityCheck& checker) { checker.disable(); }); #if MGB_ENABLE_OPR_MM - m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"), + m.def("create_mm_server", &mgb::opr::create_zmqrpc_server, py::arg("addr"), py::arg("port") = 0); #else m.def("create_mm_server", []() {}); @@ -234,51 +234,108 @@ void init_utils(py::module m) { using ExtendedPersistentCache = mgb::imperative::persistent_cache::ExtendedPersistentCache; - struct PersistentCacheManager { - std::shared_ptr instance; + struct ConfigurablePersistentCache : mgb::PersistentCache { + struct Config { + std::string type; + std::unordered_map args; + std::string on_success; + std::string on_fail; + }; - bool try_reg(std::shared_ptr cache) { - if (cache) { - instance = cache; - PersistentCache::set_impl(cache); - return true; - } - return false; - } - bool open_redis( - std::string ip, size_t port, std::string password, std::string prefix) { - return try_reg(mgb::imperative::persistent_cache::make_redis( - ip, port, password, prefix)); + std::shared_ptr impl; + std::optional impl_config; + std::vector configs; + + void add_config( + std::string type, std::unordered_map args, + std::string on_success, std::string on_fail) { + configs.push_back({type, args, on_success, on_fail}); } - bool open_file(std::string path) { - return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); + + std::optional clean() { return get_impl()->clear(); } + + void load_config() { + std::optional err_msg; + for (size_t i = 0; i < configs.size(); ++i) { + auto& config = configs[i]; + if (err_msg) { + mgb_log_warn("try fallback to %s cache", config.type.c_str()); + } else { + err_msg.emplace(); + } + auto cache = ExtendedPersistentCache::make_from_config( + config.type, config.args, *err_msg); + if (!cache) { + mgb_log_warn("%s %s", config.on_fail.c_str(), err_msg->c_str()); + } else { + impl = cache; + impl_config = config; + break; + } + } + mgb_assert(impl_config.has_value(), "not valid config"); } - std::optional clean() { - if (instance) { - return instance->clear(); + + std::shared_ptr get_impl() { + if (!impl) { + load_config(); } - return {}; + return impl; } - void put(std::string category, std::string key, std::string value) { - PersistentCache::inst().put( - category, {key.data(), key.size()}, {value.data(), value.size()}); + + virtual mgb::Maybe get(const std::string& category, const Blob& key) { + return get_impl()->get(category, key); + } + + virtual void put( + const std::string& category, const Blob& key, const Blob& value) { + return get_impl()->put(category, key, value); } - py::object get(std::string category, std::string key) { - auto value = - PersistentCache::inst().get(category, {key.data(), key.size()}); + + virtual bool support_dump_cache() { return get_impl()->support_dump_cache(); } + + py::object py_get(std::string category, std::string key) { + auto value = get_impl()->get(category, {key.data(), key.size()}); if (value.valid()) { return py::bytes(std::string((const char*)value->ptr, value->size)); } else { return py::none(); } } + + void py_put(std::string category, std::string key, std::string value) { + get_impl()->put( + category, {key.data(), key.size()}, {value.data(), value.size()}); + } + + void flush() { + if (impl) { + impl->flush(); + } + } }; - py::class_(m, "PersistentCacheManager") - .def(py::init<>()) - .def("try_open_redis", &PersistentCacheManager::open_redis) - .def("try_open_file", &PersistentCacheManager::open_file) - .def("clean", &PersistentCacheManager::clean) - .def("put", &PersistentCacheManager::put) - .def("get", &PersistentCacheManager::get); + auto PyConfigurablePersistentCache = + py::class_< + ConfigurablePersistentCache, + std::shared_ptr>(m, "PersistentCache") + .def(py::init<>()) + .def("add_config", &ConfigurablePersistentCache::add_config) + .def("reg", + [](std::shared_ptr inst) { + PersistentCache::set_impl(inst); + }) + .def("clean", &ConfigurablePersistentCache::clean) + .def("get", &ConfigurablePersistentCache::py_get) + .def("put", &ConfigurablePersistentCache::py_put) + .def_readonly("config", &ConfigurablePersistentCache::impl_config) + .def("flush", &ConfigurablePersistentCache::flush); + + py::class_( + PyConfigurablePersistentCache, "Config") + .def_readwrite("type", &ConfigurablePersistentCache::Config::type) + .def_readwrite("args", &ConfigurablePersistentCache::Config::args) + .def_readwrite("on_fail", &ConfigurablePersistentCache::Config::on_fail) + .def_readwrite( + "on_success", &ConfigurablePersistentCache::Config::on_success); } diff --git a/imperative/src/impl/ops/collective_comm.cpp b/imperative/src/impl/ops/collective_comm.cpp index 6c969a228..141f7f625 100644 --- a/imperative/src/impl/ops/collective_comm.cpp +++ b/imperative/src/impl/ops/collective_comm.cpp @@ -27,7 +27,7 @@ namespace imperative { namespace { cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& comm = def.cast_final_safe(); - auto group_client = std::make_shared( + auto group_client = std::make_shared( ssprintf("%s:%d", comm.addr.data(), comm.port)); SmallVector> dev_buffer_arr(1, nullptr); auto disable = std::make_shared(); diff --git a/imperative/src/impl/ops/io_remote.cpp b/imperative/src/impl/ops/io_remote.cpp index 29b0316ee..03e4d58ab 100644 --- a/imperative/src/impl/ops/io_remote.cpp +++ b/imperative/src/impl/ops/io_remote.cpp @@ -28,7 +28,7 @@ namespace { cg::OperatorNodeBase* apply_on_var_node_remote_send( const OpDef& def, const VarNodeArray& inputs) { auto&& send = def.cast_final_safe(); - auto group_client = std::make_shared( + auto group_client = std::make_shared( ssprintf("%s:%d", send.addr.data(), send.port)); auto&& graph = inputs[0]->owner_graph(); @@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( auto&& recv = def.cast_final_safe(); OperatorNodeConfig config{recv.cn}; config.name(recv.make_name()); - auto group_client = std::make_shared( + auto group_client = std::make_shared( ssprintf("%s:%d", recv.addr.data(), recv.port)); auto&& graph = inputs[0]->owner_graph(); return graph->insert_opr(std::make_unique( diff --git a/imperative/src/impl/persistent_cache.cpp b/imperative/src/impl/persistent_cache.cpp index ba3809abb..5fc291c8a 100644 --- a/imperative/src/impl/persistent_cache.cpp +++ b/imperative/src/impl/persistent_cache.cpp @@ -27,8 +27,10 @@ public: m_local = std::make_shared(); } - bool connect(std::string ip, size_t port, std::string password) { - m_client.auth(password); + void connect(std::string ip, size_t port, std::optional password) { + if (password) { + m_client.auth(*password); + } m_client.connect( ip, port, [](const std::string& host, std::size_t port, @@ -40,16 +42,32 @@ public: } }, std::uint32_t(200)); - if (!m_client.is_connected()) { - return false; - } + mgb_assert(m_client.is_connected(), "connect failed"); auto flag = m_client.get("mgb-cache-flag"); sync(); - return flag.get().ok(); + auto is_valid = [](const cpp_redis::reply& reply) { + switch (reply.get_type()) { + case cpp_redis::reply::type::error: + case cpp_redis::reply::type::null: + return false; + case cpp_redis::reply::type::integer: + return reply.as_integer() != 0; + case cpp_redis::reply::type::simple_string: + case cpp_redis::reply::type::bulk_string: + return !reply.as_string().empty(); + case cpp_redis::reply::type::array: + return !reply.as_array().empty(); + default: + mgb_assert(false, "unknown reply type %d", (int)reply.get_type()); + } + }; + mgb_assert(is_valid(flag.get()), "invalid mgb-cache-flag"); } bool valid() const override { return m_client.is_connected(); } + void flush() override {} + mgb::Maybe get(const std::string& category, const Blob& key) override { MGB_LOCK_GUARD(m_mtx); auto mem_result = m_local->get(category, key); @@ -75,7 +93,7 @@ public: MGB_LOCK_GUARD(m_mtx); std::string key_str(static_cast(key.ptr), key.size); std::string redis_key_str; - encode(category + '@' + key_str, redis_key_str); + encode(category + '@' + key_str, redis_key_str, 24); std::string value_str(static_cast(value.ptr), value.size); std::string redis_value_str; encode(value_str, redis_value_str); @@ -118,18 +136,16 @@ private: class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { private: - std::string m_path; + std::optional m_path; std::unique_ptr m_impl; public: ExtendedInFilePersistentCache() = default; - bool open(std::string path) { + void open(std::string path) { std::fstream file; file.open(path, std::ios::in | std::ios::binary); - if (!file.is_open()) { - return false; - } + mgb_assert(file.is_open(), "can't open file in %s", path.c_str()); std::vector bytes = { std::istreambuf_iterator(file), std::istreambuf_iterator()}; if (bytes.size()) { @@ -139,14 +155,11 @@ public: m_impl = std::make_unique(); } m_path = path; - return true; } - ~ExtendedInFilePersistentCache() { - if (m_impl) { - m_impl->dump_cache(m_path.c_str()); - } - } + void open() { m_impl = std::make_unique(); } + + ~ExtendedInFilePersistentCache() { flush(); } mgb::Maybe get(const std::string& category, const Blob& key) override { return m_impl->get(category, key); @@ -157,29 +170,64 @@ public: } std::optional clear() override { - m_impl = std::make_unique(); - m_impl->dump_cache(m_path.c_str()); + if (m_impl) { + m_impl = std::make_unique(); + if (m_path) { + m_impl->dump_cache(m_path->c_str()); + } + } return {}; } bool valid() const override { return m_impl != nullptr; } -}; -std::shared_ptr make_redis( - std::string ip, size_t port, std::string password, std::string prefix) { - auto cache = std::make_shared(prefix, 100); - if (!cache->connect(ip, port, password)) { - return nullptr; + void flush() override { + if (m_impl && m_path) { + m_impl->dump_cache(m_path->c_str()); + } } - return cache; -} +}; -std::shared_ptr make_in_file(std::string path) { - auto cache = std::make_shared(); - if (!cache->open(path)) { - return nullptr; +std::shared_ptr ExtendedPersistentCache::make_from_config( + std::string type, std::unordered_map args, + std::string& err_msg) { + try { + if (type == "redis") { + std::string prefix = args.at("prefix"); + std::optional password = args.count("password") + ? args.at("password") + : std::optional(); + auto cache = std::make_shared(prefix, 100); + if (args.count("unixsocket")) { + std::string unixsocket = args.at("unixsocket"); + cache->connect(unixsocket, 0, password); + } else { + std::string ip = args.at("hostname"); + int port = atoi(args.at("port").c_str()); + std::optional password = + args.count("password") ? args.at("password") + : std::optional(); + cache->connect(ip, port, password); + } + return cache; + } else if (type == "in-file") { + std::string path = args.at("path"); + auto cache = std::make_shared(); + cache->open(path); + return cache; + } else if (type == "in-memory") { + auto cache = std::make_shared(); + cache->open(); + return cache; + } else { + mgb_assert(false, "persistent cache type %s unsupported", type.c_str()); + } + } catch (const std::exception& exc) { + err_msg = exc.what(); + } catch (...) { + err_msg = "unknown exception"; } - return cache; + return nullptr; } } // namespace mgb::imperative::persistent_cache diff --git a/imperative/src/include/megbrain/imperative/persistent_cache.h b/imperative/src/include/megbrain/imperative/persistent_cache.h index 59326bbde..4d63eaae7 100644 --- a/imperative/src/include/megbrain/imperative/persistent_cache.h +++ b/imperative/src/include/megbrain/imperative/persistent_cache.h @@ -20,12 +20,12 @@ class ExtendedPersistentCache : public mgb::PersistentCache { public: virtual bool valid() const = 0; virtual std::optional clear() = 0; -}; - -std::shared_ptr make_redis( - std::string ip, size_t port, std::string password, std::string prefix); + virtual void flush() = 0; -std::shared_ptr make_in_file(std::string path); + static std::shared_ptr make_from_config( + std::string type, std::unordered_map args, + std::string& err_msg); +}; } // namespace mgb::imperative::persistent_cache // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/test/collective_comm.cpp b/imperative/src/test/collective_comm.cpp index 01c0829dd..4a31c54be 100644 --- a/imperative/src/test/collective_comm.cpp +++ b/imperative/src/test/collective_comm.cpp @@ -20,7 +20,7 @@ TEST(TestImperative, AllReduceBasic) { REQUIRE_GPU(2); const char* server_addr = "127.0.0.1"; uint32_t port = 3456; - mgb_assert(create_zmqrpc_server(server_addr, port) > 0); + mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0); HostTensorGenerator<> gen; CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); diff --git a/imperative/src/test/io_remote.cpp b/imperative/src/test/io_remote.cpp index 8e32f7abd..97a7b62dd 100644 --- a/imperative/src/test/io_remote.cpp +++ b/imperative/src/test/io_remote.cpp @@ -20,7 +20,7 @@ TEST(TestImperative, IORemote) { REQUIRE_GPU(2); const char* server_addr = "127.0.0.1"; uint32_t port = 4567; - mgb_assert(create_zmqrpc_server(server_addr, port) > 0); + mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0); HostTensorGenerator<> gen; CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); diff --git a/src/opr-mm/impl/mm_handler.cpp b/src/opr-mm/impl/mm_handler.cpp index 1f13b6974..7cbb8a3ec 100644 --- a/src/opr-mm/impl/mm_handler.cpp +++ b/src/opr-mm/impl/mm_handler.cpp @@ -17,6 +17,9 @@ #include "megbrain/opr/zmq_rpc.h" #include "mm_handler.pb.h" +using namespace mgb; +using namespace opr; + /* ======================== GroupServerProxy ========================== */ /*! * A proxy that receives zmqrpc call, direct call to NCCL Manager @@ -213,7 +216,7 @@ struct ServerInfo { std::unique_ptr server; }; -int create_zmqrpc_server(const std::string& server_addr, int port) { +int mgb::opr::create_zmqrpc_server(const std::string& server_addr, int port) { static std::unordered_map addr2server; static std::mutex mtx; MGB_LOCK_GUARD(mtx); diff --git a/src/opr-mm/include/megbrain/opr/mm_handler.h b/src/opr-mm/include/megbrain/opr/mm_handler.h index 7c03bf961..97b829d44 100644 --- a/src/opr-mm/include/megbrain/opr/mm_handler.h +++ b/src/opr-mm/include/megbrain/opr/mm_handler.h @@ -16,8 +16,8 @@ #include "megbrain/opr/collective_comm.h" #include "megbrain/opr/group_manager.h" -using namespace mgb; -using namespace opr; +namespace mgb { +namespace opr { /*! * Comm MM Client Proxy. @@ -56,6 +56,9 @@ private: int create_zmqrpc_server(const std::string& server_addr, int port); +} // namespace opr +} // namespace mgb + #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/version.ld b/src/version.ld index db71a72b3..f70a56775 100644 --- a/src/version.ld +++ b/src/version.ld @@ -13,8 +13,6 @@ global: base_exceptions*; }; megcore*; - *GroupClientProxy*; - *create_zmqrpc_server*; *custom*; -- GitLab