module.cpp 1.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
#include <pybind11/eval.h>

#define DO_IMPORT_ARRAY
#include "./numpy_dtypes.h"
#include "./helper.h"

#include "./common.h"
#include "./utils.h"
#include "./imperative_rt.h"
#include "./graph_rt.h"
#include "./ops.h"

namespace py = pybind11;

#ifndef MODULE_NAME
#define MODULE_NAME imperative_rt
#endif

PYBIND11_MODULE(MODULE_NAME, m) {
    // initialize numpy
    if ([]() {import_array1(1); return 0;}()) {
        throw py::error_already_set();
    }

    py::module::import("sys").attr("modules")[m.attr("__name__")] = m;

    m.attr("__package__") = m.attr("__name__");
    m.attr("__builtins__") = py::module::import("builtins");

    auto atexit = py::module::import("atexit");
    atexit.attr("register")(py::cpp_function([]() {
        py::gil_scoped_release _;
        py_task_q.wait_all_task_finish();
    }));

    auto common = submodule(m, "common");
    auto utils = submodule(m, "utils");
    auto imperative = submodule(m, "imperative");
    auto graph = submodule(m, "graph");
    auto ops = submodule(m, "ops");

    init_common(common);
    init_utils(utils);
    init_imperative_rt(imperative);
    init_graph_rt(graph);
    init_ops(ops);

    py::exec(R"(
        from .common import *
        from .utils import *
        from .imperative import *
        from .graph import *
        )",
        py::getattr(m, "__dict__"));
}