/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "poly/scop_builder.h" #include #include #include "pass/utils.h" #include "construct_poly_accesses.h" namespace akg { namespace ir { namespace poly { // Note: We can only handle empty param_space for now. isl::space CreateParamsSpace(const isl::ctx &ctx) { return isl::space(ctx, 0); } isl::space CreateParamsSpace(const isl::ctx &ctx, const std::unordered_map ¶ms) { auto space = isl::space(ctx, 0); // set parameter names for (auto it = params.begin(); it != params.end(); ++it) { space = space.add_param(isl::id(ctx, it->second->name_hint)); } return space; } isl::aff Int2Aff(const isl::space &s, int64_t v) { return isl::aff(isl::local_space(s), isl::val(s.ctx(), v)); } template inline std::vector ConcatAffs(const isl::space &space, T *op, bool allow_min, bool allow_max) { std::vector result; for (const auto &aff : Expr2AffBounds(space, op->a, allow_min, allow_max)) { result.push_back(aff); } for (const auto &aff : Expr2AffBounds(space, op->b, allow_min, allow_max)) { result.push_back(aff); } return result; } template inline std::vector UniteAffs(const isl::space &space, T *op, isl::aff (isl::aff::*unite)(isl::aff) const) { std::vector bounds_l = Expr2AffBounds(space, op->a, false, false); std::vector bounds_r = Expr2AffBounds(space, op->b, false, false); CHECK_LE(bounds_l.size(), 1u); CHECK_LE(bounds_r.size(), 1u); if (bounds_l.size() > 0 && bounds_r.size() > 0) { return {(bounds_l[0].*unite)(bounds_r[0])}; } return {}; } template bool ExprType(const Expr &e) { return (e.as() != nullptr); } std::vector Variable2AffBounds(const isl::space &space, const Variable *v, bool ignore_error) { isl::id id(space.ctx(), v->name_hint); if (space.has_param(id)) { return {isl::aff::param_on_domain(space, id)}; } CHECK(ignore_error) << "Can not find var: " << v->name_hint << " in isl::space: " << space << '\n'; return {}; } std::vector FloorDiv2AffBounds(const isl::space &space, const FloorDiv *f_div) { if (f_div->type.is_int() || f_div->type.is_uint()) { auto left = Expr2AffBounds(space, f_div->a, false, false); auto right = Expr2AffBounds(space, f_div->b, false, false); if (left.size() == 0 || right.size() == 0) { return {}; } return {(left[0].div)(right[0]).floor()}; } return UniteAffs(space, f_div, &isl::aff::div); } std::vector Div2AffBounds(const isl::space &space, const Div *div) { if (div->type.is_int() || div->type.is_uint()) { auto left = Expr2AffBounds(space, div->a, false, false); auto right = Expr2AffBounds(space, div->b, false, false); if (left.size() == 0 || right.size() == 0) { return {}; } return {(left[0].div)(right[0]).floor()}; } return UniteAffs(space, div, &isl::aff::div); } std::vector FloorMod2AffBounds(const isl::space &space, const FloorMod *f_mod, bool ignore_error) { auto left = Expr2AffBounds(space, f_mod->a, false, false); Expr right = f_mod->b; if (const int64_t *val = as_const_int(right)) { isl::val v = isl::val(space.ctx(), *val); return {left[0].mod(v)}; } CHECK(ignore_error) << "Mod's denominator is not a const_int\n"; return {}; } std::vector Mod2AffBounds(const isl::space &space, const Mod *mod, bool ignore_error) { auto left = Expr2AffBounds(space, mod->a, false, false); Expr right = mod->b; if (const int64_t *val = as_const_int(right)) { isl::val v = isl::val(space.ctx(), *val); return {left[0].mod(v)}; } CHECK(ignore_error) << "Mod's denominator is not a const_int \n"; return {}; } std::vector Select2AffBounds(const isl::space &space, const Select *sel) { /************************************** * Support Select expression aff bounds computation * select((15 < int32(ceil((float32(w)*5.000000f)))), 15, int32(ceil((float32(w)*5.000000f)))) **************************************/ auto true_aff_bounds = Expr2AffBounds(space, sel->true_value, false, false); auto false_aff_bounds = Expr2AffBounds(space, sel->false_value, false, false); if (true_aff_bounds.size() == 0 || false_aff_bounds.size() == 0) { return {}; } /******************************************************** * temp method just add true_value aff and false_value aff *******************************************************/ return {(true_aff_bounds[0].add)(false_aff_bounds[0])}; } std::vector Expr2AffBounds(const isl::space &space, const Expr &e, bool allow_min, bool allow_max, bool ignore_error) { CHECK(!(allow_min && allow_max)); if (ExprType(e)) { return Variable2AffBounds(space, e.as(), ignore_error); } else if (const int64_t *i = as_const_int(e)) { return {Int2Aff(space, *i)}; } else if (ExprType(e)) { return {Int2Aff(space, int64_t(e.as()->value))}; } else if (ExprType(e)) { return Expr2AffBounds(space, e.as()->value, false, false); } else if (const auto call = e.as()) { if ((call->name == "floor" || call->name == "ceil") && call->args.size() == 1) { return Expr2AffBounds(space, call->args[0], false, false); } LOG(INFO) << "not parse call type: " << call->name << " with expr :" << e; } else if (ExprType(e)) { if (!allow_min) return {}; return ConcatAffs(space, e.as(), allow_min, allow_max); } else if (ExprType(e)) { if (!allow_max) return {}; return ConcatAffs(space, e.as(), allow_min, allow_max); } else if (ExprType(e)) { return UniteAffs(space, e.as(), &isl::aff::add); } else if (ExprType(e)) { return UniteAffs(space, e.as(), &isl::aff::sub); } else if (ExprType(e)) { return UniteAffs(space, e.as(), &isl::aff::mul); } else if (ExprType(e)) { return FloorDiv2AffBounds(space, e.as()); } else if (ExprType
(e)) { return Div2AffBounds(space, e.as
()); } else if (ExprType(e)) { return FloorMod2AffBounds(space, e.as(), ignore_error); } else if (ExprType(e)) { return Mod2AffBounds(space, e.as(), ignore_error); } else if (ExprType()); } CHECK(ignore_error) << "Expr2AffBounds " << e << "\n"; return {}; } std::vector Expr2AffChecked(const isl::space &space, const Expr &e, bool allow_min, bool allow_max) { bool ignore_error = false; return Expr2AffBounds(space, e, allow_min, allow_max, ignore_error); } isl::aff Expr2Aff(const isl::space &space, const Expr &e) { auto list = Expr2AffChecked(space, e, false, false); return list.empty() ? isl::aff() : list[0]; } isl::multi_id CollectTensorCoordinate(const isl::space &pspace, const isl::id &id, size_t dim) { isl::id_list args(pspace.ctx(), 0); for (size_t i = 0; i < dim; ++i) { auto name = std::string("arg") + std::to_string(i); args = args.add(isl::id(pspace.ctx(), name)); } return isl::multi_id(pspace.add_named_tuple_id_ui(id, static_cast(dim)), args); } isl::map AddSuffix4Accesses(AccessMap &accesses, const isl::map &in_map, const Node *op, const isl::ctx &ctx) { auto tensor_map = in_map; // Based on different condition, add suffix to the domain space std::string suffix; if (accesses.count(op) > 0) { // reuse existing tag if the op is accessed previously suffix = accesses[op].to_str(); } else { // create a new tag with unique name suffix = "__poly_ref_" + std::to_string(accesses.size()); } isl::id suffix_id(ctx, suffix); if (accesses.count(op) == 0) { // only insert, not replace accesses.emplace(op, suffix_id); } auto domain_space = tensor_map.get_space().domain(); auto tag_space = domain_space.params().add_named_tuple_id_ui(suffix_id, 0); domain_space = domain_space.product(tag_space).unwrap(); tensor_map = tensor_map.preimage_domain(isl::multi_aff::domain_map(domain_space)); return tensor_map; } isl::union_pw_aff GetUnionPwAffAtDomain(const isl::aff &f, const isl::union_set &domain, const OperatorDomainMap &map) { auto upa = isl::union_pw_aff::empty(domain.space()); for (auto set : domain.get_set_list()) { upa = upa.union_add(isl::union_pw_aff(f.unbind_params_insert_domain(map.at(set.tuple_id()).tuple))); } return upa; } static const char kStatementLabel[] = "S_"; bool ParseWithStmt(const Expr &s, const AnalysisResult &result) { class ParseWith final : public IRVisitor { public: void Visit_(const Call *op) final { if (!find_tensor && (0 != writes.size())) { if (op->call_type == Call::Halide) { if (writes.find(op->name) != writes.end()) { find_tensor = true; } } } IRVisitor::Visit_(op); } bool find_tensor{false}; std::unordered_set writes; bool GetResult() const { return find_tensor; } ParseWith(const Expr &stmt, const AnalysisResult &result) { result.GetWrites().foreach_map([&, this](const isl::map m) -> void { writes.insert(m.get_tuple_id(isl_dim_out).get_name()); return; }); IRVisitor::Visit(stmt); } ~ParseWith() override = default; } paserWith(s, result); return paserWith.GetResult(); } std::map call_op_ = { {"log", PolyOpType::elewise_single_log}, {"exp", PolyOpType::elewise_single_exp}, {"sqrt", PolyOpType::elewise_single_sqrt}, {"rsqrt", PolyOpType::elewise_single_rsqrt}, {"fabs", PolyOpType::elewise_single_fabs}, {"rec", PolyOpType::elewise_single_rec}, {"floor", PolyOpType::vec_single_floor}, {"round", PolyOpType::vec_single_round}, {"ceil", PolyOpType::elewise_single_ceil}, {"trunc", PolyOpType::vec_single_trunc}, {"not", PolyOpType::elewise_single_not}, {"relu", PolyOpType::elewise_single_relu}, {"EQ", PolyOpType::elewise_binary_EQ}, {"NE", PolyOpType::elewise_binary_NE}, {"GT", PolyOpType::elewise_binary_GT}, {"GE", PolyOpType::elewise_binary_GE}, {"LT", PolyOpType::elewise_binary_LT}, {"LE", PolyOpType::elewise_binary_LE}, {"fargmax", PolyOpType::vec_argmax}, {"fargmin", PolyOpType::vec_argmin}, {"four2five_nchw", PolyOpType::four2five_nchw}, {"vand", PolyOpType::elewise_binary_and}, {"bitwise_and", PolyOpType::elewise_binary_bitwise_and}, {"bitwise_or", PolyOpType::elewise_binary_bitwise_or}, {"bitwise_not", PolyOpType::elewise_single_bitwise_not}, {"proposal_sort", PolyOpType::elewise_binary_proposal_sort}, {"topk_sort", PolyOpType::elewise_binary_topk_sort}, {"nms", PolyOpType::elewise_binary_nms}, {"dropout", PolyOpType::elewise_binary_dropout}, {"iou", PolyOpType::elewise_binary_iou}, {"vmadd", PolyOpType::vmadd}, {"vmaddrelu", PolyOpType::vmaddrelu}, {"vaxpy", PolyOpType::vaxpy}, {"vmla", PolyOpType::vmla}, }; void ParseStmtOpCall(const isl::id &id, const Call *call, AnalysisResult &result, const FunctionRef &func) { CHECK(call); if (call->call_type == Call::PureIntrinsic) { if (call_op_.count(call->name) > 0) { result.GetStmtOpInfoMap().at(id).ops.push_back(call_op_[call->name]); } else if (0 == strcmp(call->name.c_str(), "with")) { result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::with); if (!result.GetStmtOpInfoMap().at(id).isWith) { for (unsigned i = 0; i < call->args.size(); ++i) { if (ParseWithStmt(call->args[i], result)) { result.GetStmtOpInfoMap().at(id).isWith = true; break; } } } } else if (0 == strcmp(call->name.c_str(), "reshape")) { // do nothing } else if (0 == strcmp(call->name.c_str(), "transpose")) { // do nothing } else if (0 == strcmp(call->name.c_str(), "divide_var")) { // do nothing } else if (0 == strcmp(call->name.c_str(), "sub_relu")) { // do nothing } else if (0 == strcmp(call->name.c_str(), "load3d_l1_ub")) { result.GetStmtOpInfoMap().at(id).isLoad3d = true; ParseStmtOps(id, call->args[0], result, func); } else if (0 == strcmp(call->name.c_str(), "mad")) { result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::mad); result.GetStmtOpInfoMap().at(id).isCube = true; // assign + mad std::string name = id.get_name(); size_t index = static_cast(WrappedStrtol(name.substr(name.length() - 1))); std::string tmp = name.substr(0, name.length() - 1); std::stringstream ss; ss << tmp << index - 1; if (result.GetStmtOpInfoMap().count(isl::id(id.ctx(), ss.str())) > 0 && result.GetStmtOpInfoMap().at(isl::id(id.ctx(), ss.str())).ops[0] == PolyOpType::broadcast) result.GetStmtOpInfoMap().at(isl::id(id.ctx(), ss.str())).isCubeAssign = true; // end result.GetStmtOpInfoMap().at(id).C_ = func->func_name(); CHECK(call->args.size() == 2) << "invalid args of mad! "; auto mul_arg = call->args[0].as() ? call->args[0].as() : call->args[1].as(); if (call->args[1].as()) { CHECK(call->args[1].as()->value.as()); mul_arg = call->args[1].as()->value.as(); } CHECK(mul_arg); auto a = mul_arg->a.as(); auto b = mul_arg->b.as(); // in gemm case, C = mad(C, A * B) if (a && b) { result.GetStmtOpInfoMap().at(id).A_ = a->name; result.GetStmtOpInfoMap().at(id).B_ = b->name; } // in conv case, reassign A&B by attr if (func.as() != nullptr) { result.GetStmtOpInfoMap().at(id).MadType_ = call->args[1].as() ? call->args[1].as()->type : Float(16); for (auto i : func.as()->attrs) { if ("feature" == i.first) { result.GetStmtOpInfoMap().at(id).A_ = i.second.as()->value; } if ("filter" == i.first) { result.GetStmtOpInfoMap().at(id).B_ = i.second.as()->value; } } } } else { LOG(FATAL) << "Unknown pure intrinsic: " << call->name.c_str() << std::endl; } } } void ParseStmtOps(const isl::id &id, const Expr &val, AnalysisResult &result, const FunctionRef &func) { result.GetStmtOpInfoMap().at(id).isCube = false; result.GetStmtOpInfoMap().at(id).isCubeAssign = false; if (auto add = val.as()) { if (isImm(add->a) || isImm(add->b)) { if (!isImm(add->a)) { // if add->a is not a scalar, then put it into recursion ParseStmtOps(id, add->a, result, func); } else if (!isImm(add->b)) { // if add->b is not a scalar, then put it into recursion ParseStmtOps(id, add->b, result, func); } else { // if add->a and add->b are both scalar, then report error LOG(FATAL) << "Error: Scalar + Scalar, Please Check."; } result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_single_VS_add); } else { ParseStmtOps(id, add->a, result, func); ParseStmtOps(id, add->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_add); } } else if (auto sub = val.as()) { ParseStmtOps(id, sub->a, result, func); ParseStmtOps(id, sub->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_sub); } else if (auto mul = val.as()) { if (isImm(mul->a) || isImm(mul->b)) { // if mul->a is not a scalar, then put it into recursion if (!isImm(mul->a)) { ParseStmtOps(id, mul->a, result, func); } else if (!isImm(mul->b)) { // if mul->b is not a scalar, then put it into recursion ParseStmtOps(id, mul->b, result, func); } else { // if mul->a and mul->b are both scalar, then report error LOG(FATAL) << "Error: Scalar + Scalar, Please Check."; } if (isZero(mul->b) || isZero(mul->a)) { result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::broadcast); } else { result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_single_VS_mul); } } else { ParseStmtOps(id, mul->a, result, func); ParseStmtOps(id, mul->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_mul); } } else if (auto f_div = val.as()) { ParseStmtOps(id, f_div->a, result, func); ParseStmtOps(id, f_div->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_div); } else if (auto f_mod = val.as()) { ParseStmtOps(id, f_mod->a, result, func); ParseStmtOps(id, f_mod->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_mod); } else if (auto div = val.as
()) { ParseStmtOps(id, div->a, result, func); ParseStmtOps(id, div->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_div); } else if (auto mod = val.as()) { ParseStmtOps(id, mod->a, result, func); ParseStmtOps(id, mod->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_mod); } else if (auto and_op = val.as()) { ParseStmtOps(id, and_op->a, result, func); ParseStmtOps(id, and_op->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_and); } else if (auto or_op = val.as()) { ParseStmtOps(id, or_op->a, result, func); ParseStmtOps(id, or_op->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_or); } else if (auto min = val.as()) { ParseStmtOps(id, min->a, result, func); ParseStmtOps(id, min->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_min); } else if (auto max = val.as()) { ParseStmtOps(id, max->a, result, func); ParseStmtOps(id, max->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::elewise_binary_max); } else if (auto ge = val.as()) { ParseStmtOps(id, ge->a, result, func); ParseStmtOps(id, ge->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::pandora_cmp); } else if (auto gt = val.as()) { ParseStmtOps(id, gt->a, result, func); ParseStmtOps(id, gt->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::pandora_cmp); } else if (auto le = val.as()) { ParseStmtOps(id, le->a, result, func); ParseStmtOps(id, le->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::pandora_cmp); } else if (auto lt = val.as()) { ParseStmtOps(id, lt->a, result, func); ParseStmtOps(id, lt->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::pandora_cmp); } else if (auto eq = val.as()) { ParseStmtOps(id, eq->a, result, func); ParseStmtOps(id, eq->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::pandora_cmp); } else if (auto ne = val.as()) { ParseStmtOps(id, ne->a, result, func); ParseStmtOps(id, ne->b, result, func); result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::pandora_cmp); } else if ((isImm(val) || val.type().is_int()) && val.as() == nullptr) { result.GetStmtOpInfoMap().at(id).ops.push_back(PolyOpType::broadcast); } else if (auto sel = val.as