common.h 2.9 KB
Newer Older
1 2 3 4 5 6 7 8
/**
 * \file src/jit/impl/mlir/ir/common.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16
 */

#pragma once

#include "megbrain_build_config.h"
#if MGB_JIT && MGB_JIT_MLIR
17

18
#include "megbrain/tensor.h"
19

20 21
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/OperationSupport.h>
22 23 24 25 26
#include <mlir/IR/Value.h>

namespace mgb {
namespace jit {

27 28 29 30 31 32 33 34 35 36 37
/**
 * \brief Helper function for common value builder
 */
class ValueBuilderHelper {
public:
    ValueBuilderHelper(mlir::OpBuilder& b, mlir::Location location)
            : m_builder{b}, m_location{location} {};

#define cb(name)                                                              \
    mlir::Value name(mlir::ValueRange operands) { return name(operands[0]); } \
    mlir::Value name(mlir::Value lhs)
38 39

    // unary functions
40 41 42 43 44
    cb(abs);
    cb(ceil);
    cb(cos);
    cb(exp);
    cb(exp2);
45 46
    cb(floor);
    cb(log);
47 48
    cb(log10);
    cb(log2);
49
    cb(neg);
50 51 52 53
    cb(rsqrt);
    cb(sin);
    cb(sqrt);
    cb(tanh);
54

55 56
#undef cb

57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
#define cb(name)                                  \
    mlir::Value name(mlir::ValueRange operands) { \
        return name(operands[0], operands[1]);    \
    }                                             \
    mlir::Value name(mlir::Value lhs, mlir::Value rhs)

    // binary functions
    cb(add);
    cb(bit_and);
    cb(bit_or);
    cb(div);
    cb(divI);
    cb(eq);
    cb(ge);
    cb(gt);
    cb(le);
    cb(lt);
    cb(max);
    cb(min);
    cb(mod);
    cb(modI);
    cb(mul);
    cb(sub);

#undef cb

    // constant functions
    mlir::Value const_f32(float val);
    mlir::Value const_i32(int32_t val);

    // select function
88 89 90 91 92 93 94
    mlir::Value select(mlir::Value cond, mlir::Value true_val,
                       mlir::Value false_val);

private:
    mlir::OpBuilder& m_builder;
    mlir::Location m_location;
};
95

96 97 98 99 100 101 102 103 104 105
template <typename Op>
mlir::Value get_operand(mlir::OpBuilder& builder, const mlir::Location& loc,
                        const mlir::Value& val, const mlir::ValueRange& index) {
    if (val.getType().isa<mlir::MemRefType>()) {
        return builder.create<Op>(loc, val, index);
    } else {
        return val;
    }
}

106 107 108 109 110 111 112 113 114
mlir::AffineMap get_affinemap(mlir::OpBuilder& builder, const mlir::Value& val,
                              const TensorLayout& layout);

mlir::Value get_affine_load_op(mlir::OpBuilder& builder,
                               const mlir::Location& loc,
                               const mlir::Value& val,
                               const mlir::ValueRange& index,
                               const TensorLayout& dst);

115 116 117
}  // namespace jit
}  // namespace mgb

M
Megvii Engine Team 已提交
118
#endif  // MGB_JIT && MGB_JIT_MLIR
119

120
// vim: syntax=cpp.doxygen