function_replace.cpp 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/**
 * \file python_module/src/cpp/function_replace.cpp
 *
 * This file is part of MegBrain, a deep learning framework developed by Megvii.
 *
 * \brief replace functions in megbrain core
 *
 * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 */

#include "./megbrain_wrap.h"
#include "./python_helper.h"

#include "megbrain/utils/debug.h"
#include "megbrain/common.h"
#include "megbrain/system.h"

#include <stdexcept>
#include <cstring>
#include <cstdarg>

#include <Python.h>
24 25 26 27 28

#ifdef WIN32
#include <windows.h>
#include <stdio.h>
#else
29
#include <unistd.h>
30
#endif
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

namespace {

PyObject *logger = nullptr;

#if MGB_ENABLE_DEBUG_UTIL
void throw_fork_cuda_exc() {
    // set python error state, so when returning to parent process that calls
    // fork(), an exception could be raised
    //
    // call chain:
    // python -> fork() -> pthread_atfork -> CudaCheckOnFork ->
    // ForkAfterCudaError::throw_
    mgb_log_warn("try to raise python exception for fork after cuda");
    PyErr_SetString(PyExc_SystemError, "fork after cuda has been initialized");
}
#endif

class Init {
    static Init inst;
    Init() {
#if MGB_ENABLE_DEBUG_UTIL
        mgb::debug::ForkAfterCudaError::throw_ = throw_fork_cuda_exc;
#endif
    }
};
Init Init::inst;

59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
int fork_exec_impl(const std::string& arg0, const std::string& arg1,
                   const std::string& arg2) {
#ifdef WIN32
    STARTUPINFO si;
    PROCESS_INFORMATION pi;
    ZeroMemory(&si, sizeof(si));
    si.cb = sizeof(si);
    ZeroMemory(&pi, sizeof(pi));
    auto args_str = " " + arg1 + " " + arg2;

    // Start the child process.
    if (!CreateProcess(arg0.c_str(),                         // exe name
                       const_cast<char*>(args_str.c_str()),  // Command line
                       NULL,   // Process handle not inheritable
                       NULL,   // Thread handle not inheritable
                       FALSE,  // Set handle inheritance to FALSE
                       0,      // No creation flags
                       NULL,   // Use parent's environment block
                       NULL,   // Use parent's starting directory
                       &si,    // Pointer to STARTUPINFO structure
                       &pi)    // Pointer to PROCESS_INFORMATION structure
    ) {
        mgb_log_warn("CreateProcess failed (%lu).\n", GetLastError());
        fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]\n",
                arg0.c_str(), arg1.c_str(), arg2.c_str());
        __builtin_trap();
    }
    return pi.dwProcessId;
#else
88 89 90 91 92 93 94 95 96 97
    auto pid = fork();
    if (!pid) {
        execl(arg0.c_str(), arg0.c_str(), arg1.c_str(), arg2.c_str(), nullptr);
        fprintf(stderr, "[megbrain] failed to execl %s [%s, %s]: %s\n",
                arg0.c_str(), arg1.c_str(), arg2.c_str(),
                std::strerror(errno));
        std::terminate();
    }
    mgb_assert(pid > 0, "failed to fork: %s", std::strerror(errno));
    return pid;
98
#endif
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
}

} // anonymous namespace

// called from swig/misc.i
void _timed_func_set_fork_exec_path(const char *arg0, const char *arg1) {
    using namespace std::placeholders;
    mgb::sys::TimedFuncInvoker::ins().set_fork_exec_impl(
            std::bind(fork_exec_impl, std::string{arg0}, std::string{arg1},
                _1));
}

void _timed_func_exec_cb(const char *user_data) {
    mgb::sys::TimedFuncInvoker::ins().fork_exec_impl_mainloop(user_data);
}

void _register_logger(PyObject *l) {
    logger = l;
}

namespace {
void py_log_handler(mgb::LogLevel level,
        const char *file, const char *func, int line, const char *fmt,
        va_list ap) {
    if (global_finalized()) {
        return;
    }

    using mgb::LogLevel;

    MGB_MARK_USED_VAR(file);
    MGB_MARK_USED_VAR(func);
    MGB_MARK_USED_VAR(line);

    if (!logger)
        return;

    PYTHON_GIL;

    const char *py_type;
    switch (level) {
        case LogLevel::DEBUG:
            py_type = "debug";
            break;
        case LogLevel::INFO:
            py_type = "info";
            break;
        case LogLevel::WARN:
            py_type = "warning";
            break;
        case LogLevel::ERROR:
            py_type = "error";
            break;
        default:
            throw std::runtime_error("bad log level");
    }

    std::string msg = mgb::svsprintf(fmt, ap);
    PyObject *py_msg = Py_BuildValue("s", msg.c_str());
    PyObject_CallMethod(logger,
            const_cast<char*>(py_type), const_cast<char*>("O"),
            py_msg);
    Py_DECREF(py_msg);
}

class LogHandlerSetter {
    static LogHandlerSetter ins;
    public:
        LogHandlerSetter() {
            mgb::set_log_handler(py_log_handler);
        }
};
LogHandlerSetter LogHandlerSetter::ins;
} // anobymous namespace

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}