/** * \file imperative/python/src/utils.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "utils.h" #ifdef WIN32 #include #include #endif #include #include #include #include "./imperative_rt.h" #include "megbrain/common.h" #include "megbrain/comp_node.h" #include "megbrain/imperative/blob_manager.h" #include "megbrain/imperative/profiler.h" #include "megbrain/imperative/tensor_sanity_check.h" #include "megbrain/serialization/helper.h" #include "megbrain/utils/persistent_cache.h" #if MGB_ENABLE_OPR_MM #include "megbrain/opr/mm_handler.h" #endif namespace py = pybind11; namespace { bool g_global_finalized = false; class LoggerWrapper { public: using LogLevel = mgb::LogLevel; using LogHandler = mgb::LogHandler; static void set_log_handler(py::object logger_p) { logger = logger_p; mgb::set_log_handler(py_log_handler); } static LogLevel set_log_level(LogLevel log_level) { return mgb::set_log_level(log_level); } private: static py::object logger; static void py_log_handler(mgb::LogLevel level, const char* file, const char* func, int line, const char* fmt, va_list ap) { using mgb::LogLevel; MGB_MARK_USED_VAR(file); MGB_MARK_USED_VAR(func); MGB_MARK_USED_VAR(line); if (g_global_finalized) return; const char* py_type; switch (level) { case LogLevel::DEBUG: py_type = "debug"; break; case LogLevel::INFO: py_type = "info"; break; case LogLevel::WARN: py_type = "warning"; break; case LogLevel::ERROR: py_type = "error"; break; default: throw std::runtime_error("bad log level"); } std::string msg = mgb::svsprintf(fmt, ap); auto do_log = [msg = msg, py_type]() { if (logger.is_none()) return; py::object _call = logger.attr(py_type); _call(py::str(msg)); }; if (PyGILState_Check()) { do_log(); } else { py_task_q.add_task(do_log); } } }; py::object LoggerWrapper::logger = py::none{}; uint32_t _get_dtype_num(py::object dtype) { return static_cast(npy::dtype_np2mgb(dtype.ptr()).enumv()); } py::bytes _get_serialized_dtype(py::object dtype) { std::string sdtype; auto write = [&sdtype](const void* data, size_t size) { auto pos = sdtype.size(); sdtype.resize(pos + size); memcpy(&sdtype[pos], data, size); }; mgb::serialization::serialize_dtype(npy::dtype_np2mgb(dtype.ptr()), write); return py::bytes(sdtype.data(), sdtype.size()); } int fork_exec_impl(const std::string& arg0, const std::string& arg1, const std::string& arg2) { #ifdef WIN32 STARTUPINFO si; PROCESS_INFORMATION pi; ZeroMemory(&si, sizeof(si)); si.cb = sizeof(si); ZeroMemory(&pi, sizeof(pi)); auto args_str = " " + arg1 + " " + arg2; // Start the child process. if (!CreateProcess(arg0.c_str(), // exe name const_cast(args_str.c_str()), // Command line NULL, // Process handle not inheritable NULL, // Thread handle not inheritable FALSE, // Set handle inheritance to FALSE 0, // No creation flags NULL, // Use parent's environment block NULL, // Use parent's starting directory &si, // Pointer to STARTUPINFO structure &pi) // Pointer to PROCESS_INFORMATION structure ) { mgb_log_warn("CreateProcess failed (%lu).\n", GetLastError()); fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]\n", arg0.c_str(), arg1.c_str(), arg2.c_str()); __builtin_trap(); } return pi.dwProcessId; #else auto pid = fork(); if (!pid) { execl(arg0.c_str(), arg0.c_str(), arg1.c_str(), arg2.c_str(), nullptr); fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]: %s\n", arg0.c_str(), arg1.c_str(), arg2.c_str(), std::strerror(errno)); std::terminate(); } mgb_assert(pid > 0, "failed to fork: %s", std::strerror(errno)); return pid; #endif } } // namespace void init_utils(py::module m) { auto atexit = py::module::import("atexit"); atexit.attr("register")(py::cpp_function([]() { g_global_finalized = true; })); py::class_>(m, "AtomicUint64") .def(py::init<>()) .def(py::init()) .def("load", [](const std::atomic& self) { return self.load(); }) .def("store", [](std::atomic& self, uint64_t value) { return self.store(value); }) .def("fetch_add", [](std::atomic& self, uint64_t value) { return self.fetch_add(value); }) .def("fetch_sub", [](std::atomic& self, uint64_t value) { return self.fetch_sub(value); }) .def(py::self += uint64_t()) .def(py::self -= uint64_t()); // FIXME!!! Should add a submodule instead of using a class for logger py::class_ logger(m, "Logger"); logger.def(py::init<>()) .def_static("set_log_level", &LoggerWrapper::set_log_level) .def_static("set_log_handler", &LoggerWrapper::set_log_handler); py::enum_(logger, "LogLevel") .value("Debug", LoggerWrapper::LogLevel::DEBUG) .value("Info", LoggerWrapper::LogLevel::INFO) .value("Warn", LoggerWrapper::LogLevel::WARN) .value("Error", LoggerWrapper::LogLevel::ERROR); m.def("_get_dtype_num", &_get_dtype_num, "Convert numpy dtype to internal dtype"); m.def("_get_serialized_dtype", &_get_serialized_dtype, "Convert numpy dtype to internal dtype for serialization"); m.def("_get_device_count", &mgb::CompNode::get_device_count, "Get total number of specific devices on this system"); using mgb::imperative::TensorSanityCheck; py::class_(m, "TensorSanityCheckImpl") .def(py::init<>()) .def("enable", [](TensorSanityCheck& checker) -> TensorSanityCheck& { checker.enable(); return checker; }) .def("disable", [](TensorSanityCheck& checker) { checker.disable(); }); #if MGB_ENABLE_OPR_MM m.def("create_mm_server", &create_zmqrpc_server, py::arg("addr"), py::arg("port") = 0); #else m.def("create_mm_server", []() {}); #endif // Debug code, internal only m.def("_set_defrag", [](bool enable) { mgb::imperative::BlobManager::inst()->set_enable(enable); }); m.def("_defrag", [](const mgb::CompNode& cn) { mgb::imperative::BlobManager::inst()->defrag(cn); }); m.def("_set_fork_exec_path_for_timed_func", [](const std::string& arg0, const ::std::string arg1) { using namespace std::placeholders; mgb::sys::TimedFuncInvoker::ins().set_fork_exec_impl(std::bind( fork_exec_impl, std::string{arg0}, std::string{arg1}, _1)); }); m.def("_timed_func_exec_cb", [](const std::string& user_data){ mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data.c_str()); }); using mgb::PersistentCache; class PyPersistentCache: public mgb::PersistentCache { private: using KeyPair = std::pair; using BlobPtr = std::unique_ptr; static size_t hash_key_pair(const KeyPair& kp) { std::hash hasher; return hasher(kp.first) ^ hasher(kp.second); } std::string blob_to_str(const Blob& key) { return std::string(reinterpret_cast(key.ptr), key.size); } BlobPtr copy_blob(const Blob& blob) { auto blob_deleter = [](Blob* blob){ if (blob) { std::free(const_cast(blob->ptr)); delete blob; } }; auto blob_ptr = BlobPtr{ new Blob(), blob_deleter }; blob_ptr->ptr = std::malloc(blob.size); std::memcpy(const_cast(blob_ptr->ptr), blob.ptr, blob.size); blob_ptr->size = blob.size; return blob_ptr; } BlobPtr str_to_blob(const std::string& str) { auto blob = Blob{ str.data(), str.size() }; return copy_blob(blob); } std::unique_ptr empty_blob() { return BlobPtr{ nullptr, [](Blob* blob){} }; } public: mgb::Maybe get(const std::string& category, const Blob& key) override { thread_local std::unordered_map m_local_cache; auto py_get = [this](const std::string& category, const Blob& key) -> mgb::Maybe { PYBIND11_OVERLOAD_PURE(mgb::Maybe, PersistentCache, get, category, key); }; KeyPair kp = { category, blob_to_str(key) }; auto iter = m_local_cache.find(kp); if (iter == m_local_cache.end()) { auto py_ret = py_get(category, key); if (!py_ret.valid()) { iter = m_local_cache.insert({kp, empty_blob()}).first; } else { iter = m_local_cache.insert({kp, copy_blob(py_ret.val())}).first; } } if (iter->second) { return *iter->second; } else { return {}; } } void put(const std::string& category, const Blob& key, const Blob& value) override { PYBIND11_OVERLOAD_PURE(void, PersistentCache, put, category, key, value); } }; py::class_>(m, "PersistentCache") .def(py::init<>()) .def("get", &PersistentCache::get) .def("put", &PersistentCache::put) .def("reg", &PersistentCache::set_impl); }