diff --git a/imperative/python/megengine/__init__.py b/imperative/python/megengine/__init__.py index 38291bfdea5dd9f975c294944e90bbad1f6665e5..e248be57073a8c226d4dd55e458483f9a116edb8 100644 --- a/imperative/python/megengine/__init__.py +++ b/imperative/python/megengine/__init__.py @@ -84,7 +84,7 @@ from .logger import enable_debug_log, get_logger, set_log_file, set_log_level from .serialization import load, save from .tensor import Parameter, Tensor, tensor from .utils import comp_graph_tools as cgtools -from .utils import persistent_cache +from .utils.persistent_cache import PersistentCacheOnServer as _PersistentCacheOnServer from .version import __version__ _set_fork_exec_path_for_timed_func( @@ -92,15 +92,13 @@ _set_fork_exec_path_for_timed_func( os.path.join(os.path.dirname(__file__), "utils", "_timed_func_fork_exec_entry.py"), ) -atexit.register(_close) - del _set_fork_exec_path_for_timed_func _exit_handlers = [] def _run_exit_handlers(): - for handler in _exit_handlers: + for handler in reversed(_exit_handlers): handler() _exit_handlers.clear() @@ -117,6 +115,13 @@ def _atexit(handler): _exit_handlers.append(handler) +_atexit(_close) + +_persistent_cache = _PersistentCacheOnServer() +_persistent_cache.reg() + +_atexit(_persistent_cache.flush) + # subpackages import megengine.amp import megengine.autodiff @@ -132,5 +137,3 @@ import megengine.quantization import megengine.random import megengine.utils import megengine.traced_module - -persistent_cache.get_manager() diff --git a/imperative/python/megengine/utils/persistent_cache.py b/imperative/python/megengine/utils/persistent_cache.py index 3b0f7ae2c07423e98ecc7141334286ea1041dda9..d517b9226ddcbaf4ff2f94146fc31c233f813639 100644 --- a/imperative/python/megengine/utils/persistent_cache.py +++ b/imperative/python/megengine/utils/persistent_cache.py @@ -8,87 +8,114 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import argparse +import contextlib import getpass import os import sys import urllib.parse -from ..core._imperative_rt import PersistentCacheManager as _PersistentCacheManager +import filelock + +from ..core._imperative_rt import PersistentCache as _PersistentCache from ..logger import get_logger from ..version import __version__, git_version -class PersistentCacheManager(_PersistentCacheManager): +class PersistentCacheOnServer(_PersistentCache): def __init__(self): super().__init__() - if os.getenv("MGE_FASTRUN_CACHE_TYPE") == "MEMORY": - get_logger().info("fastrun use in-memory cache") - self.open_memory() - elif os.getenv("MGE_FASTRUN_CACHE_TYPE") == "FILE": - self.open_file() - else: - self.open_redis() - - def open_memory(self): - pass + cache_type = os.getenv("MGE_FASTRUN_CACHE_TYPE") + if cache_type not in ("FILE", "MEMORY"): + try: + redis_config = self.get_redis_config() + except Exception as exc: + get_logger().error( + "failed to connect to cache server {!r}; try fallback to " + "in-file cache".format(exc) + ) + else: + self.add_config( + "redis", + redis_config, + "fastrun use redis cache", + "failed to connect to cache server", + ) + if cache_type != "MEMORY": + path = self.get_cache_file(self.get_cache_dir()) + self.add_config( + "in-file", + {"path": path}, + "fastrun use in-file cache in {}".format(path), + "failed to create cache file in {}".format(path), + ) + self.add_config( + "in-memory", + {}, + "fastrun use in-memory cache", + "failed to create in-memory cache", + ) - def open_file(self): + def get_cache_dir(self): cache_dir = os.getenv("MGE_FASTRUN_CACHE_DIR") - try: - if not cache_dir: - from ..hub.hub import _get_megengine_home + if not cache_dir: + from ..hub.hub import _get_megengine_home - cache_dir = os.path.expanduser( - os.path.join(_get_megengine_home(), "persistent_cache.bin") - ) - os.makedirs(cache_dir, exist_ok=True) - cache_file = os.path.join(cache_dir, "cache") - with open(cache_file, "a"): - pass - assert self.try_open_file(cache_file), "cannot create file" - get_logger().info("fastrun use in-file cache in {}".format(cache_dir)) - except Exception as exc: - get_logger().error( - "failed to create cache file in {} {!r}; fallback to " - "in-memory cache".format(cache_dir, exc) + cache_dir = os.path.expanduser( + os.path.join(_get_megengine_home(), "persistent_cache") ) - self.open_memory() - - def open_redis(self): + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + def get_cache_file(self, cache_dir): + cache_file = os.path.join(cache_dir, "cache.bin") + with open(cache_file, "a"): + pass + return cache_file + + @contextlib.contextmanager + def lock_cache_file(self, cache_dir): + lock_file = os.path.join(cache_dir, "cache.lock") + with filelock.FileLock(lock_file): + yield + + def get_redis_config(self): + url = os.getenv("MGE_FASTRUN_CACHE_URL") + if url is None: + return None + assert sys.platform != "win32", "redis cache on windows not tested" prefix = "mgbcache:{}:MGB{}:GIT:{}".format( getpass.getuser(), __version__, git_version ) - url = os.getenv("MGE_FASTRUN_CACHE_URL") - if url is None: - self.open_file() - try: - assert sys.platform != "win32", "redis cache on windows not tested" - parse_result = urllib.parse.urlparse(url, scheme="redis") - assert parse_result.scheme == "redis", "unsupported scheme" - assert not parse_result.username, "redis conn with username unsupported" - assert self.try_open_redis( - parse_result.hostname, parse_result.port, parse_result.password, prefix - ), "connect failed" - except Exception as exc: - get_logger().error( - "failed to connect to cache server {!r}; try fallback to " - "in-file cache".format(exc) - ) - self.open_file() - - -_manager = None - + parse_result = urllib.parse.urlparse(url) + assert not parse_result.username, "redis conn with username unsupported" + if parse_result.scheme == "redis": + assert parse_result.hostname and parse_result.port, "invalid url" + assert not parse_result.path + config = { + "hostname": parse_result.hostname, + "port": str(parse_result.port), + } + elif parse_result.scheme == "redis+socket": + assert not (parse_result.hostname or parse_result.port) + assert parse_result.path + config = { + "unixsocket": parse_result.path, + } + else: + assert False, "unsupported scheme" + if parse_result.password is not None: + config["password"] = parse_result.password + config["prefix"] = prefix + return config -def get_manager(): - global _manager - if _manager is None: - _manager = PersistentCacheManager() - return _manager + def flush(self): + if self.config is not None and self.config.type == "in-file": + with self.lock_cache_file(self.get_cache_dir()): + super().flush() def _clean(): - nr_del = get_manager().clean() + nr_del = PersistentCacheOnServer().clean() if nr_del is not None: print("{} cache entries deleted".format(nr_del)) diff --git a/imperative/python/requires.txt b/imperative/python/requires.txt index 58a806c05713288a481e8482760c4916b4ad8cca..894b332a02b9c87ac9c57085db6712120fc18f6b 100644 --- a/imperative/python/requires.txt +++ b/imperative/python/requires.txt @@ -4,8 +4,8 @@ pyarrow requests tabulate tqdm -redispy deprecated mprop wheel -megfile>=0.0.10 \ No newline at end of file +megfile>=0.0.10 +filelock diff --git a/imperative/python/src/utils.cpp b/imperative/python/src/utils.cpp index 5162f488f31044c4c4fc5b435d63ee70b39b7c13..260c7aa268b14b98fbe4b5c9340b46d47e976137 100644 --- a/imperative/python/src/utils.cpp +++ b/imperative/python/src/utils.cpp @@ -210,7 +210,7 @@ void init_utils(py::module m) { .def("disable", [](TensorSanityCheck& checker) { checker.disable(); }); #if MGB_ENABLE_OPR_MM - m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"), + m.def("create_mm_server", &mgb::opr::create_zmqrpc_server, py::arg("addr"), py::arg("port") = 0); #else m.def("create_mm_server", []() {}); @@ -234,51 +234,108 @@ void init_utils(py::module m) { using ExtendedPersistentCache = mgb::imperative::persistent_cache::ExtendedPersistentCache; - struct PersistentCacheManager { - std::shared_ptr instance; + struct ConfigurablePersistentCache : mgb::PersistentCache { + struct Config { + std::string type; + std::unordered_map args; + std::string on_success; + std::string on_fail; + }; - bool try_reg(std::shared_ptr cache) { - if (cache) { - instance = cache; - PersistentCache::set_impl(cache); - return true; - } - return false; - } - bool open_redis( - std::string ip, size_t port, std::string password, std::string prefix) { - return try_reg(mgb::imperative::persistent_cache::make_redis( - ip, port, password, prefix)); + std::shared_ptr impl; + std::optional impl_config; + std::vector configs; + + void add_config( + std::string type, std::unordered_map args, + std::string on_success, std::string on_fail) { + configs.push_back({type, args, on_success, on_fail}); } - bool open_file(std::string path) { - return try_reg(mgb::imperative::persistent_cache::make_in_file(path)); + + std::optional clean() { return get_impl()->clear(); } + + void load_config() { + std::optional err_msg; + for (size_t i = 0; i < configs.size(); ++i) { + auto& config = configs[i]; + if (err_msg) { + mgb_log_warn("try fallback to %s cache", config.type.c_str()); + } else { + err_msg.emplace(); + } + auto cache = ExtendedPersistentCache::make_from_config( + config.type, config.args, *err_msg); + if (!cache) { + mgb_log_warn("%s %s", config.on_fail.c_str(), err_msg->c_str()); + } else { + impl = cache; + impl_config = config; + break; + } + } + mgb_assert(impl_config.has_value(), "not valid config"); } - std::optional clean() { - if (instance) { - return instance->clear(); + + std::shared_ptr get_impl() { + if (!impl) { + load_config(); } - return {}; + return impl; } - void put(std::string category, std::string key, std::string value) { - PersistentCache::inst().put( - category, {key.data(), key.size()}, {value.data(), value.size()}); + + virtual mgb::Maybe get(const std::string& category, const Blob& key) { + return get_impl()->get(category, key); + } + + virtual void put( + const std::string& category, const Blob& key, const Blob& value) { + return get_impl()->put(category, key, value); } - py::object get(std::string category, std::string key) { - auto value = - PersistentCache::inst().get(category, {key.data(), key.size()}); + + virtual bool support_dump_cache() { return get_impl()->support_dump_cache(); } + + py::object py_get(std::string category, std::string key) { + auto value = get_impl()->get(category, {key.data(), key.size()}); if (value.valid()) { return py::bytes(std::string((const char*)value->ptr, value->size)); } else { return py::none(); } } + + void py_put(std::string category, std::string key, std::string value) { + get_impl()->put( + category, {key.data(), key.size()}, {value.data(), value.size()}); + } + + void flush() { + if (impl) { + impl->flush(); + } + } }; - py::class_(m, "PersistentCacheManager") - .def(py::init<>()) - .def("try_open_redis", &PersistentCacheManager::open_redis) - .def("try_open_file", &PersistentCacheManager::open_file) - .def("clean", &PersistentCacheManager::clean) - .def("put", &PersistentCacheManager::put) - .def("get", &PersistentCacheManager::get); + auto PyConfigurablePersistentCache = + py::class_< + ConfigurablePersistentCache, + std::shared_ptr>(m, "PersistentCache") + .def(py::init<>()) + .def("add_config", &ConfigurablePersistentCache::add_config) + .def("reg", + [](std::shared_ptr inst) { + PersistentCache::set_impl(inst); + }) + .def("clean", &ConfigurablePersistentCache::clean) + .def("get", &ConfigurablePersistentCache::py_get) + .def("put", &ConfigurablePersistentCache::py_put) + .def_readonly("config", &ConfigurablePersistentCache::impl_config) + .def("flush", &ConfigurablePersistentCache::flush); + + py::class_( + PyConfigurablePersistentCache, "Config") + .def_readwrite("type", &ConfigurablePersistentCache::Config::type) + .def_readwrite("args", &ConfigurablePersistentCache::Config::args) + .def_readwrite("on_fail", &ConfigurablePersistentCache::Config::on_fail) + .def_readwrite( + "on_success", &ConfigurablePersistentCache::Config::on_success); } diff --git a/imperative/src/impl/ops/collective_comm.cpp b/imperative/src/impl/ops/collective_comm.cpp index 6c969a2283b01695cd255b74a04aea9fee4e2e3f..141f7f625d9b09a23d5294c31a2cf60c407e4b2f 100644 --- a/imperative/src/impl/ops/collective_comm.cpp +++ b/imperative/src/impl/ops/collective_comm.cpp @@ -27,7 +27,7 @@ namespace imperative { namespace { cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& comm = def.cast_final_safe(); - auto group_client = std::make_shared( + auto group_client = std::make_shared( ssprintf("%s:%d", comm.addr.data(), comm.port)); SmallVector> dev_buffer_arr(1, nullptr); auto disable = std::make_shared(); diff --git a/imperative/src/impl/ops/io_remote.cpp b/imperative/src/impl/ops/io_remote.cpp index 29b0316ee657d506e2b4daac5aa3591935e48658..03e4d58ab3a0fa6d1776a38660586a075567b794 100644 --- a/imperative/src/impl/ops/io_remote.cpp +++ b/imperative/src/impl/ops/io_remote.cpp @@ -28,7 +28,7 @@ namespace { cg::OperatorNodeBase* apply_on_var_node_remote_send( const OpDef& def, const VarNodeArray& inputs) { auto&& send = def.cast_final_safe(); - auto group_client = std::make_shared( + auto group_client = std::make_shared( ssprintf("%s:%d", send.addr.data(), send.port)); auto&& graph = inputs[0]->owner_graph(); @@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_recv( auto&& recv = def.cast_final_safe(); OperatorNodeConfig config{recv.cn}; config.name(recv.make_name()); - auto group_client = std::make_shared( + auto group_client = std::make_shared( ssprintf("%s:%d", recv.addr.data(), recv.port)); auto&& graph = inputs[0]->owner_graph(); return graph->insert_opr(std::make_unique( diff --git a/imperative/src/impl/persistent_cache.cpp b/imperative/src/impl/persistent_cache.cpp index ba3809abb04cb6962e2291ccdd04da398dbf49a3..5fc291c8a6d8c8a9a56e4a483fb8fa543fec6c9d 100644 --- a/imperative/src/impl/persistent_cache.cpp +++ b/imperative/src/impl/persistent_cache.cpp @@ -27,8 +27,10 @@ public: m_local = std::make_shared(); } - bool connect(std::string ip, size_t port, std::string password) { - m_client.auth(password); + void connect(std::string ip, size_t port, std::optional password) { + if (password) { + m_client.auth(*password); + } m_client.connect( ip, port, [](const std::string& host, std::size_t port, @@ -40,16 +42,32 @@ public: } }, std::uint32_t(200)); - if (!m_client.is_connected()) { - return false; - } + mgb_assert(m_client.is_connected(), "connect failed"); auto flag = m_client.get("mgb-cache-flag"); sync(); - return flag.get().ok(); + auto is_valid = [](const cpp_redis::reply& reply) { + switch (reply.get_type()) { + case cpp_redis::reply::type::error: + case cpp_redis::reply::type::null: + return false; + case cpp_redis::reply::type::integer: + return reply.as_integer() != 0; + case cpp_redis::reply::type::simple_string: + case cpp_redis::reply::type::bulk_string: + return !reply.as_string().empty(); + case cpp_redis::reply::type::array: + return !reply.as_array().empty(); + default: + mgb_assert(false, "unknown reply type %d", (int)reply.get_type()); + } + }; + mgb_assert(is_valid(flag.get()), "invalid mgb-cache-flag"); } bool valid() const override { return m_client.is_connected(); } + void flush() override {} + mgb::Maybe get(const std::string& category, const Blob& key) override { MGB_LOCK_GUARD(m_mtx); auto mem_result = m_local->get(category, key); @@ -75,7 +93,7 @@ public: MGB_LOCK_GUARD(m_mtx); std::string key_str(static_cast(key.ptr), key.size); std::string redis_key_str; - encode(category + '@' + key_str, redis_key_str); + encode(category + '@' + key_str, redis_key_str, 24); std::string value_str(static_cast(value.ptr), value.size); std::string redis_value_str; encode(value_str, redis_value_str); @@ -118,18 +136,16 @@ private: class ExtendedInFilePersistentCache final : public ExtendedPersistentCache { private: - std::string m_path; + std::optional m_path; std::unique_ptr m_impl; public: ExtendedInFilePersistentCache() = default; - bool open(std::string path) { + void open(std::string path) { std::fstream file; file.open(path, std::ios::in | std::ios::binary); - if (!file.is_open()) { - return false; - } + mgb_assert(file.is_open(), "can't open file in %s", path.c_str()); std::vector bytes = { std::istreambuf_iterator(file), std::istreambuf_iterator()}; if (bytes.size()) { @@ -139,14 +155,11 @@ public: m_impl = std::make_unique(); } m_path = path; - return true; } - ~ExtendedInFilePersistentCache() { - if (m_impl) { - m_impl->dump_cache(m_path.c_str()); - } - } + void open() { m_impl = std::make_unique(); } + + ~ExtendedInFilePersistentCache() { flush(); } mgb::Maybe get(const std::string& category, const Blob& key) override { return m_impl->get(category, key); @@ -157,29 +170,64 @@ public: } std::optional clear() override { - m_impl = std::make_unique(); - m_impl->dump_cache(m_path.c_str()); + if (m_impl) { + m_impl = std::make_unique(); + if (m_path) { + m_impl->dump_cache(m_path->c_str()); + } + } return {}; } bool valid() const override { return m_impl != nullptr; } -}; -std::shared_ptr make_redis( - std::string ip, size_t port, std::string password, std::string prefix) { - auto cache = std::make_shared(prefix, 100); - if (!cache->connect(ip, port, password)) { - return nullptr; + void flush() override { + if (m_impl && m_path) { + m_impl->dump_cache(m_path->c_str()); + } } - return cache; -} +}; -std::shared_ptr make_in_file(std::string path) { - auto cache = std::make_shared(); - if (!cache->open(path)) { - return nullptr; +std::shared_ptr ExtendedPersistentCache::make_from_config( + std::string type, std::unordered_map args, + std::string& err_msg) { + try { + if (type == "redis") { + std::string prefix = args.at("prefix"); + std::optional password = args.count("password") + ? args.at("password") + : std::optional(); + auto cache = std::make_shared(prefix, 100); + if (args.count("unixsocket")) { + std::string unixsocket = args.at("unixsocket"); + cache->connect(unixsocket, 0, password); + } else { + std::string ip = args.at("hostname"); + int port = atoi(args.at("port").c_str()); + std::optional password = + args.count("password") ? args.at("password") + : std::optional(); + cache->connect(ip, port, password); + } + return cache; + } else if (type == "in-file") { + std::string path = args.at("path"); + auto cache = std::make_shared(); + cache->open(path); + return cache; + } else if (type == "in-memory") { + auto cache = std::make_shared(); + cache->open(); + return cache; + } else { + mgb_assert(false, "persistent cache type %s unsupported", type.c_str()); + } + } catch (const std::exception& exc) { + err_msg = exc.what(); + } catch (...) { + err_msg = "unknown exception"; } - return cache; + return nullptr; } } // namespace mgb::imperative::persistent_cache diff --git a/imperative/src/include/megbrain/imperative/persistent_cache.h b/imperative/src/include/megbrain/imperative/persistent_cache.h index 59326bbdedfddcc4ca6eaf2f4ba6caf375cd3d8f..4d63eaae7a66010e17fc0c87824c9c3580304c7c 100644 --- a/imperative/src/include/megbrain/imperative/persistent_cache.h +++ b/imperative/src/include/megbrain/imperative/persistent_cache.h @@ -20,12 +20,12 @@ class ExtendedPersistentCache : public mgb::PersistentCache { public: virtual bool valid() const = 0; virtual std::optional clear() = 0; -}; - -std::shared_ptr make_redis( - std::string ip, size_t port, std::string password, std::string prefix); + virtual void flush() = 0; -std::shared_ptr make_in_file(std::string path); + static std::shared_ptr make_from_config( + std::string type, std::unordered_map args, + std::string& err_msg); +}; } // namespace mgb::imperative::persistent_cache // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/test/collective_comm.cpp b/imperative/src/test/collective_comm.cpp index 01c0829dd78b4127e61cc3a6462d317c74bfc89c..4a31c54be31ad62883ec581c5b64d02083b3bbe1 100644 --- a/imperative/src/test/collective_comm.cpp +++ b/imperative/src/test/collective_comm.cpp @@ -20,7 +20,7 @@ TEST(TestImperative, AllReduceBasic) { REQUIRE_GPU(2); const char* server_addr = "127.0.0.1"; uint32_t port = 3456; - mgb_assert(create_zmqrpc_server(server_addr, port) > 0); + mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0); HostTensorGenerator<> gen; CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); diff --git a/imperative/src/test/io_remote.cpp b/imperative/src/test/io_remote.cpp index 8e32f7abd4f1c8c9a99583e8af18bc0cc6a4f2bc..97a7b62dddfefe94b5d18acb65b5a4214c2978d0 100644 --- a/imperative/src/test/io_remote.cpp +++ b/imperative/src/test/io_remote.cpp @@ -20,7 +20,7 @@ TEST(TestImperative, IORemote) { REQUIRE_GPU(2); const char* server_addr = "127.0.0.1"; uint32_t port = 4567; - mgb_assert(create_zmqrpc_server(server_addr, port) > 0); + mgb_assert(opr::create_zmqrpc_server(server_addr, port) > 0); HostTensorGenerator<> gen; CompNode cn0 = CompNode::load("gpu0"), cn1 = CompNode::load("gpu1"); diff --git a/src/opr-mm/impl/mm_handler.cpp b/src/opr-mm/impl/mm_handler.cpp index 1f13b6974d7f057d97e8fb17d2c4b5ac8328217b..7cbb8a3ec54489f740d903df19d234d25bcd8f07 100644 --- a/src/opr-mm/impl/mm_handler.cpp +++ b/src/opr-mm/impl/mm_handler.cpp @@ -17,6 +17,9 @@ #include "megbrain/opr/zmq_rpc.h" #include "mm_handler.pb.h" +using namespace mgb; +using namespace opr; + /* ======================== GroupServerProxy ========================== */ /*! * A proxy that receives zmqrpc call, direct call to NCCL Manager @@ -213,7 +216,7 @@ struct ServerInfo { std::unique_ptr server; }; -int create_zmqrpc_server(const std::string& server_addr, int port) { +int mgb::opr::create_zmqrpc_server(const std::string& server_addr, int port) { static std::unordered_map addr2server; static std::mutex mtx; MGB_LOCK_GUARD(mtx); diff --git a/src/opr-mm/include/megbrain/opr/mm_handler.h b/src/opr-mm/include/megbrain/opr/mm_handler.h index 7c03bf961cb7efbc53679d6ff00a5290e4fdd08f..97b829d446b46cda9b47d39856212b963a2dd5cf 100644 --- a/src/opr-mm/include/megbrain/opr/mm_handler.h +++ b/src/opr-mm/include/megbrain/opr/mm_handler.h @@ -16,8 +16,8 @@ #include "megbrain/opr/collective_comm.h" #include "megbrain/opr/group_manager.h" -using namespace mgb; -using namespace opr; +namespace mgb { +namespace opr { /*! * Comm MM Client Proxy. @@ -56,6 +56,9 @@ private: int create_zmqrpc_server(const std::string& server_addr, int port); +} // namespace opr +} // namespace mgb + #endif // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/version.ld b/src/version.ld index db71a72b3a2f8003aef62ce11a39c422dd0221a6..f70a5677511a7c2588a780b2c401f603e95bbde4 100644 --- a/src/version.ld +++ b/src/version.ld @@ -13,8 +13,6 @@ global: base_exceptions*; }; megcore*; - *GroupClientProxy*; - *create_zmqrpc_server*; *custom*;