ast_c.cpp 7.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
#include "megbrain/jit/ast_c.h"
#include "megbrain/jit/executor_opr.h"
#include "megbrain/opr/tensor_manip.h"

#if MGB_JIT

using namespace mgb;
using namespace jit;
using namespace ast_c;

namespace {
ASTPtr gen_powc(ASTPtr inp, float exp) {
    auto int_neg = [exp](ASTPtr x) {
        if (exp < 0) {
            return 1.f / x;
        }
        return x;
    };
    if (almost_equal(std::abs(exp), 0.f)) {
        return 1.f;
    }
    if (almost_equal(std::abs(exp), 1.f)) {
        return int_neg(inp);
    }
    if (almost_equal(std::abs(exp), 2.f)) {
        return int_neg(inp * inp);
    }
    if (almost_equal(std::abs(exp), 3.f)) {
        return int_neg(inp * inp * inp);
    }
    if (almost_equal(exp, 1.f / 3.f)) {
        return make_call("cbrtf", {inp});
    }
    if (almost_equal(exp, -1.f / 3.f)) {
        return make_call("rcbrtf", {inp});
    }
    if (almost_equal(exp, .5f)) {
        return make_call("sqrtf", {inp});
    }
    if (almost_equal(exp, -.5f)) {
        return make_call("rsqrtf", {inp});
    }
    int exp_i = std::round(exp);
    if (almost_equal(static_cast<float>(exp_i), exp)) {
        auto inp_abs = make_call("fabsf", {inp});
        if (exp_i & 1) {
            auto pow = make_call("powf", {inp_abs, exp});
            return make_call("copysign", {pow, inp});
        } else {
            return make_call("powf", {inp_abs, exp});
        }
    }

    return make_call("powf", {inp, exp});
}
56

57 58
}  // anonymous namespace

59 60 61 62 63 64 65 66 67
const ElemGeneratorMap& ast_c::elem_opr_generator(CompNode::DeviceType device_type) {
#define ENTRY(_mode, _impl)                                             \
    {                                                                   \
        ElemMode::_mode, {                                              \
            [=](const ASTPtrArray& inps, bool is_half) -> ASTPtrArray { \
                MGB_MARK_USED_VAR(is_half);                             \
                return {_impl};                                         \
            }                                                           \
        }                                                               \
68
    }
69 70 71

    //! other backends map
    static ElemGeneratorMap other_map = {
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
            // unary
            ENTRY(RELU, make_call("fmaxf", {inps[0], 0.f})),
            ENTRY(ABS, make_call("fabsf", inps)),
            ENTRY(ACOS, make_call("acosf", inps)),
            ENTRY(ASIN, make_call("asinf", inps)),
            ENTRY(CEIL, make_call("ceilf", inps)),
            ENTRY(COS, make_call("cosf", inps)),
            ENTRY(EXP, make_call("expf", inps)),
            ENTRY(EXPM1, make_call("expm1f", inps)),
            ENTRY(FLOOR, make_call("floorf", inps)),
            ENTRY(LOG, make_call("logf", inps)),
            ENTRY(LOG1P, make_call("log1pf", inps)),
            ENTRY(NEGATE, make_call("-", inps)),
            ENTRY(SIGMOID, 1 / (1 + make_call("expf", {0 - inps[0]}))),
            ENTRY(SIN, make_call("sinf", inps)),
            ENTRY(TANH, make_call("tanhf", inps)),
            ENTRY(ERF, make_call("erff", inps)),
            ENTRY(ERFC, make_call("erfcf", inps)),
            ENTRY(H_SWISH,
                  inps[0] *
M
Megvii Engine Team 已提交
92 93 94
                          make_call(
                                  "fmaxf",
                                  {make_call("fminf", {inps[0] + 3.f, 6.f}), 0.f}) /
95 96 97
                          6.f),

            // binary
M
Megvii Engine Team 已提交
98
            ENTRY(ABS_GRAD, ASTPtr::make<Cond3AST>(inps[0] > 0, inps[1], -inps[1])),
99 100 101 102 103 104 105 106 107 108 109 110
            ENTRY(ADD, inps[0] + inps[1]),
            ENTRY(FLOOR_DIV, make_call("floorf", {inps[0] / inps[1]})),
            ENTRY(MAX, make_call("fmaxf", inps)),
            ENTRY(MIN, make_call("fminf", inps)),
            ENTRY(MOD, make_call("fmodf", inps)),
            ENTRY(MUL, inps[0] * inps[1]),
            ENTRY(POW, make_call("powf", inps)),
            ENTRY(SIGMOID_GRAD, inps[0] * (1 - inps[0]) * inps[1]),
            ENTRY(SUB, inps[0] - inps[1]),
            ENTRY(SWITCH_GT0, ASTPtr::make<Cond3AST>(inps[0] > 0, inps[1], 0)),
            ENTRY(TANH_GRAD, (1 - inps[0] * inps[0]) * inps[1]),
            ENTRY(TRUE_DIV, inps[0] / inps[1]),
111
            ENTRY(LOG_SUM_EXP, make_call("jit_log_sum_exp", {inps[0], inps[1]})),
112 113 114 115 116 117 118 119 120 121 122 123 124 125
            ENTRY(LT, ASTPtr::make<BinaryAST>("<", inps[0], inps[1])),
            ENTRY(LEQ, ASTPtr::make<BinaryAST>("<=", inps[0], inps[1])),
            ENTRY(EQ, ASTPtr::make<BinaryAST>("==", inps[0], inps[1])),
            ENTRY(ATAN2, make_call("atan2f", inps)),
            ENTRY(H_SWISH_GRAD,
                  ASTPtr::make<Cond3AST>(
                          -inps[0] > 3.f, 0.f,
                          ASTPtr::make<Cond3AST>(
                                  inps[0] > 3.f, inps[1],
                                  (2.f * inps[0] + 3.f) * inps[1] / 6.f))),

            // misc
            ENTRY(COND_LEQ_MOV,
                  ASTPtr::make<BinaryAST>("<=", inps[0], inps[1]) * inps[2]),
126 127
            ENTRY(COND_LT_MOV,
                  ASTPtr::make<BinaryAST>("<", inps[0], inps[1]) * inps[2]),
128 129 130 131 132 133 134 135 136 137
            ENTRY(FUSE_MUL_ADD3, inps[0] * inps[1] + inps[2]),
            ENTRY(FUSE_MUL_ADD4, inps[0] * inps[1] + inps[2] * inps[3]),
            ENTRY(FUSE_ADD_RELU, make_call("fmaxf", {inps[0] + inps[1], 0})),
            ENTRY(FUSE_ADD_SIGMOID,
                  1 / (1 + make_call("expf", {-(inps[0] + inps[1])}))),
            ENTRY(FUSE_ADD_TANH, make_call("tanhf", {inps[0] + inps[1]})),
            ENTRY(FUSE_ADD_H_SWISH,
                  (inps[0] + inps[1]) *
                          make_call(
                                  "fmaxf",
M
Megvii Engine Team 已提交
138
                                  {make_call("fminf", {(inps[0] + inps[1]) + 3.f, 6.f}),
139 140 141
                                   0.f}) /
                          6.f),
    };
142
    mgb_assert(other_map.size() + 41 == opr::Elemwise::Param::MODE_NR_MEMBER);
143
    // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH,
144
    // ERFINV, ERFCINV, NOT, AND, OR, XOR, NEQ, ISNAN, ISINF
145 146

    return other_map;
147 148 149
#undef ADD_OPR
}

150 151 152
ASTPtrArray ast_c::opr2AST(
        cg::OperatorNodeBase* opr, const ASTPtrArray& inputs,
        CompNode::DeviceType device_type) {
153 154
    using namespace opr;
    if (auto elem = gopt::try_cast_as_op<Elemwise>(opr)) {
155 156 157 158
        if (check_elem_mode(elem->param().mode, device_type)) {
            return elem_opr_generator(device_type)
                    .find(elem->param().mode)
                    ->second(inputs, false);
159 160 161 162
        }
    }

    if (auto powc = gopt::try_cast_as_op<PowC>(opr)) {
163

164 165 166 167 168 169 170 171
        mgb_assert(inputs.size() == 1);
        return {gen_powc(inputs[0], powc->param().exp)};
    }

    auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar();
    if (imm.valid()) {
        auto dtype = imm->dtype();
        if (dtype == dtype::Int32{}) {
172

173 174 175 176 177 178 179 180
            return {ASTPtr::make<IntAST>(imm->get<int>())};
        }
        float scalar_value;
        if (dtype == dtype::Float32()) {
            scalar_value = imm->get<float>();
        } else if (dtype == dtype::Float16()) {
            scalar_value = imm->get<dt_float16>();
        } else {
M
Megvii Engine Team 已提交
181 182 183
            mgb_throw(
                    InternalError, "dtype(%s) is not any of [Float16, Float32, Int32]",
                    dtype.name());
184
        }
185 186

        return {ASTPtr::make<FloatAST>(scalar_value, device_type, false)};
187 188 189
    }

    if (opr->same_type<opr::TypeCvt>()) {
190

191 192 193 194 195
        // simply ignore TypeCvt oprs.
        mgb_assert(inputs.size() == 1);
        return inputs;
    }

M
Megvii Engine Team 已提交
196 197 198
    mgb_throw(
            InternalError, "unknown opr %s{%s}", opr->cname(),
            opr->dyn_typeinfo()->name);
199 200 201 202 203
}

#endif  // MGB_JIT

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}