compiler.cpp 4.3 KB
Newer Older
M
Megvii Engine Team 已提交
1
#include "./mlir/compiler.h"
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
#include "./halide/compiler_cuda.h"
#include "./nvrtc/compiler_cuda.h"

#include "megbrain/jit/compiler.h"
#include "megbrain/utils/hash.h"

#if MGB_JIT

using namespace mgb;
using namespace jit;

namespace {
class CompilerHolder final : public UserDataContainer::UserData {
    MGB_TYPEINFO_OBJ_DECL;

public:
    std::mutex mtx;
    ThinHashMap<CompNode::DeviceType, std::unique_ptr<Compiler>> dev2compiler;
};
MGB_TYPEINFO_OBJ_IMPL(CompilerHolder);

}  // anonymous namespace

class Compiler::EmptyCompiler final : public Compiler {
public:
    Property property() const {
        return {Property::Flag::NONE, JITFeatureBits::NONE, 100};
    }

    size_t get_nr_workspace_outputs(JITExecutor*) const { return 0; }

    void init_workspace_size_infer(JITExecutor*) {}

M
Megvii Engine Team 已提交
35 36
    std::unique_ptr<Executable> do_compile(
            const InternalGraph&, const JITExecutor::Args&) {
37 38 39 40 41 42 43 44 45 46
        mgb_throw(InternalError, "EmptyCompiler should not be used");
    }
};

bool Compiler::is_supported_device(CompNode::DeviceType device) {
    switch (device) {
#if MGB_CUDA
        case CompNode::DeviceType::CUDA:
            return true;
#endif
47 48
        case CompNode::DeviceType::CPU:
            return true;
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
        default:
            return false;
    }
}

Compiler* Compiler::get(ComputingGraph& graph, CompNode comp_node) {
    static EmptyCompiler empty_compiler;
    if (comp_node == CompNode::default_cpu()) {
        // oprs in the internal graph are on default cpu; this case handles
        // nested JITExecutor
        return &empty_compiler;
    }

    CompilerHolder* holder;
    {
        static std::mutex mtx;
        MGB_LOCK_GUARD(mtx);
M
Megvii Engine Team 已提交
66
        holder = graph.options().user_data.get_user_data_or_create<CompilerHolder>();
67 68 69
    }
    MGB_LOCK_GUARD(holder->mtx);
    auto&& compiler = holder->dev2compiler[comp_node.device_type()];
70 71 72 73 74
    auto backend = ::std::getenv(ssprintf("%c%cB_JIT_BACKEND", 'M', 'G').c_str());
    mgb_assert(
            backend,
            "code issue happened, need call config_jit_backends before get compiler");
    //! please keep logic with JITFusionPass::Impl::config_jit_backends
75 76 77 78 79
    if (!compiler) {
        switch (comp_node.device_type()) {
#if MGB_CUDA
            case CompNode::DeviceType::CUDA:
#if MGB_JIT_HALIDE
80
                if (!strcmp(backend, "HALIDE")) {
81 82 83
                    compiler = std::make_unique<HalideCudaCompiler>();
                    break;
                }
84 85
#endif
#if MGB_JIT_MLIR
86
                if (!strcmp(backend, "MLIR")) {
M
Megvii Engine Team 已提交
87 88
                    compiler =
                            std::make_unique<MLIRCompiler>(CompNode::DeviceType::CUDA);
89 90
                    break;
                }
91
#endif
92
                if (!strcmp(backend, "NVRTC")) {
93 94 95
                    compiler = std::make_unique<CudaCompiler>();
                    break;
                }
96 97 98 99
                mgb_throw(
                        InternalError,
                        "No compiler support for cuda, may caused by build not enable "
                        "MLIR/HALIDE module or error config jit backend env");
100
                break;
101
#endif
102 103
            case CompNode::DeviceType::CPU:
#if MGB_JIT_MLIR
104
                if (!strcmp(backend, "MLIR")) {
M
Megvii Engine Team 已提交
105 106
                    compiler =
                            std::make_unique<MLIRCompiler>(CompNode::DeviceType::CPU);
107 108 109
                    break;
                }
#endif
110 111 112 113
                mgb_throw(
                        InternalError,
                        "No compiler support for cpu, may caused by build not enable "
                        "MLIR module or error config jit backend env");
114
                break;
115
            default:
M
Megvii Engine Team 已提交
116 117 118 119 120
                mgb_throw(
                        InternalError,
                        "unsupported JIT config: "
                        "comp_node=%s backend_setting=%s",
                        comp_node.to_string().c_str(), backend);
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        }
    }

    return compiler.get();
}

Executable* Compiler::compile(JITExecutor* opr) {
    MGB_LOCK_GUARD(m_mtx);
    auto&& args = opr->args();
    auto&& args_cache = m_expr_cache[&(opr->internal_graph())];
    auto q = args_cache.get(args);
    if (q.first) {
        *q.second = do_compile(opr->internal_graph(), opr->args());
    }
    return q.second->get();
}

#endif  // MGB_JIT

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}