module.cpp 1.8 KB
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4 5 6 7 8 9 10 11
/**
 * \file imperative/python/src/module.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12 13 14 15 16 17 18 19 20 21 22 23
#include <pybind11/eval.h>

#define DO_IMPORT_ARRAY
#include "./numpy_dtypes.h"
#include "./helper.h"

#include "./common.h"
#include "./utils.h"
#include "./imperative_rt.h"
#include "./graph_rt.h"
#include "./ops.h"

24 25
#include "./dispatcher.h"

26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
namespace py = pybind11;

#ifndef MODULE_NAME
#define MODULE_NAME imperative_rt
#endif

PYBIND11_MODULE(MODULE_NAME, m) {
    // initialize numpy
    if ([]() {import_array1(1); return 0;}()) {
        throw py::error_already_set();
    }

    py::module::import("sys").attr("modules")[m.attr("__name__")] = m;

    m.attr("__package__") = m.attr("__name__");
    m.attr("__builtins__") = py::module::import("builtins");

    auto atexit = py::module::import("atexit");
    atexit.attr("register")(py::cpp_function([]() {
        py::gil_scoped_release _;
        py_task_q.wait_all_task_finish();
    }));

    auto common = submodule(m, "common");
    auto utils = submodule(m, "utils");
    auto imperative = submodule(m, "imperative");
    auto graph = submodule(m, "graph");
    auto ops = submodule(m, "ops");

    init_common(common);
    init_utils(utils);
    init_imperative_rt(imperative);
    init_graph_rt(graph);
    init_ops(ops);

    py::exec(R"(
        from .common import *
        from .utils import *
        from .imperative import *
        from .graph import *
        )",
        py::getattr(m, "__dict__"));
68 69

    init_dispatcher(submodule(m, "dispatcher"));
70
}