mlir_gen.cpp 8.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/**
 * \file src/jit/impl/mlir/mlir_gen.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR

#include "./mlir_gen.h"
17
#include "./ir/each_mode.h"
18
#include "./ir/types.h"
19

20
#include "megbrain/jit/mlir/ir/dialect.h"
21
#include "megbrain/jit/mlir/ir/utils.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
#include "megbrain/opr/basic_arith.h"
#include "megdnn/dtype.h"

#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/Types.h>
#include <mlir/IR/Value.h>
#include <mlir/IR/Verifier.h>
#include <mlir/Support/LogicalResult.h>

#include <llvm/ADT/ScopedHashTable.h>
#include <llvm/Support/raw_ostream.h>

using namespace mgb;
using namespace jit;

namespace {
class MLIRGenImpl {
public:
    MLIRGenImpl(mlir::MLIRContext& context) : m_builder(&context) {}

    std::pair<llvm::StringRef, mlir::OwningModuleRef> gen(
            const InternalGraph& internal_graph,
            const JITExecutor::Args& args) {
        mlir::ModuleOp module =
                mlir::ModuleOp::create(m_builder.getUnknownLoc());

        //! Create main routine function
        auto func_op = gen_func_op(internal_graph, args);
        module.push_back(func_op);

        if (mlir::failed(mlir::verify(module))) {
            module.emitError("module verification error");
            return {};
        }

        return {func_op.getName(), module};
    }

private:
    mlir::OpBuilder m_builder;
    llvm::ScopedHashTable<mlir::StringRef, mlir::Value> m_symbol_table;

    mlir::FuncOp gen_func_op(const InternalGraph& internal_graph,
                             const JITExecutor::Args& args) {
        llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(
                m_symbol_table);
        std::vector<mlir::Type> func_args;
        for (auto&& arg : args.inputs) {
77
            func_args.push_back(get_type(arg.from->layout()));
78 79
        }
        for (auto&& arg : args.outputs) {
80
            func_args.push_back(get_type(arg.from->layout()));
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        }
        //! the last arg is nr_elements
        func_args.push_back(m_builder.getIndexType());

        auto func_type = m_builder.getFunctionType(func_args, llvm::None);
        //! function name maybe renamed in later pass
        mlir::FuncOp func_op = mlir::FuncOp::create(m_builder.getUnknownLoc(),
                                                    "func", func_type);
        if (!func_op)
            return nullptr;

        func_op.setAttr("llvm.emit_c_interface",
                        mlir::UnitAttr::get(m_builder.getContext()));
        auto& entry_block = *func_op.addEntryBlock();
        size_t idx = 0;
        for (auto&& input : args.inputs) {
            if (mlir::failed(declare(internal_graph.placeholders()[input.idx]
                                             ->output(0)
                                             ->name(),
                                     entry_block.getArgument(idx)))) {
                return nullptr;
            }
            idx++;
        }
        for (auto&& output : args.outputs) {
            if (mlir::failed(declare(output.from->name(),
                                     entry_block.getArgument(idx)))) {
                return nullptr;
            }
            idx++;
        }

        m_builder.setInsertionPointToStart(&entry_block);

        if (mlir::failed(gen_func_body(internal_graph, args))) {
            func_op.erase();
            return nullptr;
        }

120
        dialect::ReturnOp return_op;
121
        if (!return_op) {
122
            m_builder.create<dialect::ReturnOp>(m_builder.getUnknownLoc());
123
        }
124
        std::string op_content = mlir_type_to_string(func_op);
125 126 127 128 129 130 131 132 133 134 135 136 137 138
        func_op.setName(
                ssprintf("jit_mlir_%" PRIx64,
                         XXHash{}.update(op_content.data(), op_content.size())
                                 .digest()));
        return func_op;
    }

    mlir::LogicalResult gen_func_body(const InternalGraph& internal_graph,
                                      const JITExecutor::Args& args) {
        llvm::ScopedHashTableScope<llvm::StringRef, mlir::Value> var_scope(
                m_symbol_table);
        cg::DepOprIter{[&](cg::OperatorNodeBase* opr) {
            if (opr->same_type<JITPlaceholder>()) {
                return;
139
            } else if (opr->same_type<opr::ImmutableTensor>()) {
140 141 142 143 144 145 146 147 148 149 150 151
                auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar();
                if (imm.valid()) {
                    auto dtype = imm->dtype();
                    float scalar_value;
                    if (dtype == dtype::Float32()) {
                        scalar_value = imm->get<float>();
                    } else {
                        mgb_throw(InternalError,
                                  "mlir backend currently only support f32 "
                                  "dtype, but got %s",
                                  dtype.name());
                    }
152
                    auto&& out = m_builder.create<dialect::ConstantScalarOp>(
153 154 155 156 157
                            m_builder.getUnknownLoc(), m_builder.getF32Type(),
                            m_builder.getF32FloatAttr(scalar_value));
                    mgb_assert(mlir::succeeded(
                            declare(opr->output(0)->name(), out)));
                }
158 159 160 161 162 163 164
            } else if (opr->same_type<opr::Elemwise>()) {
                auto&& out = gen_elemwise(opr->cast_final<opr::Elemwise>());
                mgb_assert(
                        mlir::succeeded(declare(opr->output(0)->name(), out)));
                return;
            } else if (opr->same_type<opr::TypeCvt>()) {
                auto&& out = gen_typecvt(opr->cast_final<opr::TypeCvt>());
165 166 167
                mgb_assert(
                        mlir::succeeded(declare(opr->output(0)->name(), out)));
            }
168 169
        }}
                .add(internal_graph.output());
170 171 172
        m_builder.create<dialect::AssignOp>(m_builder.getUnknownLoc(),
                                            get(internal_graph.output()),
                                            get(args.outputs[0].from));
173 174 175 176

        return mlir::success();
    }

177 178 179 180
    mlir::Value gen_elemwise(const opr::Elemwise& opr) {
        llvm::SmallVector<mlir::Value, 4> operands;
        for (size_t i = 0; i < opr.input().size(); i++) {
            operands.push_back(get(opr.input(i)));
181
        }
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
        mlir::Type res_type = deduce_elemwise_res_type(operands);
        return m_builder.create<dialect::Elemwise>(
                m_builder.getUnknownLoc(), res_type, mlir::ValueRange(operands),
                opr.param().mode);
    }

    mlir::Value gen_typecvt(const opr::TypeCvt& opr) {
        auto shape = get(opr.input(0))
                             .getType()
                             .dyn_cast_or_null<mlir::MemRefType>()
                             .getShape();
        auto res_type = mlir::MemRefType::get(
                shape,
                megdnn_dtype_to_mlir_type(opr.param(), m_builder.getContext()));
        return m_builder.create<dialect::TypeCvt>(
                m_builder.getUnknownLoc(), res_type, get(opr.input(0)),
                opr.input(0)->dtype(), opr.param());
199 200 201
    }

    mlir::Type get_type(const TensorLayout& layout) {
202
        return layout_to_mlir_type(layout, m_builder);
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
    }

    mlir::Value get(const VarNode* var) {
        if (auto ret = m_symbol_table.lookup(var->name())) {
            return ret;
        }
        mgb_throw(InternalError, "Unknown var: %s", var->cname());
    }

    mlir::LogicalResult declare(llvm::StringRef var, mlir::Value value) {
        if (m_symbol_table.count(var)) {
            return mlir::failure();
        }
        m_symbol_table.insert(var, value);
        return mlir::success();
    }
};
}  // namespace

std::pair<llvm::StringRef, mlir::OwningModuleRef> mgb::jit::mlir_gen(
        mlir::MLIRContext& context,
        const mgb::jit::InternalGraph& internal_graph,
        const mgb::jit::JITExecutor::Args& args) {
    return MLIRGenImpl(context).gen(internal_graph, args);
}

M
Megvii Engine Team 已提交
229
#endif  // MGB_JIT && MGB_JIT_MLIR
230 231

// vim: syntax=cpp.doxygen