ast_c.cpp 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
#include "megbrain/jit/ast_c.h"
#include "megbrain/jit/executor_opr.h"
#include "megbrain/opr/tensor_manip.h"

#if MGB_JIT

using namespace mgb;
using namespace jit;
using namespace ast_c;

namespace {
ASTPtr gen_powc(ASTPtr inp, float exp) {
    auto int_neg = [exp](ASTPtr x) {
        if (exp < 0) {
            return 1.f / x;
        }
        return x;
    };
    if (almost_equal(std::abs(exp), 0.f)) {
        return 1.f;
    }
    if (almost_equal(std::abs(exp), 1.f)) {
        return int_neg(inp);
    }
    if (almost_equal(std::abs(exp), 2.f)) {
        return int_neg(inp * inp);
    }
    if (almost_equal(std::abs(exp), 3.f)) {
        return int_neg(inp * inp * inp);
    }
    if (almost_equal(exp, 1.f / 3.f)) {
        return make_call("cbrtf", {inp});
    }
    if (almost_equal(exp, -1.f / 3.f)) {
        return make_call("rcbrtf", {inp});
    }
    if (almost_equal(exp, .5f)) {
        return make_call("sqrtf", {inp});
    }
    if (almost_equal(exp, -.5f)) {
        return make_call("rsqrtf", {inp});
    }
    int exp_i = std::round(exp);
    if (almost_equal(static_cast<float>(exp_i), exp)) {
        auto inp_abs = make_call("fabsf", {inp});
        if (exp_i & 1) {
            auto pow = make_call("powf", {inp_abs, exp});
            return make_call("copysign", {pow, inp});
        } else {
            return make_call("powf", {inp_abs, exp});
        }
    }

    return make_call("powf", {inp, exp});
}
}  // anonymous namespace

const ElemGeneratorMap& ast_c::elem_opr_generator() {
#define ENTRY(_mode, _impl)                                                \
    {                                                                      \
        ElemMode::_mode, {                                                 \
            [](const ASTPtrArray& inps) -> ASTPtrArray { return {_impl}; } \
        }                                                                  \
    }
    static ElemGeneratorMap map = {
            // unary
            ENTRY(RELU, make_call("fmaxf", {inps[0], 0.f})),
            ENTRY(ABS, make_call("fabsf", inps)),
            ENTRY(ACOS, make_call("acosf", inps)),
            ENTRY(ASIN, make_call("asinf", inps)),
            ENTRY(CEIL, make_call("ceilf", inps)),
            ENTRY(COS, make_call("cosf", inps)),
            ENTRY(EXP, make_call("expf", inps)),
            ENTRY(EXPM1, make_call("expm1f", inps)),
            ENTRY(FLOOR, make_call("floorf", inps)),
            ENTRY(LOG, make_call("logf", inps)),
            ENTRY(LOG1P, make_call("log1pf", inps)),
            ENTRY(NEGATE, make_call("-", inps)),
            ENTRY(SIGMOID, 1 / (1 + make_call("expf", {0 - inps[0]}))),
            ENTRY(SIN, make_call("sinf", inps)),
            ENTRY(TANH, make_call("tanhf", inps)),
            ENTRY(ERF, make_call("erff", inps)),
            ENTRY(ERFC, make_call("erfcf", inps)),
            ENTRY(H_SWISH,
                  inps[0] *
M
Megvii Engine Team 已提交
86 87 88
                          make_call(
                                  "fmaxf",
                                  {make_call("fminf", {inps[0] + 3.f, 6.f}), 0.f}) /
89 90 91
                          6.f),

            // binary
M
Megvii Engine Team 已提交
92
            ENTRY(ABS_GRAD, ASTPtr::make<Cond3AST>(inps[0] > 0, inps[1], -inps[1])),
93 94 95 96 97 98 99 100 101 102 103 104
            ENTRY(ADD, inps[0] + inps[1]),
            ENTRY(FLOOR_DIV, make_call("floorf", {inps[0] / inps[1]})),
            ENTRY(MAX, make_call("fmaxf", inps)),
            ENTRY(MIN, make_call("fminf", inps)),
            ENTRY(MOD, make_call("fmodf", inps)),
            ENTRY(MUL, inps[0] * inps[1]),
            ENTRY(POW, make_call("powf", inps)),
            ENTRY(SIGMOID_GRAD, inps[0] * (1 - inps[0]) * inps[1]),
            ENTRY(SUB, inps[0] - inps[1]),
            ENTRY(SWITCH_GT0, ASTPtr::make<Cond3AST>(inps[0] > 0, inps[1], 0)),
            ENTRY(TANH_GRAD, (1 - inps[0] * inps[0]) * inps[1]),
            ENTRY(TRUE_DIV, inps[0] / inps[1]),
M
Megvii Engine Team 已提交
105
            ENTRY(LOG_SUM_EXP, make_call("mgb_log_sum_exp", {inps[0], inps[1]})),
106 107 108 109 110 111 112 113 114 115 116 117 118 119
            ENTRY(LT, ASTPtr::make<BinaryAST>("<", inps[0], inps[1])),
            ENTRY(LEQ, ASTPtr::make<BinaryAST>("<=", inps[0], inps[1])),
            ENTRY(EQ, ASTPtr::make<BinaryAST>("==", inps[0], inps[1])),
            ENTRY(ATAN2, make_call("atan2f", inps)),
            ENTRY(H_SWISH_GRAD,
                  ASTPtr::make<Cond3AST>(
                          -inps[0] > 3.f, 0.f,
                          ASTPtr::make<Cond3AST>(
                                  inps[0] > 3.f, inps[1],
                                  (2.f * inps[0] + 3.f) * inps[1] / 6.f))),

            // misc
            ENTRY(COND_LEQ_MOV,
                  ASTPtr::make<BinaryAST>("<=", inps[0], inps[1]) * inps[2]),
120 121
            ENTRY(COND_LT_MOV,
                  ASTPtr::make<BinaryAST>("<", inps[0], inps[1]) * inps[2]),
122 123 124 125 126 127 128 129 130 131
            ENTRY(FUSE_MUL_ADD3, inps[0] * inps[1] + inps[2]),
            ENTRY(FUSE_MUL_ADD4, inps[0] * inps[1] + inps[2] * inps[3]),
            ENTRY(FUSE_ADD_RELU, make_call("fmaxf", {inps[0] + inps[1], 0})),
            ENTRY(FUSE_ADD_SIGMOID,
                  1 / (1 + make_call("expf", {-(inps[0] + inps[1])}))),
            ENTRY(FUSE_ADD_TANH, make_call("tanhf", {inps[0] + inps[1]})),
            ENTRY(FUSE_ADD_H_SWISH,
                  (inps[0] + inps[1]) *
                          make_call(
                                  "fmaxf",
M
Megvii Engine Team 已提交
132
                                  {make_call("fminf", {(inps[0] + inps[1]) + 3.f, 6.f}),
133 134 135
                                   0.f}) /
                          6.f),
    };
136
    mgb_assert(map.size() + 41 == opr::Elemwise::Param::MODE_NR_MEMBER);
137
    // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH,
138
    // ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF
139 140 141 142
    return map;
#undef ADD_OPR
}

M
Megvii Engine Team 已提交
143
ASTPtrArray ast_c::opr2AST(cg::OperatorNodeBase* opr, const ASTPtrArray& inputs) {
144 145 146
    using namespace opr;
    if (auto elem = gopt::try_cast_as_op<Elemwise>(opr)) {
        if (check_elem_mode(elem->param().mode)) {
M
Megvii Engine Team 已提交
147
            return elem_opr_generator().find(elem->param().mode)->second(inputs);
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        }
    }

    if (auto powc = gopt::try_cast_as_op<PowC>(opr)) {
        mgb_assert(inputs.size() == 1);
        return {gen_powc(inputs[0], powc->param().exp)};
    }

    auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar();
    if (imm.valid()) {
        auto dtype = imm->dtype();
        if (dtype == dtype::Int32{}) {
            return {ASTPtr::make<IntAST>(imm->get<int>())};
        }
        float scalar_value;
        if (dtype == dtype::Float32()) {
            scalar_value = imm->get<float>();
        } else if (dtype == dtype::Float16()) {
            scalar_value = imm->get<dt_float16>();
        } else {
M
Megvii Engine Team 已提交
168 169 170
            mgb_throw(
                    InternalError, "dtype(%s) is not any of [Float16, Float32, Int32]",
                    dtype.name());
171 172 173 174 175 176 177 178 179 180
        }
        return {ASTPtr::make<FloatAST>(scalar_value)};
    }

    if (opr->same_type<opr::TypeCvt>()) {
        // simply ignore TypeCvt oprs.
        mgb_assert(inputs.size() == 1);
        return inputs;
    }

M
Megvii Engine Team 已提交
181 182 183
    mgb_throw(
            InternalError, "unknown opr %s{%s}", opr->cname(),
            opr->dyn_typeinfo()->name);
184 185 186 187 188
}

#endif  // MGB_JIT

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}