helper.cpp 6.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
#include "./helper.h"

#if MGB_JIT

#include "megbrain/gopt/framework.h"
#include "megbrain/jit/executor_opr.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/rand.h"
#include "megbrain/opr/tensor_manip.h"

using namespace mgb;
using namespace jit;

void jit::set_backend(Backend backend) {
    switch (backend) {
        case Backend::NONE:
            setenv("MGB_JIT_BACKEND", "non_exist", 1);
            return;
        case Backend::HALIDE:
            setenv("MGB_JIT_BACKEND", "HALIDE", 1);
            return;
        case Backend::NVRTC:
            setenv("MGB_JIT_BACKEND", "NVRTC", 1);
            return;
27 28 29
        case Backend::MLIR:
            setenv("MGB_JIT_BACKEND", "MLIR", 1);
            return;
30 31 32
        case Backend::TINYOPENCL:
            setenv("MGB_JIT_BACKEND", "TINYOPENCL", 1);
            return;
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        default:
            mgb_assert(0);
    }
}

std::vector<cg::OperatorNodeBase*> jit::get_rev_topo_order(
        SymbolVar nd, ThinHashSet<VarNode*> endpoints_set) {
    std::vector<cg::OperatorNodeBase*> topo;
    thin_function<void(VarNode*)> dfs;
    dfs = [&](VarNode* p) {
        if (endpoints_set.count(p)) {
            return;
        }
        endpoints_set.insert(p);
        for (auto i : p->owner_opr()->input()) {
            dfs(i);
        }
        topo.push_back(p->owner_opr());
    };
    dfs(nd.node());
    std::reverse(topo.begin(), topo.end());
    return topo;
}

FusionChecker& FusionChecker::disable_inp_grad() {
    for (size_t i = 0; i < m_nr_input; ++i) {
        m_disable_inp_grad.insert(i);
    }
    return *this;
}

void FusionChecker::ensure_init_graph() {
    if (m_jit_y.node())
        return;

    m_graph = ComputingGraph::make();
    SymbolVarArray inputs(m_nr_input);
    for (size_t i = 0; i < m_nr_input; ++i) {
        inputs[i] = opr::Host2DeviceCopy::make(*m_graph, m_inputs_val[i])
                            .rename(ssprintf("inp%zu", i));

        auto dt = m_idx2dtype.find(i);
        if (dt != m_idx2dtype.end()) {
            inputs[i] = opr::TypeCvt::make(inputs[i], dt->second);
        }
    }
    m_truth_y = m_exp_func(inputs);

    SymbolVar jit_y;
    if (m_direct_build) {
M
Megvii Engine Team 已提交
83 84
        auto ig_gen =
                std::make_unique<InternalGraphGenerator>(m_truth_y.node()->owner_opr());
85 86 87 88 89 90
        ThinHashSet<VarNode*> endpoints_set;
        for (size_t i = 0; i < m_nr_input; ++i) {
            endpoints_set.insert(inputs[i].node());
        }
        for (auto&& opr : get_rev_topo_order(m_truth_y, endpoints_set))
            ig_gen->add_opr(opr);
M
Megvii Engine Team 已提交
91
        jit_y = JITExecutor::make(ig_gen->generate(), cg::to_var_node_array(inputs));
92 93 94
    } else {
        ComputingGraph::Options opt;
        opt.graph_opt_level = 3;
M
Megvii Engine Team 已提交
95
        opt.graph_opt.jit = m_jit_level;
M
Megvii Engine Team 已提交
96 97 98 99 100 101
        unpack_vector(
                gopt::GraphOptimizer{}
                        .add_preset_passes(true, nullptr, &opt)
                        .apply({{m_truth_y}})
                        .endpoint_vars(),
                jit_y);
102 103 104 105 106 107 108

        size_t nr_jit_opr = 0;
        cg::DepOprIter{[&nr_jit_opr, this](cg::OperatorNodeBase* opr) {
            if (opr->same_type<JITExecutor>()) {
                ++nr_jit_opr;
            } else {
                static const ThinHashSet<Typeinfo*> allowed_types{
M
Megvii Engine Team 已提交
109 110 111 112 113
                        opr::Host2DeviceCopy::typeinfo(), opr::GetVarShape::typeinfo()};
                mgb_throw_if(
                        m_check_opr_type && !allowed_types.count(opr->dyn_typeinfo()),
                        InternalError, "encountered non-JIT opr after fusion: %s{%s}",
                        opr->cname(), opr->dyn_typeinfo()->name);
114
            }
M
Megvii Engine Team 已提交
115
        }}.add(jit_y.node());
116 117 118 119 120 121 122 123 124 125 126 127 128 129
        mgb_assert(nr_jit_opr == 1);
    }

    SymbolVar loss_var0, loss_var1;
    SmallVector<std::tuple<size_t, SymbolVar, SymbolVar>> grad_vars;
    for (size_t i = 0; i < m_nr_input; ++i) {
        if (!m_disable_inp_grad.count(i)) {
            if (!loss_var1.node()) {
                auto y0 = m_truth_y.flatten(), y1 = jit_y.flatten(),
                     coeff = opr::TypeCvt::make(
                             opr::UniformRNG::make(y0.symshape()), y0.dtype());
                loss_var0 = opr::Dot::make(y0, coeff);
                loss_var1 = opr::Dot::make(y1, coeff);
            }
M
Megvii Engine Team 已提交
130 131
            grad_vars.emplace_back(
                    i, cg::grad(loss_var0, inputs[i]), cg::grad(loss_var1, inputs[i]));
132 133 134 135 136 137 138
        }
    }

    m_outputs_val.resize(grad_vars.size() + 1);

    ComputingGraph::OutputSpec outspec(m_outputs_val.size() * 2);
    std::get<0>(m_outputs_val[0]) = -1;
M
Megvii Engine Team 已提交
139 140
    outspec[0] = make_callback_copy(m_truth_y, std::get<1>(m_outputs_val[0]), false);
    outspec[1] = make_callback_copy(jit_y, std::get<2>(m_outputs_val[0]), false);
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185

    for (size_t i = 0; i < grad_vars.size(); ++i) {
        auto&& dst = m_outputs_val[i + 1];
        auto&& src = grad_vars[i];
        std::get<0>(dst) = std::get<0>(src);
        outspec[i * 2 + 2] =
                make_callback_copy(std::get<1>(src), std::get<1>(dst), false);
        outspec[i * 2 + 3] =
                make_callback_copy(std::get<2>(src), std::get<2>(dst), false);
    }

    m_func = m_graph->compile(outspec);
}

FusionChecker& FusionChecker::run(const TensorShapeArray& input_shapes) {
    if (::testing::Test::HasFailure()) {
        return *this;
    }
    mgb_assert(input_shapes.size() == m_nr_input);
    if (m_inputs_val.empty()) {
        m_inputs_val.resize(m_nr_input);
        for (size_t i = 0; i < m_nr_input; ++i) {
            m_inputs_val[i] = m_input_gen(input_shapes[i]);
        }
    } else {
        for (size_t i = 0; i < m_nr_input; ++i) {
            *m_inputs_val[i] = *m_input_gen(input_shapes[i]);
        }
    }

    ensure_init_graph();
    m_func->execute().wait();
    auto chk = [this]() {
        for (auto&& i : m_outputs_val) {
            MGB_ASSERT_TENSOR_NEAR(std::get<1>(i), std::get<2>(i), 1e-5)
                    << ssprintf("failed for input %zd", std::get<0>(i));
        }
    };
    chk();
    return *this;
}

#endif  // MGB_JIT

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}