function_replace.cpp 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
/**
 * \file python_module/src/cpp/function_replace.cpp
 *
 * This file is part of MegBrain, a deep learning framework developed by Megvii.
 *
 * \brief replace functions in megbrain core
 *
 * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 */

#include "./megbrain_wrap.h"
#include "./python_helper.h"

#include "megbrain/utils/debug.h"
#include "megbrain/common.h"
#include "megbrain/system.h"

#include <stdexcept>
#include <cstring>
#include <cstdarg>

#include <Python.h>
#include <unistd.h>

namespace {

PyObject *logger = nullptr;

#if MGB_ENABLE_DEBUG_UTIL
void throw_fork_cuda_exc() {
    // set python error state, so when returning to parent process that calls
    // fork(), an exception could be raised
    //
    // call chain:
    // python -> fork() -> pthread_atfork -> CudaCheckOnFork ->
    // ForkAfterCudaError::throw_
    mgb_log_warn("try to raise python exception for fork after cuda");
    PyErr_SetString(PyExc_SystemError, "fork after cuda has been initialized");
}
#endif

class Init {
    static Init inst;
    Init() {
#if MGB_ENABLE_DEBUG_UTIL
        mgb::debug::ForkAfterCudaError::throw_ = throw_fork_cuda_exc;
#endif
    }
};
Init Init::inst;

int fork_exec_impl(const std::string &arg0, const std::string &arg1,
        const std::string &arg2) {
    auto pid = fork();
    if (!pid) {
        execl(arg0.c_str(), arg0.c_str(), arg1.c_str(), arg2.c_str(), nullptr);
        fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]: %s\n",
                arg0.c_str(), arg1.c_str(), arg2.c_str(),
                std::strerror(errno));
        std::terminate();
    }
    mgb_assert(pid > 0, "failed to fork: %s", std::strerror(errno));
    return pid;
}

} // anonymous namespace

// called from swig/misc.i
void _timed_func_set_fork_exec_path(const char *arg0, const char *arg1) {
    using namespace std::placeholders;
    mgb::sys::TimedFuncInvoker::ins().set_fork_exec_impl(
            std::bind(fork_exec_impl, std::string{arg0}, std::string{arg1},
                _1));
}

void _timed_func_exec_cb(const char *user_data) {
    mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data);
}

void _register_logger(PyObject *l) {
    logger = l;
}

namespace {
void py_log_handler(mgb::LogLevel level,
        const char *file, const char *func, int line, const char *fmt,
        va_list ap) {
    if (global_finalized()) {
        return;
    }

    using mgb::LogLevel;

    MGB_MARK_USED_VAR(file);
    MGB_MARK_USED_VAR(func);
    MGB_MARK_USED_VAR(line);

    if (!logger)
        return;

    PYTHON_GIL;

    const char *py_type;
    switch (level) {
        case LogLevel::DEBUG:
            py_type = "debug";
            break;
        case LogLevel::INFO:
            py_type = "info";
            break;
        case LogLevel::WARN:
            py_type = "warning";
            break;
        case LogLevel::ERROR:
            py_type = "error";
            break;
        default:
            throw std::runtime_error("bad log level");
    }

    std::string msg = mgb::svsprintf(fmt, ap);
    PyObject *py_msg = Py_BuildValue("s", msg.c_str());
    PyObject_CallMethod(logger,
            const_cast<char*>(py_type), const_cast<char*>("O"),
            py_msg);
    Py_DECREF(py_msg);
}

class LogHandlerSetter {
    static LogHandlerSetter ins;
    public:
        LogHandlerSetter() {
            mgb::set_log_handler(py_log_handler);
        }
};
LogHandlerSetter LogHandlerSetter::ins;
} // anobymous namespace

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}