diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index e18b91e5f13f47f8e9e39a0809c5cc423768c006..70e36467df34e559bfe5ac11b4bc5b62cf4718d8 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -115,6 +115,9 @@ REGISTER_PASS(AnalyzeMinAlignStatic); REGISTER_PASS(AnalyzeMinAlignDynamic); REGISTER_PASS(RewriteBroadcastVector); REGISTER_PASS(OptimizePragma); +REGISTER_PASS(PackStore); +REGISTER_PASS(RecoverStore); +REGISTER_PASS(MergeLoops); REGISTER_PASS(ExpandC0); REGISTER_PASS(ForEliminate); REGISTER_PASS(FixLoopExtent); diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 9097835b1df3d6b52b56e9717811e07d799d4af3..15030ab426a94c898c2b3ffcbd4920cc7309e624 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -738,16 +738,19 @@ NodeRef Lower(Schedule sch, const Array &in_args, const Array if (global_attrs.GetBoolAttr(kDeadCodeElim, false)) { stmt = NEXT_PASS(DeadCodeElim, stmt); } - if (!is_dynamic) { - stmt = NEXT_PASS(RewriteBroadcastVector, stmt); - stmt = NEXT_PASS(OptimizePragma, stmt); - } + if (is_dynamic) { stmt = NEXT_PASS(AnalyzeMinAlignDynamic, stmt, global_attrs.GetIntAttr(kEnableConvAnalyzeAlign, true), - global_attrs.GetIntAttr(kEnableScalarAlign, false)); + global_attrs.GetIntAttr(kEnableScalarAlign, false)); } else { + stmt = NEXT_PASS(RewriteBroadcastVector, stmt); + stmt = NEXT_PASS(OptimizePragma, stmt); + stmt = NEXT_PASS(MergeLoops, stmt, false); + stmt = NEXT_PASS(PackStore, stmt); stmt = NEXT_PASS(AnalyzeMinAlignStatic, stmt); + stmt = NEXT_PASS(RecoverStore, stmt); } + stmt = NEXT_PASS(MultiLastAxisReductions, stmt, is_dynamic); stmt = NEXT_PASS(AutoReorder, stmt); if (enable_multicore != 0) { diff --git a/src/emit_insn/insn_builder.h b/src/emit_insn/insn_builder.h index 6dd8ff7cdff77c851e9e18c96f1ba69e1d84f29d..a1dd84369ea7f1b01f87f26d96bee13ff34c0d6f 100644 --- a/src/emit_insn/insn_builder.h +++ b/src/emit_insn/insn_builder.h @@ -25,6 +25,9 @@ #include "insn_info.h" #include "cce_params.h" namespace akg { + +enum SingleType {SIMD, Tensor_Scalar, Vector_Dump}; + struct MutableMaskParams { Var mask_var_; Expr loop_var_; @@ -239,8 +242,11 @@ class VectorInsnBuilder : public InsnBuilder { class SingleVecInsnBuilder : public VectorInsnBuilder { public: SingleVecInsnBuilder(const StmtStoreInfo &dst, const StmtStoreInfo &src, const ArgInfo &args, - const std::string &intrin_name, const Buffer &tmp_buf = Buffer()) - : VectorInsnBuilder(dst, {src}, args, intrin_name), src_info_(src_info_list_[0]), tmp_buffer_(tmp_buf) { + const std::string &intrin_name, const Expr &scalar_src = Expr(), + const SingleType insn_type = SingleType::SIMD) + : VectorInsnBuilder(dst, {src}, args, intrin_name), + src_info_(src_info_list_[0]), + scalar_src_(scalar_src), insn_type_(insn_type) { CHECK(src_info_.defined()); } ~SingleVecInsnBuilder() override = default; @@ -254,8 +260,10 @@ class SingleVecInsnBuilder : public VectorInsnBuilder { Stmt CreateBroadcast(const VectorArgInfo &arg_info, const Var &local_var, Stmt stmt); StmtStoreInfo src_info_; - Buffer tmp_buffer_; Buffer broadcast_buffer_; + Expr scalar_src_; + SingleType insn_type_; // 0 simd : 1 vector_scalar : 2 vector_dup + }; class MultiVecInsnBuilder : public VectorInsnBuilder { diff --git a/src/emit_insn/insn_builder_vector.cc b/src/emit_insn/insn_builder_vector.cc index 5d653a03cf02c318f59faff6dd88b10a163ca306..fecc561d46067f9d0fd716c2efbf28b4ca55c0bb 100644 --- a/src/emit_insn/insn_builder_vector.cc +++ b/src/emit_insn/insn_builder_vector.cc @@ -92,9 +92,6 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) { Expr dst_offset = dst_info_->insn_offset_; Expr src_offset = src_info_->insn_offset_; - Var local_var = Var("broadcast_for_vec_local_UB", Handle()); - stmt = CreateBroadcast(arg_info, local_var, stmt); - // Handle stride_m1 loop of single vector intrin, if stride_m1 > 255, it will be separated if (dst_stride_m1 >= MAX_STRIDE_M1 || src_stride_m1 >= MAX_STRIDE_M1) { auto var = Var("repeatStrideM1Idx"); @@ -112,14 +109,6 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) { } } - if (!dst_info_->var_.empty() && src_info_->var_.empty() && intrin_name_ != INTRIN_NAME_VECTOR_DUP) { - // need to broadcast src first - stmt = Allocate::make(local_var, src_info_->dtype_, {Expr(src_block_size * FULL_BLOCK_NUM)}, const_true(), stmt); - if (!src_info_->scope_.empty()) { - stmt = AttrStmt::make(local_var, STORAGE_SCOPE, StringImm::make(src_info_->scope_), stmt); - } - } - CHECK(stmt.defined()) << "Error: Stmt is undefined!"; return stmt; @@ -131,70 +120,36 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) { /// \return Stmt SingleVecInsnBuilder::EmitIntrinBody(const VectorArgInfo &arg_info, const Map &args) { Stmt body; - CHECK(!arg_info->src_stride_m0_list_.empty()); CHECK(!arg_info->src_stride_m1_list_.empty()); - - auto dst_buffer_id = GenBufferId(dst_info_); - auto src_buffer_id = GenBufferId(src_info_); - Expr repeat = args["repeat"]; + auto dst_buffer_id = GenBufferId(dst_info_); Expr dst_offset = Sub::make(args["dstOffset"], arg_info->block_offset_); - Expr src_offset = args["srcOffset"]; - Expr src_stride_m1 = arg_info->src_stride_m1_list_[0]; - auto dst = GetAccessPtr(dst_buffer_id, "w", dst_offset); - auto src = GetAccessPtr(src_buffer_id, "r", src_offset); - if (broadcast_buffer_.defined()) { - src_stride_m1 = 0; - src = GetAccessPtr(broadcast_buffer_, "r", Expr(0)); + Array insn_args {}; + if (insn_type_ == SingleType::Vector_Dump) { + insn_args = {dst, scalar_src_, repeat}; + } else { + auto src_buffer_id = GenBufferId(src_info_); + Expr src_offset = args["srcOffset"]; + auto src = GetAccessPtr(src_buffer_id, "r", src_offset); + if (insn_type_ == SingleType::SIMD) { + insn_args = {dst, src, repeat}; + } else if (insn_type_ == SingleType::Tensor_Scalar) { + insn_args = {dst, src, scalar_src_, repeat}; + } else { + CHECK(0) << "\nUnknown insn_type_\n"; + } } Array stride_args = {arg_info->dst_stride_m0_, arg_info->src_stride_m0_list_[0], arg_info->dst_stride_m1_, - src_stride_m1}; - Array insn_args = {dst, src, repeat}; - if (arg_info->scalar_.defined()) { - auto scalar = arg_info->scalar_; - if (tmp_buffer_.defined()) { - dst = GetAccessPtr(tmp_buffer_, "w", dst_offset); - } - - insn_args = {dst, scalar, repeat}; - - if (intrin_name_ != INTRIN_NAME_VECTOR_DUP) { - Insert(insn_args, 1, src); - } - } + arg_info->src_stride_m1_list_[0]}; insn_args = MergeTwo(insn_args, stride_args); body = EmitCceIntrinTemplate(Stmt(), dst.type(), insn_args, intrin_name_); - return body; } -/// Create broadcast intrin if src is scalar -/// \param arg_info -/// \param local_var -/// \param stmt -/// \return -Stmt SingleVecInsnBuilder::CreateBroadcast(const VectorArgInfo &arg_info, const Var &local_var, Stmt stmt) { - if (!dst_info_->var_.empty() && src_info_->var_.empty() && intrin_name_ != INTRIN_NAME_VECTOR_DUP) { - // need to broadcast src first - auto src_block_size = GetUbBlkSize(src_info_->dtype_); - broadcast_buffer_ = BufferNode::make(local_var, src_info_->dtype_, {Expr(src_block_size * FULL_BLOCK_NUM)}, {}, - src_info_->elem_offset_, "broadcast_for_vec_local_UB", src_info_->scope_, - src_info_->data_alignment_, 1, BufferType::kDefault); - auto broad_dst = GetAccessPtr(broadcast_buffer_, "w", 0); - Array args = { - broad_dst, GenBufferId(src_info_).vload({Expr(0)}, src_info_->dtype_), Expr(1), Expr(1), Expr(1), Expr(0), - Expr(0)}; - stmt = EmitSetVecMaskIntrin(stmt, src_info_->dtype_, GetAllMask(src_info_->dtype_)); - stmt = InsertBody(stmt, EmitCceIntrinTemplate(Stmt(), src_info_->dtype_, args, INTRIN_NAME_VECTOR_DUP)); - stmt = EmitSetVecMaskIntrin(stmt, dst_info_->dtype_, arg_info->vec_mask_); - } - - return stmt; -} /// if repeat-size > cce_max_repeat, then split it into loop as "Davinci ISA User Guide t6.3 (8.2.2)" mentioned /// max_cce_repeat = 255, considering params are about 2 cycles, set it to be 255 // 2 = 127 @@ -1250,8 +1205,10 @@ Stmt EmitCceBinaryVectorToReduceLastAxis(const StmtStoreInfo &dst_info, const St auto vec_dup_arg_info = GenReduceHelperArgInfo(vec_dup_dst_info, for_extent, scalar, "VecDup"); + vec_dup_dst_info.GetNode()->data_ = final_var; + vec_dup_dst_info.GetNode()->name_ = final_var->name_hint; SingleVecInsnBuilder single_vec_builder = SingleVecInsnBuilder(vec_dup_dst_info, vec_dup_dst_info, vec_dup_arg_info, - INTRIN_NAME_VECTOR_DUP, final_dst_buffer); + INTRIN_NAME_VECTOR_DUP, scalar, SingleType::Vector_Dump); auto insn_list = single_vec_builder.EmitIntrin(); auto stmt = std::accumulate(insn_list.begin(), insn_list.end(), Stmt(), [](const Stmt &s0, const Stmt &s1) { return InsertBody(s0, s1); }); diff --git a/src/emit_insn/insn_emitter.cc b/src/emit_insn/insn_emitter.cc index 89684be511c1e289f846d38527d1ad22ed5aa7bf..7c0067bedddaf9c64a9a3cea6c71fe0faa1e208b 100644 --- a/src/emit_insn/insn_emitter.cc +++ b/src/emit_insn/insn_emitter.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "emit_insn/insn_emitter.h" +#include "insn_emitter.h" #include #include @@ -53,145 +53,68 @@ std::vector SortIndexes(const std::vector &v) { /// \param intrin_name - The CCE intrin name /// \param broadcast_last_axis - Tag of broadcast_last_axis mode /// \return Stmt of emitted CCE intrin -Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name, bool broadcast_last_axis = false) { +Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) { CHECK(op); Stmt result; - // optimization of copy_ubuf_to_ubuf - bool is_dma_opt = false; - if (intrin_name == INTRIN_NAME_COPY_UB_TO_UB) { - CommentManager::GetInstance().AddComment("Insn_type", "dma_copy"); - CommentManager::GetInstance().AddComment("Insn_name", INTRIN_NAME_COPY_UB_TO_UB); - CommentManager::GetInstance().AddComment("Vadds_replace_copy", "enable"); - intrin_name = "vadds"; - is_dma_opt = true; - } else { - CommentManager::GetInstance().AddComment("Insn_type", "single_vector"); - CommentManager::GetInstance().AddComment("Insn_name", intrin_name); - } + + CommentManager::GetInstance().AddComment("Insn_type", "single_vector"); + CommentManager::GetInstance().AddComment("Insn_name", intrin_name); StmtInfoList dst_info_list; StmtInfoList src_info_list; - StmtStoreInfo scalar_info; StmtInfo for_info; StmtInfo if_info; - std::string mode = GetSingleVecComputationInfo(op, intrin_name, dst_info_list, src_info_list, if_info, for_info); + + bool same_dtype = intrin_name.find("vconv_") == std::string::npos; + GetCompactComputationInfo(op, dst_info_list, src_info_list, if_info, for_info, same_dtype, true); CHECK(!dst_info_list.empty()); - if (broadcast_last_axis) { - mode = "broadcast_last_axis"; - // In this case, must come from binary vec, so must have two src - CHECK(src_info_list.size() >= 2) << "Broadcast last axis mode must have at least two srcs."; - if (!IsTwoItemEqual(src_info_list[0]->var_, dst_info_list[0]->var_, -1)) { - scalar_info = src_info_list[0]; - src_info_list.Set(0, src_info_list[1]); - } else if (!IsTwoItemEqual(src_info_list[1]->var_, dst_info_list[0]->var_, -1)) { - scalar_info = src_info_list[1]; - } - } else { - if (mode == "broadcast" && !src_info_list.empty() && dst_info_list.size() == 1) { - if (!IsTwoItemEqual(src_info_list[0]->var_, dst_info_list[0]->var_, -1)) { - mode = "broadcast_last_axis"; + Array call_args; + int call_cnt = 0; + if (intrin_name == "vector_dup" || intrin_name == "vadds" || + intrin_name == "vmuls" || intrin_name == "vaxpy") { + auto GetCallInfo = [&intrin_name, &call_args, &call_cnt](const NodeRef &op) { + if (op.as() && op.as()->name == intrin_name) { + call_args = op.as()->args; + call_cnt = call_cnt + 1; } - if (src_info_list.size() > 1) { - if (!IsTwoItemEqual(src_info_list[1]->var_, dst_info_list[0]->var_, -1)) { - mode = "broadcast_last_axis"; - } else { - scalar_info = src_info_list[0]; - src_info_list.Set(0, src_info_list[1]); - } - } - } - } - - if (broadcast_last_axis) { - mode = "broadcast_last_axis"; + }; + PostOrderVisit(op, GetCallInfo); + CHECK_EQ(call_cnt, 1); } - - if (intrin_name == INTRIN_NAME_VECTOR_DUP) { - auto dst_info = dst_info_list[0]; - if (dst_info->var_.size() > 1 && - GetIntConst(GetItem(dst_info->strides_, -1)) == GetIntConst(GetItem(dst_info->shape_, -1)) + 1) { - // diagnoal broadcast case - return op; - } - dst_info.CleanFlexVar(); + SingleType insn_type {SingleType::SIMD}; + Expr scalar_src {}; + if (intrin_name == "vector_dup") { + insn_type = SingleType::Vector_Dump; + src_info_list = {}; + scalar_src = call_args[0]; + } else if (intrin_name == "vadds" || intrin_name == "vmuls" || intrin_name == "vaxpy") { + insn_type = SingleType::Tensor_Scalar; + src_info_list = {src_info_list[0]}; + scalar_src = call_args[1]; } // check is single vector broadcast reduce mode exist - SingleVecPatternGenerator generator = SingleVecPatternGenerator(dst_info_list, src_info_list, for_info, mode); + SingleVecPatternGenerator generator = SingleVecPatternGenerator(dst_info_list, src_info_list, for_info); auto params = generator.GetInsnArgs(); dst_info_list = params.dst_info_list; src_info_list = params.src_info_list; for_info = params.for_info; ArgInfo arg_info = params.arg_info; - CommentManager::GetInstance().AddComment("Compute_type", mode); + CommentManager::GetInstance().AddComment("Compute_type", intrin_name); CommentManager::GetInstance().AddComment("Pattern", arg_info.GetPattern()); - if (intrin_name == "vadds" || intrin_name == "vmuls" || intrin_name == INTRIN_NAME_VECTOR_DUP) { - auto stores = GetStores(op); - auto store = stores[0].as(); - auto scalar = Expr(0); - if (intrin_name == "vadds" || intrin_name == "vmuls") { - if (!dst_info_list.empty()) { - scalar = FloatImm::make(dst_info_list[0]->dtype_, 0.000000); - } - if (!dst_info_list[0]->dtype_.is_float()) { - return op; - } - if (!is_dma_opt) { - if (!scalar_info.defined()) { - auto children = GetBinaryOpExprChildren(store->value); - if (children.empty()) { - LOG(FATAL) << store->value << " is not binary op."; - } - scalar = children[1]; - } else { - scalar = Load::make(scalar_info->dtype_, scalar_info->data_, scalar_info->index_, Expr(1)); - } - } - } else if (intrin_name == INTRIN_NAME_VECTOR_DUP) { - if (store->value->IsInstance()) { - // scale is load - scalar = - Load::make(src_info_list[0]->dtype_, store->value.as()->buffer_var, src_info_list[0]->index_, Expr(1)); - } else { - // scale is imm - scalar = store->value; - } - } - - if (arg_info->body_arg_info_.defined()) { - arg_info->body_arg_info_.GetNode()->scalar_ = scalar; - } - if (arg_info->tail_arg_info_.defined()) { - arg_info->tail_arg_info_.GetNode()->scalar_ = scalar; - } - } - if (intrin_name == "vconv_deq") { result = InsertBody( result, Evaluate::make(Call::make(Float(16), "set_deqscale", {FloatImm::make(Float(16), 1.0)}, Call::Extern))); } SingleVecInsnBuilder single_vec_builder = - SingleVecInsnBuilder(dst_info_list[0], src_info_list[0], arg_info, intrin_name); + SingleVecInsnBuilder(dst_info_list[0], src_info_list[0], arg_info, intrin_name, scalar_src, insn_type); auto insn_list = single_vec_builder.EmitIntrin(); - - if (intrin_name == INTRIN_NAME_VECTOR_DUP && dst_info_list[0]->var_.empty()) { - Stmt store; - auto ScanStore = [&store](const NodeRef &op) { - const auto e = op.as(); - if (e != nullptr) { - store = Store::make(e->buffer_var, e->value, e->index, e->predicate); - } - }; - air::ir::PostOrderVisit(op, ScanStore); - store = EmitSetVecMaskIntrin(store, dst_info_list[0]->dtype_); - insn_list = {store}; - } - - return FoldInsnWithForInfo(insn_list, if_info, for_info, result); + auto ret = FoldInsnWithForInfo(insn_list, if_info, for_info, result); + return ret; } /// Function to emit binary vector intrin @@ -211,11 +134,6 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec CommentManager::GetInstance().AddComment("Insn_name", intrin_name); switch (arg_info->arg_type_) { - case ARG_VECTOR_BROADCAST_LAST_AXIS: { - CommentManager::GetInstance().CleanComments(); - intrin_name += "s"; - return SingleVecEmitter(op, intrin_name, true); - } case ARG_VECTOR_REDUCTION_LAST_AXIS: { CommentManager::GetInstance().AddComment("Compute_type", "reduce_last_axis"); auto dst_info = dst_info_list[0]; @@ -928,83 +846,8 @@ Stmt DmaMovEmitter(const Stmt &op, bool enable_cover_protect) { StmtInfo for_info; GetDmaComputationInfo(op, dst_info_list, src_info_list, if_info, for_info, dma_mode, intrin_name); - auto check_alignment = [](const Expr &align, const Array &shape) { - if (GetIntConst(align) == 1 || shape.size() == 1u) { - return true; - } - - if (shape.empty()) { - return false; - } - Expr sz = 1; - for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { - sz = sz * shape[i]; - if (GetIntConst(align) == GetIntConst(sz)) { - return true; - } - } - return false; - }; - const auto &dst_info = dst_info_list[0]; const auto &src_info = src_info_list[0]; - int block_size = GetUbBlkSize(dst_info->dtype_); - - // check scalar to scalar - // check if dst is considered as scalar - // check if src is considered as scalar - bool is_broadcast = - (dst_info->var_.empty() || (!dst_info->strides_.empty() && GetIntConst(GetItem(dst_info->strides_, -1)) != 1)) && - (src_info->var_.empty() || (!src_info->strides_.empty() && GetIntConst(GetItem(src_info->strides_, -1)) != 1)); - // check vector to vector, but in scalar dma mode - bool last_dim_equal = !dst_info->var_.empty() && !src_info->var_.empty() && !dst_info->strides_.empty() && - !src_info->strides_.empty() && - GetItem(dst_info->var_, -1).get() == GetItem(src_info->var_, -1).get() && - GetIntConst(GetItem(dst_info->strides_, -1)) != GetIntConst(GetItem(src_info->strides_, -1)); - bool broadcast_scalar = intrin_name == "broadcast" && is_broadcast; - bool ubuf_scalar = intrin_name == INTRIN_NAME_COPY_UB_TO_UB && (is_broadcast || last_dim_equal); - - if (broadcast_scalar || ubuf_scalar) { - int shape1 = GetInt32Const(GetItem(dst_info->shape_, -1)); - int stride1 = GetInt32Const(GetItem(dst_info->strides_, -1)); - if (ubuf_scalar && shape1 < block_size && stride1 == block_size && - IsTwoItemEqual(dst_info->strides_, src_info->strides_, -1, true) && src_info->dtype_.bits() != 64) { - // if last dim small than blocksize, then use vadds - return SingleVecEmitter(op, intrin_name); - } - CommentManager::GetInstance().AddComment("Insn_type", "dma_copy"); - CommentManager::GetInstance().AddComment("Insn_name", "scalar"); - if (src_info->var_.empty() && dst_info->var_.empty()) { - return op; - } else { - // check align - if (!check_alignment(dst_info->data_alignment_, dst_info->shape_)) { - return op; - } - Stmt base_stmt = EmitScalarDmaIntrinTemplate(op, src_info, dst_info); - return GenIfAndFor(base_stmt, if_info, for_info, false); - } - } - - if (intrin_name == "broadcast") { - return SingleVecEmitter(op, INTRIN_NAME_VECTOR_DUP); - } else if (intrin_name == INTRIN_NAME_COPY_UB_TO_UB) { - // Use vadds to optimize dma copy - if (if_info.vars_.empty() && dst_info->dtype_.is_float() && src_info->dtype_.is_float()) { - if ((dst_info->dtype_.bits() == 32 && src_info->dtype_.bits() == 32) || - (dst_info->dtype_.bits() == 16 && src_info->dtype_.bits() == 16)) { - int repeat_len = block_size * FULL_BLOCK_NUM; - CHECK_NE(block_size, 0); - int shape1 = GetInt32Const(GetItem(dst_info->shape_, -1)); - if ((shape1 >= repeat_len / 2 && shape1 <= repeat_len) || - (dst_info->shape_.size() >= 3 && shape1 <= block_size) || - (dst_info->shape_.size() >= 2 && shape1 % block_size == 0)) { - // if last dim shape is too small, there is no need to opt - return SingleVecEmitter(op, intrin_name); - } - } - } - } CommentManager::GetInstance().AddComment("Insn_type", "dma_copy"); @@ -1014,31 +857,10 @@ Stmt DmaMovEmitter(const Stmt &op, bool enable_cover_protect) { Map ub_copy_post; auto arg_info_map = GetDmaCopyInsnArgs(intrin_name, dst_info_list, src_info_list, for_info, ub_copy_pre, ub_copy_post); - if (intrin_name == "vtranspose_scalar") { - base_stmt = EmitScalarDmaIntrinTemplate(op, src_info, dst_info); - CommentManager::GetInstance().AddComment("Insn_name", "scalar"); - } else if (intrin_name == "vtranspose") { - Array args = {arg_info_map["loop_width"], arg_info_map["loop_height"], arg_info_map["shape_width"]}; - Array pre_ub_copy_args; - if (!ub_copy_pre.empty()) { - pre_ub_copy_args = Array( - {ub_copy_pre["nBurst"], ub_copy_pre["lenBurst"], ub_copy_pre["srcStride"], ub_copy_pre["dstStride"]}); - } - Array post_ub_copy_args; - if (!ub_copy_post.empty()) { - post_ub_copy_args = Array( - {ub_copy_post["nBurst"], ub_copy_post["lenBurst"], ub_copy_post["srcStride"], ub_copy_post["dstStride"]}); - } - TransposeInsnBuilder builder = - TransposeInsnBuilder(dst_info, src_info, args, pre_ub_copy_args, post_ub_copy_args); - base_stmt = builder.EmitSingleIntrin(); - CommentManager::GetInstance().AddComment("Insn_name", intrin_name); - } else { - DmaInsnBuilder dma_builder = - DmaInsnBuilder(dst_info, src_info, intrin_name, arg_info_map, false, false, enable_cover_protect); - base_stmt = dma_builder.EmitSingleIntrin(); - CommentManager::GetInstance().AddComment("Insn_name", intrin_name); - } + DmaInsnBuilder dma_builder = + DmaInsnBuilder(dst_info, src_info, intrin_name, arg_info_map, false, false, enable_cover_protect); + base_stmt = dma_builder.EmitSingleIntrin(); + CommentManager::GetInstance().AddComment("Insn_name", intrin_name); } else if (dma_mode == "cce_load") { auto arg_info_map = GetDmaLoad2DInsnArgs(intrin_name, dst_info_list, src_info_list, for_info); DmaInsnBuilder builder = DmaInsnBuilder(dst_info, src_info, intrin_name, arg_info_map, true); @@ -1104,6 +926,19 @@ Stmt DmaAtomicAddEmitter(const Stmt &op) { return stmt; } +Stmt VTransposeEmitter(const Stmt &op) { + StmtInfoList dst_info_list; + StmtInfoList src_info_list; + StmtInfo for_info; + StmtInfo if_info; + GetCompactComputationInfo(op, dst_info_list, src_info_list, if_info, for_info, true, true); + auto dst_buffer_id = GenBufferId(dst_info_list[0]); + auto src_buffer_id = GenBufferId(src_info_list[0]); + auto dst = GetAccessPtr(dst_buffer_id, "w", 0); + auto src = GetAccessPtr(src_buffer_id, "r", 0); + return Evaluate::make(Call::make(Float(16), "vtranspose", {dst, src}, Call::Extern)); +} + /// Function to emit dropout intrin /// \param op - The input stmt to be emitted as intrin /// \return Stmt of emitted CCE intrin @@ -1913,97 +1748,6 @@ Stmt ReduceCombineEmitter(const Stmt &op, bool enable_bisect) { Stmt InsnEmit(std::string insn_name, const Stmt &op, bool enable_bisect, bool enable_cover_protect, int comment_level) { CHECK(op.defined()); - static const std::map ReplaceAttrPragmaMap = { - // vector binary - {"binary_vcadd", "vec_binary_add"}, - {"vaxpy", "vec_binary_axpy"}, - // vector single - {"vec_single_fabs", "vec_single_abs"}, - {"broadcast", "vec_broadcast"}, - // cube - {"mad", "cube_mad"}, - {"ub2gm", "cube_ub2gm"}, - {"im2col", "cube_img2col"}, - // special attrs - {"vec_binary_proposal_sort", "vec_proposal_sort"}, - {"vec_binary_topk_sort", "vec_topk_sort"}, - {"vec_binary_dropout", "vec_dropout"}, - {"vec_binary_fargmax", "vec_argmax"}, - {"vec_binary_fargmin", "vec_argmin"}, - {"vec_binary_iou", "vec_iou"}, - {"vec_binary_nms", "vec_nms"}, - {"mask_broadcast", "vec_broadcast"}, - }; - - static const std::map BinaryVecInsnMap = { - // vadd.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vadd.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vadd.f32 support target:mini_v100 cloud_v100 - // vadd contains two situations: - // 1. normal elewise vector add - // - all src[i].shape = dst.shape - // 2. reductive vector add - // - exist src[i].shape != dst.shape - {"vec_binary_add", "vadd"}, - // vsub.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vsub.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vsub.f32 support target:mini_v100 cloud_v100 - {"vec_binary_sub", "vsub"}, - // vmul.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vmul.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vmul.f32 support target:mini_v100 cloud_v100 - {"vec_binary_mul", "vmul"}, - // vmin.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vmin.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vmin.f32 support target:mini_v100 cloud_v100 - {"vec_binary_min", "vmin"}, - // vmax.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vmax.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vmax.f32 support target:mini_v100 cloud_v100 - {"vec_binary_max", "vmax"}, - {"vec_binary_div", "vdiv"}, - {"vec_binary_and", "vand"}, - {"vec_binary_bitwise_and", "vand"}, - {"vec_binary_or", "vor"}, - {"vec_binary_bitwise_or", "vor"}, - {"vec_binary_vmadd", "vmadd"}, - {"vec_binary_vmaddrelu", "vmaddrelu"}, - {"vec_binary_vmla", "vmla"}}; - - static const std::map SingleVecInsnMap = { - // vmuls.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vmuls.f32 supporttarget:mini_v100 cloud_v100 - {"vec_single_muls", "vmuls"}, - // vadds.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vadds.f32 support target:mini_v100 cloud_v100 - {"vec_single_adds", "vadds"}, - // vrelu.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - {"vec_single_relu", "vrelu"}, - // vabs.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vabs.f32 support target:mini_v100 cloud_v100 - {"vec_single_abs", "vabs"}, - // vln.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vln.f32 support target:cloud_v100 - {"vec_single_log", "vln"}, - // vexp.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vexp.f32 support target:cloud_v100 - {"vec_single_exp", "vexp"}, - // vrec.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - // vrec.f32 support target:mini_v100 cloud_v100 - {"vec_single_rec", "vrec"}, - // vnot support target:mini_v100 tiny_v100 lite_v100 cloud_v100 - {"vec_single_not", "vnot"}, - {"vec_single_bitwise_not", "vnot"}, - // vsqrt support target:cloud_v100 - {"vec_single_sqrt", "vsqrt"}, - {"vec_single_rsqrt", "vrsqrt"}, - {"vec_broadcast", "vector_dup"}}; - - static const std::map SingleCastInsnMap = { - {"vec_single_floor", "f"}, {"vec_single_round", "r"}, {"vec_single_ceil", "c"}, {"vec_single_trunc", "z"}}; - - static const std::set ReturnOpInsnSet = {"scalar_dma", "scatter", "vec_binary_select_loop_var"}; - static const std::map> InsnFunctorMap = { {"dma_atomic_add", DmaAtomicAddEmitter}, {"vec_single_cast", SingleCastEmitter}, @@ -2017,9 +1761,9 @@ Stmt InsnEmit(std::string insn_name, const Stmt &op, bool enable_bisect, bool en {"vec_dropout", BinaryDropoutEmitter}, {"cube_mad", MadEmitter}, {"vec_select_scalar", SelectWithScalarEmitter}, - {"vec_binary_axpy", VaxpyEmitter}, {"opt_broadcast", MultiMaskEmitter}, - {"vec_single_four2five_nchw", VnchwconvEmitter}}; + {"vec_single_four2five_nchw", VnchwconvEmitter}, + {"vtranspose", VTransposeEmitter}}; if (ReplaceAttrPragmaMap.count(insn_name) != 0) { insn_name = ReplaceAttrPragmaMap.find(insn_name)->second; diff --git a/src/emit_insn/insn_emitter.h b/src/emit_insn/insn_emitter.h index 91dd5c7ea029d21d728d178ee14ffe098083b429..8e49d0e0194d4a80bb5a063e4f6ad5031709aa69 100644 --- a/src/emit_insn/insn_emitter.h +++ b/src/emit_insn/insn_emitter.h @@ -30,6 +30,100 @@ namespace akg { namespace ir { + static const std::map ReplaceAttrPragmaMap = { + // vector binary + {"binary_vcadd", "vec_binary_add"}, + // vector single + {"vec_single_fabs", "vec_single_abs"}, + {"broadcast", "vec_broadcast"}, + // cube + {"mad", "cube_mad"}, + {"ub2gm", "cube_ub2gm"}, + {"im2col", "cube_img2col"}, + // special attrs + {"vec_binary_proposal_sort", "vec_proposal_sort"}, + {"vec_binary_topk_sort", "vec_topk_sort"}, + {"vec_binary_dropout", "vec_dropout"}, + {"vec_binary_fargmax", "vec_argmax"}, + {"vec_binary_fargmin", "vec_argmin"}, + {"vec_binary_iou", "vec_iou"}, + {"vec_binary_nms", "vec_nms"}, + {"mask_broadcast", "vec_broadcast"}, + }; + + static const std::map BinaryVecInsnMap = { + // vadd.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vadd.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vadd.f32 support target:mini_v100 cloud_v100 + // vadd contains two situations: + // 1. normal elewise vector add + // - all src[i].shape = dst.shape + // 2. reductive vector add + // - exist src[i].shape != dst.shape + {"vec_binary_add", "vadd"}, + // vsub.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vsub.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vsub.f32 support target:mini_v100 cloud_v100 + {"vec_binary_sub", "vsub"}, + // vmul.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vmul.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vmul.f32 support target:mini_v100 cloud_v100 + {"vec_binary_mul", "vmul"}, + // vmin.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vmin.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vmin.f32 support target:mini_v100 cloud_v100 + {"vec_binary_min", "vmin"}, + // vmax.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vmax.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vmax.f32 support target:mini_v100 cloud_v100 + {"vec_binary_max", "vmax"}, + {"vec_binary_div", "vdiv"}, + {"vec_binary_and", "vand"}, + {"vec_binary_bitwise_and", "vand"}, + {"vec_binary_or", "vor"}, + {"vec_binary_bitwise_or", "vor"}, + {"vec_binary_vmadd", "vmadd"}, + {"vec_binary_vmaddrelu", "vmaddrelu"}, + {"vec_binary_vmla", "vmla"}}; + + static const std::map SingleVecInsnMap = { + // vmuls.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vmuls.f32 supporttarget:mini_v100 cloud_v100 + {"vec_single_muls", "vmuls"}, + // vadds.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vadds.f32 support target:mini_v100 cloud_v100 + {"vec_single_adds", "vadds"}, + // vrelu.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + {"vec_single_relu", "vrelu"}, + // vabs.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vabs.f32 support target:mini_v100 cloud_v100 + {"vec_single_abs", "vabs"}, + // vln.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vln.f32 support target:cloud_v100 + {"vec_single_log", "vln"}, + // vexp.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vexp.f32 support target:cloud_v100 + {"vec_single_exp", "vexp"}, + // vrec.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + // vrec.f32 support target:mini_v100 cloud_v100 + {"vec_single_rec", "vrec"}, + // vnot support target:mini_v100 tiny_v100 lite_v100 cloud_v100 + {"vec_single_not", "vnot"}, + {"vec_single_bitwise_not", "vnot"}, + // vsqrt support target:cloud_v100 + {"vec_single_sqrt", "vsqrt"}, + {"vec_single_rsqrt", "vrsqrt"}, + {"vaxpy", "vaxpy"}, + {"vec_broadcast", "vector_dup"}, + {"vadds", "vadds"}, + {"vmuls", "vmuls"}, + {"vector_dup", "vector_dup"}, + }; + + static const std::map SingleCastInsnMap = { + {"vec_single_floor", "f"}, {"vec_single_round", "r"}, {"vec_single_ceil", "c"}, {"vec_single_trunc", "z"}}; + + static const std::set ReturnOpInsnSet = {"scalar_calc", "scalar_dma", "scatter", "vec_binary_select_loop_var"}; Stmt EmitInsnWithDynamicShapes(const Stmt &s, const Map &extern_buffer); diff --git a/src/emit_insn/insn_info.cc b/src/emit_insn/insn_info.cc index 4d17e7dfc13b9acf07fc41e382b1ee4443609388..c11470fad6a41f309f449b1ac47ca5a64d97e350 100644 --- a/src/emit_insn/insn_info.cc +++ b/src/emit_insn/insn_info.cc @@ -935,7 +935,7 @@ void GetCompactComputationInfo(const Stmt &stmt, StmtInfoList &dst_info_list, St /// \param if_info - The if-condition as input /// \param for_info - The for-loop info to be modified void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_info_list, const StmtInfo &if_info, - StmtInfo &for_info) { + StmtInfo &for_info) { auto MergeTwoVar = [](const Var &keep_var, const Var &delete_var, StmtInfoList &dst_info_list, StmtInfoList &src_info_list, StmtInfo &for_info) { for (auto info : dst_info_list) { @@ -1059,8 +1059,7 @@ void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_i bool find_merge = false; for (size_t i = 0; (i < var_cnt - 1) && (!find_merge); i++) { for (size_t j = i + 1; j < var_cnt; j++) { - if (CanMergeTwoVar(for_info.vars_[i], for_info.vars_[j], dst_info_list, src_info_list, - for_info)) { + if (CanMergeTwoVar(for_info.vars_[i], for_info.vars_[j], dst_info_list, src_info_list, for_info)) { find_merge = true; break; } @@ -1075,7 +1074,6 @@ void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_i } } - /// A helper function for single dst_info's compact /// \param dst_info /// \param src_info_list @@ -1357,6 +1355,43 @@ int GetVectorizedVarPosition(const Expr &index, Array &loop_vars) { return pos; } +std::string GetOpType(const Expr &value) { + if (value.as()) { + return value.as()->_type_key; + } + if (value.as()) { + return value.as()->_type_key; + } + if (value.as()) { + return value.as()->_type_key; + } + if (value.as
()) { + return value.as
()->_type_key; + } + if (value.as()) { + return value.as()->_type_key; + } + if (value.as()) { + return value.as()->_type_key; + } + if (value.as()) { + return value.as()->_type_key; + } + if (value.as()) { + return value.as()->_type_key; + } + if (value.as()) { + return value.as()->_type_key; + } + if (value.as()) { + return value.as()->name; + } + if (value.as() || value.as() || value.as()) { + return "DMACopy"; + } + return "undefined"; +} + /// TVM Function Register, enable python code to call these cpp function. TVM_REGISTER_API("cce_util.GetCceAxis").set_body([](TVMArgs args, TVMRetValue *ret) { *ret = GetCceAxis(); }); diff --git a/src/emit_insn/insn_info.h b/src/emit_insn/insn_info.h index ccb299ecf0ab60b5b6c5075e9dfda6a1265c069d..5eca1814edecec281003899bf1998d8611490a74 100644 --- a/src/emit_insn/insn_info.h +++ b/src/emit_insn/insn_info.h @@ -49,13 +49,7 @@ enum ArgType { ARG_NOT_DEFINE }; -enum PatternType { - PATTERN_3D = 1, - PATTERN_PARTIAL_3D, - PATTERN_2D, - PATTERN_2D_BLOCK, - PATTERN_1D -}; +enum PatternType { PATTERN_3D = 1, PATTERN_PARTIAL_3D, PATTERN_2D, PATTERN_2D_BLOCK, PATTERN_1D }; class StmtStoreInfoNode : public Node { public: @@ -98,13 +92,9 @@ class StmtStoreInfo : public NodeRef { explicit StmtStoreInfo(const ObjectPtr &n) : NodeRef(n), node_(n) {} ~StmtStoreInfo() = default; - inline StmtStoreInfoNode *GetNode() const { - return static_cast(node_.get()); - } + inline StmtStoreInfoNode *GetNode() const { return static_cast(node_.get()); } - inline const StmtStoreInfoNode *operator->() const { - return static_cast(node_.get()); - } + inline const StmtStoreInfoNode *operator->() const { return static_cast(node_.get()); } void CleanFlexVar(); @@ -188,13 +178,9 @@ class VectorArgInfo : public NodeRef { explicit VectorArgInfo(const ObjectPtr &n) : NodeRef(n), node_(n) {} ~VectorArgInfo() = default; - inline VectorArgInfoNode *GetNode() const { - return static_cast(node_.get()); - } + inline VectorArgInfoNode *GetNode() const { return static_cast(node_.get()); } - inline const VectorArgInfoNode *operator->() const { - return static_cast(node_.get()); - } + inline const VectorArgInfoNode *operator->() const { return static_cast(node_.get()); } void Print() const { LOG(DEBUG) << "[ body_num: " << GetNode()->body_num_ << ", body_offset: " << GetNode()->body_offset_ @@ -235,13 +221,9 @@ class ArgInfo : public NodeRef { explicit ArgInfo(const ObjectPtr &n) : NodeRef(n), node_(n) {} ~ArgInfo() = default; - inline ArgInfoNode *GetNode() const { - return static_cast(node_.get()); - } + inline ArgInfoNode *GetNode() const { return static_cast(node_.get()); } - inline const ArgInfoNode *operator->() const { - return static_cast(node_.get()); - } + inline const ArgInfoNode *operator->() const { return static_cast(node_.get()); } inline std::string GetPattern() const { switch (GetNode()->pattern_) { @@ -373,6 +355,8 @@ bool IsBisectionReduction(const StmtInfoList &dst_info_list, const StmtInfoList bool HasVars(const Expr &index, const Var &vec_var); int GetVectorizedVarPosition(const Expr &index, Array &loop_vars); + +std::string GetOpType(const Expr &value); } // namespace akg namespace air { diff --git a/src/emit_insn/insn_pattern.h b/src/emit_insn/insn_pattern.h index dec4102e5ec3012b4ba4e5eb7e24d1d1445241f9..66b6c9ffe3c004c4fe5f6123cc5787f0f8f5edec 100644 --- a/src/emit_insn/insn_pattern.h +++ b/src/emit_insn/insn_pattern.h @@ -77,7 +77,7 @@ class PatternGenerator { class SingleVecPatternGenerator : public PatternGenerator { public: SingleVecPatternGenerator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, - const StmtInfo &for_info, const std::string &mode) + const StmtInfo &for_info, const std::string &mode = "elewise") : PatternGenerator(dst_info_list, for_info), arg_info(ArgInfo(make_node())), body_args(VectorArgInfo()), diff --git a/src/emit_insn/insn_with_variable.cc b/src/emit_insn/insn_with_variable.cc index 163cb29dbaa72564d5590e525a362133357b4b3c..01150ff3e3663bd4016ed177caf820effd210c70 100644 --- a/src/emit_insn/insn_with_variable.cc +++ b/src/emit_insn/insn_with_variable.cc @@ -33,9 +33,11 @@ #include "insn_info.h" #include "insn_pattern.h" #include "insn_emitter.h" +#include "ir_transform.h" namespace akg { namespace ir { + Expr GetVarCoefExpr(const Expr &index, const Var &loop_var) { Expr ret = Expr(); Array coefs = air::arith::DetectLinearEquation(index, {loop_var}); @@ -203,7 +205,7 @@ class HasScalarVarValue : public IRVisitor { class AdjustPragma : public IRMutator { public: Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { - if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn" && op->value.as()) { + if (op->attr_key == "pragma_emit_insn" && op->value.as()) { is_candidate_ = true; loop_vars_ = {}; loop_extends_ = {}; @@ -295,7 +297,7 @@ class AdjustPragma : public IRMutator { Array srcs = call_ptr->args; CHECK_EQ(srcs.size(), 2); is_argmax_min_ = true; - reduce_type_ = (op->value.as()->name == "fargmin") ? "arg_min" : "arg_max"; + reduce_type_ = (op->value.as()->name == "fargmin") ? "reduce_fargmin" : "reduce_fargmax"; return Store::make(op->buffer_var, Call::make(call_ptr->type, reduce_type_, {srcs[1]}, Call::CallType::Extern), op->index, op->predicate); } else if ((op->value.as() || op->value.as() || op->value.as()) && @@ -484,353 +486,6 @@ class AdjustPragma : public IRMutator { Array transpose_vars_; }; -class TransposeTransform : public IRMutator { - public: - Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { - if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn" && op->value.as() && - op->value.as()->value == "dma_copy") { - pre_transpose_buffer = Var("srcTranspose_local_UB"); - post_transpose_buffer = Var("dstTranspose_local_UB"); - loop_vars_ = {}; - loop_extends_ = {}; - is_candidate_ = true; - is_block_transpose_ = false; - auto body = this->Mutate(op->body); - is_candidate_ = false; - if (is_block_transpose_) { - is_block_transpose_ = false; - auto allocate_pre_buffer = Allocate::make(pre_transpose_buffer, t_type, {TransTotalSize}, const_true(1), body); - auto attr_pre_buffer = - AttrStmt::make(pre_transpose_buffer, "storage_scope", Expr("local.UB"), allocate_pre_buffer); - auto allocate_post_buffer = - Allocate::make(post_transpose_buffer, t_type, {TransTotalSize}, const_true(1), attr_pre_buffer); - auto attr_post_buffer = - AttrStmt::make(post_transpose_buffer, "storage_scope", Expr("local.UB"), allocate_post_buffer); - return attr_post_buffer; - } else { - return AttrStmt::make(op->node, op->attr_key, op->value, body); - } - } else { - return IRMutator::Mutate_(op, s); - } - } - - Stmt Mutate_(const For *op, const Stmt &s) final { - if (is_candidate_) { - loop_vars_.push_back(op->loop_var); - loop_extends_.push_back(op->extent); - Stmt body = this->Mutate(op->body); - if (is_block_transpose_ && IsInArray(trans_vars_, op->loop_var)) { - return body; - } else { - return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body); - } - } - return IRMutator::Mutate_(op, s); - } - - Stmt Mutate_(const Store *op, const Stmt &s) final { - if (is_candidate_) { - auto value = op->value; - if (auto cast = op->value.as()) { - value = cast->value; - } - CHECK(value.as()); - auto src_ptr = value.as(); - if (GetBufferType(op->buffer_var) == SCOPE_UBUF && GetBufferType(src_ptr->buffer_var) == SCOPE_UBUF) { - int dst_pos = GetVectorizedVarPosition(op->index, loop_vars_); - int src_pos = GetVectorizedVarPosition(src_ptr->index, loop_vars_); - if (dst_pos != -1 && src_pos != -1 && dst_pos != src_pos && - floormod(loop_extends_[dst_pos], TransAxisLen).as() && - floormod(loop_extends_[dst_pos], TransAxisLen).as()->value == 0 && - Equal(GetVarCoefExpr(op->index, loop_vars_[src_pos]), loop_extends_[dst_pos])) { - if (loop_extends_[dst_pos].as() && loop_extends_[dst_pos].as()->value == TransAxisLen && - loop_extends_[src_pos].as() && loop_extends_[src_pos].as()->value == TransAxisLen) { - return s; - } else { - is_block_transpose_ = true; - t_type = src_ptr->type; - trans_vars_ = {}; - trans_vars_.push_back(loop_vars_[src_pos]); - trans_vars_.push_back(loop_vars_[dst_pos]); - Expr ori_w = GetVarCoefExpr(src_ptr->index, loop_vars_[dst_pos]); - Expr ori_h = loop_extends_[dst_pos]; - Expr ori_block_w = floordiv(ori_w, TransAxisLen); - Expr ori_block_h = floordiv(ori_h, TransAxisLen); - Var loop_w = Var("block_w"); - Var loop_h = Var("block_h"); - Expr src_base_index = EliminateVarInExpr(src_ptr->index, trans_vars_); - Expr dst_base_index = EliminateVarInExpr(op->index, trans_vars_); - - Var tt0 = Var("tt0"); - Var tt1 = Var("tt1"); - auto pre_copy = Store::make( - pre_transpose_buffer, - Load::make(t_type, src_ptr->buffer_var, - src_base_index + loop_h * TransAxisLen * ori_w + loop_w * TransAxisLen + tt1 * ori_w + tt0, 1), - tt1 * TransAxisLen + tt0, 1); - auto pre_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, pre_copy); - auto pre_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, pre_l0); - auto pre_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), pre_l1); - - auto transpose = - Store::make(post_transpose_buffer, Load::make(t_type, pre_transpose_buffer, tt1 * TransAxisLen + tt0, 1), - tt0 * 16 + tt1, 1); - auto trans_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, transpose); - auto trans_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, trans_l0); - auto trans_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), trans_l1); - - auto post_copy = Store::make( - op->buffer_var, Load::make(t_type, post_transpose_buffer, tt1 * TransAxisLen + tt0, 1), - dst_base_index + loop_w * TransAxisLen * ori_h + loop_h * TransAxisLen + tt1 * ori_h + tt0, 1); - auto post_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, post_copy); - auto post_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, post_l0); - auto post_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), post_l1); - - auto full_inner = Block::make(Block::make(pre_attr, trans_attr), post_attr); - auto inner_w = For::make(loop_w, 0, ori_block_w, ForType::Serial, DeviceAPI::None, full_inner); - auto inner_h = For::make(loop_h, 0, ori_block_h, ForType::Serial, DeviceAPI::None, inner_w); - return inner_h; - } - } - } - } - return s; - } - - bool is_candidate_{false}; - bool is_block_transpose_{false}; - Array trans_vars_; - Array loop_vars_; - Array loop_extends_; - Type t_type; - Var pre_transpose_buffer; - Var post_transpose_buffer; -}; - -class IfReorder : public IRMutator { - public: - Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { - if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn" && op->value.as() && - op->value.as()->value != "mad") { - in_insn_ = true; - for_vars_.clear(); - if_vars_.clear(); - for_vec_.clear(); - if_vec_.clear(); - auto body = this->Mutate(op->body); - in_insn_ = false; - if (!if_vec_.empty()) { - Stmt new_s = AttrStmt::make(op->node, op->attr_key, op->value, body); - for (auto if_op : if_vec_) { - new_s = IfThenElse::make(if_op->condition, new_s); - } - - for (auto for_op = for_vec_.rbegin(); for_op != for_vec_.rend(); ++for_op) { - bool find_flag = false; - for (auto for_iter = for_vars_.begin(); for_iter != for_vars_.end(); ++for_iter) { - if (Equal((*for_iter), (*for_op)->loop_var)) { - find_flag = true; - break; - } - } - if (find_flag) { - new_s = For::make((*for_op)->loop_var, (*for_op)->min, (*for_op)->extent, ForType::Serial, DeviceAPI::None, - new_s); - } - } - return new_s; - } else { - return s; - } - } - return IRMutator::Mutate_(op, s); - } - - Stmt Mutate_(const For *op, const Stmt &s) final { - if (in_insn_) { - for_vec_.push_back(op); - for_vars_.push_back(op->loop_var); - Stmt body = this->Mutate(op->body); - std::vector::iterator for_iter; - for (for_iter = for_vars_.begin(); for_iter != for_vars_.end(); ++for_iter) { - if (Equal((*for_iter), op->loop_var)) { - break; - } - } - - if (!if_vec_.empty()) { - std::vector::iterator if_iter; - bool find_flag = false; - for (if_iter = if_vars_.begin(); if_iter != if_vars_.end(); ++if_iter) { - if (Equal((*if_iter), op->loop_var)) { - find_flag = true; - break; - } - } - if (find_flag) { - return body; - } else { - for_vars_.erase(for_iter); - return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body); - } - } else { - for_vars_.erase(for_iter); - return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body); - } - } - return IRMutator::Mutate_(op, s); - } - - Stmt Mutate_(const IfThenElse *op, const Stmt &s) final { - if (in_insn_) { - if_vec_.push_back(op); - for (auto loop_var : for_vars_) { - if (HasVars(op->condition, loop_var)) { - if_vars_.push_back(loop_var); - } - } - Stmt body = this->Mutate(op->then_case); - return body; - } - return IRMutator::Mutate_(op, s); - } - - Stmt Mutate_(const Store *op, const Stmt &s) final { - if (in_insn_) { - return s; - } - return IRMutator::Mutate_(op, s); - } - - bool in_insn_{false}; - std::vector if_vec_; - std::vector if_vars_; - std::vector for_vars_; - std::vector for_vec_; - std::vector before_if_; -}; - -class LoopReorder : public IRMutator { - Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { - if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn" && op->value.as()) { - in_insn_ = true; - pragma = op->value.as()->value; - for_map_.clear(); - ori_vars_ = {}; - var_order_.clear(); - auto ret = this->Mutate(op->body); - in_insn_ = false; - if (!has_changed_) { - return s; - } else { - if (var_order_.empty()) { - ret = AttrStmt::make(op->node, op->attr_key, op->value, ret); - for (size_t i = 0; i < ori_vars_.size(); ++i) { - CHECK_GT(for_map_.count(ori_vars_[i].get()), 0); - auto ptr = for_map_[ori_vars_[i].get()]; - ret = For::make(ptr->loop_var, ptr->min, ptr->extent, ptr->for_type, ptr->device_api, ret); - } - } else { - for (size_t i = 0; i < var_order_.size(); ++i) { - CHECK_GT(for_map_.count(var_order_[i].get()), 0); - auto ptr = for_map_[var_order_[i].get()]; - ret = For::make(ptr->loop_var, ptr->min, ptr->extent, ptr->for_type, ptr->device_api, ret); - } - ret = AttrStmt::make(op->node, op->attr_key, op->value, ret); - } - return ret; - } - } - return IRMutator::Mutate_(op, s); - } - - Stmt Mutate_(const For *op, const Stmt &s) final { - if (in_insn_) { - for_map_[(op->loop_var).get()] = op; - ori_vars_.push_back(op->loop_var); - auto body = this->Mutate(op->body); - return body; - } else { - return IRMutator::Mutate_(op, s); - } - } - - Stmt Mutate_(const Store *op, const Stmt &s) final { - int dst_pos = GetVectorizedVarPosition(op->index, ori_vars_); - int len = static_cast(ori_vars_.size()); - - std::vector srcs; - auto get_loads = [&srcs](const NodeRef &node) { - if (const auto v = node.as()) { - srcs.push_back(v); - } - }; - PostOrderVisit(op->value, get_loads); - - bool same_pos = true; - std::vector srcs_pos; - for (int i = 0; i < static_cast(srcs.size()); ++i) { - int temp_pos = GetVectorizedVarPosition(srcs[i]->index, ori_vars_); - srcs_pos.push_back(temp_pos); - if (temp_pos != dst_pos) { - same_pos = false; - } - } - - has_changed_ = false; - if (dst_pos >= 0 && len >= 2 && dst_pos != (len - 1) && (same_pos || pragma == "broadcast")) { - // Src Load empty; all Load and Dst has the same key axis; broadcast - has_changed_ = true; - var_order_.push_back(ori_vars_[dst_pos]); - for (int i = len - 1; i >= 0; i--) { - if (i != dst_pos) { - var_order_.push_back(ori_vars_[i]); - } - } - } else if (pragma.find("reduce") != pragma.npos && len >= 2 && srcs_pos[0] != (len - 1)) { - // based on dst key axis: reduce - has_changed_ = true; - var_order_.push_back(ori_vars_[srcs_pos[0]]); - for (int i = len - 1; i >= 0; i--) { - if (i != srcs_pos[0]) { - var_order_.push_back(ori_vars_[i]); - } - } - } - - return s; - } - - std::unordered_map for_map_; - std::vector var_order_; - Array ori_vars_; - bool has_changed_{false}; - bool in_insn_{false}; - std::string pragma; -}; - -class ForVarUnique : public IRMutator { - public: - Stmt Mutate_(const For *op, const Stmt &s) final { - auto body = this->Mutate(op->body); - if (var_maps_.count(op->loop_var.get())) { - Var new_var = Var("ii" + std::to_string(++index_)); - std::unordered_map value_map; - value_map[op->loop_var.get()] = new_var; - auto new_body = Substitute(body, value_map); - var_maps_[new_var.get()] = 1; - return For::make(new_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, new_body); - } else { - var_maps_[op->loop_var.get()] = 1; - return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body); - } - } - - std::unordered_map var_maps_; - int index_{0}; -}; - class GenSIMD { public: GenSIMD(CCEInfo &t_info, Map &buffer_map, const std::string &pragma) @@ -1520,9 +1175,9 @@ class GenReduce { ~GenReduce() = default; Stmt Run(int pre_index) { - is_arg_type_ = (pragma_ == "arg_max" || pragma_ == "arg_min"); + is_arg_type_ = (pragma_ == "reduce_fargmax" || pragma_ == "reduce_fargmin"); RemoveVectorizedIndex(t_info_, 0); - if (pragma_.find("sum") != std::string::npos) { + if (pragma_.find("sum") != std::string::npos || pragma_.find("add") != std::string::npos) { insn_intrinsic_ = "vcadd"; expansion_factor_ = 1; } else if (pragma_.find("max") != std::string::npos) { @@ -1769,7 +1424,7 @@ class EmitVariableInsns : public IRMutator { } Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { - if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn") { + if (op->attr_key == "pragma_emit_insn") { CHECK(op->value.as()); pragma = op->value.as()->value; Stmt r; @@ -1791,8 +1446,7 @@ class EmitVariableInsns : public IRMutator { if (!r.same_as(s)) { return r; } - } else if (air::ir::attr::IsPragmaKey(op->attr_key) && - (op->attr_key == "pragma_im2col" || op->attr_key == "pragma_load3d")) { + } else if (op->attr_key == "pragma_im2col" || op->attr_key == "pragma_load3d") { if (paramters_.defined() && Downcast>(paramters_).count("feature")) { auto feature = Downcast>(paramters_)["feature"].as(); CHECK(feature); @@ -1842,13 +1496,13 @@ class EmitVariableInsns : public IRMutator { if (pragma.find("vec_select") != std::string::npos) { EmitSelect(op, t_info); - } else if (pragma.find("dma_copy") == 0) { + } else if (pragma.find("dma_copy") != std::string::npos) { EmitDMA(t_info); - } else if (pragma.find("vec_binary") == 0 || pragma.find("vec_single") == 0) { + } else if (pragma.find("vec_binary") != std::string::npos || pragma.find("vec_single") != std::string::npos) { EmitSIMD(t_info); - } else if (pragma.find("reduce") == 0 || pragma.find("arg_") == 0) { + } else if (pragma.find("reduce") != std::string::npos || pragma.find("arg_") != std::string::npos) { EmitReduce(t_info); - } else if (pragma.find("broadcast") == 0) { + } else if (pragma.find("broadcast") != std::string::npos) { if (loops_vars_.empty()) { gen_cce = t_info.ori_stmt; } else { diff --git a/src/emit_insn/insn_with_variable.h b/src/emit_insn/insn_with_variable.h index 6a3c3bf2b444bdc3a448adfa402db39ff511c3bf..55326037f647eaf345a11db36df699933af01d42 100644 --- a/src/emit_insn/insn_with_variable.h +++ b/src/emit_insn/insn_with_variable.h @@ -31,8 +31,7 @@ namespace akg { namespace ir { -const int TransTotalSize = 256; -const int TransAxisLen = 16; + const int64_t FullReduceMaskValue = 6148914691236517205; class CCEInsn { diff --git a/src/emit_insn/ir_transform.h b/src/emit_insn/ir_transform.h new file mode 100644 index 0000000000000000000000000000000000000000..c274c24c6cc8bf3549e2dd6c95021556dcd2db7e --- /dev/null +++ b/src/emit_insn/ir_transform.h @@ -0,0 +1,481 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef IR_TRANSFORM_H_ +#define IR_TRANSFORM_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "ir_pass.h" +#include "common/array_api.h" + +#include "insn_with_variable.h" +#include "insn_builder.h" +#include "insn_info.h" +#include "insn_pattern.h" +#include "../pass/analyze_align.h" + +const int TransTotalSize = 256; +const int TransAxisLen = 16; + +namespace akg { +namespace ir { + +Expr GetVarCoefExpr(const Expr &index, const Var &loop_var); + +std::string GetBufferType(Expr address); + +class TransposeTransform : public IRMutator { + public: + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { + if (op->attr_key == "pragma_emit_insn" && op->value.as() && + (op->value.as()->value == "dma_copy")) { + pre_transpose_buffer_ = Var("srcTranspose_local_UB"); + post_transpose_buffer_ = Var("dstTranspose_local_UB"); + pre_trans_cast_ = Var("pre_trans_cast__local_UB"); + post_trans_cast_ = Var("post_trans_cast__local_UB"); + loop_vars_ = {}; + loop_extends_ = {}; + is_candidate_ = true; + is_block_transpose_ = false; + is_native_transpose_ = false; + align_value = FREE_ALIGN; + remain_fors_.clear(); + auto body = this->Mutate(op->body); + is_candidate_ = false; + if (is_block_transpose_) { + is_block_transpose_ = false; + if (t_type_ == Float(32)) { // need cast + body = Allocate::make(pre_trans_cast_, Float(16), {TransTotalSize}, const_true(1), body); + body = AttrStmt::make(pre_trans_cast_, "storage_scope", Expr("local.UB"), body); + body = Allocate::make(post_trans_cast_, Float(16), {TransTotalSize}, const_true(1), body); + body = AttrStmt::make(post_trans_cast_, "storage_scope", Expr("local.UB"), body); + } + auto allocate_pre_buffer = + Allocate::make(pre_transpose_buffer_, t_type_, {TransTotalSize}, const_true(1), body); + auto attr_pre_buffer = + AttrStmt::make(pre_transpose_buffer_, "storage_scope", Expr("local.UB"), allocate_pre_buffer); + auto allocate_post_buffer = + Allocate::make(post_transpose_buffer_, t_type_, {TransTotalSize}, const_true(1), attr_pre_buffer); + auto attr_post_buffer = + AttrStmt::make(post_transpose_buffer_, "storage_scope", Expr("local.UB"), allocate_post_buffer); + Stmt ret = attr_post_buffer; + if (align_value != FREE_ALIGN) { + ret = AttrStmt::make(align_buffer_, "align_info", Expr(align_value), ret); + } + return ret; + } + if (is_native_transpose_) { + Stmt ret = AttrStmt::make(op->node, op->attr_key, Expr("dma_copy_transpose"), body); + for (int i = 0; i <= static_cast(remain_fors_.size()) - 1; ++i) { + ret = For::make(remain_fors_[i]->loop_var, remain_fors_[i]->min, remain_fors_[i]->extent, ForType::Serial, + DeviceAPI::None, ret); + } + return ret; + } + return AttrStmt::make(op->node, op->attr_key, op->value, body); + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const For *op, const Stmt &s) final { + if (is_candidate_) { + loop_vars_.push_back(op->loop_var); + loop_extends_.push_back(op->extent); + Stmt body = this->Mutate(op->body); + if (is_block_transpose_ && IsInArray(trans_vars_, op->loop_var)) { + return body; + } + if (is_native_transpose_) { + if (IsInArray(trans_vars_, op->loop_var)) { + return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body); + } + remain_fors_.push_back(op); + return body; + } + return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body); + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Store *op, const Stmt &s) final { + if (is_candidate_) { + auto value = op->value; + if (auto cast = op->value.as()) { + value = cast->value; + } + CHECK(value.as()); + auto src_ptr = value.as(); + if (GetBufferType(op->buffer_var) == SCOPE_UBUF && GetBufferType(src_ptr->buffer_var) == SCOPE_UBUF && + src_ptr->type == Float(16)) { + int dst_pos = GetVectorizedVarPosition(op->index, loop_vars_); + int src_pos = GetVectorizedVarPosition(src_ptr->index, loop_vars_); + if (dst_pos != -1 && src_pos != -1 && dst_pos != src_pos && HasVars(src_ptr->index, loop_vars_[dst_pos]) && + HasVars(op->index, loop_vars_[src_pos]) && floormod(loop_extends_[dst_pos], TransAxisLen).as() && + floormod(loop_extends_[dst_pos], TransAxisLen).as()->value == 0 && + Equal(GetVarCoefExpr(op->index, loop_vars_[src_pos]), loop_extends_[dst_pos])) { + if (loop_extends_[dst_pos].as() && loop_extends_[dst_pos].as()->value == TransAxisLen && + loop_extends_[src_pos].as() && loop_extends_[src_pos].as()->value == TransAxisLen) { + trans_vars_ = {}; + trans_vars_.push_back(loop_vars_[src_pos]); + trans_vars_.push_back(loop_vars_[dst_pos]); + is_native_transpose_ = true; + return s; + } + is_block_transpose_ = true; + if (GetVarCoefExpr(src_ptr->index, loop_vars_[dst_pos]).as()) { + int coef_t = GetVarCoefExpr(src_ptr->index, loop_vars_[dst_pos]).as()->value; + if (coef_t % TransAxisLen != 0) { + align_value = coef_t; + align_buffer_ = src_ptr->buffer_var; + } + } + t_type_ = src_ptr->type; + trans_vars_ = {}; + trans_vars_.push_back(loop_vars_[src_pos]); + trans_vars_.push_back(loop_vars_[dst_pos]); + Expr ori_w = GetVarCoefExpr(src_ptr->index, loop_vars_[dst_pos]); + Expr ori_h = loop_extends_[dst_pos]; + Expr ori_block_w = floordiv(ori_w, TransAxisLen); + // padding the width + Expr unit_width = TransAxisLen; + if (!Equal(floormod(ori_w, TransAxisLen), 0)) { + ori_block_w = ori_block_w + 1; + } + if (ori_w.as() && ori_w.as()->value < TransAxisLen) { + unit_width = ori_w; + } + Expr ori_block_h = floordiv(ori_h, TransAxisLen); + Var loop_w = Var("block_w"); + Var loop_h = Var("block_h"); + Expr src_base_index = EliminateVarInExpr(src_ptr->index, trans_vars_); + Expr dst_base_index = EliminateVarInExpr(op->index, trans_vars_); + Var tt0 = Var("tt0"); + Var tt1 = Var("tt1"); + auto pre_copy = Store::make( + pre_transpose_buffer_, + Load::make(t_type_, src_ptr->buffer_var, + src_base_index + loop_h * TransAxisLen * ori_w + loop_w * TransAxisLen + tt1 * ori_w + tt0, 1), + tt1 * TransAxisLen + tt0, 1); + auto pre_l0 = For::make(tt0, 0, unit_width, ForType::Serial, DeviceAPI::None, pre_copy); + auto pre_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, pre_l0); + auto pre_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), pre_l1); + Stmt trans_attr = Stmt(); + if (t_type_ == Float(16)) { + auto transpose = + Store::make(post_transpose_buffer_, + Load::make(t_type_, pre_transpose_buffer_, tt1 * TransAxisLen + tt0, 1), tt0 * 16 + tt1, 1); + auto trans_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, transpose); + auto trans_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, trans_l0); + trans_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy_transpose"), trans_l1); + } else { + auto pre_cast_store = Store::make( + pre_trans_cast_, Cast::make(Float(16), Load::make(t_type_, pre_transpose_buffer_, tt0, 1)), tt0, 1); + auto pre_cast_for = For::make(tt0, 0, TransTotalSize, ForType::Serial, DeviceAPI::None, pre_cast_store); + auto pre_cast_attr = + AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("vec_single_cast"), pre_cast_for); + + auto transpose = Store::make( + post_trans_cast_, Load::make(Float(16), pre_trans_cast_, tt1 * TransAxisLen + tt0, 1), tt0 * 16 + tt1, 1); + auto trans_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, transpose); + auto trans_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, trans_l0); + auto trans_block = + AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy_transpose"), trans_l1); + + auto post_cast_store = Store::make( + post_transpose_buffer_, Cast::make(t_type_, Load::make(Float(16), post_trans_cast_, tt0, 1)), tt0, 1); + auto post_cast_for = For::make(tt0, 0, TransTotalSize, ForType::Serial, DeviceAPI::None, post_cast_store); + auto post_cast_attr = + AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("vec_single_cast"), post_cast_for); + + trans_attr = Block::make(Block::make(pre_cast_attr, trans_block), post_cast_attr); + } + auto post_copy = + Store::make(op->buffer_var, Load::make(t_type_, post_transpose_buffer_, tt1 * TransAxisLen + tt0, 1), + dst_base_index + loop_w * TransAxisLen * ori_h + loop_h * TransAxisLen + tt1 * ori_h + tt0, 1); + auto post_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, post_copy); + auto post_l1 = For::make(tt1, 0, unit_width, ForType::Serial, DeviceAPI::None, post_l0); + auto post_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), post_l1); + auto full_inner = Block::make(Block::make(pre_attr, trans_attr), post_attr); + auto inner_w = For::make(loop_w, 0, ori_block_w, ForType::Serial, DeviceAPI::None, full_inner); + if (ori_block_w.as() && ori_block_w.as()->value == 1) { + std::unordered_map init; + init[loop_w.get()] = 0; + inner_w = Simplify(Substitute(full_inner, init)); + } + auto inner_h = For::make(loop_h, 0, ori_block_h, ForType::Serial, DeviceAPI::None, inner_w); + if (ori_block_h.as() && ori_block_h.as()->value == 1) { + std::unordered_map init; + init[loop_h.get()] = 0; + inner_h = Simplify(Substitute(inner_w, init)); + } + return inner_h; + } + } + } + return s; + } + + private: + bool is_candidate_{false}; + bool is_native_transpose_{false}; + bool is_block_transpose_{false}; + int align_value{FREE_ALIGN}; + Var align_buffer_; + Array trans_vars_; + Array loop_vars_; + Array loop_extends_; + std::vector remain_fors_; + Type t_type_; + Var pre_transpose_buffer_; + Var pre_trans_cast_; + Var post_trans_cast_; + Var post_transpose_buffer_; +}; + +class ForVarUnique : public IRMutator { + public: + Stmt Mutate_(const For *op, const Stmt &s) final { + auto body = this->Mutate(op->body); + if (var_maps_.count(op->loop_var.get())) { + Var new_var = Var("ii" + std::to_string(++index_)); + std::unordered_map value_map; + value_map[op->loop_var.get()] = new_var; + auto new_body = Substitute(body, value_map); + var_maps_[new_var.get()] = 1; + return For::make(new_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, new_body); + } + var_maps_[op->loop_var.get()] = 1; + return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body); + } + + private: + std::unordered_map var_maps_; + int index_{0}; +}; + +class LoopReorder : public IRMutator { + public: + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { + if (op->attr_key == "pragma_emit_insn" && op->value.as()) { + in_insn_ = true; + pragma_ = op->value.as()->value; + for_map_.clear(); + ori_vars_ = {}; + var_order_.clear(); + auto ret = this->Mutate(op->body); + in_insn_ = false; + if (!has_changed_) { + return s; + } + if (var_order_.empty()) { + ret = AttrStmt::make(op->node, op->attr_key, op->value, ret); + for (size_t i = 0; i < ori_vars_.size(); ++i) { + CHECK_GT(for_map_.count(ori_vars_[i].get()), 0); + auto ptr = for_map_[ori_vars_[i].get()]; + ret = For::make(ptr->loop_var, ptr->min, ptr->extent, ptr->for_type, ptr->device_api, ret); + } + return ret; + } + for (size_t i = 0; i < var_order_.size(); ++i) { + CHECK_GT(for_map_.count(var_order_[i].get()), 0); + auto ptr = for_map_[var_order_[i].get()]; + ret = For::make(ptr->loop_var, ptr->min, ptr->extent, ptr->for_type, ptr->device_api, ret); + } + ret = AttrStmt::make(op->node, op->attr_key, op->value, ret); + return ret; + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const For *op, const Stmt &s) final { + if (in_insn_) { + for_map_[(op->loop_var).get()] = op; + ori_vars_.push_back(op->loop_var); + auto body = this->Mutate(op->body); + return body; + } else { + return IRMutator::Mutate_(op, s); + } + } + + Stmt Mutate_(const Store *op, const Stmt &s) final { + int dst_pos = GetVectorizedVarPosition(op->index, ori_vars_); + int len = static_cast(ori_vars_.size()); + + std::vector srcs; + auto get_loads = [&srcs](const NodeRef &node) { + if (const auto v = node.as()) { + srcs.push_back(v); + } + }; + PostOrderVisit(op->value, get_loads); + + bool same_pos = true; + std::vector srcs_pos; + for (int i = 0; i < static_cast(srcs.size()); ++i) { + int temp_pos = GetVectorizedVarPosition(srcs[i]->index, ori_vars_); + srcs_pos.push_back(temp_pos); + if (temp_pos != dst_pos) { + same_pos = false; + } + } + + has_changed_ = false; + if (dst_pos >= 0 && len >= 2 && dst_pos != (len - 1) && (same_pos || pragma_ == "broadcast")) { + // Src Load empty; all Load and Dst has the same key axis; broadcast + has_changed_ = true; + var_order_.push_back(ori_vars_[dst_pos]); + for (int i = len - 1; i >= 0; i--) { + if (i != dst_pos) { + var_order_.push_back(ori_vars_[i]); + } + } + } else if (pragma_.find("reduce") != pragma_.npos && len >= 2 && srcs_pos[0] != (len - 1)) { + // based on dst key axis: reduce + has_changed_ = true; + var_order_.push_back(ori_vars_[srcs_pos[0]]); + for (int i = len - 1; i >= 0; i--) { + if (i != srcs_pos[0]) { + var_order_.push_back(ori_vars_[i]); + } + } + } + return s; + } + + private: + std::unordered_map for_map_; + std::vector var_order_; + Array ori_vars_; + bool has_changed_{false}; + bool in_insn_{false}; + std::string pragma_; +}; + +class IfReorder : public IRMutator { + public: + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { + if (op->attr_key == "pragma_emit_insn" && op->value.as() && + !exclude_align_analyze_list.count(op->value.as()->value)) { + in_insn_ = true; + for_vars_.clear(); + if_vars_.clear(); + for_vec_.clear(); + if_vec_.clear(); + auto body = this->Mutate(op->body); + in_insn_ = false; + if (!if_vec_.empty()) { + Stmt new_s = AttrStmt::make(op->node, op->attr_key, op->value, body); + for (auto if_op : if_vec_) { + new_s = IfThenElse::make(if_op->condition, new_s); + } + for (auto for_op = for_vec_.rbegin(); for_op != for_vec_.rend(); ++for_op) { + bool find_flag = false; + for (auto for_iter = for_vars_.begin(); for_iter != for_vars_.end(); ++for_iter) { + if (Equal((*for_iter), (*for_op)->loop_var)) { + find_flag = true; + break; + } + } + if (find_flag) { + new_s = For::make((*for_op)->loop_var, (*for_op)->min, (*for_op)->extent, ForType::Serial, DeviceAPI::None, + new_s); + } + } + return new_s; + } + return s; + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const For *op, const Stmt &s) final { + if (in_insn_) { + for_vec_.push_back(op); + for_vars_.push_back(op->loop_var); + Stmt body = this->Mutate(op->body); + std::vector::iterator for_iter; + for (for_iter = for_vars_.begin(); for_iter != for_vars_.end(); ++for_iter) { + if (Equal((*for_iter), op->loop_var)) { + break; + } + } + + if (!if_vec_.empty()) { + std::vector::iterator if_iter; + bool find_flag = false; + for (if_iter = if_vars_.begin(); if_iter != if_vars_.end(); ++if_iter) { + if (Equal((*if_iter), op->loop_var)) { + find_flag = true; + break; + } + } + if (find_flag) { + return body; + } + for_vars_.erase(for_iter); + return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body); + } + for_vars_.erase(for_iter); + return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body); + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const IfThenElse *op, const Stmt &s) final { + if (in_insn_) { + if_vec_.push_back(op); + for (auto loop_var : for_vars_) { + if (HasVars(op->condition, loop_var)) { + if_vars_.push_back(loop_var); + } + } + Stmt body = this->Mutate(op->then_case); + return body; + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Store *op, const Stmt &s) final { + if (in_insn_) { + return s; + } + return IRMutator::Mutate_(op, s); + } + + private: + bool in_insn_{false}; + std::vector if_vec_; + std::vector if_vars_; + std::vector for_vars_; + std::vector for_vec_; + std::vector before_if_; +}; + +} // namespace ir +} // namespace akg + +#endif // IR_TRANSFORM_H_ \ No newline at end of file diff --git a/src/include/ir_pass.h b/src/include/ir_pass.h index e3ca5ed3cd24ce4f03348ab5fa60e8ce8af62615..c5bad4a3a1c3762f0ebccbfdc3753fa9d5005f75 100644 --- a/src/include/ir_pass.h +++ b/src/include/ir_pass.h @@ -265,6 +265,10 @@ Stmt RewriteBroadcastVector(Stmt stmt); Stmt OptimizePragma(Stmt stmt); +Stmt PackStore(Stmt stmt); + +Stmt RecoverStore(Stmt stmt); + Stmt RewriteByAlignDynamic(Stmt stmt); Stmt EliminateAtomicDma(Stmt stmt); diff --git a/src/pass/analyze_align.h b/src/pass/analyze_align.h index 372f40aba560c0c8b693a3dd0e7467bcafad863f..65b345c0c5beb0f99f3c4662c1a2d3a45c5b2c02 100644 --- a/src/pass/analyze_align.h +++ b/src/pass/analyze_align.h @@ -21,14 +21,18 @@ #include #include +#include +#include #include "pass/utils.h" #include "arith_expr_simplify.h" #include "expr_alg_simplify.h" +#include "emit_insn/cce_params.h" +#include "common/array_api.h" namespace akg { namespace ir { -const std::set exclude_list = { +const std::set exclude_align_analyze_list = { "mad", "scatter", "vec_binary_proposal_sort", @@ -39,8 +43,15 @@ const std::set exclude_list = { "vec_single_four2five_nchw", "opt_broadcast", "reduce_reorder", + "dma_atomic_add", + "dma_copy_transpose", }; -class IndexOptimizer : public air::ir::IRMutator { + +const std::set exclude_index_fix_list = { + "mad", "vec_binary_proposal_sort", "vec_binary_topk_sort", "vec_binary_nms", "vec_binary_iou", "vec_binary_dropout", +}; + +class IndexOptimizer : public IRMutator { public: explicit IndexOptimizer(bool rm = false) : var2expr(), rm_load_(rm) {} ~IndexOptimizer() override = default; @@ -71,6 +82,745 @@ class IndexOptimizer : public air::ir::IRMutator { private: bool rm_load_; }; + +int GetCommonDivisor(std::vector numbers); + +class IndexInfo { + public: + Array vars; + Array coefs; + Array extents; + int divisor; + int vec_len{-1}; + Var vec_var{}; + Expr offset; + Expr index; + bool is_serial{true}; + bool is_scalar{true}; +}; + +class DstInfo : public IndexInfo { + public: + bool IsGlobal() { return (GetBufScope(p_store->buffer_var->name_hint) == DMA_COPY_GLOBAL); } + bool IsUB() { return (GetBufScope(p_store->buffer_var->name_hint) == SCOPE_UBUF); } + const Store *p_store; +}; + +class SrcInfo : public IndexInfo { + public: + bool IsGlobal() { return (GetBufScope(p_load->buffer_var->name_hint) == DMA_COPY_GLOBAL); } + bool IsUB() { return (GetBufScope(p_load->buffer_var->name_hint) == SCOPE_UBUF); } + const Load *p_load; + bool is_imm; + Expr imm; +}; + +class ArithInfo { + public: + Stmt GenIR() { return store; } + + void GetIntrinsicType(Array &for_vars, Array &if_vars) { + if (for_vars.empty()) { + if (TryScalarType()) { + insn_type = "scalar"; + } else { + insn_type = "discrete"; + } + return; + } + if (TryScalarAssignType(if_vars)) { + insn_type = "scalar"; + return; + } + if (TryReduceType()) { + insn_type = "reduce"; + return; + } + auto simd_t = TrySIMDType(); + if (simd_t == 1) { + insn_type = "simd"; + return; + } else if (simd_t == 2) { + insn_type = "simd_split"; + return; + } + if (TryVectorScalarType()) { + insn_type = "vector_scalar"; + return; + } + if (TryVectorDumpType()) { + insn_type = "vector_dump"; + return; + } + if (TryCrossingType()) { + insn_type = "crossing"; + return; + } + if (TryDiscrete()) { + insn_type = "discrete"; + return; + } + if (insn_type == "unknown") { + CHECK(0) << "\nUnknown Intrinsic Type"; + } + } + + // A[0] = B[1] + C[2] + bool TryScalarType() { + if (dst_info.IsUB() && dst_info.p_store->value.as() && src_info[0].IsGlobal()) { + return true; + } + if (!is_const(dst_info.index)) { + return false; + } + for (auto info : src_info) { + if (!is_const(info.index)) { + return false; + } + } + return true; + } + + // for i { for j { A[i] = reduce(C[X*i + j]) } } + bool TryReduceType() { + if (dst_info.p_store->value.as()) { + auto t_call = dst_info.p_store->value.as(); + if (t_call->name.find("reduce_") != std::string::npos) { + return true; + } + } + return false; + } + + // for i { for j { A[X*i + j] = B[X*i + j] + C[j] } } + int TrySIMDType() { + Var cur_var = dst_info.vec_var; + int cur_len = dst_info.vec_len; + Expr cur_offset = dst_info.offset; + int block_size = GetUbBlkSize(dst_info.p_store->value.type()); + bool is_simd = (cur_len >= 1) ? true : false; + for (auto info : src_info) { + if (info.vec_len != cur_len || !Equal(info.vec_var, cur_var)) { + is_simd = false; + break; + } + } + + bool need_split = false; + if (is_simd) { + for (auto info : src_info) { + int info_block_size = GetUbBlkSize(info.p_load->type); + if (dst_info.IsUB() && info.IsUB()) { + if (is_const(cur_offset) && is_const(info.offset) && + cur_offset.as()->value % block_size != info.offset.as()->value % info_block_size) { + need_split = true; + break; + } + } + } + } + + if (is_simd && need_split) { + if (src_info.size() == 1) { + if (dst_info.divisor != 0 && src_info[0].divisor != 0 && + dst_info.offset.as()->value % dst_info.divisor != + src_info[0].offset.as()->value % src_info[0].divisor) { + dst_info.divisor = air::ir::gcd(dst_info.divisor, dst_info.offset.as()->value); + src_info[0].divisor = air::ir::gcd(src_info[0].divisor, src_info[0].offset.as()->value); + auto min_dst_src = std::min(dst_info.divisor, src_info[0].divisor); + dst_info.divisor = min_dst_src; + src_info[0].divisor = min_dst_src; + } + } else { + CHECK(0) << "\nNeed to split the vector var to make the offset equal or scalar computing\n"; + } + } + + bool unaligned_divisor = false; + if (is_simd) { + if (dst_info.IsUB()) { + if (dst_info.divisor != 0 && dst_info.divisor < cur_len) { + dst_info.divisor = air::ir::gcd(dst_info.divisor, cur_len); + unaligned_divisor = true; + } + } + for (auto info : src_info) { + if (info.IsUB()) { + if (info.divisor != 0 && info.divisor < cur_len) { + unaligned_divisor = true; + int temp_divisor = air::ir::gcd(info.divisor, cur_len); + dst_info.divisor = air::ir::gcd(dst_info.divisor, temp_divisor); + } + } + } + } + if (is_simd && !need_split && !unaligned_divisor) { + return 1; + } + if (is_simd && (need_split || unaligned_divisor)) { + return 2; + } + return 0; + } + + // for i { for j { A[X*i + j] = B[X*i + j] + C[Z*i] } } + bool TryVectorScalarType() { + if (src_info.size() != 2) { + return false; + } + if (dst_info.is_serial && Equal(dst_info.vec_var, src_info[0].vec_var) && + !HasVars(src_info[1].index, dst_info.vec_var) && + (!src_info[1].is_serial || !Equal(dst_info.vec_var, src_info[1].vec_var))) { + scalar_load = src_info[1]; + src_info.pop_back(); + return true; + } + if (dst_info.is_serial && Equal(dst_info.vec_var, src_info[1].vec_var) && + !HasVars(src_info[0].index, dst_info.vec_var) && + (!src_info[0].is_serial || !Equal(dst_info.vec_var, src_info[0].vec_var))) { + scalar_load = src_info[0]; + src_info.erase(src_info.begin()); + return true; + } + return false; + } + + // for i { for j { A[X*i + j] = C[Z*i] } } + bool TryVectorDumpType() { + if (src_info.size() != 1) { + return false; + } + if (GetBufScope(dst_info.p_store->buffer_var->name_hint) == SCOPE_UBUF && + GetBufScope(src_info[0].p_load->buffer_var->name_hint) == SCOPE_UBUF && dst_info.is_serial && + !HasVars(src_info[0].index, dst_info.vec_var) && + (!src_info[0].is_serial || !Equal(dst_info.vec_var, src_info[0].vec_var))) { + scalar_load = src_info[0]; + src_info.pop_back(); + return true; + } + return false; + } + + bool TryScalarAssignType(Array &if_vars) { + if (dst_info.IsUB() && dst_info.is_serial && src_info.size() == 1 && dst_info.p_store->value.as() && + src_info[0].IsUB()) { + bool not_simd_or_dump = HasVars(src_info[0].index, dst_info.vec_var) && + (!src_info[0].is_serial || !Equal(dst_info.vec_var, src_info[0].vec_var)); + bool in_if_vars = !if_vars.empty() && IsInArray(if_vars, dst_info.vec_var); + if (not_simd_or_dump || in_if_vars) { + return true; + } + } + return false; + } + + // for i { for j { A[X*i + j] = C[Y*j + i] } } + // for i { for j { A[X*i + j] = C[Y*j] } } + bool TryCrossingType() { + if (dst_info.is_serial && src_info.size() == 1 && HasVars(src_info[0].index, dst_info.vec_var) && + (!src_info[0].is_serial || !Equal(dst_info.vec_var, src_info[0].vec_var))) { + return true; + } + return false; + } + + // for i {for j { A[X*i + Y*j] = ....} } + bool TryDiscrete() { return !(dst_info.is_serial); } + + void GetVectorizedInfo() { + if (insn_type == "scalar") { + is_scalar = true; + return; + } + if (insn_type == "simd" || insn_type == "vector_scalar" || insn_type == "vector_dump") { + vec_len = dst_info.vec_len; + vec_var = dst_info.vec_var; + offset = dst_info.offset; + return; + } + if (insn_type == "simd_split") { + vec_len = dst_info.divisor; + offset = 0; + return; + } + if (insn_type == "reduce") { + vec_len = src_info[0].vec_len; + vec_var = src_info[0].vec_var; + offset = src_info[0].offset; + return; + } + if (insn_type == "crossing" || insn_type == "discrete") { + vec_len = 1; + if (dst_info.is_serial) { + dst_info.vec_len = 1; + dst_info.divisor = 1; + } + for (size_t i = 0; i < src_info.size(); i++) { + if (src_info[i].is_serial) { + src_info[i].divisor = 1; + src_info[i].vec_len = 1; + } + } + return; + } + CHECK(0) << "\ninsn_type is unknown\n"; + } + + DstInfo dst_info; + std::vector src_info; + int vec_len; + Var vec_var; + Expr offset; + bool is_scalar{false}; + Stmt store; + std::string op_type; + std::string insn_type{"unknown"}; + SrcInfo scalar_load; + Expr scalar_imm{Expr()}; + int scalar_imm_num{0}; +}; + +class IRIfInfo { + public: + Array conds; + Array vars; + Array ops; +}; + +class IRForInfo { + public: + Array vars; + std::vector exts; + Array ops; +}; + +class IRInfo { + public: + Stmt GenStmt() { + auto ret = GenIfAndFor(); + return ret; + } + + Stmt GenIfAndFor() { + auto core = arith_info.store; + if (for_info.vars.empty()) { + return core; + } + Stmt ret = core; + for (int i = static_cast(for_info.vars.size()) - 1; i >= 0; --i) { + ret = For::make(for_info.vars[i], 0, for_info.exts[i], ForType::Serial, DeviceAPI::None, ret); + } + return ret; + } + + bool ChangeLastDimReduce() { + if (arith_info.src_info.size() != 2) { + return false; + } + + size_t i = 0; + for (i = 0; i < arith_info.src_info.size(); ++i) { + if (Equal(arith_info.src_info[i].p_load->buffer_var, arith_info.dst_info.p_store->buffer_var) && + Equal(arith_info.src_info[i].p_load->index, arith_info.dst_info.p_store->index)) { + break; + } + } + + if (i >= 2) { + return false; + } + + size_t index = 0; + if (!Equal(arith_info.src_info[1 - i].vec_var, arith_info.src_info[i].vec_var) && + GetIndexOfElement(for_info.vars, arith_info.src_info[1 - i].vec_var, index) && + !HasVars(arith_info.src_info[i].p_load->index, {arith_info.src_info[1 - i].vec_var})) { + SrcInfo t_src = arith_info.src_info[1 - i]; + arith_info.src_info.clear(); + arith_info.src_info.push_back(t_src); + arith_info.insn_type = "reduce_" + GetReduceType(); + Expr pack_value = + Call::make(t_src.p_load->type, arith_info.insn_type, {GetRef(t_src.p_load)}, Call::Extern); + arith_info.store = Store::make(arith_info.store.as()->buffer_var, pack_value, + arith_info.store.as()->index, arith_info.store.as()->predicate); + return true; + } + + return false; + } + + std::string GetReduceType() { + std::string ret = GetOpType(arith_info.dst_info.p_store->value); + std::transform(ret.begin(), ret.end(), ret.begin(), ::tolower); + return ret; + } + + IRIfInfo if_info; + IRForInfo for_info; + ArithInfo arith_info; +}; + +class ImmOffsetVisitor : public IRVisitor { + public: + int Run(const Expr &e) { + auto temp_index = Simplify(e); + IRVisitor::Visit(temp_index); + return imm_offset_; + } + + void Visit_(const Add *op) { + if (op->a.as()) { + imm_offset_ = op->a.as()->value; + } else if (op->b.as()) { + imm_offset_ = op->b.as()->value; + } else { + IRVisitor::Visit(op->b); + } + } + + bool in_add_flag_{false}; + int imm_offset_{0}; +}; + +class ParserVisitor : public IRVisitor { + public: + ParserVisitor(IRInfo &in, bool flag = false) : info(in), with_align(flag) {} + ~ParserVisitor() override = default; + + void Run(const Stmt &s) { + in_store = false; + IRVisitor::Visit(s); + if (with_align) { + GetInsnType(); + info.arith_info.GetVectorizedInfo(); + } + } + + void Visit_(const For *op) { + info.for_info.vars.push_back(op->loop_var); + info.for_info.exts.push_back(op->extent.as()->value); + info.for_info.ops.push_back(op->body); + IRVisitor::Visit(op->body); + } + + void Visit_(const IfThenElse *op) { + CHECK(!op->else_case.defined()); + info.if_info.conds.push_back(op->condition); + auto var_list = GetVarsInExpr(op->condition); + for (auto t_var : var_list) { + if (!IsInArray(info.if_info.vars, t_var)) { + info.if_info.vars.push_back(t_var); + } + } + info.if_info.ops.push_back(op->then_case); + IRVisitor::Visit(op->then_case); + } + + void Visit_(const Load *op) { + SrcInfo src_info; + src_info.index = op->index; + src_info.p_load = op; + GetIndexInfo(op->index, src_info); + info.arith_info.src_info.push_back(src_info); + } + + void Visit_(const FloatImm *op) { + if (in_store) { + info.arith_info.scalar_imm = GetRef(op); + ++info.arith_info.scalar_imm_num; + } + } + + void Visit_(const IntImm *op) { + if (in_store) { + info.arith_info.scalar_imm = GetRef(op); + ++info.arith_info.scalar_imm_num; + } + } + + void Visit_(const Store *op) { + info.arith_info.store = GetRef(op); + info.arith_info.op_type = GetOpType(op->value); + in_store = true; + IRVisitor::Visit(op->value); + in_store = false; + DstInfo dst_info; + dst_info.p_store = op; + dst_info.index = op->index; + GetIndexInfo(op->index, dst_info); + info.arith_info.dst_info = dst_info; + } + + void GetInsnType() { info.arith_info.GetIntrinsicType(info.for_info.vars, info.if_info.vars); } + + template + void GetIndexInfo(const Expr &e, T &t) { + bool is_serial = false; + int imm_offset = ImmOffsetVisitor().Run(e); + t.offset = imm_offset; + + std::vector nums; + bool is_linear_inner_for = true; + if (info.for_info.vars.empty()) { + t.is_scalar = true; + return; + } + for (size_t i = 0; i < info.for_info.vars.size(); i++) { + auto coef = air::arith::DetectLinearEquation(e, {info.for_info.vars[i]}); + if (!coef.empty() && !Equal(coef[0], 0)) { + t.vars.push_back(info.for_info.vars[i]); + t.coefs.push_back(coef[0].as()->value); + t.extents.push_back(info.for_info.exts[i]); + if (!Equal(coef[0], 1)) { + nums.push_back(coef[0].as()->value); + } else { + is_serial = true; + t.vec_var = info.for_info.vars[i]; + t.vec_len = info.for_info.exts[i]; + } + } else if (coef.empty()) { + is_linear_inner_for = false; + } + } + + if (is_linear_inner_for) { + if (nums.empty()) { + t.divisor = 0; + } else { + t.divisor = GetCommonDivisor(nums); + } + } else { + if (is_serial) { + Map value_map; + value_map.Set(t.vec_var, 0); + auto new_e = Simplify(Substitute(e, value_map)); + if (Equal(Simplify(Mod::make(new_e, t.vec_len)), 0)) { + t.divisor = t.vec_len; + } else { + t.divisor = 1; + } + } else { + t.divisor = 1; + } + } + t.is_serial = is_serial; + } + + private: + IRInfo &info; + bool with_align{false}; + bool in_store{false}; +}; + +class InsnTensor { + public: + InsnTensor(std::string name, Type type) : m_name(name), m_type(type) {} + virtual ~InsnTensor() {} + + void SetAlignment(int align) { m_alignment = align; } + int GetAlignment() { return m_alignment; } + Type GetType() { return m_type; } + + std::string m_name; + Type m_type; + int m_alignment{FREE_ALIGN}; +}; + +class UnifyAlignInfo { + public: + bool NeedPadding(int align, int block_size) { return (align > 0 && align % block_size != 0); } + + bool UnifyAlign() { + bool need_adjust = false; + int align = observers[0]->m_alignment; + int align_size = 32 / observers[0]->GetType().bytes(); + for (size_t i = 1; i < observers.size(); ++i) { + auto temp_align = observers[i]->m_alignment; + auto temp_block = 32 / observers[i]->GetType().bytes(); + if (align != temp_align && (NeedPadding(align, align_size) || NeedPadding(temp_align, temp_block))) { + need_adjust = true; + align = SpreadAlign(align, observers[i]->m_alignment, align_size, temp_block); + } + } + if (need_adjust) { + for (size_t i = 0; i < observers.size(); ++i) { + observers[i]->m_alignment = align; + } + } + return need_adjust; + } + + int SpreadAlign(int left, int right, int left_block, int right_block) { + if (left < 0 || left % left_block == 0) { + return right; + } + if (right < 0 || right % right_block == 0) { + return left; + } + return GetCommonDivisor({left, right}); + } + + std::vector observers; + std::vector divisors; + std::vector offsets; + int vector_len; +}; + +class AlignAttach : public IRMutator { + public: + AlignAttach(std::map &in_map) : m_map_(in_map) {} + + Stmt Mutate_(const Store *op, const Stmt &s) { + auto value = this->Mutate(op->value); + int align = 1; + if (m_map_.count(op->buffer_var.get())) { + align = m_map_[op->buffer_var.get()]->m_alignment; + } + return Store::make(op->buffer_var, value, op->index, align); + } + + Expr Mutate_(const Load *op, const Expr &e) { + int align = 1; + if (m_map_.count(op->buffer_var.get())) { + align = m_map_[op->buffer_var.get()]->m_alignment; + } + return Load::make(op->type, op->buffer_var, op->index, align); + } + + private: + std::map &m_map_; +}; + +class AlignGen : public IRVisitor { + public: + Stmt Run(const Stmt stmt, std::unordered_map &var_info) { + for (auto &item : var_info) { + auto ptr = new InsnTensor(item.first->name_hint, item.second); + observer_dic_[item.first] = ptr; + } + IRVisitor::Visit(stmt); + BroadcastAlign(); + auto ret = AlignAttach(observer_dic_).Mutate(stmt); + return ret; + } + + void Visit_(const AttrStmt *op) final { + if (op->attr_key == "pragma_emit_insn" && exclude_align_analyze_list.count(op->value.as()->value) == 0) { + IRInfo info; + ParserVisitor(info, true).Run(op->body); + AddAlignInfo(info); + } else if (op->attr_key == "align_info" && op->node.as() && observer_dic_[op->node.as()] && + op->value.as()) { + observer_dic_[op->node.as()]->m_alignment = op->value.as()->value; + } else { + IRVisitor::Visit_(op); + } + } + + void AddAlignInfo(IRInfo &info) { + if (info.arith_info.insn_type == "scalar") { + return; + } + bool is_ub_to_gm = (info.arith_info.src_info.size() == 1) && + GetBufScope(info.arith_info.dst_info.p_store->buffer_var->name_hint) == DMA_COPY_GLOBAL; + bool is_gm_to_ub = (info.arith_info.src_info.size() == 1) && + GetBufScope(info.arith_info.src_info[0].p_load->buffer_var->name_hint) == DMA_COPY_GLOBAL; + if (!is_ub_to_gm) { + auto dst_name = info.arith_info.dst_info.p_store->buffer_var.get(); + auto divisor_dst = info.arith_info.dst_info.divisor; + if (!info.arith_info.is_scalar) { + HandleAlignment(observer_dic_[dst_name], divisor_dst, info.arith_info.vec_len); + } + } + + if (!is_gm_to_ub) { + for (size_t i = 0; i < info.arith_info.src_info.size(); i++) { + auto src_name = info.arith_info.src_info[i].p_load->buffer_var.get(); + if (observer_dic_.count(src_name) && !info.arith_info.is_scalar) { + auto src_observer = observer_dic_[src_name]; + auto divisor_src = info.arith_info.src_info[i].divisor; + HandleAlignment(src_observer, divisor_src, info.arith_info.vec_len); + } + } + } + + if (!is_ub_to_gm && !is_gm_to_ub && info.arith_info.insn_type != "reduce" && + info.arith_info.insn_type != "crossing" && info.arith_info.insn_type != "discrete") { + UnifyAlignInfo temp_info; + auto dst_name = info.arith_info.dst_info.p_store->buffer_var.get(); + temp_info.observers.push_back(observer_dic_[dst_name]); + temp_info.divisors.push_back(info.arith_info.dst_info.divisor); + temp_info.offsets.push_back(info.arith_info.dst_info.offset); + temp_info.vector_len = info.arith_info.vec_len; + + for (size_t i = 0; i < info.arith_info.src_info.size(); i++) { + auto src_name = info.arith_info.src_info[i].p_load->buffer_var.get(); + if (observer_dic_.count(src_name)) { + temp_info.observers.push_back(observer_dic_[src_name]); + temp_info.divisors.push_back(info.arith_info.src_info[i].divisor); + temp_info.offsets.push_back(info.arith_info.src_info[i].offset); + } + } + aligns_info_.push_back(temp_info); + } + } + + void HandleAlignment(InsnTensor *observer, int divisor, int vector_len) { + auto block_size = GetUbBlkSize(observer->GetType()); + CHECK(divisor % block_size == 0 || divisor >= vector_len); + auto cur_align = observer->GetAlignment(); + int align_temp = 0; + if (cur_align == FREE_ALIGN && divisor % block_size == 0 && divisor >= vector_len) { + return; + } + if (cur_align == FREE_ALIGN && divisor % block_size == 0 && divisor < vector_len) { + return; + } + if (divisor != 0) { + if (cur_align == FREE_ALIGN) { + if (divisor == vector_len) { + align_temp = vector_len; + observer->SetAlignment(align_temp); + return; + } + if (divisor >= vector_len) { + return; + } + CHECK(0) << "Conditions not considered"; + } + if (divisor % cur_align == 0 && vector_len < cur_align) { + return; + } + if (divisor % cur_align != 0) { + if (cur_align % block_size != 0) { + align_temp = air::ir::gcd(divisor, cur_align); + } else { + align_temp = divisor; + } + if (vector_len <= align_temp) { + observer->SetAlignment(align_temp); + } else { + align_temp = air::ir::gcd(vector_len, align_temp); + observer->SetAlignment(align_temp); + } + } + } + } + + void BroadcastAlign() { + bool has_update = true; + while (has_update) { + has_update = false; + for (size_t i = 0; i < aligns_info_.size(); ++i) { + has_update = aligns_info_[i].UnifyAlign() || has_update; + } + } + } + + private: + std::map observer_dic_; + std::vector aligns_info_; +}; + } // namespace ir } // namespace akg #endif // PASS_ANALYZE_ALIGN_H_ diff --git a/src/pass/analyze_align_dynamic.cc b/src/pass/analyze_align_dynamic.cc index 63e718d20ef92bee63f116b345a13b88de795fa9..765ad442be9813a33076d7708ece7c8f57173602 100644 --- a/src/pass/analyze_align_dynamic.cc +++ b/src/pass/analyze_align_dynamic.cc @@ -466,7 +466,7 @@ class AlignVistor : public IRVisitor { // only scan dma insns if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value.as() && op->value.as()->value != "vec_binary_dropout" && - exclude_list.count(op->value.as()->value) == 0)) { + exclude_align_analyze_list.count(op->value.as()->value) == 0)) { bool in_dma_copy = false; if (op->value.as() && op->value.as()->value == "dma_copy") { in_dma_copy = true; diff --git a/src/pass/analyze_align_static.cc b/src/pass/analyze_align_static.cc index f7d51ce654bed0bb79489da5b5d969d7b9543ac8..8f179535f48cca3ace6952d7be604963d4d99da1 100644 --- a/src/pass/analyze_align_static.cc +++ b/src/pass/analyze_align_static.cc @@ -26,673 +26,21 @@ namespace akg { namespace ir { -namespace { -using Var2Scope = std::map; - -bool IsInStorageScope(const Var2Scope &table, const Variable *var) { return table.find(var) != table.end(); } - -using AlignModifier = std::function; -using std::placeholders::_1; - -class AlignInfo { - public: - explicit AlignInfo(const Type &t, int64_t off, const AlignModifier func = nullptr, bool spread = false) - - : blk_sz(GetUbBlkSize(t)), base_offset(off), modifiers(), need_spread(spread) { - if (func != nullptr) { - modifiers.push_back(func); - } - } - explicit AlignInfo(const Type &t) : AlignInfo(t, 0, nullptr, false) {} - AlignInfo() : AlignInfo(Handle(1), 0, nullptr, false) { blk_sz = 0; } - ~AlignInfo() = default; - - int64_t blk_sz; - - int64_t base_offset; - - std::vector modifiers; - bool need_spread; -}; - -struct VarComp { - bool operator()(const Var &v0, const Var &v1) const { return v0.get() < v1.get(); } -}; - -using AlignDict = std::map; - -void MergeAlignInfo(AlignInfo &a, const AlignInfo &b) { - CHECK(a.blk_sz != 0 || b.blk_sz != 0); - CHECK(a.blk_sz == 0 || b.blk_sz == 0 || a.blk_sz == b.blk_sz); - if (a.blk_sz == 0) { - a.blk_sz = b.blk_sz; - } - a.need_spread = a.need_spread || b.need_spread; - - a.base_offset = air::ir::gcd(a.base_offset, b.base_offset); - - a.modifiers.insert(a.modifiers.end(), b.modifiers.begin(), b.modifiers.end()); -} - -AlignDict MergeAlignDict(const AlignDict &a, const AlignDict &b) { - AlignDict rst = a; - for (const auto &e : b) { - auto it = rst.find(e.first); - if (it != rst.end()) { - MergeAlignInfo(it->second, e.second); - } else { - rst.emplace(e); - } - } - return rst; -} - -AlignDict GenFreeAlignDict(const StmtInfoList &com_info_list) { - AlignDict dict; - for (const auto &com_info : com_info_list) { - dict.emplace(com_info->data_, AlignInfo(com_info->dtype_)); - } - return dict; -} - -AlignDict GenSpecAlignDict(const StmtInfoList &com_info_list, int64_t align, bool is_spread) { - AlignDict dict; - for (const auto &com_info : com_info_list) { - dict.emplace(com_info->data_, AlignInfo(com_info->dtype_, align, nullptr, is_spread)); - } - return dict; -} - -void FixAlignBySize(int64_t &align, int64_t size) { - if (align < size && align != 0 && (size % align) != 0) { - align = air::ir::gcd(align, size); - } -} - -class RegExprSub : public IRMutator { - public: - RegExprSub() {} - ~RegExprSub() override = default; - - Expr run(const Expr &e) { return this->Mutate(e); } - - Expr Mutate_(const Load *op, const Expr &e) final { - if (GetBufScope(op->buffer_var->name_hint) == SCOPE_REG && isImm(op->index)) { - return Variable::make(Int(32), "tmp"); - } - return IRMutator::Mutate_(op, e); - } -}; - -AlignDict GenNormalAlignDict(const StmtInfoList &com_info_list, bool is_spread, bool all_remained_axis = false) { - AlignDict dict; - for (const auto &com_info : com_info_list) { - if (com_info->var_.empty() && !all_remained_axis) { - MergeAlignInfo(dict[com_info->data_], AlignInfo(com_info->dtype_, 0, nullptr, is_spread)); - continue; - } - - bool min_stride_eq1 = !com_info->var_.empty() && GetIntConst(GetItem(com_info->strides_, -1)) == 1; - auto index_expr = IndexOptimizer().Mutate(com_info->index_); - if (min_stride_eq1) { - auto var = GetItem(com_info->var_, -1); - index_expr = Simplify(EliminateVarInExpr(index_expr, {var})); - } - - int64_t offset_gcd = 1; - int64_t continuity_len = min_stride_eq1 ? GetIntConst(GetItem(com_info->shape_, -1)) : 1; - - index_expr = RegExprSub().run(index_expr); - - auto vars = GetVarsInExpr(index_expr); - if (vars.empty()) { - CHECK(is_const(index_expr)); - offset_gcd = std::abs(GetIntConst(index_expr)); - } else { - auto strides = air::arith::DetectLinearEquation(index_expr, vars); - if (strides.empty()) { - offset_gcd = -2; // "-2" means no need to consider - } else { - CHECK(!strides.empty()); - offset_gcd = 0; - for (const auto &e : strides) { - offset_gcd = air::ir::gcd(offset_gcd, GetIntConst(e)); - } - } - } - - AlignModifier func = std::bind(FixAlignBySize, _1, continuity_len); - MergeAlignInfo(dict[com_info->data_], AlignInfo(com_info->dtype_, offset_gcd, func, is_spread)); - } - return dict; -} - -bool IsNonLinearScalar(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list) { - if (std::any_of(dst_info_list.begin(), dst_info_list.end(), - [](const StmtStoreInfo &com_info) { return (!com_info->var_.empty()); })) { - return true; - } - if (std::any_of(src_info_list.begin(), src_info_list.end(), - [](const StmtStoreInfo &com_info) { return (!com_info->var_.empty()); })) { - return true; - } - return false; -} - -inline bool IsTranspose(const StmtStoreInfo &dst, const StmtStoreInfo &src) { - return dst->var_.size() > 1 && src->var_.size() > 1 && Equal(GetItem(dst->var_, -2), GetItem(src->var_, -1)) && - Equal(GetItem(dst->var_, -1), GetItem(src->var_, -2)) && - Equal(GetItem(dst->shape_, -1), GetItem(src->shape_, -2)) && - - Equal(GetItem(dst->shape_, -2), GetItem(src->shape_, -1)) && GetIntConst(GetItem(dst->strides_, -1)) == 1 && - GetIntConst(GetItem(src->strides_, -1)) == 1 && Equal(GetItem(dst->strides_, -2), GetItem(src->shape_, -2)) && - - Equal(GetItem(src->strides_, -2), GetItem(dst->shape_, -2)); -} - -void FixAlignByShape(int64_t &align, int64_t shape0, int64_t shape1) { - if (align >= shape0 * shape1) { - return; - } else if (align >= shape0) { - CHECK_NE(shape0, 0); - if (align % shape0 == 0) { - auto times = align / shape0; - align = shape0 * air::ir::gcd(times, shape1); - return; - } - } - align = air::ir::gcd(align, shape0); -} - -AlignDict GenTransposeAlign(const StmtStoreInfo &ori_dst, const StmtStoreInfo &ori_src, StmtInfo &if_info, - StmtInfo &for_info) { - auto dst = ori_dst.Copy(); - auto src = ori_src.Copy(); - - auto var_old = GetItem(dst->var_, -1); - auto var_new = GetItem(dst->var_, -2); - dst.GetNode()->var_ = RemoveItemAtIndex(dst->var_, -1); - src.GetNode()->var_ = RemoveItemAtIndex(src->var_, -2); - - int64_t sh0 = GetIntConst(GetItem(dst->shape_, -1)); - int64_t sh1 = GetIntConst(GetItem(dst->shape_, -2)); - auto shape = static_cast(sh0 * sh1); - - dst.GetNode()->shape_ = RemoveItemAtIndex(dst->shape_, -1); - src.GetNode()->shape_ = RemoveItemAtIndex(src->shape_, -1); - SetItem(dst.GetNode()->shape_, -1, Expr(shape)); - SetItem(src.GetNode()->shape_, -1, Expr(shape)); - - dst.GetNode()->strides_ = RemoveItemAtIndex(dst->strides_, -2); - src.GetNode()->strides_ = RemoveItemAtIndex(src->strides_, -2); - - Map map({{var_old, Expr(0)}, {var_new, Expr(0)}}); - dst.GetNode()->index_ = Simplify(Substitute(dst->index_, map) + var_new); - src.GetNode()->index_ = Simplify(Substitute(src->index_, map) + var_new); - - StmtInfoList dst_list({dst}); - StmtInfoList src_list({src}); - CompactComputationInfoList(dst_list, src_list, if_info, for_info); - - auto dict = GenNormalAlignDict(MergeTwo(dst_list, src_list), false); - - dict[dst->data_].modifiers.clear(); - dict[dst->data_].modifiers.push_back(std::bind(FixAlignByShape, _1, sh0, sh1)); - - dict[src->data_].modifiers.clear(); - dict[src->data_].modifiers.push_back(std::bind(FixAlignByShape, _1, sh1, sh0)); - - return dict; -} - -bool IsScalarDMA(const Stmt &op) { - StmtInfo f_info; - StmtInfo i_info; - std::string intrin; - std::string dma; - StmtInfoList src_info_list; - StmtInfoList dst_info_list; - GetDmaComputationInfo(op, dst_info_list, src_info_list, i_info, f_info, dma, intrin); - - const auto &d_info = dst_info_list[0]; - const auto &s_info = src_info_list[0]; - - bool last_dim_equal = !d_info->var_.empty() && !s_info->var_.empty() && - GetItem(d_info->var_, -1).get() == GetItem(s_info->var_, -1).get() && - !d_info->strides_.empty() && !s_info->strides_.empty() && - GetIntConst(GetItem(d_info->strides_, -1)) != GetIntConst(GetItem(s_info->strides_, -1)); - - bool is_broadcast = - - ((!s_info->strides_.empty() && GetIntConst(GetItem(s_info->strides_, -1)) != 1) || s_info->var_.empty()) && - ((!d_info->strides_.empty() && GetIntConst(GetItem(d_info->strides_, -1)) != 1) || d_info->var_.empty()); - - bool ubuf_scalar = (is_broadcast || last_dim_equal) && intrin == INTRIN_NAME_COPY_UB_TO_UB; - bool broadcast_scalar = is_broadcast && intrin == "broadcast"; - - if (broadcast_scalar || ubuf_scalar) { - int shape = GetInt32Const(GetItem(d_info->shape_, -1)); - int stride = GetInt32Const(GetItem(d_info->strides_, -1)); - int block_size = GetUbBlkSize(d_info->dtype_); - if (!(ubuf_scalar && shape < block_size && stride == block_size && - IsTwoItemEqual(d_info->strides_, s_info->strides_, -1, true))) { - return true; - } - } - return false; -} - -AlignDict GetDataAlign(const Stmt &op, const bool is_dma_copy, std::vector &info_vec) { - StmtInfo if_info; - StmtInfo for_info; - StmtInfoList dst_info_list; - StmtInfoList src_info_list; - GetCompactComputationInfo(op, dst_info_list, src_info_list, if_info, for_info, false, true); - auto merged_com_list = MergeTwo(dst_info_list, src_info_list); - - info_vec.push_back(merged_com_list); - - Array stores; - Array loads; - GetStoreAndLoads(op, stores, loads); - auto org_dst_info_list = GetComputationInfo(stores, for_info); - auto org_src_info_list = GetComputationInfo(loads, for_info); - - StmtInfoList empty_com_list; - - // check load list - if (src_info_list.empty()) { - // broadcast/scalar mode, such as A[i, j] = 0.0 / A[1] = 2.0 - if (dst_info_list[0]->var_.empty()) { - return GenFreeAlignDict(dst_info_list); - } else { - return GenNormalAlignDict(merged_com_list, false); - } - } else if (src_info_list.size() == 1) { - auto dst_info = dst_info_list[0]; - auto src_info = src_info_list[0]; - - if (dst_info->scope_ == SCOPE_UBUF && src_info->scope_ == SCOPE_UBUF) { - if (dst_info->var_.empty() && src_info->var_.empty()) { - if (is_dma_copy) { - if (IsNonLinearScalar(org_dst_info_list, org_src_info_list)) { - // check if it is non-linear index scalar mov, such as - // for (cc2, 0, 4) { - // for (cc3, 0, 6) { - // T_tile_local_UB[((cc2*6) + cc3)] = data_local__ub[(((cc2 % 2)*2) + (cc3 % 2))] - // } - // } - CleanNonLinearVar(org_dst_info_list, empty_com_list, if_info); - auto align_src = GenFreeAlignDict(src_info_list); - auto align_dst = GenNormalAlignDict(org_dst_info_list, false); - return MergeAlignDict(align_src, align_dst); - } - // intrin_name = 'copy_ubuf_to_ubuf' - // scalar op, will not influence the align - return GenFreeAlignDict(merged_com_list); - } - // intrin_name = vadds or vmuls - return GenNormalAlignDict(merged_com_list, false, true); - } else if (src_info->var_.empty()) { - if (GetIntConst(GetItem(dst_info->strides_, -1)) == 1) { - // scalar broadcast - CleanNonLinearVar(org_dst_info_list, empty_com_list, if_info); - auto align_src = GenFreeAlignDict(src_info_list); - auto align_dst = GenNormalAlignDict(org_dst_info_list, false); - return MergeAlignDict(align_src, align_dst); - } - // intrin_name = vector_dup - return GenFreeAlignDict(merged_com_list); - } else if (!(dst_info->var_.empty()) && Equal(GetItem(dst_info->var_, -1), GetItem(src_info->var_, -1))) { - if (GetIntConst(GetItem(dst_info->strides_, -1)) == GetIntConst(GetItem(src_info->strides_, -1)) && - Equal(GetItem(org_dst_info_list[0]->var_, -1), GetItem(org_src_info_list[0]->var_, -1))) { - // elemwise mode, intrin_name = copy_ubuf_to_ubuf - return GenNormalAlignDict(merged_com_list, true); - } - // scalar dma mode - return GenFreeAlignDict(merged_com_list); - } else if (IsTranspose(dst_info, src_info)) { - if (is_dma_copy) { - // intrin_name = vtranspose - int block_size = GetUbBlkSize(dst_info->dtype_); - CHECK_NE(block_size, 0); - - int dst_shape = GetInt32Const(GetItem(dst_info->shape_, -1)); - int src_shape = GetInt32Const(GetItem(src_info->shape_, -1)); - if (dst_shape % block_size != 0 || - (src_shape % block_size != 0 && (src_shape > block_size || dst_shape > block_size))) { - return GenTransposeAlign(dst_info, src_info, if_info, for_info); - } else { - // special case optimization - return GenNormalAlignDict(merged_com_list, false); - } - } - // align = 1 - return GenSpecAlignDict(merged_com_list, 1, true); - } else if (dst_info->var_.size() > 1 && src_info->var_.size() > 1 && - !Equal(GetItem(dst_info->var_, -1), GetItem(src_info->var_, -1)) && - Equal(GetItem(dst_info->var_, -2), GetItem(src_info->var_, -2))) { - // intrin_name = broadcast - // special case of last dim axis broadcast issue #675 - CleanNonLinearVar(org_dst_info_list, empty_com_list, if_info); - auto align_src = GenFreeAlignDict(src_info_list); - auto align_dst = GenNormalAlignDict(org_dst_info_list, false); - return MergeAlignDict(align_src, align_dst); - } else if (IsScalarDMA(op)) { - return GenFreeAlignDict(merged_com_list); - } - return GenNormalAlignDict(merged_com_list, false); - } else if (dst_info->scope_ != DMA_COPY_GLOBAL && src_info->scope_ != DMA_COPY_GLOBAL && - dst_info->var_.size() > 1 && src_info->var_.size() > 1 && - Equal(GetItem(dst_info->var_, -1), GetItem(src_info->var_, -2)) && - Equal(GetItem(dst_info->var_, -2), GetItem(src_info->var_, -1))) { - // check transopse cbuf, ca, cb, cc - if (is_dma_copy) { - // intrin_name = vtranspose - int64_t align = GetIntConst(GetItem(dst_info->shape_, -1) * GetItem(src_info->shape_, -1)); - return GenSpecAlignDict(merged_com_list, align, true); - } - // discontinuoust dma mov - return GenSpecAlignDict(merged_com_list, 1, true); - } else if (dst_info->var_.empty() && src_info->var_.empty()) { - // not ub to ub mode, discontinuous dma mov - return GenNormalAlignDict(merged_com_list, true, true); - } else if (dst_info->var_.empty()) { - LOG(FATAL) << "Error: Copy Vector into a scalar."; - } else if (src_info->var_.empty()) { - // broadcast between ub and gm - return GenNormalAlignDict(merged_com_list, true, true); - } else if (!Equal(GetItem(dst_info->var_, -1), GetItem(src_info->var_, -1)) || - GetIntConst(GetItem(dst_info->strides_, -1)) != 1 || GetIntConst(GetItem(src_info->strides_, -1)) != 1) { - // discontinuoust dma mov - return GenSpecAlignDict(merged_com_list, 1, true); - } - return GenNormalAlignDict(merged_com_list, true); - } else if (src_info_list.size() < 5) { // src_info_list allowed max value + 1 - if (IsLastAxisReduction(dst_info_list, src_info_list)) { - // reduction mode - - if (GetIntConst(GetItem(dst_info_list[0]->shape_, -1)) == 1) { - // reduce to a scalar - return GenFreeAlignDict(merged_com_list); - } - // last dim is compacted separately - return GenNormalAlignDict(merged_com_list, false); - } else if (IsElementwise(dst_info_list, src_info_list)) { - // elementwise mode - return GenNormalAlignDict(merged_com_list, true, true); - } else if (IsBroadcast(dst_info_list, src_info_list)) { - // broadcast mode - bool need_spread = !IsLastAxisBroadcast(dst_info_list, src_info_list); - return GenNormalAlignDict(merged_com_list, need_spread); - } - return GenNormalAlignDict(merged_com_list, true); - } else { - LOG(FATAL) << "Error: Can not support more than 4 loads."; +int GetCommonDivisor(std::vector numbers) { + CHECK(numbers.size() >= 1); + int divisor = numbers[0]; + for (size_t i = 1; i < numbers.size(); i++) { + divisor = air::ir::gcd(divisor, numbers[i]); } - // error, and return empty map - return AlignDict(); + return divisor; } -class AlignVistor : public IRVisitor { - public: - explicit AlignVistor(const Var2Scope &table) - : min_align(), gbl_storage(), storage_scope_(table), all_aligns_(), spread_vec_(), info_vec_() {} - ~AlignVistor() override = default; - - void Run(const Stmt stmt) { - this->Visit(stmt); - UpdateAlign(); - } - - void Visit_(const AttrStmt *op) final { - // nested scop, just return - if (op->attr_key == "isolate_range") return; - - if (auto str_ptr = op->node.as()) { - if (str_ptr->value == "no_align") { - return IRVisitor::Visit_(op); - } - } - - // only scan dma insns - if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance() && - op->value.as()->value != "vec_binary_dropout" && - op->value.as()->value != "mask_broadcast" && - exclude_list.count(op->value.as()->value) == 0)) { - bool in_dma_copy = false; - if (op->value.as() && op->value.as()->value == "dma_copy") { - in_dma_copy = true; - } - - auto dict = GetDataAlign(op->body, in_dma_copy, info_vec_); - for (auto it = dict.begin(); it != dict.end();) { - if (!IsInStorageScope(storage_scope_, it->first.get())) { - gbl_storage.insert(it->first.get()); - it = dict.erase(it); - } else { - ++it; - } - } - - std::vector spread_var; - for (const auto &e : dict) { - if (e.second.need_spread) { - spread_var.push_back(e.first); - } - MergeAlignInfo(all_aligns_[e.first], e.second); - } - if (spread_var.size() > 1) { - spread_vec_.push_back(std::move(spread_var)); - } - } - return IRVisitor::Visit_(op); - } - - std::map min_align; - - std::set gbl_storage; - - private: - void UpdateAlign() { - for (auto e : gbl_storage) { - auto var_ptr = const_cast(e); - all_aligns_.emplace(Var(GetObjectPtr(var_ptr)), AlignInfo(var_ptr->type)); - } - do { - for (auto &e : all_aligns_) { - auto &info = e.second; - auto blk_sz = info.blk_sz; - CHECK_NE(blk_sz, 0); - - if (info.base_offset % blk_sz != 0) { - while (info.base_offset != 1) { - bool done = true; - for (auto func : info.modifiers) { - auto old = info.base_offset; - func(info.base_offset); - - CHECK_LE(info.base_offset, old); - if (info.base_offset < old) { - done = false; - } - } - if (done && FixLoopAxis()) { - break; - } - } - } - } - } while (!DealWithSpread()); - for (const auto &e : all_aligns_) { - if (IsInStorageScope(storage_scope_, e.first.get())) { - min_align.emplace(e.first.get(), e.second.base_offset); - } - } - } - - bool FixLoopAxis() { - for (const auto &vec_ele : info_vec_) { - // for_v -> times - std::map, VarComp> coef_table; - // for_v -> [buffer -> times] - std::map, VarComp> buf_table; - - for (const auto &info : vec_ele) { - auto it = all_aligns_.find(info->data_); - CHECK(it != all_aligns_.end()); - - if (it->second.base_offset <= 1) { - continue; - } - for (size_t i = 0; i != info->var_.size(); ++i) { - auto stride = std::abs(GetIntConst(info->strides_[i])); - auto extent = std::abs(GetIntConst(info->shape_[i])); - - auto align = it->second.base_offset; - - if (stride < align && stride * extent > align) { - CHECK_NE(stride, 0); - if (align % stride != 0) { - it->second.base_offset = air::ir::gcd(align, stride); - - return false; - } - - CHECK_NE((align / stride), 0); - if (extent % (align / stride) != 0) { - auto times = align / stride; - auto new_times = air::ir::gcd(extent, times); - it->second.base_offset = it->second.base_offset * new_times / times; - - return false; - } - - auto var = info->var_[i]; - - auto times = align / stride; - - coef_table[var].push_back(times); - - auto ×_record = buf_table[var][it->first]; - - CHECK(times_record == 0 || times_record == times); - - times_record = times; - } - } - } - - for (const auto &i : coef_table) { - auto align = i.second.front(); - bool changed = false; - for (auto ele : i.second) { - changed = changed || (ele != align); - align = air::ir::gcd(align, ele); - } - if (changed) { - for (auto v : buf_table[i.first]) { - all_aligns_[v.first].base_offset *= align; - - CHECK_NE(v.second, 0); - all_aligns_[v.first].base_offset /= v.second; - } - return false; - } - } - } - return true; - } - - bool DealWithSpread() { - for (const auto &vec : spread_vec_) { - auto it = all_aligns_.find(vec.front()); - CHECK(it != all_aligns_.end()); - - auto align = it->second.base_offset; - bool changed = false; - for (const auto &e : vec) { - auto it_in = all_aligns_.find(e); - CHECK(it_in != all_aligns_.end()); - - changed = changed || (it_in->second.base_offset != align); - align = air::ir::gcd(align, it_in->second.base_offset); - } - if (changed) { - for (const auto &e : vec) { - auto it_in = all_aligns_.find(e); - CHECK(it_in != all_aligns_.end()); - it_in->second.base_offset = align; - } - return false; - } - } - return true; - } - - // storage scope - const Var2Scope &storage_scope_; - // all align_ info - AlignDict all_aligns_; - std::vector> spread_vec_; - std::vector info_vec_; -}; - -// predicate is for GPU, use it to hold min align -class AlignInsert : public IRMutator { - public: - AlignInsert() : min_align_(), gbl_storage_() {} - ~AlignInsert() override = default; - - Stmt Run(const Stmt stmt, const Var2Scope &storage_scope) { - AlignVistor visitor(storage_scope); - visitor.Run(stmt); - min_align_ = std::move(visitor.min_align); - gbl_storage_ = std::move(visitor.gbl_storage); - - return this->Mutate(stmt); - } - - Stmt Mutate_(const Store *op, const Stmt &s) final { - Expr value = this->Mutate(op->value); - auto index = this->Mutate(op->index); - - int64_t val = gbl_storage_.find(op->buffer_var.get()) == gbl_storage_.end() ? free_align_flag_ : 1; - - auto it = min_align_.find(op->buffer_var.get()); - if (it != min_align_.end()) { - val = GetAlignValue(it->second, op->value.type()); - } - - return Store::make(op->buffer_var, value, index, make_const(Int(32), val)); - } - - Expr Mutate_(const Load *op, const Expr &e) final { - auto index = this->Mutate(op->index); - - int64_t val = gbl_storage_.find(op->buffer_var.get()) == gbl_storage_.end() ? free_align_flag_ : 1; - auto it = min_align_.find(op->buffer_var.get()); - if (it != min_align_.end()) { - val = GetAlignValue(it->second, op->type); - } - - return Load::make(op->type, op->buffer_var, index, make_const(Int(32), val)); - } - - private: - static int64_t GetAlignValue(int64_t val, const air::DataType dtype) { - int value = GetUbBlkSize(dtype); - CHECK_NE(value, 0); - return val % value == 0 ? FREE_ALIGN : val; - } - - std::map min_align_; +namespace { - std::set gbl_storage_; +using Var2Scope = std::map; - const int free_align_flag_ = -2; -}; +bool IsInStorageScope(const Var2Scope &table, const Variable *var) { return table.find(var) != table.end(); } class FindSameNameBuf : public IRVisitor { public: @@ -782,16 +130,35 @@ class InsertIsolate : public IRMutator { bool insert_isolate_; }; +class CacheVisiter : public IRVisitor { + public: + CacheVisiter() = default; + ~CacheVisiter() override = default; + + void Visit_(const Allocate *op) final { + var_type_map[op->buffer_var.get()] = op->type; + IRVisitor::Visit_(op); + } + + std::unordered_map var_type_map; +}; + // process each isolate_range once a time class ProcessParts : public IRMutator { public: explicit ProcessParts(const Var2Scope &table) : level_(0), storage_scope_(table) {} ~ProcessParts() override = default; + std::unordered_map var_type_map; + Stmt Run(Stmt stmt) { + CacheVisiter buffer_visitor; + buffer_visitor.Visit(stmt); + var_type_map = buffer_visitor.var_type_map; + stmt = this->Mutate(stmt); if (level_ == 0) { - stmt = AlignInsert().Run(stmt, storage_scope_); + stmt = AlignGen().Run(stmt, var_type_map); } return stmt; } @@ -799,7 +166,7 @@ class ProcessParts : public IRMutator { Stmt Mutate_(const Block *op, const Stmt &s) final { if (!HasIsolate(s)) { Stmt stmt = s; - stmt = AlignInsert().Run(stmt, storage_scope_); + stmt = AlignGen().Run(stmt, var_type_map); level_++; return stmt; } @@ -813,7 +180,7 @@ class ProcessParts : public IRMutator { Stmt stmt = IRMutator::Mutate_(op, s); // no isolate_range in this attr if (cur_level == level_) { - stmt = AlignInsert().Run(stmt, storage_scope_); + stmt = AlignGen().Run(stmt, var_type_map); } return stmt; } @@ -841,14 +208,14 @@ class ProcessParts : public IRMutator { Stmt AnalyzeMinAlignStatic(Stmt stmt) { stmt = air::ir::ConvertSSA(stmt); + CacheVisiter buffer_visitor; + buffer_visitor.Visit(stmt); + FindSameNameBuf find_visitor; find_visitor.Visit(stmt); - stmt = MergeLoops(stmt); - stmt = InsertIsolate(find_visitor.storage_scope_).Mutate(stmt); stmt = ProcessParts(find_visitor.storage_scope_).Run(stmt); - stmt = RewriteByAlignStatic(stmt); return stmt; } diff --git a/src/pass/merge_loops.cc b/src/pass/merge_loops.cc index f62f599f3ac1b53e2696331e59ad55a5b7ca6b85..478473f44e3e813dffd64a983df011adfb3b568f 100644 --- a/src/pass/merge_loops.cc +++ b/src/pass/merge_loops.cc @@ -43,7 +43,7 @@ class LoopsCompacter : public IRMutator { Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance() && - !exclude_list.count(op->value.as()->value))) { + !exclude_align_analyze_list.count(op->value.as()->value))) { stores_ = Array(); loads_ = Array(); GetStoreAndLoads(op->body, stores_, loads_); diff --git a/src/pass/multi_last_axis_reduction.cc b/src/pass/multi_last_axis_reduction.cc index d8fc5319d750ce07ee6a575661162056e9380cd3..1f2137df151f084ff1caee98073108439dcc96c4 100644 --- a/src/pass/multi_last_axis_reduction.cc +++ b/src/pass/multi_last_axis_reduction.cc @@ -192,6 +192,7 @@ class MultiLastAxisReduction : public IRMutator { lastResult = loadTmp + storeLeft; } + broadcastNum = Call::make(type_tmp, "vector_dup", {broadcastNum}, Call::PureIntrinsic); Stmt stForOnce = Store::make(tmpBuffer, storeResult, newIdx, storeTmp->predicate); Stmt stForTwice = Store::make(storeTmp->buffer_var, lastResult, storeTmp->index, storeTmp->predicate); Stmt stBroadcast = Store::make(tmpBuffer, broadcastNum, newIdx, storeTmp->predicate); @@ -212,7 +213,7 @@ class MultiLastAxisReduction : public IRMutator { stForOnce = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr(str), stForOnce); stForTwice = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr(str), stForTwice); - stBroadcast = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr("broadcast"), stBroadcast); + stBroadcast = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr("vector_dup"), stBroadcast); stmt = Block::make({stBroadcast, stForOnce, stForTwice}); stmt = Allocate::make(tmpBuffer, type_tmp, extentsArray, const_true(), stmt); diff --git a/src/pass/optimize_pragma.cc b/src/pass/optimize_pragma.cc index b612038689962ad532d8f216e3cb7381e2317257..8de8d8343a11ef32825bb2ba8a45b9a666030233 100644 --- a/src/pass/optimize_pragma.cc +++ b/src/pass/optimize_pragma.cc @@ -147,7 +147,7 @@ class EstimateAlign : public IRMutator { Stmt Mutate_(const AttrStmt *op, const Stmt &stmt) final { if (air::ir::attr::IsPragmaKey(op->attr_key) && op->value.as()) { - if (exclude_list.count(op->value.as()->value)) { + if (exclude_align_analyze_list.count(op->value.as()->value)) { return stmt; } diff --git a/src/pass/rewrite_by_align_dynamic.cc b/src/pass/rewrite_by_align_dynamic.cc index 0c20ae6b8301cd1f49bd590884597e45ffe8a290..0c71133e794dbc767f5ffb75eeb7cf2cb2fce5d4 100644 --- a/src/pass/rewrite_by_align_dynamic.cc +++ b/src/pass/rewrite_by_align_dynamic.cc @@ -46,7 +46,7 @@ class AxisPartitioner : public IRMutator { Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance() && - exclude_list.count(op->value.as()->value) == 0)) { + exclude_index_fix_list.count(op->value.as()->value) == 0)) { in_insn_ = true; counter_ = 0; auto ret = IRMutator::Mutate_(op, s); @@ -180,7 +180,7 @@ class RewriteAllocateAndIndex : public IRMutator { } } if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance() && - (exclude_list.count(op->value.as()->value) == 0 || + (exclude_index_fix_list.count(op->value.as()->value) == 0 || op->value.as()->value == "scatter"))) { in_insn_ = true; auto ret = IRMutator::Mutate_(op, s); diff --git a/src/pass/rewrite_by_align_static.cc b/src/pass/rewrite_by_align_static.cc index d477a2ca8331114cd3af9038d06669cb2b4fedd2..e6e001da36f4ad23ba2409cad5a954ed112fb517 100644 --- a/src/pass/rewrite_by_align_static.cc +++ b/src/pass/rewrite_by_align_static.cc @@ -46,7 +46,7 @@ class AxisPartitioner : public IRMutator { Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance() && - exclude_list.count(op->value.as()->value) == 0)) { + exclude_index_fix_list.count(op->value.as()->value) == 0)) { in_insn_ = true; counter_ = 0; auto ret = IRMutator::Mutate_(op, s); @@ -182,7 +182,7 @@ class RewriteAllocateAndIndex : public IRMutator { } } if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance() && - (exclude_list.count(op->value.as()->value) == 0 || + (exclude_index_fix_list.count(op->value.as()->value) == 0 || op->value.as()->value == "scatter"))) { in_insn_ = true; auto ret = IRMutator::Mutate_(op, s); @@ -307,12 +307,7 @@ class RewriteAllocateAndIndex : public IRMutator { CHECK_NE(align, 0); int64_t coef = GetIntConst(strides[0]); if (std::abs(coef) < align) { - auto it = var2ext_.find(v.get()); - if (it != var2ext_.end() && std::abs(coef * it->second) <= align) { - rst += v * strides[0]; - } else { - return SimpleFix(tmp_idx_bk, opt.var2expr, align, times); - } + rst += v * strides[0]; } else if (coef % align == 0) { auto new_coef = coef * times / align; rst += v * Expr(static_cast(new_coef)); @@ -359,7 +354,8 @@ class RewriteAllocateAndIndex : public IRMutator { Stmt RewriteByAlignStatic(Stmt stmt) { stmt = AxisPartitioner().Run(stmt); stmt = RewriteAllocateAndIndex().Mutate(stmt); - return MergeLoops(stmt); + stmt = MergeLoops(stmt); + return stmt; } } // namespace ir } // namespace akg diff --git a/src/pass/store_pack.cc b/src/pass/store_pack.cc new file mode 100644 index 0000000000000000000000000000000000000000..99ac30dbbcc509b52272567384ae94416cb1d768 --- /dev/null +++ b/src/pass/store_pack.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "emit_insn/insn_info.h" +#include "emit_insn/ir_transform.h" +#include "analyze_align.h" + +namespace akg { +namespace ir { + +class ReducePacker : public IRMutator { + public: + ReducePacker() = default; + ~ReducePacker() override = default; + + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { + if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance() && + !exclude_align_analyze_list.count(op->value.as()->value))) { + IRInfo info; + ParserVisitor(info, false).Run(s); + if (info.ChangeLastDimReduce()) { + auto body = info.GenStmt(); + return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr(info.arith_info.insn_type), body); + } + return s; + } + return IRMutator::Mutate_(op, s); + } +}; + +Stmt PackStore(Stmt stmt) { + stmt = TransposeTransform().Mutate(stmt); + stmt = ReducePacker().Mutate(stmt); + return stmt; +} +} // namespace ir +} // namespace akg \ No newline at end of file diff --git a/src/pass/store_recover.cc b/src/pass/store_recover.cc new file mode 100644 index 0000000000000000000000000000000000000000..955d5fff87d2be9613a682bb3e305162e5aca4ac --- /dev/null +++ b/src/pass/store_recover.cc @@ -0,0 +1,223 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "emit_insn/insn_info.h" +#include "analyze_align.h" +#include "emit_insn/ir_transform.h" + +namespace akg { +namespace ir { + +class ReduceRecover : public IRMutator { + public: + ReduceRecover() = default; + ~ReduceRecover() override = default; + + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { + if (op->attr_key == "pragma_emit_insn" && op->value->IsInstance() && + op->value.as()->value.find("reduce_") != std::string::npos) { + old_pragma_ = op->value.as()->value; + if (old_pragma_ == "reduce_add") { + new_pragma_ = "vec_binary_add"; + } else if (old_pragma_ == "reduce_max") { + new_pragma_ = "vec_binary_max"; + } else if (old_pragma_ == "reduce_min") { + new_pragma_ = "vec_binary_min"; + } else if (old_pragma_ == "reduce_fargmax") { + new_pragma_ = "vec_binary_fargmax"; + } else if (old_pragma_ == "reduce_fargmin") { + new_pragma_ = "vec_binary_fargmin"; + } + in_reduce_ = true; + auto body = this->Mutate(op->body); + in_reduce_ = false; + return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr(new_pragma_), body); + } else if (op->attr_key == "pragma_emit_insn" && op->value->IsInstance() && + op->value.as()->value == "dma_copy_transpose") { + return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("vtranspose"), op->body); + } else if (op->attr_key == "align_info") { + return this->Mutate(op->body); + } + return IRMutator::Mutate_(op, s); + } + + Stmt Mutate_(const Store *op, const Stmt &s) final { + if (in_reduce_) { + if (old_pragma_ == "reduce_fargmax") { + auto load_load = op->value.as()->args[0]; + auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate); + auto new_value = Call::make(load_load.type(), "fargmax", {src_load, load_load}, Call::CallType::PureIntrinsic); + auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate); + return new_store; + } else if (old_pragma_ == "reduce_fargmin") { + auto load_load = op->value.as()->args[0]; + auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate); + auto new_value = Call::make(load_load.type(), "fargmin", {src_load, load_load}, Call::CallType::PureIntrinsic); + auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate); + return new_store; + } else if (old_pragma_ == "reduce_add") { + auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate); + auto new_value = Add::make(src_load, op->value.as()->args[0]); + auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate); + return new_store; + } else if (old_pragma_ == "reduce_max") { + auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate); + auto new_value = Max::make(src_load, op->value.as()->args[0]); + auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate); + return new_store; + } else if (old_pragma_ == "reduce_min") { + auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate); + auto new_value = Min::make(src_load, op->value.as()->args[0]); + auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate); + return new_store; + } else { + return s; + } + } else { + return IRMutator::Mutate_(op, s); + } + } + + private: + std::string old_pragma_; + std::string new_pragma_; + bool in_reduce_; +}; + +std::string GetOpCode(const std::string &op_type) { + std::string op_code{}; + if (op_type == "Add") { + op_code = "vadds"; + } else if (op_type == "Mul") { + op_code = "vmuls"; + } else if (op_type == "vaxpy") { + op_code = "vaxpy"; + } else if (op_type == "DMACopy") { + op_code = "vector_dup"; + } + return op_code; +} + +class FinetunePragma : public IRMutator { + public: + FinetunePragma() = default; + ~FinetunePragma() override = default; + + Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { + if ((op->attr_key == "pragma_emit_insn" && op->value->IsInstance() && + !exclude_align_analyze_list.count(op->value.as()->value))) { + IRInfo info; + ParserVisitor(info, true).Run(s); + std::string op_code = GetOpCode(info.arith_info.op_type); + if (!info.arith_info.dst_info.IsUB() || op_code.empty() || + (!info.arith_info.src_info.empty() && !info.arith_info.src_info[0].IsUB())) { + return s; + } + if (info.arith_info.insn_type == "simd" && info.arith_info.scalar_imm_num == 1 && + (op_code == "vmuls" || op_code == "vadds") && !info.arith_info.dst_info.p_store->value.type().is_float()) { + return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("scalar_calc"), op->body); + } + if (info.arith_info.insn_type == "vector_scalar" || info.arith_info.insn_type == "vector_dump") { + return GenStore(info, op_code, 0); + } else if (info.arith_info.insn_type == "simd" && info.arith_info.scalar_imm_num > 0) { + CHECK_EQ(info.arith_info.scalar_imm_num, 1); + return GenStore(info, op_code, 1); + } else if (info.arith_info.insn_type == "simd" && info.arith_info.scalar_imm_num == 0 && + info.arith_info.op_type == "DMACopy" && info.arith_info.dst_info.IsUB() && + info.arith_info.src_info.size() == 1 && info.arith_info.src_info[0].IsUB() && + info.arith_info.dst_info.p_store->value.type().is_float()) { + /// change copy_ub_to_ub (fp16 or fp32) to adds (scalar = 0) + op_code = "vadds"; + info.arith_info.scalar_imm_num = 1; + info.arith_info.scalar_imm = FloatImm::make(info.arith_info.dst_info.p_store->value.type(), 0); + return GenStore(info, op_code, 1); + } else if (info.arith_info.op_type == "DMACopy" && + (info.arith_info.insn_type == "scalar" || info.arith_info.insn_type == "discrete") && + info.arith_info.dst_info.IsUB() && + (info.arith_info.src_info.size() == 1 && info.arith_info.src_info[0].IsUB())) { + return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("scalar_dma"), op->body); + } else if (info.arith_info.op_type == "DMACopy" && + (info.arith_info.insn_type == "scalar" || info.arith_info.insn_type == "discrete") && + info.arith_info.dst_info.IsUB() && info.arith_info.scalar_imm_num == 1) { + return GenStore(info, op_code, 1); + } else if (op->value.as()->value == "vec_single_muls" || + op->value.as()->value == "vec_single_adds") { + if (op->value.as()->value == "vec_single_muls") { + op_code = "vmuls"; + } else if (op->value.as()->value == "vec_single_adds") { + op_code = "vadds"; + } + return GenStore(info, op_code, 1); + } + return s; + } + return IRMutator::Mutate_(op, s); + } + + Stmt GenStore(IRInfo &info, const std::string &intrin_name, const int scalar_type = 0) { + CHECK(intrin_name == "vector_dup" || intrin_name == "vadds" || intrin_name == "vmuls" || intrin_name == "vaxpy"); + + /// scalar value + Expr scalar_value = + (scalar_type == 0) ? GetRef(info.arith_info.scalar_load.p_load) : info.arith_info.scalar_imm; + Array call_args{}; + if (intrin_name == "vector_dup") { + call_args = {scalar_value}; + } else { + Expr tensor_value = GetRef(info.arith_info.src_info[0].p_load); + call_args = {tensor_value, scalar_value}; + } + /// set store + auto old_ptr = info.arith_info.dst_info.p_store; + Expr new_value = Call::make(old_ptr->value.type(), intrin_name, call_args, Call::PureIntrinsic); + Stmt ret = Store::make(old_ptr->buffer_var, new_value, old_ptr->index, old_ptr->predicate); + if (scalar_type == 0) { + auto scalar_vars = info.arith_info.scalar_load.vars; + /// set inner for loop + for (int i = static_cast(info.for_info.vars.size()) - 1; i >= 0; --i) { + if (!IsInArray(scalar_vars, info.for_info.vars[i])) { + ret = For::make(info.for_info.vars[i], 0, info.for_info.exts[i], ForType::Serial, DeviceAPI::None, ret); + } + } + /// set attribute + ret = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr(intrin_name), ret); + /// set outer for loop + for (int i = static_cast(info.for_info.vars.size()) - 1; i >= 0; --i) { + if (IsInArray(scalar_vars, info.for_info.vars[i])) { + ret = For::make(info.for_info.vars[i], 0, info.for_info.exts[i], ForType::Serial, DeviceAPI::None, ret); + } + } + return ret; + } else { + for (int i = static_cast(info.for_info.vars.size()) - 1; i >= 0; --i) { + ret = For::make(info.for_info.vars[i], 0, info.for_info.exts[i], ForType::Serial, DeviceAPI::None, ret); + } + ret = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr(intrin_name), ret); + return ret; + } + } +}; + +Stmt RecoverStore(Stmt stmt) { + stmt = IfReorder().Mutate(stmt); + stmt = FinetunePragma().Mutate(stmt); + stmt = ReduceRecover().Mutate(stmt); + return stmt; +} +} // namespace ir +} // namespace akg