// Copyright (c) 2021 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/cinn/common/arithmatic.h" #include #include #include #include #include #include "paddle/cinn/common/ir_util.h" #include "paddle/cinn/ir/ir_operators.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/ir_visitor.h" #include "paddle/cinn/utils/string.h" namespace cinn { namespace common { using utils::GetStreamCnt; using utils::Join; using utils::Replace; using utils::Split; using namespace ir; // NOLINT #ifdef As #undef As #endif std::string ExprToGinacConverter::Repr(const ir::Expr& expr) { auto* load_n = expr.As(); auto* var_n = expr.As<_Var_>(); auto* broadcast_n = expr.As(); auto* mod_n = expr.As(); auto* min_n = expr.As(); auto* max_n = expr.As(); auto* div_n = expr.As
(); auto* frac_n = expr.As(); if (load_n || broadcast_n || mod_n || min_n || max_n || div_n || frac_n) { std::string repr = GetStreamCnt(expr); Replace(&repr, "[", "lsq_"); Replace(&repr, "]", "_rsq"); Replace(&repr, "(", "lb_"); Replace(&repr, ")", "_rb"); Replace(&repr, "+", "_add_"); Replace(&repr, "-", "_sub_"); Replace(&repr, ":", "_ref_"); Replace(&repr, "*", "_mul_"); Replace(&repr, "/", "_div_"); // remove the spaces auto fields = utils::Split(repr, " "); repr = utils::Join(fields, "_"); return repr; } else if (var_n) { return utils::GetStreamCnt(expr); } return ""; } void ExprToGinacConverter::RecordExpr(const ir::Expr& expr) { repr_to_expr_[Repr(expr)] = expr; } GiNaC::ex ExprToGinacConverter::BuildHelper(ir::Expr expr) { auto* load_n = expr.As(); auto* var_n = expr.As<_Var_>(); auto* int_n = expr.As(); auto* float_n = expr.As(); auto* add_n = expr.As(); auto* sub_n = expr.As(); auto* mul_n = expr.As(); auto* div_n = expr.As
(); auto* minus_n = expr.As(); auto* broadcast_n = expr.As(); auto* mod_n = expr.As(); auto* frac_n = expr.As(); auto* min_n = expr.As(); auto* max_n = expr.As(); bool is_integer_math = expr.type().is_int(); bool is_invalid_arith = load_n || var_n || broadcast_n || mod_n || min_n || max_n; if (is_integer_math) is_invalid_arith = is_invalid_arith || div_n || frac_n; // GiNac can't deal with integer division. if (is_invalid_arith) { RecordExpr(expr); std::string repr = Repr(expr); return CreateGinacSymbol(repr); } else if (int_n) { return int_n->value; } else if (float_n) { return float_n->value; } else if (add_n) { auto a = BuildHelper(add_n->a()); auto b = BuildHelper(add_n->b()); return (a + b) * 1; } else if (sub_n) { return (BuildHelper(sub_n->a()) - BuildHelper(sub_n->b())); } else if (mul_n) { return (BuildHelper(mul_n->a()) * BuildHelper(mul_n->b())); } else if (div_n) { return (BuildHelper(div_n->a()) / BuildHelper(div_n->b())); } else if (frac_n) { return (BuildHelper(frac_n->a()) / BuildHelper(frac_n->b())); } else if (minus_n) { return -BuildHelper(minus_n->v()); } else { CINN_NOT_IMPLEMENTED } } GiNaC::ex ExprToGinacConverter::operator()(Expr expr) { // TODO(Superjomn) Replace this with common::IsPureMath( auto complex_nodes = CollectIRNodes(expr, [](const Expr* n) { return n->As() || // n->As() || // n->As() || // n->As() || // n->As() || // n->As() || // n->As() || // n->As() || // n->As() || // n->As() || // n->As() || // n->As() || // n->As() || // n->As