From 170897f2e1d5cba36e50f761d0c696aa41c704f6 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 22 Oct 2020 08:27:23 +0800 Subject: [PATCH] feat(mgb/jit): add more elemwise modes for mlir backend GitOrigin-RevId: 5883c688041b069675634a00b6be96fd481bd93a --- src/jit/impl/mlir/ir/common.cpp | 11 +- src/jit/impl/mlir/ir/common.h | 2 + src/jit/impl/mlir/ir/each_mode.h | 106 +++++++++++- src/jit/impl/mlir/ir/numerical.cpp | 248 +++++++++++++++++++++++++++++ src/jit/impl/mlir/ir/numerical.h | 46 ++++++ src/jit/impl/mlir/ir/ops.td | 16 +- src/jit/test/codegen.cpp | 76 +++++++-- 7 files changed, 473 insertions(+), 32 deletions(-) create mode 100644 src/jit/impl/mlir/ir/numerical.cpp create mode 100644 src/jit/impl/mlir/ir/numerical.h diff --git a/src/jit/impl/mlir/ir/common.cpp b/src/jit/impl/mlir/ir/common.cpp index 6979a335d..58d8c0a0a 100644 --- a/src/jit/impl/mlir/ir/common.cpp +++ b/src/jit/impl/mlir/ir/common.cpp @@ -29,6 +29,8 @@ cb(sub, SubFOp); cb(mul, MulFOp); cb(div, DivFOp); cb(mod, RemFOp); +cb(bit_and, AndOp); +cb(bit_or, OrOp); #undef cb #define cb(name, mode) \ @@ -72,6 +74,7 @@ cb(exp, ExpOp); cb(exp2, Exp2Op); cb(log10, Log10Op); cb(log2, Log2Op); +cb(log, LogOp); cb(rsqrt, RsqrtOp); cb(sin, SinOp); cb(sqrt, SqrtOp); @@ -79,7 +82,8 @@ cb(tanh, TanhOp); #undef cb mlir::Value ValueBuilderHelper::abs(mlir::Value lhs) { - return max(lhs, const_val(0.f)); + auto zero = const_val(0.f); + return select(ge(lhs, zero), lhs, sub(zero, lhs)); } mlir::Value ValueBuilderHelper::floor(mlir::Value lhs) { @@ -87,11 +91,6 @@ mlir::Value ValueBuilderHelper::floor(mlir::Value lhs) { return neg(ceil(neg(lhs))); } -mlir::Value ValueBuilderHelper::log(mlir::Value lhs) { - // math.log10(math.e) = 0.4342944819032518f - return div(log10(lhs), const_val(0.4342944819032518f)); -} - mlir::Value ValueBuilderHelper::select(mlir::Value cond, mlir::Value true_val, mlir::Value false_val) { return m_builder.create(m_location, cond, true_val, diff --git a/src/jit/impl/mlir/ir/common.h b/src/jit/impl/mlir/ir/common.h index 616ea88b4..7ea03cad6 100644 --- a/src/jit/impl/mlir/ir/common.h +++ b/src/jit/impl/mlir/ir/common.h @@ -47,6 +47,8 @@ public: cb(lt); cb(le); cb(eq); + cb(bit_and); + cb(bit_or); #undef cb mlir::Value const_val(float val); diff --git a/src/jit/impl/mlir/ir/each_mode.h b/src/jit/impl/mlir/ir/each_mode.h index 53f2b5369..b7cab3ef4 100644 --- a/src/jit/impl/mlir/ir/each_mode.h +++ b/src/jit/impl/mlir/ir/each_mode.h @@ -18,6 +18,7 @@ #include "megbrain/jit/mlir/ir/dialect.h" #include "./common.h" +#include "./numerical.h" #include #include @@ -28,6 +29,8 @@ cb(ReluOp, RELU) \ cb(AbsOp, ABS) \ cb(NegOp, NEGATE) \ + cb(AcosOp, ACOS) \ + cb(AsinOp, ASIN) \ cb(CeilOp, CEIL) \ cb(CosOp, COS) \ cb(ExpOp, EXP) \ @@ -40,7 +43,11 @@ cb(FastTanhOp, FAST_TANH) \ cb(HswishOp, H_SWISH) \ cb(ExpM1Op, EXPM1) \ - cb(RoundOp, ROUND) + cb(RoundOp, ROUND) \ + cb(ErfOp, ERF) \ + cb(ErfInvOp, ERFINV) \ + cb(ErfCOp, ERFC) \ + cb(ErfCInvOp, ERFCINV) #define MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb) \ cb(AbsGradOp, ABS_GRAD) \ @@ -52,6 +59,7 @@ cb(SubOp, SUB) \ cb(MulOp, MUL) \ cb(TrueDivOp, TRUE_DIV) \ + cb(PowOp, POW) \ cb(SigmoidGradOp, SIGMOID_GRAD) \ cb(SwishGt0Op, SWITCH_GT0) \ cb(TanhGradOp, TANH_GRAD) \ @@ -64,7 +72,8 @@ cb(FastTanhGradOp, FAST_TANH_GRAD) \ cb(FuseAddSigmoidOp, FUSE_ADD_SIGMOID) \ cb(HswishGradOp, H_SWISH_GRAD) \ - cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) + cb(FuseAddHswishOp, FUSE_ADD_H_SWISH) \ + cb(Atan2Op, ATAN2) #define MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb) \ cb(CondLeqMovOp, COND_LEQ_MOV) \ @@ -197,6 +206,79 @@ struct StandardOp { } }; +//! pi / 2 - arctan2(x, sqrt(1 - x * x)) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto x = operands[0]; + auto one_minus_x_2 = helper.sub(helper.const_val(1.f), helper.mul(x, x)); + auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); + auto pi_over_2 = helper.const_val(1.57079637f); + return helper.sub(pi_over_2, asin); + } +}; + +//! arctan2(x, sqrt(1 - x * x)) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto x = operands[0]; + auto one_minus_x_2 = helper.sub(helper.const_val(1.f), helper.mul(x, x)); + return atan2_approx(helper, x, helper.sqrt(one_minus_x_2)); + } +}; + +//! gauss error function +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return erf_approx(helper, operands[0]); + } +}; + +//! inverse of gauss error function +//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto sqrt2 = helper.const_val(1.4142135623f); + auto x = helper.mul(helper.const_val(0.5f), + helper.add(operands[0], helper.const_val(1.f))); + return helper.div(ndtri_approx(helper, x), sqrt2); + } +}; + +//! complementary error function +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.sub(helper.const_val(1.f), erf_approx(helper, operands[0])); + } +}; + +//! inverse of complementary gauss error function +//! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + auto minus_sqrt2 = helper.const_val(-1.4142135623f); + auto x = helper.mul(helper.const_val(0.5f), operands[0]); + return helper.div(ndtri_approx(helper, x), minus_sqrt2); + } +}; + /////////////////////////// binary op /////////////////////////// //! binary: x > 0 ? y : -y @@ -210,6 +292,16 @@ struct StandardOp { } }; +//! x^y = exp(y * log(x)) +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return helper.exp(helper.mul(operands[1], helper.log(operands[0]))); + } +}; + //! x * (1 - x) * y template <> struct StandardOp { @@ -382,6 +474,16 @@ struct StandardOp { } }; +//! arctan +template <> +struct StandardOp { + mlir::Value operator()(mlir::OpBuilder& builder, mlir::Location loc, + ValueRange operands) { + ValueBuilderHelper helper(builder, loc); + return atan2_approx(helper, operands[0], operands[1]); + } +}; + /////////////////////////// ternary op /////////////////////////// //! x <= y ? z : ctype(0) template <> diff --git a/src/jit/impl/mlir/ir/numerical.cpp b/src/jit/impl/mlir/ir/numerical.cpp new file mode 100644 index 000000000..ee66cd9b9 --- /dev/null +++ b/src/jit/impl/mlir/ir/numerical.cpp @@ -0,0 +1,248 @@ +/** + * \file src/jit/impl/mlir/ir/numerical.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include "numerical.h" + +namespace mgb { +namespace jit { + +mlir::Value polynomial(ValueBuilderHelper& helper, mlir::Value x, + std::vector& coeff) { + size_t n = coeff.size(); + if (n == 0) { + return helper.const_val(0); + } + + mlir::Value r = coeff[0]; + for (size_t i = 1; i < n; i++) { + r = helper.add(helper.mul(r, x), coeff[i]); + } + return r; +} + +// polynomial approximation of arctangent +// atan(t) = t + c3 * t^3 + c5 * t^5 + ... + c17 * t^17 +// original paper: +// https://arxiv.org/pdf/1508.03211.pdf +mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y, + mlir::Value x) { + auto atan_poly = [&](mlir::Value t) { + std::vector coeff = { + helper.const_val(2.90188402868807315826416015625E-3), + helper.const_val(-1.62907354533672332763671875E-2), + helper.const_val(4.3082617223262786865234375E-2), + helper.const_val(-7.5408883392810821533203125E-2), + helper.const_val(0.1066047251224517822265625), + helper.const_val(-0.14209578931331634521484375), + helper.const_val(0.19993579387664794921875), + helper.const_val(-0.3333314359188079833984375)}; + auto t2 = helper.mul(t, t); + auto p = polynomial(helper, t2, coeff); + return helper.add(helper.mul(helper.mul(p, t2), t), t); + }; + + // constants + auto zero = helper.const_val(0); + auto pi = helper.const_val(3.141592653589793); + auto pi_over_2 = helper.const_val(1.570796326794897); + + // transform the angle into interval [0, pi/4] + auto ax = helper.abs(x); + auto ay = helper.abs(y); + auto q = helper.div(helper.min(ax, ay), helper.max(ax, ay)); + + // get approximation for interval [0, pi/4] + auto r = atan_poly(q); + + // [0, pi/4] => [0, pi/2] + r = helper.select(helper.le(ax, ay), helper.sub(pi_over_2, r), r); + + // [0, pi/2] => [0, pi] + r = helper.select(helper.le(x, zero), helper.sub(pi, r), r); + + // [0, pi] => [-pi, pi] + r = helper.select(helper.le(y, zero), helper.sub(zero, r), r); + + return r; +} + +// numerical approximation of gauss error function +// https://en.wikipedia.org/wiki/Error_function#Polynomial +// original book: +// Numerical Recipes in Fortran 77: The Art of Scientific Computing +mlir::Value erf_approx(ValueBuilderHelper& helper, mlir::Value x) { + auto zero = helper.const_val(0); + auto one = helper.const_val(1); + auto half = helper.const_val(0.5); + + auto t = helper.div(one, helper.add(one, helper.mul(half, helper.abs(x)))); + + std::vector coeff = { + helper.const_val(0.17087277), + helper.const_val(-0.82215223), + helper.const_val(1.48851587), + helper.const_val(-1.13520398), + helper.const_val(0.27886807), + helper.const_val(-0.18628806), + helper.const_val(0.09678418), + helper.const_val(0.37409196), + helper.const_val(1.00002368), + helper.const_val(-1.26551223)}; + auto p = polynomial(helper, t, coeff); + + auto r = helper.mul(t, helper.exp(helper.sub(p, helper.mul(x, x)))); + return helper.select(helper.ge(x, zero), + helper.sub(one, r), + helper.sub(r, one)); +} + +// numerical approximation of the inverse of normal distribution function +// original algorithm: +// https://github.com/scipy/scipy/blob/master/scipy/special/cephes/ndtri.c +// case 1: 0 < x < exp(-2) +// z = sqrt(-2 * log(x)) +// t = 1 / z +// res = log(z) / z - z + t * P(t) / Q(t) +// where coefficients of P and Q are different +// for z < 8 and for z >= 8 +// +// case2: exp(-2) <= x <= 1 - exp(-2) +// w = x - 0.5 +// res = sqrt(2pi) * (w + w^3 * R(w^2) / S(w^2)) +// +// case3: 1 - exp(-2) < x < 1 +// 0 < 1 - x < exp(-2) +// ndtri(x) = -ndtri(1 - x) +// fallback to case 1 +mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x) { + // polynomial P + auto P = [&](mlir::Value i, mlir::Value cond) { + std::vector coeff0 = { + helper.const_val(4.05544892305962419923E0), + helper.const_val(3.15251094599893866154E1), + helper.const_val(5.71628192246421288162E1), + helper.const_val(4.40805073893200834700E1), + helper.const_val(1.46849561928858024014E1), + helper.const_val(2.18663306850790267539E0), + helper.const_val(-1.40256079171354495875E-1), + helper.const_val(-3.50424626827848203418E-2), + helper.const_val(-8.57456785154685413611E-4)}; + std::vector coeff1 = { + helper.const_val(3.23774891776946035970E0), + helper.const_val(6.91522889068984211695E0), + helper.const_val(3.93881025292474443415E0), + helper.const_val(1.33303460815807542389E0), + helper.const_val(2.01485389549179081538E-1), + helper.const_val(1.23716634817820021358E-2), + helper.const_val(3.01581553508235416007E-4), + helper.const_val(2.65806974686737550832E-6), + helper.const_val(6.23974539184983293730E-9)}; + return helper.select(cond, + polynomial(helper, i, coeff0), + polynomial(helper, i, coeff1)); + }; + + // polynomial Q + auto Q = [&](mlir::Value i, mlir::Value cond) { + std::vector coeff0 = { + helper.const_val(1.f), + helper.const_val(1.57799883256466749731E1), + helper.const_val(4.53907635128879210584E1), + helper.const_val(4.13172038254672030440E1), + helper.const_val(1.50425385692907503408E1), + helper.const_val(2.50464946208309415979E0), + helper.const_val(-1.42182922854787788574E-1), + helper.const_val(-3.80806407691578277194E-2), + helper.const_val(-9.33259480895457427372E-4)}; + std::vector coeff1 = { + helper.const_val(1.f), + helper.const_val(6.02427039364742014255E0), + helper.const_val(3.67983563856160859403E0), + helper.const_val(1.37702099489081330271E0), + helper.const_val(2.16236993594496635890E-1), + helper.const_val(1.34204006088543189037E-2), + helper.const_val(3.28014464682127739104E-4), + helper.const_val(2.89247864745380683936E-6), + helper.const_val(6.79019408009981274425E-9)}; + return helper.select(cond, + polynomial(helper, i, coeff0), + polynomial(helper, i, coeff1)); + }; + + // polynomial R + auto R = [&](mlir::Value i) { + std::vector coeff = { + helper.const_val(-5.99633501014107895267E1), + helper.const_val(9.80010754185999661536E1), + helper.const_val(-5.66762857469070293439E1), + helper.const_val(1.39312609387279679503E1), + helper.const_val(-1.23916583867381258016E0)}; + return polynomial(helper, i, coeff); + }; + + // polynomial S + auto S = [&](mlir::Value i) { + std::vector coeff = { + helper.const_val(1.f), + helper.const_val(1.95448858338141759834E0), + helper.const_val(4.67627912898881538453E0), + helper.const_val(8.63602421390890590575E1), + helper.const_val(-2.25462687854119370527E2), + helper.const_val(2.00260212380060660359E2), + helper.const_val(-8.20372256168333339912E1), + helper.const_val(1.59056225126211695515E1), + helper.const_val(-1.18331621121330003142E0)}; + return polynomial(helper, i, coeff); + }; + + // constants + auto zero = helper.const_val(0); + auto one = helper.const_val(1); + auto half = helper.const_val(0.5); + auto eight = helper.const_val(8); + auto minus_2 = helper.const_val(-2); + auto exp_minus_2 = helper.const_val(0.135335283236); // exp(-2) + auto sqrt_2pi = helper.const_val(2.506628274631); // sqrt(2pi) + + // conditions + auto case1 = helper.lt(x, exp_minus_2); // x < exp(-2) + auto case3 = helper.gt(x, helper.sub(one, exp_minus_2)); // x > 1 - exp(-2) + auto case13 = helper.bit_or(case1, case3); + + // case1 or case3 + auto x13 = helper.select(case1, x, helper.sub(one, x)); // x or (1 - x) + auto z = helper.sqrt(helper.mul(minus_2, helper.log(x13))); + auto z_lt_8 = helper.lt(z, eight); + auto t = helper.div(one, z); + auto res1 = helper.add(helper.sub(helper.div(helper.log(z), z), z), + helper.div(helper.mul(t, P(t, z_lt_8)), Q(t, z_lt_8))); + auto res13 = helper.select(case1, res1, helper.sub(zero, res1)); + + // case2 + auto w = helper.sub(x, half); + auto w2 = helper.mul(w, w); + auto w3 = helper.mul(w, w2); + auto res2 = helper.mul( + sqrt_2pi, helper.add(w, helper.div(helper.mul(w3, R(w2)), S(w2)))); + + return helper.select(case13, res13, res2); +} + +} // namespace jit +} // namespace mgb + +#endif // MGB_JIT && MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/numerical.h b/src/jit/impl/mlir/ir/numerical.h new file mode 100644 index 000000000..9550701e5 --- /dev/null +++ b/src/jit/impl/mlir/ir/numerical.h @@ -0,0 +1,46 @@ +/** + * \file src/jit/impl/mlir/ir/numerical.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once + +#include "megbrain_build_config.h" +#if MGB_JIT && MGB_JIT_MLIR + +#include + +#include "./common.h" + +namespace mgb { +namespace jit { + +/*! polynomial of degree N: + * C_0 + C_1 * x + C_2 * x^2 + ... + C_N * x^N + * where coeff = [C_N, ..., C_2, C_1, C_0] + */ +mlir::Value polynomial(ValueBuilderHelper& helper, mlir::Value x, + std::vector& coeff); + +//! numerical approximation of arctangent +mlir::Value atan2_approx(ValueBuilderHelper& helper, mlir::Value y, mlir::Value x); + +//! numerical approximation of gauss error function +mlir::Value erf_approx(ValueBuilderHelper& helper, mlir::Value x); + +//! numerical approximation of the inverse of normal distribution function +mlir::Value ndtri_approx(ValueBuilderHelper& helper, mlir::Value x); + +} // namespace jit +} // namespace mgb + +#endif // MGB_JIT && MGB_JIT_MLIR + +// vim: syntax=cpp.doxygen diff --git a/src/jit/impl/mlir/ir/ops.td b/src/jit/impl/mlir/ir/ops.td index f18a5ad0f..960897604 100644 --- a/src/jit/impl/mlir/ir/ops.td +++ b/src/jit/impl/mlir/ir/ops.td @@ -68,8 +68,8 @@ class ElemwiseUnaryOp traits = [NoSideEffect]> : def ReluOp : ElemwiseUnaryOp<"relu", [NoSideEffect]>; def AbsOp : ElemwiseUnaryOp<"abs", [NoSideEffect]>; def NegOp : ElemwiseUnaryOp<"negate", [NoSideEffect]>; -/* ACOS */ -/* ASIN */ +def AcosOp : ElemwiseUnaryOp<"acos", [NoSideEffect]>; +def AsinOp : ElemwiseUnaryOp<"asin", [NoSideEffect]>; def CeilOp : ElemwiseUnaryOp<"ceil", [NoSideEffect]>; def CosOp : ElemwiseUnaryOp<"cos", [NoSideEffect]>; def ExpOp : ElemwiseUnaryOp<"exp", [NoSideEffect]>; @@ -83,10 +83,10 @@ def TanhOp : ElemwiseUnaryOp<"tanh", [NoSideEffect]>; def FastTanhOp : ElemwiseUnaryOp<"fast_tanh", [NoSideEffect]>; def HswishOp : ElemwiseUnaryOp<"hswish", [NoSideEffect]>; def RoundOp : ElemwiseUnaryOp<"round", [NoSideEffect]>; -/* ERF */ -/* ERFINV */ -/* ERFC */ -/* ERFCINV */ +def ErfOp : ElemwiseUnaryOp<"erf", [NoSideEffect]>; +def ErfInvOp : ElemwiseUnaryOp<"erfinv", [NoSideEffect]>; +def ErfCOp : ElemwiseUnaryOp<"erfc", [NoSideEffect]>; +def ErfCInvOp : ElemwiseUnaryOp<"erfcinv", [NoSideEffect]>; class ElemwiseBinaryOp traits = [NoSideEffect]> : ElemwiseOp { @@ -130,14 +130,14 @@ def LeqOp : ElemwiseBinaryOp<"leq", [NoSideEffect]>; def EqOp : ElemwiseBinaryOp<"eq", [Commutative, NoSideEffect]>; def FuseAddReluOp : ElemwiseBinaryOp<"fuse_add_relu", [NoSideEffect]>; def TrueDivOp : ElemwiseBinaryOp<"true_div", [NoSideEffect]>; -/* POW */ +def PowOp : ElemwiseBinaryOp<"pow", [NoSideEffect]>; def LogSumExpOp : ElemwiseBinaryOp<"log_sum_exp", [Commutative, NoSideEffect]>; def FuseAddTanhOp : ElemwiseBinaryOp<"fuse_add_tanh", [NoSideEffect]>; def FastTanhGradOp : ElemwiseBinaryOp<"fast_tanh_grad", [NoSideEffect]>; def FuseAddSigmoidOp : ElemwiseBinaryOp<"fuse_add_sigmoid", [NoSideEffect]>; def HswishGradOp : ElemwiseBinaryOp<"hswish_grad", [NoSideEffect]>; def FuseAddHswishOp : ElemwiseBinaryOp<"fuse_add_hswish", [NoSideEffect]>; -/* ATAN2 */ +def Atan2Op : ElemwiseBinaryOp<"atan2", [NoSideEffect]>; class ElemwiseTernaryOp traits = [NoSideEffect]> : ElemwiseOp { diff --git a/src/jit/test/codegen.cpp b/src/jit/test/codegen.cpp index 3f8a90b2a..0c6768206 100644 --- a/src/jit/test/codegen.cpp +++ b/src/jit/test/codegen.cpp @@ -159,22 +159,48 @@ void run_mlir(CompNode cn) { MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); } +struct MlirTestOpt { + float low; + float high; + float maxerr; +}; + +struct MlirTestOpt get_mode_opt(opr::Elemwise::Mode mode) { + struct MlirTestOpt opt = {0, 1, 1e-6}; + if (mode == opr::Elemwise::Mode::ABS) { + opt.low = -10; + opt.high = 10; + } else if (mode == opr::Elemwise::Mode::LOG) { + opt.low = 0.1; + opt.high = 4; + } else if (mode == opr::Elemwise::Mode::ERF or + mode == opr::Elemwise::Mode::ERFC) { + opt.low = -5; + opt.high = 5; + } else if (mode == opr::Elemwise::Mode::ERFINV) { + opt.low = -0.999; + opt.high = 0.999; + opt.maxerr = 1e-4; + } else if (mode == opr::Elemwise::Mode::ERFCINV) { + opt.low = 0.001; + opt.high = 1.999; + opt.maxerr = 1e-4; + } + return opt; +} + template void run_mlir_mode(CompNode cn) { set_backend(Backend::MLIR); auto graph = ComputingGraph::make(); - float low = 0.f, high = 1.f; - if (tag::mode == opr::Elemwise::Mode::LOG) { - low = 0.1; - high = 4; - } - HostTensorGenerator gen(low, - high); + auto opt = get_mode_opt(tag::mode); + HostTensorGenerator gen(opt.low, + opt.high); SmallVector> hosts; VarNodeArray input_vars; for (int i = 0; i < arity; i++) { - hosts.push_back(gen({23, 42}, cn)); + hosts.push_back(gen({2323, 4242}, cn)); input_vars.push_back( opr::Host2DeviceCopy::make(*graph, hosts[i]).node()); } @@ -198,7 +224,7 @@ void run_mlir_mode(CompNode cn) { make_callback_copy(y_jit, host_y_jit)}); func->execute(); - MGB_ASSERT_TENSOR_EQ(host_y, host_y_jit); + MGB_ASSERT_TENSOR_NEAR(host_y, host_y_jit, opt.maxerr); } #endif @@ -240,18 +266,25 @@ TEST(TestJITMlirCodeGen, BasicGPU) { cb(RELU) \ cb(ABS) \ cb(NEGATE) \ + cb(ACOS) \ + cb(ASIN) \ cb(CEIL) \ cb(EXP) \ cb(FLOOR) \ cb(LOG) \ cb(LOG1P) \ cb(SIN) \ + cb(COS) \ cb(TANH) \ cb(FAST_TANH) \ cb(H_SWISH) \ cb(SIGMOID) \ cb(EXPM1) \ - cb(ROUND) + cb(ROUND) \ + cb(ERF) \ + cb(ERFINV) \ + cb(ERFC) \ + cb(ERFCINV) // clang-format on template class TestJITMlirUnaryElemwise : public ::testing::Test {}; @@ -268,21 +301,27 @@ FOREACH_UNARY_MODE(def_tag) ::testing::Types; #undef t TYPED_TEST_CASE(TestJITMlirUnaryElemwise, mlir_elemwise_unary_types); -TYPED_TEST(TestJITMlirUnaryElemwise, run) { - auto cn = CompNode::load("cpu0"); - run_mlir_mode(cn); -} #define SKIP_MODE(_mode) \ if (TypeParam::mode == opr::Elemwise::Mode::_mode) { \ printf("skip\n"); \ return; \ } + +TYPED_TEST(TestJITMlirUnaryElemwise, run) { + auto cn = CompNode::load("cpu0"); + + SKIP_MODE(ROUND); + + run_mlir_mode(cn); +} + TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { REQUIRE_GPU(1); auto cn = CompNode::load("gpu0"); SKIP_MODE(SIN); + SKIP_MODE(ROUND); run_mlir_mode(cn); } @@ -298,6 +337,7 @@ TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { cb(MOD) \ cb(SUB) \ cb(TRUE_DIV) \ + cb(POW) \ cb(ABS_GRAD) \ cb(SIGMOID_GRAD) \ cb(SWITCH_GT0) \ @@ -311,7 +351,8 @@ TYPED_TEST(TestJITMlirUnaryElemwise, runGpu) { cb(FAST_TANH_GRAD) \ cb(FUSE_ADD_SIGMOID) \ cb(H_SWISH_GRAD) \ - cb(FUSE_ADD_H_SWISH) + cb(FUSE_ADD_H_SWISH) \ + cb(ATAN2) // clang-format on template class TestJITMlirBinaryElemwise : public ::testing::Test {}; @@ -336,6 +377,9 @@ TYPED_TEST(TestJITMlirBinaryElemwise, run) { TYPED_TEST(TestJITMlirBinaryElemwise, runGpu) { REQUIRE_GPU(1); auto cn = CompNode::load("gpu0"); + + SKIP_MODE(MOD); + run_mlir_mode(cn); } @@ -373,7 +417,7 @@ TYPED_TEST(TestJITMlirTernaryElemwise, runGpu) { #undef SKIP_MODE -#endif +#endif // MGB_JIT_MLIR #endif // MGB_JIT -- GitLab