diff --git a/src/emit_insn/insn_args_calculator.cc b/src/emit_insn/insn_args_calculator.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e6a7144cd5a8c20b59b7b37826bbd6827e42581 --- /dev/null +++ b/src/emit_insn/insn_args_calculator.cc @@ -0,0 +1,1089 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "common/array_api.h" +#include "pass/expr_alg_simplify.h" +#include "insn_pattern.h" +#include "insn_args_calculator.h" + +namespace akg { + +InsnAxis::InsnAxis(const For *for_stmt, const Array &info_list) { + this->var = for_stmt->loop_var; + this->extent = GetInt32Const(for_stmt->extent); + this->min = GetInt32Const(for_stmt->min); + int index = 0; + for (auto it : info_list) { + auto stride = GetInt32Const(GetStrideByAxis(it->var_, it->strides_, this->var)); + this->stride_list.push_back(stride); + if (index == 0) { + this->dst_stride = stride; + } else { + this->src_stride_list.push_back(stride); + } + index++; + } +} + +Expr InsnAxis::GetStrideByAxis(const Array &vars, const Array &strides, Var obj_var) { + int index = 0; + for (auto var_it : vars) { + if (Equal(var_it, obj_var)) { + return strides[index]; + } + index++; + } + return Expr(0); +}; + +bool InsnAxis::IsValid() { return this->is_valid; } + +void InsnAxis::Print(const std::string &name) { + if (!name.empty()) { + LOG(DEBUG) << "********** " << name << " ************"; + } + auto r_stride = this->src_stride_list.size() > 1 ? src_stride_list[1] : 99999; + LOG(DEBUG) << "var:" << this->var << " extent:" << this->extent << " min:" << this->min + << " dst_stride:" << this->dst_stride << " src_stride_l:" << this->src_stride_list.front() + << "src_stride_r:" << r_stride; +} + +Array GetInfoList(const StmtStoreInfo &dst_info, const Array &src_info_list) { + Array res; + res.push_back(dst_info.Copy()); + for (auto it : src_info_list) { + res.push_back(it.Copy()); + } + return res; +}; + +std::list GetAxisList(const StmtInfo &for_info, const Array &info_list) { + std::list axis_list; + for (auto it : for_info.ops_) { + auto for_stmt = it.as(); + CHECK(for_stmt); + auto axis = InsnAxis(for_stmt, info_list); + axis_list.push_back(axis); + } + return axis_list; +} + +void Print(std::list &axis_list) { + LOG(DEBUG) << "+++++++++++++++++++ AXIS_LIST +++++++++++++++++++"; + int index = 0; + for (auto it : axis_list) { + LOG(DEBUG) << "================== INDEX " << index << " ================="; + it.Print(); + index++; + } + LOG(DEBUG) << "------------------ END ---------------------"; +} + +InsnArgsCalculator::InsnArgsCalculator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, + const StmtInfo &for_info, const std::string &intrin_name) + : dst_info_list_(dst_info_list), src_info_list_(src_info_list), for_info_(for_info), intrin_name_(intrin_name) { + InitArg(); + CalAxis(); +} + +void InsnArgsCalculator::CalAxis() { + CHECK(!dst_info_list_.empty()); + dst_info_ = dst_info_list_[0]; + if (src_info_list_.empty()) { + src_info_list_ = {dst_info_.Copy()}; + } + auto src_info = src_info_list_[0]; + dst_info_.Print(); + for (auto src_info_it : src_info_list_) { + src_info_it.Print(); + if (src_info_it->name_ == dst_info_->name_) { + meta_.same_dst_src = true; + } + } + meta_.dst_block_size = GetUbBlkSize(dst_info_->dtype_); + meta_.src_block_size = GetUbBlkSize(src_info->dtype_); + meta_.cast = meta_.dst_block_size != meta_.src_block_size; + meta_.block_size = meta_.dst_block_size <= meta_.src_block_size ? meta_.dst_block_size : meta_.src_block_size; + meta_.src_dtype = src_info->dtype_; + meta_.dst_dtype = dst_info_->dtype_; + meta_.dtype = meta_.dst_dtype.bits() >= meta_.src_dtype.bits() ? meta_.dst_dtype : meta_.src_dtype; + auto elem_offset_mod = ir::ExprSimplifier().Simplify(Mod::make(dst_info_->elem_offset_, meta_.block_size)); + if (elem_offset_mod.as()) { + meta_.block_offset = elem_offset_mod.as()->value; + } + axis_list_ = GetAxisList(for_info_, GetInfoList(dst_info_, src_info_list_)); +} // namespace akg + +void InsnArgsCalculator::InitArg() { + arg_.src_m1_list = {0, 0}; + arg_.src_m0_list = {1, 1}; +} + +std::function InsnArgsCalculator::GetStrideLambda() { + return [&](const InsnAxis &axis) { + auto is_stride = [&](int stride) { return stride % meta_.block_size == 0; }; + auto zero_stride = [&](int stride) { return stride == 0; }; + return std::all_of(axis.stride_list.begin(), axis.stride_list.end(), is_stride) && + !std::all_of(axis.stride_list.begin(), axis.stride_list.end(), zero_stride); + }; +} + +std::function InsnArgsCalculator::GetM0LimitLambda() { + return [&](const InsnAxis &axis) { + auto is_limit = [&](int stride) { return stride / meta_.block_size < MAX_STRIDE_M0_SINGLE; }; + return std::all_of(axis.stride_list.begin(), axis.stride_list.end(), is_limit); + }; +} +std::function InsnArgsCalculator::GetBlockStrideLimitLambda() { + return [&](const InsnAxis &axis) { + auto is_limit = [&](int stride) { return stride / meta_.block_size <= max_block_stride_; }; + return std::all_of(axis.stride_list.begin(), axis.stride_list.end(), is_limit); + }; +} +std::function InsnArgsCalculator::GetM1LimitLambda() { + return [&](const InsnAxis &axis) { + auto is_limit = [&](int stride) { return stride / meta_.src_block_size < MAX_STRIDE_M1; }; + return axis.dst_stride / meta_.dst_block_size < MAX_STRIDE_M1 && + std::all_of(axis.src_stride_list.begin(), axis.src_stride_list.end(), is_limit); + }; +} + +std::function And(const std::list> &lambda_list) { + return [&lambda_list](const InsnAxis &axis) { + bool res = true; + for (auto lambda_it : lambda_list) { + res = res && lambda_it(axis); + } + return res; + }; +} + +AxisIt InsnArgsCalculator::GetAxisByLambda(const std::function &lambda) { + for (auto axis_it = axis_list_.begin(); axis_it != axis_list_.end(); axis_it++) { + if (lambda(*axis_it)) { + return axis_it; + } + } + return axis_list_.end(); +} + +InsnAxis InsnArgsCalculator::ExtractAxis(AxisIt &it) { + InsnAxis res = *it; + axis_list_.erase(it); + return res; +} + +bool InsnArgsCalculator::IsValid(AxisIt &it) { return it != axis_list_.end(); } + +void AxisSort(std::list &axis_arr, bool order = true) { + auto up_compare = [&](InsnAxis &a, InsnAxis &b) { return a.extent < b.extent; }; + auto down_compare = [&](InsnAxis &a, InsnAxis &b) { return a.extent > b.extent; }; + + if (order) { + axis_arr.sort(up_compare); + } else { + axis_arr.sort(down_compare); + } +} + +AxisIt InsnArgsCalculator::GetVecAxisIt() { + axis_list_.reverse(); + auto IsVecAxis = [&](const InsnAxis &axis) { + return !(std::any_of(axis.stride_list.begin(), axis.stride_list.end(), [](int stride) { return stride > 1; }) || + std::all_of(axis.stride_list.begin(), axis.stride_list.end(), [](int stride) { return stride == 0; })); + }; + return GetAxisByLambda(IsVecAxis); +} + +SplitStat InsnArgsCalculator::SplitAxis(int extent, InsnAxis &axis) { + if (axis.extent <= extent) { + return NO_SPLIT; + } + if (axis.extent % extent != 0) { + return TAIL; + } + InsnAxis new_axis; + new_axis.extent = axis.extent / extent; + for (auto stride : axis.stride_list) { + new_axis.stride_list.push_back(stride * extent); + } + auto temp_stride_list = new_axis.stride_list; + CHECK(!temp_stride_list.empty()); + new_axis.dst_stride = temp_stride_list.front(); + temp_stride_list.erase(temp_stride_list.begin()); + new_axis.src_stride_list = temp_stride_list; + new_axis.var = Var(axis.var->name_hint); + axis_list_.push_back(new_axis); + axis.extent = extent; + return SUCCESS; +} + +AxisIt InsnArgsCalculator::GetBlockAxis() { + AxisSort(axis_list_); + auto stride_lambda = GetStrideLambda(); + auto m0_limit_lambda = GetM0LimitLambda(); + auto block_stride_limit_lambda = GetBlockStrideLimitLambda(); + auto axis_it = + GetAxisByLambda(And({stride_lambda, m0_limit_lambda, block_stride_limit_lambda, [&](const InsnAxis &axis) { + return axis.extent >= FULL_BLOCK_NUM && axis.extent % FULL_BLOCK_NUM == 0; + }})); + if (IsValid(axis_it)) { + return axis_it; + } + axis_it = GetAxisByLambda(And({stride_lambda, m0_limit_lambda, block_stride_limit_lambda, + [&](const InsnAxis &axis) { return axis.extent >= FULL_BLOCK_NUM; }})); + if (IsValid(axis_it) && axis_list_.size() == 1) { + return axis_it; + } + axis_list_.reverse(); + axis_it = GetAxisByLambda(And({stride_lambda, m0_limit_lambda, block_stride_limit_lambda, + [&](const InsnAxis &axis) { return axis.extent < FULL_BLOCK_NUM; }})); + if (IsValid(axis_it)) { + return axis_it; + } + return GetAxisByLambda(And({stride_lambda, m0_limit_lambda, [&](const InsnAxis &axis) { + return axis.extent <= FULL_BLOCK_NUM || axis.extent % FULL_BLOCK_NUM == 0; + }})); +} + +AxisIt InsnArgsCalculator::GetRepeatAxisIt() { + AxisSort(axis_list_); + auto stride_lambda = GetStrideLambda(); + auto m1_limit_lambda = GetM1LimitLambda(); + auto axis_it = GetAxisByLambda( + And({stride_lambda, m1_limit_lambda, [&](const InsnAxis &axis) { return axis.extent >= MAX_REPEAT - 1; }})); + if (IsValid(axis_it)) { + return axis_it; + } + axis_list_.reverse(); + return GetAxisByLambda(And({stride_lambda, m1_limit_lambda})); +} + +void InsnArgsCalculator::SetArgMask(int len) { + SetArgBlockNum(1); + SetArgBlockLen(len); +} + +void InsnArgsCalculator::SetArgBlockNum(int data_num) { arg_.block_num = data_num; } +void InsnArgsCalculator::SetArgBlockLen(int data_len) { arg_.block_len = data_len; } + +void InsnArgsCalculator::SetArgM0(int dst_m0, int lsrc_m0, int rsrc_m0 = 0) { + arg_.dst_m0 = dst_m0; + arg_.src_m0_list = {lsrc_m0, rsrc_m0}; +} + +void InsnArgsCalculator::SetArgM1(int dst_m1, int lsrc_m1, int rsrc_m1 = 0) { + arg_.dst_m1 = dst_m1; + arg_.src_m1_list = {lsrc_m1, rsrc_m1}; +} + +void InsnArgsCalculator::SetArgRepeat(int repeat) { arg_.repeat = repeat; } + +void InsnArgsCalculator::BlockAxisReduction() { + Print(axis_list_); + auto block_axis_it = GetBlockAxis(); + if (IsValid(block_axis_it)) { + auto origin_block_axis = *block_axis_it; + InsnAxis block_axis = ExtractAxis(block_axis_it); + if (block_axis.extent % FULL_BLOCK_NUM != 0 && block_axis.extent > FULL_BLOCK_NUM) { + arg_.tail_len = block_axis.extent % FULL_BLOCK_NUM; + block_axis.extent = FloorTo(block_axis.extent, FULL_BLOCK_NUM); + arg_.dst_tail_offset = block_axis.dst_stride * block_axis.extent; + for (auto stride : block_axis.src_stride_list) { + arg_.src_tail_offset_list.push_back(stride * block_axis.extent); + } + SplitAxis(FULL_BLOCK_NUM, block_axis); + auto repeat_axis_it = GetRepeatAxisIt(); + if (!IsValid(repeat_axis_it) && axis_list_.size() > 0) { + for (auto it = axis_list_.begin(); it != axis_list_.end(); it++) { + if (it->var->name_hint == block_axis.var->name_hint) { + axis_list_.erase(it); + break; + } + } + axis_list_.push_back(origin_block_axis); + return; + } + } else { + SplitAxis(FULL_BLOCK_NUM, block_axis); + } + + block_axis.Print("BLOCK_AXIS"); + SetArgM0(block_axis.dst_stride / meta_.block_size, block_axis.src_stride_list.front() / meta_.block_size, + block_axis.src_stride_list.back() / meta_.block_size); + SetArgBlockNum(block_axis.extent); + } +} + +void InsnArgsCalculator::RepeatAxisReduction() { + Print(axis_list_); + auto repeat_axis = GetRepeatAxis(); + if (repeat_axis.IsValid()) { + repeat_axis.Print("REPEAT_AXIS"); + SetArgM1(repeat_axis.dst_stride / meta_.dst_block_size, repeat_axis.src_stride_list.front() / meta_.src_block_size, + repeat_axis.src_stride_list.back() / meta_.src_block_size); + SetArgRepeat(repeat_axis.extent); + } +} + +InsnAxis InsnArgsCalculator::GetInvalidAxis() { + InsnAxis res; + res.is_valid = false; + return res; +} + +InsnAxis InsnArgsCalculator::GetRepeatAxis() { + auto repeat_axis_it = GetRepeatAxisIt(); + if (IsValid(repeat_axis_it)) { + InsnAxis repeat_axis = ExtractAxis(repeat_axis_it); + SplitAxis(MAX_REPEAT - 1, repeat_axis); + return repeat_axis; + } + return GetInvalidAxis(); +} + +void InsnArgsCalculator::CastCaseReduction() { + if (axis_list_.empty()) { + return; + } + Print(axis_list_); + int cast_block_size = meta_.dst_block_size < meta_.src_block_size ? meta_.dst_block_size : meta_.src_block_size; + auto vec_axis_it = GetVecAxisIt(); + if (IsValid(vec_axis_it)) { + InsnAxis vec_axis = ExtractAxis(vec_axis_it); + int max_vec_len = cast_block_size * FULL_BLOCK_NUM; + if (vec_axis.extent > cast_block_size && vec_axis.extent < max_vec_len) { + SetArgMask(DivFloor(vec_axis.extent, cast_block_size) * cast_block_size); + SetArgM0(1, 1, 1); + } else if (vec_axis.extent >= max_vec_len) { + SplitAxis(max_vec_len, vec_axis); + SetArgMask(DivFloor(vec_axis.extent, cast_block_size) * cast_block_size); + SetArgM0(1, 1, 1); + } else { + SetArgBlockLen(cast_block_size); + } + } + RepeatAxisReduction(); +} + +int DivFloor(int a, int b) { + if (a % b == 0) { + return a / b; + } else { + return a / b + 1; + } +} + +void InsnArgsCalculator::InsnReduction() { + if (axis_list_.empty()) { + return; + } + Print(axis_list_); + auto vec_axis_it = GetVecAxisIt(); + meta_.scalar = !IsValid(vec_axis_it); + if (!meta_.scalar) { + InsnAxis vec_axis = ExtractAxis(vec_axis_it); + int max_vec_len = meta_.block_size * FULL_BLOCK_NUM; + if (vec_axis.extent > meta_.block_size && vec_axis.extent < max_vec_len && + (vec_axis.extent % meta_.block_size != 0 || vec_axis.extent > max_vec_len * meta_.vec_rate)) { + vec_axis.Print("VEC_BLOCK_AXIS"); + SetArgMask(DivFloor(vec_axis.extent, meta_.block_size) * meta_.block_size); + SetArgM0(1, 1, 1); + } else { + SplitAxis(meta_.block_size, vec_axis); + vec_axis.Print("VEC_AXIS"); + SetArgBlockLen(meta_.block_size); + BlockAxisReduction(); + } + RepeatAxisReduction(); + } else { + BlockAxisReduction(); + RepeatAxisReduction(); + } + Print(axis_list_); +} + +Expr InsnArgsCalculator::GetOffset(int stride_index) { + Expr res = Expr(0); + for (auto axis_it : axis_list_) { + auto stride = axis_it.stride_list[stride_index]; + auto mul_expr = Mul::make(stride, axis_it.var); + res = Add::make(mul_expr, res); + } + return Simplify(res); +} + +StmtInfo InsnArgsCalculator::ExportForInfo() { + if (for_info_.ops_.empty()) { + return for_info_; + } + int last_index = for_info_.ops_.size() - 1; + auto last_for = for_info_.ops_[last_index].as(); + auto store_stmt = last_for->body; + Stmt for_stmt = store_stmt; + StmtInfo result; + for (auto axis_it : axis_list_) { + for_stmt = For::make(axis_it.var, axis_it.min, axis_it.extent, last_for->for_type, last_for->device_api, for_stmt); + result.ops_.push_back(for_stmt); + result.vars_.push_back(axis_it.var); + } + return result; +} + +PatternResult InsnArgsCalculator::ExportResult() { + PatternResult res; + auto arg_info = ArgInfo(make_node()); + auto body_args = VectorArgInfo(make_node()); + body_args.GetNode()->body_num_ = arg_.body_num; + body_args.GetNode()->body_offset_ = meta_.block_size * FULL_BLOCK_NUM; + body_args.GetNode()->repeat_ = Expr(arg_.repeat); + body_args.GetNode()->dst_stride_m0_ = Expr(arg_.dst_m0); + body_args.GetNode()->dst_stride_m1_ = Expr(arg_.dst_m1); + body_args.GetNode()->src_stride_m0_list_ = arg_.src_m0_list; + body_args.GetNode()->src_stride_m1_list_ = arg_.src_m1_list; + body_args.GetNode()->vec_mask_ = GetVecMask(arg_.block_len, arg_.block_num, meta_.dtype, meta_.block_offset); + body_args.GetNode()->block_offset_ = make_const(Int(32), meta_.block_offset); + arg_info.GetNode()->body_arg_info_ = body_args; + if (arg_.tail_len > 0) { + auto tail_args = VectorArgInfo(make_node()); + tail_args.GetNode()->dst_head_ = Expr(arg_.dst_tail_offset); + tail_args.GetNode()->dst_stride_m1_ = Expr(arg_.dst_m1); + tail_args.GetNode()->src_stride_m1_list_ = arg_.src_m1_list; + tail_args.GetNode()->repeat_ = Expr(1); + tail_args.GetNode()->src_head_list_ = arg_.src_tail_offset_list; + tail_args.GetNode()->body_offset_ = meta_.block_size * FULL_BLOCK_NUM; + tail_args.GetNode()->dst_stride_m0_ = Expr(arg_.dst_m0); + tail_args.GetNode()->src_stride_m0_list_ = arg_.src_m0_list; + tail_args.GetNode()->vec_mask_ = GetVecMask(arg_.block_len, arg_.tail_len, meta_.dtype, meta_.block_offset); + tail_args.GetNode()->block_offset_ = make_const(Int(32), meta_.block_offset); + arg_info.GetNode()->tail_arg_info_ = tail_args; + } + StmtInfoList info_list = GetInfoList(dst_info_, src_info_list_); + CleanZeroStrides(info_list); + for (size_t i = 0; i < info_list.size(); i++) { + info_list[i].GetNode()->insn_offset_ = GetOffset(i); + } + info_list[1].Print(); + res.for_info = ExportForInfo(); + res.arg_info = arg_info; + res.dst_info_list = {info_list[0]}; + if (info_list.size() > 2) { + res.src_info_list = {info_list[1], info_list[2]}; + } else { + res.src_info_list = {info_list[1]}; + } + body_args.Print(); + if (arg_info->tail_arg_info_.defined()) { + arg_info->tail_arg_info_.Print(); + } + return res; +} + +SingleVecInsnArgsCalculator::SingleVecInsnArgsCalculator(const StmtInfoList &dst_info_list, + const StmtInfoList &src_info_list, const StmtInfo &for_info, + const std::string &intrin_name) + : InsnArgsCalculator(dst_info_list, src_info_list, for_info, intrin_name) {} + +PatternResult SingleVecInsnArgsCalculator::GetInsnArgs() { + if (meta_.cast) { + CastCaseReduction(); + } else { + InsnReduction(); + } + return ExportResult(); +} + +BinaryVecInsnArgsCalculator::BinaryVecInsnArgsCalculator(const StmtInfoList &dst_info_list, + const StmtInfoList &src_info_list, const StmtInfo &for_info, + const std::string &mode, const std::string &intrin_name, + bool expand_mask) + : InsnArgsCalculator(dst_info_list, src_info_list, for_info, intrin_name), mode_{mode}, expand_mask_{expand_mask} { + if (mode_ == "reduction" && src_info_list_.size() == 2 && src_info_list_[0]->name_ == dst_info_list[0]->name_) { + auto temp = src_info_list_[0].Copy(); + src_info_list_.Set(0, src_info_list_[1].Copy()); + src_info_list_.Set(1, temp); + CalAxis(); + } +} + +PatternResult BinaryVecInsnArgsCalculator::GetInsnArgs() { + LOG(DEBUG) << "Binary vec Insn reduction"; + InsnReduction(); + return ExportResult(); +} + +std::function BinaryVecInsnArgsCalculator::GetM0LimitLambda() { + return [&](const InsnAxis &axis) { + auto is_limit = [&](int stride) { return stride / meta_.block_size < MAX_STRIDE_M0; }; + return std::all_of(axis.stride_list.begin(), axis.stride_list.end(), is_limit) && axis.dst_stride != 0; + }; +} +std::function BinaryVecInsnArgsCalculator::GetM1LimitLambda() { + return [&](const InsnAxis &axis) { + auto is_limit = [&](int stride) { return stride / meta_.src_block_size < MAX_STRIDE_M1; }; + return axis.dst_stride / meta_.dst_block_size < MAX_STRIDE_M1 && + std::all_of(axis.src_stride_list.begin(), axis.src_stride_list.end(), is_limit); + }; +} + +void BinaryVecInsnArgsCalculator::InsnReduction() { + if (axis_list_.empty()) { + return; + } + Print(axis_list_); + + auto vec_axis_it = GetVecAxisIt(); + meta_.scalar = !IsValid(vec_axis_it); + if (!meta_.scalar) { + vec_axis_ = *vec_axis_it; + InsnAxis vec_axis = ExtractAxis(vec_axis_it); + auto bad_axis_lambda = [&](const InsnAxis &axis) { + int min_stride = vec_axis_it->extent; + auto dst_name = dst_info_list_[0]->name_; + if (meta_.same_dst_src && axis.dst_stride < min_stride && axis.dst_stride != 0) { + return true; + } + return false; + }; + auto bad_axis_it = GetAxisByLambda(bad_axis_lambda); + InsnAxis bad_axis; + bad_axis.is_valid = false; + if (IsValid(bad_axis_it)) { + bad_axis = ExtractAxis(bad_axis_it); + } + int max_vec_len = meta_.block_size * FULL_BLOCK_NUM; + if (vec_axis.extent > meta_.block_size && vec_axis.extent < max_vec_len && + (vec_axis.extent % meta_.block_size != 0 || vec_axis.extent > max_vec_len * meta_.vec_rate)) { + vec_axis.Print("VEC_BLOCK_AXIS"); + if (expand_mask_) { + SetArgMask(DivFloor(vec_axis.extent, meta_.block_size) * meta_.block_size); + } else { + SetArgMask(vec_axis.extent); + } + SetArgM0(1, 1, 1); + } else { + SplitAxis(meta_.block_size, vec_axis); + vec_axis.Print("VEC_AXIS"); + if (expand_mask_ && mode_ != "reduction") { + SetArgBlockLen(meta_.block_size); + } else { + SetArgBlockLen(vec_axis.extent); + } + BlockAxisReduction(); + } + RepeatAxisReduction(); + if (bad_axis.IsValid()) { + axis_list_.push_back(bad_axis); + } + } else { + BlockAxisReduction(); + RepeatAxisReduction(); + } + Print(axis_list_); +} + +PatternResult LastAxisReduceInsnArgsCalculator::GetInsnArgs() { + CalcParams(); + Array elim_var; + elim_var = GetPattern(); + arg_info.GetNode()->pattern_ = PATTERN_1D; + return GenResult(elim_var); +} + +Array LastAxisReduceInsnArgsCalculator::GetPattern() { + int body_len = params.last_dim_shape / params.vec_max_len * params.vec_max_len; + int tail_len = params.last_dim_shape % params.vec_max_len; + int cmd_body_len = 0; + bool is_vadd = intrin_name == "vadd"; + int repeat_stride = FULL_BLOCK_NUM; + if (is_vadd) { + repeat_stride = 1; + } + const int fp16_block_size = 16; + + if (body_len > 0) { + body_args = VectorArgInfo(make_node()); + body_args.GetNode()->body_num_ = 1; + body_args.GetNode()->body_offset_ = params.vec_max_len; + body_args.GetNode()->repeat_ = Expr(body_len / params.vec_max_len); + // Here use dst_stride_m1 as dst_stride + body_args.GetNode()->dst_stride_m1_ = Expr(1); + body_args.GetNode()->src_stride_m0_list_ = {Expr(1)}; + body_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)}; + body_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, 1, dst_info->dtype_); + cmd_body_len += GetInt32Const(body_args->repeat_) * repeat_stride; + } + if (tail_len > 0) { + tail_args = VectorArgInfo(make_node()); + tail_args.GetNode()->body_offset_ = params.vec_max_len; + tail_args.GetNode()->dst_head_ = Expr(cmd_body_len); + tail_args.GetNode()->src_head_list_ = {Expr(body_len)}; + tail_args.GetNode()->repeat_ = Expr(1); + tail_args.GetNode()->dst_stride_m1_ = Expr(1); + tail_args.GetNode()->src_stride_m0_list_ = {Expr(1)}; + tail_args.GetNode()->src_stride_m1_list_ = {Expr(0)}; + tail_args.GetNode()->vec_mask_ = GetVecMask(tail_len, 1, dst_info->dtype_); + if (is_vadd) { + cmd_body_len += 1; + } else { + cmd_body_len += tail_len / fp16_block_size; + if (tail_len % fp16_block_size != 0) { + cmd_body_len += 1; + } + } + } + // cmd_body_len > 1 means vcadd size greater than 128, need to use vcadd again to compute final result + // if cmd_body_len > 128, then need to recursively emit vcadd + while (cmd_body_len > 1) { + int cmd_tail_len = cmd_body_len % params.vec_max_len; + cmd_body_len = cmd_body_len / params.vec_max_len; + if (cmd_body_len > 0) { + VectorArgInfo mix_vec_args = VectorArgInfo(make_node()); + mix_vec_args.GetNode()->repeat_ = Expr(cmd_body_len); + mix_vec_args.GetNode()->dst_head_ = Expr(0); + mix_vec_args.GetNode()->src_head_list_ = {Expr(0)}; + mix_vec_args.GetNode()->dst_stride_m1_ = Expr(1); + mix_vec_args.GetNode()->src_stride_m0_list_ = {Expr(1)}; + mix_vec_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)}; + mix_vec_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, 1, dst_info->dtype_); + mix_vec_arg_list.push_back(mix_vec_args); + if (!is_vadd) { + cmd_body_len *= FULL_BLOCK_NUM; + } + } + if (cmd_tail_len > 0) { + VectorArgInfo mix_vec_args = VectorArgInfo(make_node()); + mix_vec_args.GetNode()->repeat_ = Expr(1); + mix_vec_args.GetNode()->dst_head_ = Expr(cmd_body_len); + if (is_vadd) { + mix_vec_args.GetNode()->src_head_list_ = {Expr(cmd_body_len * params.vec_max_len)}; + } else { + mix_vec_args.GetNode()->src_head_list_ = {Expr(cmd_body_len / FULL_BLOCK_NUM * params.vec_max_len)}; + } + mix_vec_args.GetNode()->dst_stride_m1_ = Expr(1); + mix_vec_args.GetNode()->src_stride_m0_list_ = {Expr(1)}; + mix_vec_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)}; + mix_vec_args.GetNode()->vec_mask_ = GetVecMask(cmd_tail_len, 1, dst_info->dtype_); + if (is_vadd) { + cmd_body_len += 1; + } else { + cmd_body_len += cmd_tail_len / fp16_block_size; + if (cmd_tail_len % fp16_block_size != 0) { + cmd_body_len += 1; + } + } + mix_vec_arg_list.push_back(mix_vec_args); + } + } + + params.insn_offset_scale_factor = Expr(params.block_size); + int max_num = body_len / params.vec_max_len; + if (intrin_name == "vmax" || intrin_name == "vmin") { + max_num *= FULL_BLOCK_NUM; + } + if (max_num >= params.block_size) { + params.insn_offset_scale_factor = max_num + params.block_size - 1; + if (tail_len > 0) { + params.insn_offset_scale_factor += 1; + } + params.insn_offset_scale_factor = truncdiv(params.insn_offset_scale_factor, params.block_size) * params.block_size; + } + + if (!params.src_var.empty()) { + return GetRange(params.src_var, -1, 1); + } + + return {}; +} + +PatternResult LastAxisReduceInsnArgsCalculator::GenResult(const Array &elim_var) { + dst_info.GetNode()->insn_offset_ = GetInsnOffset(dst_info, elim_var) * params.insn_offset_scale_factor; + src_info.GetNode()->insn_offset_ = GetInsnOffset(src_info, elim_var); + + if (body_args.defined()) { + body_args.GetNode()->insn_offset_scale_factor_ = params.insn_offset_scale_factor; + } + if (tail_args.defined()) { + tail_args.GetNode()->insn_offset_scale_factor_ = params.insn_offset_scale_factor; + } + for (auto &arg : mix_vec_arg_list) { + arg.GetNode()->insn_offset_scale_factor_ = params.insn_offset_scale_factor; + } + + arg_info.GetNode()->body_arg_info_ = body_args; + arg_info.GetNode()->tail_arg_info_ = tail_args; + arg_info.GetNode()->reduction_tail_args_ = mix_vec_arg_list; + + CleanForInfoVars(for_info, elim_var); + arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION_LAST_AXIS; + + PatternResult result; + result.dst_info_list = {dst_info}; + result.src_info_list = {src_info}; + result.for_info = for_info; + result.arg_info = arg_info; + + return result; +} + +void LastAxisReduceInsnArgsCalculator::CalcParams() { + // check shape len + if (dst_info->shape_.empty() || src_info->shape_.empty()) { + LOG(FATAL) << "CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0."; + } + // check data type + if (dst_info->dtype_ != src_info->dtype_) { + LOG(FATAL) << "CCE Vector Insn Error: dst_buffer and src_buffer can not be different data type."; + } + + params.src_var = src_info->var_; + params.block_size = GetUbBlkSize(dst_info->dtype_); + params.last_dim_shape = GetInt32Const(GetItem(src_info->shape_, -1)); + params.vec_max_len = GetVecMaxLen(dst_info->dtype_); + CHECK_NE(params.block_size, 0); + CHECK_NE(params.vec_max_len, 0); +} +/// Generete info list for bisection intrin +/// \param dst_info_list +/// \param src_info_list +/// \param for_info +/// \param if_info +/// \param last_axis +/// \param postfix +/// \return +BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_info_list, + const StmtInfoList &src_info_list, const StmtInfo &for_info, + StmtInfo &if_info, bool last_axis, int postfix = 0) { + CHECK_EQ(dst_info_list.size(), 1); + CHECK_EQ(src_info_list.size(), 2); + BisectionInfoWrapper wrapper; + // Separate com_info and for_info + int compare_idx = 1; + int var_idx = -1; + var_idx = GetBisectionReductionIdx(dst_info_list, src_info_list, compare_idx); + StmtStoreInfo dst_info = dst_info_list[0]; + CHECK_GE(compare_idx, 0); + StmtStoreInfo src_info1 = src_info_list[compare_idx]; + + Var reduce_var = GetItem(src_info1->var_, var_idx); + int stride_len = GetInt32Const(GetItem(src_info1->strides_, var_idx)); + size_t for_idx = 0; + bool suc = GetIndexOfElement(for_info.vars_, VarExpr(reduce_var), for_idx); + CHECK(suc); + auto exist_for = GetItem(for_info.ops_, for_idx).as(); + CHECK(exist_for); + int extent = GetInt32Const(exist_for->extent); + + int simd_len = 1; + const std::string un_def_var = "un_def_var"; + Var simd_var = Var("un_def_var"); + CHECK_GT(src_info1->strides_.size(), 0); + CHECK_EQ(src_info1->var_.size(), src_info1->strides_.size()); + for (size_t i = 0; i <= src_info1->strides_.size() - 1; i++) { + if (GetInt32Const(src_info1->strides_[i]) == 1) { + simd_var = src_info1->var_[i]; + size_t simd_for_idx = 0; + bool suc = GetIndexOfElement(for_info.vars_, VarExpr(simd_var), simd_for_idx); + CHECK(suc); + auto simd_for = GetItem(for_info.ops_, simd_for_idx).as(); + CHECK(simd_for); + simd_len = GetInt32Const(simd_for->extent); + } + } + + int block_unit = GetUbBlkSize(src_info1->dtype_); + int last_dim_len = ((simd_len - 1) / block_unit + 1) * block_unit; + + Var bisec_var; + Buffer bisec_buffer; + std::string bisec_pre_header = "bisec"; + std::string bisec_name = bisec_pre_header + "_local_UB"; + if (postfix > 0) { + bisec_name = bisec_name + "_" + std::to_string(postfix); + } + + int vec_max_len = GetVecMaxLen(dst_info->dtype_); + CHECK_NE(vec_max_len, 0); + std::vector pow2_list = {0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}; + int origin_len = extent; + for (int i : pow2_list) { + if (extent <= i) { + extent = i / 2; + break; + } + } + int prolog_len = origin_len - extent; + + src_info1.Print(); + auto src_vars = src_info1->var_; + auto src_strides = src_info1->strides_; + auto src_dims = src_info1->shape_; + auto new_vars = src_info1->var_; + auto new_strides = src_info1->strides_; + auto new_dims = src_info1->shape_; + LOG(DEBUG) << "\nvar_idx:" << var_idx << "\n"; + var_idx = var_idx + src_info1->var_.size(); + if (var_idx != static_cast(src_info1->var_.size()) - 1) { + new_dims.Set(var_idx, extent); + new_dims.Set(new_dims.size() - 1, last_dim_len); + CHECK_GT(new_dims.size(), 1); + new_strides.Set(new_strides.size() - 1, 1); + for (int i = static_cast(new_dims.size()) - 2; i >= 0; i--) { + new_strides.Set(i, new_strides[i + 1] * new_dims[i + 1]); + } + new_dims.Set(new_dims.size() - 1, simd_len); + } else { + new_dims = {extent}; + } + + // copy data from origin buffer to new temp buffer + Array shape = new_dims; + wrapper.original_shape_ = new_dims; + bisec_var = Var(bisec_name, Handle()); + bisec_buffer = BufferNode::make(bisec_var, dst_info->dtype_, shape, Array(), Expr(), bisec_name, SCOPE_UBUF, 0, + 0, BufferType::kDefault); + // Need to copy input to bisect buffer + StmtStoreInfo copy_dst_info{src_info1.Copy()}; + StmtStoreInfo copy_src_info{src_info1.Copy()}; + StmtInfoList src_list = {copy_src_info}; + + auto for_tmp_info = for_info.Copy(); + auto new_for = GetItem(for_tmp_info.ops_, for_idx).as(); + CHECK(new_for); + SetItem(for_tmp_info.ops_, static_cast(for_idx), + For::make(new_for->loop_var, new_for->min, extent, new_for->for_type, new_for->device_api, new_for->body)); + ReplaceVarWithNewForInfo(copy_dst_info, for_info, for_tmp_info); + ReplaceVarWithNewForInfo(copy_src_info, for_info, for_tmp_info); + SetItem(copy_src_info.GetNode()->shape_, var_idx, Expr(extent)); + SetItem(copy_dst_info.GetNode()->shape_, var_idx, Expr(extent)); + SetItem(copy_dst_info.GetNode()->strides_, var_idx, Expr(last_dim_len)); + if (simd_var->name_hint != un_def_var) { + copy_dst_info.GetNode()->index_ = 0; + for (size_t i = 0; i <= new_vars.size() - 1; i++) { + copy_dst_info.GetNode()->index_ += new_vars[i] * new_strides[i]; + } + } else { + copy_dst_info.GetNode()->index_ = last_dim_len * reduce_var; + } + copy_dst_info.GetNode()->elem_offset_ = 0; + copy_dst_info.GetNode()->name_ = bisec_name; + copy_dst_info.GetNode()->buffer_ = bisec_buffer; + copy_dst_info.GetNode()->data_ = bisec_var; + copy_dst_info.GetNode()->strides_ = new_strides; + + CompactComputationInfoList(copy_dst_info, src_list, if_info, for_tmp_info); + wrapper.bisec_info_list_.emplace_back(StmtInfoList{copy_dst_info, copy_src_info}); + wrapper.for_info_list_.push_back(for_tmp_info); + + // Generate the vadd wrapper + while (extent >= 0) { + StmtStoreInfo dst_tmp_info = dst_info.Copy(); + StmtStoreInfo src_tmp_info0{src_info1.Copy()}; + StmtStoreInfo src_tmp_info1{src_info1.Copy()}; + auto for_tmp_info = for_info.Copy(); + int vadd_length = (prolog_len != 0) ? prolog_len : extent; + + if (extent > 0) { + dst_tmp_info = src_info1.Copy(); + dst_tmp_info.GetNode()->data_alignment_ = simd_len; + dst_tmp_info.GetNode()->name_ = bisec_name; + dst_tmp_info.GetNode()->buffer_ = bisec_buffer; + dst_tmp_info.GetNode()->data_ = bisec_var; + dst_tmp_info.GetNode()->shape_ = new_dims; + SetItem(dst_tmp_info.GetNode()->shape_, var_idx, Expr(vadd_length)); + dst_tmp_info.GetNode()->strides_ = new_strides; + dst_tmp_info.GetNode()->var_ = new_vars; + dst_tmp_info.GetNode()->index_ = 0; + for (size_t i = 0; i <= new_vars.size() - 1; i++) { + dst_tmp_info.GetNode()->index_ += new_vars[i] * new_strides[i]; + } + if (prolog_len == 0) { + src_tmp_info1 = dst_tmp_info.Copy(); + src_tmp_info1.GetNode()->index_ = dst_tmp_info.GetNode()->index_ + extent * last_dim_len; + } else { + SetItem(src_tmp_info1.GetNode()->shape_, var_idx, Expr(vadd_length)); + src_tmp_info1.GetNode()->index_ += extent * stride_len; + } + } + + src_tmp_info0 = dst_tmp_info.Copy(); + auto new_for = GetItem(for_tmp_info.ops_, for_idx).as(); + CHECK(new_for); + int temp_for_len = (vadd_length != 0) ? vadd_length : 1; + SetItem( + for_tmp_info.ops_, static_cast(for_idx), + For::make(new_for->loop_var, new_for->min, temp_for_len, new_for->for_type, new_for->device_api, new_for->body)); + + if (extent == 0) { + src_tmp_info1.GetNode()->name_ = bisec_name; + src_tmp_info1.GetNode()->buffer_ = bisec_buffer; + src_tmp_info1.GetNode()->data_ = bisec_var; + if (simd_var->name_hint != un_def_var) { + src_tmp_info1.GetNode()->shape_ = RemoveItemAtIndex(new_dims, var_idx); + src_tmp_info1.GetNode()->strides_ = RemoveItemAtIndex(new_strides, var_idx); + src_tmp_info1.GetNode()->var_ = RemoveItemAtIndex(new_vars, var_idx); + src_tmp_info1.GetNode()->index_ = 0; + for (size_t i = 0; i <= src_tmp_info1->var_.size() - 1; i++) { + src_tmp_info1.GetNode()->index_ += src_tmp_info1->var_[i] * src_tmp_info1->strides_[i]; + } + } else { + src_tmp_info1.GetNode()->shape_ = dst_tmp_info->shape_; + src_tmp_info1.GetNode()->strides_ = dst_tmp_info->strides_; + src_tmp_info1.GetNode()->var_ = dst_tmp_info->var_; + src_tmp_info1.GetNode()->index_ = dst_tmp_info->index_; + } + } + + ReplaceVarWithNewForInfo(dst_tmp_info, for_info, for_tmp_info); + ReplaceVarWithNewForInfo(src_tmp_info0, for_info, for_tmp_info); + ReplaceVarWithNewForInfo(src_tmp_info1, for_info, for_tmp_info); + StmtInfoList src_list = {src_tmp_info0, src_tmp_info1}; + CompactComputationInfoList(dst_tmp_info, src_list, if_info, for_tmp_info); + wrapper.for_info_list_.emplace_back(for_tmp_info); + if (extent == 0) { + // normally is bisect_tmp = bisect_tmp + bisect_tmp/src_tmp + wrapper.bisec_info_list_.emplace_back(StmtInfoList{dst_tmp_info, dst_tmp_info, src_tmp_info1}); + } else { + // normally is dst_tmp = dst_tmp + bisect_tmp + wrapper.bisec_info_list_.emplace_back(StmtInfoList{dst_tmp_info, src_tmp_info0, src_tmp_info1}); + } + + if (extent == 0) { + break; + } else { + extent = extent / 2; + } + prolog_len = 0; + } + // Generate arg_info + for (size_t i = 0; i < wrapper.bisec_info_list_.size(); ++i) { + auto info_list = wrapper.bisec_info_list_[i]; + auto new_for_info = wrapper.for_info_list_[i]; + + ArgInfo arg_info; + auto dst_list = GetRange(info_list, 0, 1); + auto src_list = GetRange(info_list, 1, info_list.size() - 1); + if (info_list.size() == 2) { + std::string dma_intrin = INTRIN_NAME_COPY_UB_TO_UB; + wrapper.dma_arg_info_map_ = GetDmaCopyInsnArgs(dma_intrin, dst_list, src_list, new_for_info); + } else { + // Bisect can't expand mask because it has inplace operation + if (i != wrapper.bisec_info_list_.size() - 1) { + // Last round dont need to add + FillLastDim(dst_list, src_list, new_for_info); + } + std::string mode = GetBinaryVecMode(dst_list, src_list, "vadd", false); + + BinaryVecInsnArgsCalculator args_calculator = + BinaryVecInsnArgsCalculator(dst_list, src_list, new_for_info, mode, "", false); + PatternResult params = args_calculator.GetInsnArgs(); + + arg_info = params.arg_info; + dst_list = params.dst_info_list; + src_list = params.src_info_list; + new_for_info = params.for_info; + wrapper.bisec_info_list_[i] = {dst_list[0], src_list[0], src_list[1]}; + } + wrapper.arg_info_list_.push_back(arg_info); + wrapper.for_info_list_[i] = new_for_info; + } + return wrapper; +} + +/// Get CCE Binary Vector Insn Computation Info +/// \param stmt - operand stmt +/// \param intrin_name - vector intrin name +/// \param dst_info_list - dst computation info list +/// \param src_info_list - src computation info list +/// \param if_info - if info list +/// \param for_info - for info list +/// \return intrin args +ArgInfo GetBinaryVecInsnArgs(const Stmt &stmt, std::string intrin_name, StmtInfoList &dst_info_list, + StmtInfoList &src_info_list, StmtInfo &if_info, StmtInfo &for_info, bool enable_bisect) { + // check intrin_name + std::set intrin_name_list = {"vadd", "vmax", "vmin", "vmul", "vdiv", "vsel", "vsub", "vand", + "vor", "vaxpy", "argmax", "argmin", "vmadd", "vmaddrelu", "vmla"}; + if (intrin_name_list.count(intrin_name) == 0) { + LOG(FATAL) << "Error: CCE Binary Vector Insn doesn't support the given intrin_name."; + } + + // get and check dst and src + GetCompactComputationInfo(stmt, dst_info_list, src_info_list, if_info, for_info, true); + // For vmadd/vmaddrelu/vmla we only need first two src + if (dst_info_list.size() != 1 || src_info_list.size() < 2) { + LOG(FATAL) << "CCE Binary Vector Insn only support ONE dst and TWO srcs."; + } + src_info_list = GetRange(src_info_list, 0, 2); + ArgInfo arg_info = ArgInfo(make_node()); + + // detect vector op mode + std::string mode = GetBinaryVecMode(dst_info_list, src_info_list, intrin_name, enable_bisect); + if (mode == "reduce_last_axis") { + size_t src_var_list_size = src_info_list[1]->var_.size(); + if (src_info_list[0]->var_.size() > src_info_list[1]->var_.size()) { + src_var_list_size = src_info_list[0]->var_.size(); + } + + CHECK(src_var_list_size > 0) << "Error: src can not be a scalar."; + if (src_var_list_size - dst_info_list[0]->var_.size() == 1) { + arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION_LAST_AXIS; + } else { + LOG(FATAL) << "Error: cannot support multi-last-axis reduction."; + } + } else if (mode == "reduce_bisection") { + arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION_BISECTION; + } else { + if (mode != "reduction" && mode != "broadcast") { + FillLastDim(dst_info_list, src_info_list, for_info); + } + + // vmax/vmin can't expand mask because it may introduce dirty data + bool can_expand_mask = intrin_name != "vmax" && intrin_name != "vmin"; + + BinaryVecInsnArgsCalculator args_calculator = + BinaryVecInsnArgsCalculator(dst_info_list, src_info_list, for_info, mode, intrin_name, can_expand_mask); + PatternResult params = args_calculator.GetInsnArgs(); + + arg_info = params.arg_info; + dst_info_list = params.dst_info_list; + src_info_list = params.src_info_list; + for_info = params.for_info; + if (mode == "broadcast") { + bool has_last_axis = false; + if ((arg_info->body_arg_info_.defined() && arg_info->body_arg_info_->last_axis_info_.src_index_ != -1) || + (arg_info->tail_arg_info_.defined() && arg_info->tail_arg_info_->last_axis_info_.src_index_ != -1)) { + has_last_axis = true; + } + + if (has_last_axis && (intrin_name == "vadd" || intrin_name == "vmul")) { + Array stores; + Array loads; + GetStoreAndLoads(stmt, stores, loads); + intrin_name = intrin_name + "s"; + if (arg_info->body_arg_info_.defined()) { + arg_info.GetNode()->body_arg_info_.GetNode()->last_axis_info_.intrin_name_ = intrin_name; + arg_info.GetNode()->body_arg_info_.GetNode()->last_axis_info_.src_op_ = + Downcast(loads[arg_info->body_arg_info_->last_axis_info_.src_index_]); + } + } + } + } + return arg_info; +} +} // namespace akg \ No newline at end of file diff --git a/src/emit_insn/insn_args_calculator.h b/src/emit_insn/insn_args_calculator.h new file mode 100644 index 0000000000000000000000000000000000000000..a4a25f9b80d9b2dfd09d79f0ae5c1f4bcb2235a7 --- /dev/null +++ b/src/emit_insn/insn_args_calculator.h @@ -0,0 +1,199 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef EMIT_INSN_ARGS_CALCULATOR_H_ +#define EMIT_INSN_ARGS_CALCULATOR_H_ +namespace akg { +struct InsnArg { + int dst_m0{1}; + int dst_m1{0}; + std::vector src_m0_list; + std::vector src_m1_list; + int repeat{1}; + int block_len{1}; + int block_num{1}; + int body_num{1}; + int tail_len{0}; + int dst_tail_offset{0}; + std::vector src_tail_offset_list; +}; + +struct Meta { + int block_size{0}; + int src_block_size{0}; + int dst_block_size{0}; + int block_offset{0}; + const float vec_rate{0.6}; + Type src_dtype; + Type dst_dtype; + Type dtype; + bool cast{false}; + bool tail{false}; + bool scalar{false}; + bool liner{false}; + bool same_dst_src{false}; +}; + +enum SplitStat { SUCCESS, NO_SPLIT, TAIL }; + +class InsnAxis { + public: + InsnAxis() = default; + InsnAxis(const For *for_stmt, const Array &info_list); + virtual ~InsnAxis() = default; + bool IsValid(); + void Print(const std::string &name = ""); + int min{0}; + int extent{0}; + Var var; + int dst_stride{0}; + int src_stride{0}; + std::vector src_stride_list; + std::vector stride_list; + bool is_valid{true}; + + private: + Expr GetStrideByAxis(const Array &vars, const Array &strides, Var obj_var); +}; + +using AxisIt = std::list::iterator; + +std::list GetAxisList(const StmtInfo &for_info, const Array &info_list); +Array GetInfoList(const StmtStoreInfo &dst_info, const Array &src_info_list); +int DivFloor(int a, int b); +void Print(std::list &axis_list); + +class InsnArgsCalculator { + public: + InsnArgsCalculator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, const StmtInfo &for_info, + const std::string &intrin_name); + virtual ~InsnArgsCalculator() = default; + + PatternResult ExportResult(); + void CalAxis(); + void InitArg(); + + virtual std::function GetStrideLambda(); + virtual std::function GetM0LimitLambda(); + virtual std::function GetM1LimitLambda(); + std::function GetBlockStrideLimitLambda(); + AxisIt GetAxisByLambda(const std::function &lambda); + InsnAxis ExtractAxis(AxisIt &it); + bool IsValid(AxisIt &it); + AxisIt GetVecAxisIt(); + AxisIt GetBlockAxis(); + AxisIt GetRepeatAxisIt(); + InsnAxis GetRepeatAxis(); + + void SetArgMask(int len); + void SetArgBlockNum(int data_num); + void SetArgBlockLen(int data_len); + void SetArgM0(int dst_m0, int lsrc_m0, int rsrc_m0); + void SetArgM1(int dst_m1, int lsrc_m1, int rsrc_m1); + void SetArgRepeat(int repeat); + void BlockAxisReduction(); + void RepeatAxisReduction(); + void CastCaseReduction(); + virtual void InsnReduction(); + + StmtInfo ExportForInfo(); + Expr GetOffset(int stride_index); + InsnAxis GetInvalidAxis(); + SplitStat SplitAxis(int extent, InsnAxis &axis); + std::list axis_list_; + + protected: + InsnArg arg_; + Meta meta_; + StmtInfoList dst_info_list_; + StmtInfoList src_info_list_; + StmtStoreInfo dst_info_; + StmtInfo for_info_; + const std::string intrin_name_; + const int max_block_stride_{4}; +}; + +class SingleVecInsnArgsCalculator : public InsnArgsCalculator { + public: + SingleVecInsnArgsCalculator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, const StmtInfo &for_info, + const std::string &intrin_name = ""); + virtual ~SingleVecInsnArgsCalculator() override = default; + PatternResult GetInsnArgs(); +}; +class BinaryVecInsnArgsCalculator : public InsnArgsCalculator { + public: + BinaryVecInsnArgsCalculator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, const StmtInfo &for_info, + const std::string &mode, const std::string &intrin_name = "", bool expand_mask = true); + virtual ~BinaryVecInsnArgsCalculator() override = default; + PatternResult GetInsnArgs(); + std::function GetM0LimitLambda(); + std::function GetM1LimitLambda(); + void InsnReduction(); + + private: + std::string mode_; + bool expand_mask_; + InsnAxis vec_axis_; +}; +class LastAxisReduceInsnArgsCalculator : InsnArgsCalculator{ + public: + LastAxisReduceInsnArgsCalculator(const StmtStoreInfo &dst_info, const StmtStoreInfo &src_info, const StmtInfo &for_info, + const std::string &intrin_name) + : InsnArgsCalculator({dst_info}, {src_info}, for_info, intrin_name), + dst_info(dst_info), + src_info(src_info), + for_info(for_info), + arg_info(ArgInfo(make_node())), + body_args(VectorArgInfo()), + tail_args(VectorArgInfo()), + intrin_name(intrin_name) {} + PatternResult GetInsnArgs(); + ~LastAxisReduceInsnArgsCalculator() = default; + + protected: + Array GetPattern(); + PatternResult GenResult(const Array &elim_var); + + private: + void CalcParams(); + + struct Params { + Array src_var; + int block_size = 0; + int vec_max_len = 0; + int last_dim_shape = 0; + Expr insn_offset_scale_factor; + }; + StmtStoreInfo dst_info; + StmtStoreInfo src_info; + StmtInfo for_info; + ArgInfo arg_info; + VectorArgInfo body_args; + VectorArgInfo tail_args; + Array mix_vec_arg_list; + std::string intrin_name; + Params params; +}; + +BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_info_list, + const StmtInfoList &src_info_list, const StmtInfo &for_info, + StmtInfo &if_info, bool last_axis, int postfix); + +ArgInfo GetBinaryVecInsnArgs(const Stmt &stmt, std::string intrin_name, StmtInfoList &dst_info_list, + StmtInfoList &src_info_list, StmtInfo &if_info, StmtInfo &for_info, + bool enable_bisect = true); +} // namespace akg +#endif \ No newline at end of file diff --git a/src/emit_insn/insn_binary_vec_pattern.cc b/src/emit_insn/insn_binary_vec_pattern.cc deleted file mode 100644 index a920d52e6b36c7068857299c942ea1e55359524c..0000000000000000000000000000000000000000 --- a/src/emit_insn/insn_binary_vec_pattern.cc +++ /dev/null @@ -1,1335 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -#include "ir_pass.h" -#include "contrib/cce_parm/cceconf.h" -#include "tvm.h" -#include "common/array_api.h" -#include "insn_pattern.h" -#include "insn_builder.h" - -namespace akg { -std::string GetBinaryVecMode(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, - const std::string &intrin_name, bool enable_bisect = true) { - std::set reduce_bisect_list = {"vadd", "vsub", "vmul", "vmax"}; - std::string mode = "reduction"; - if (IsElementwise(dst_info_list, src_info_list)) { - mode = "elewise"; - } else if (IsBroadcast(dst_info_list, src_info_list)) { - mode = "broadcast"; - } else if (IsLastAxisReduction(dst_info_list, src_info_list)) { - mode = "reduce_last_axis"; - } else if (enable_bisect && reduce_bisect_list.count(intrin_name) != 0 && - IsBisectionReduction(dst_info_list, src_info_list)) { - mode = "reduce_bisection"; - } - - return mode; -} - -PatternResult ReduceLastAxisPatternGenerator::GetInsnArgs() { - CalcParams(); - Array elim_var; - - float rate2d = Compute2DBlockPatternMaskRate(); - if (rate2d > 1.0f) { - elim_var = Get2DBlockPattern(); - arg_info.GetNode()->pattern_ = PATTERN_2D_BLOCK; - } else { - elim_var = Get1DPattern(); - arg_info.GetNode()->pattern_ = PATTERN_1D; - } - - return GenResult(elim_var); -} - -float ReduceLastAxisPatternGenerator::Compute2DBlockPatternMaskRate() { - const float is2_dpattern = 1.0f; - if (intrin_name == "vadd" || intrin_name == "argmax" || intrin_name == "argmin") { - return not_this_pattern; - } - - // src var size must larger than 2 - if (params.src_var.size() < 2) { - return not_this_pattern; - } - - int body_len = params.last_dim_shape / params.vec_max_len * params.vec_max_len; - int tail_len = params.last_dim_shape % params.vec_max_len; - - // there is no body in this mode - if (body_len > 0 || tail_len > params.block_size) { - return not_this_pattern; - } - - return is2_dpattern; -} - -Array ReduceLastAxisPatternGenerator::Get2DBlockPattern() { - int sec_last_dim_shape = GetInt32Const(GetItem(src_info->shape_, -2)); - int body_len = sec_last_dim_shape / FULL_BLOCK_NUM * FULL_BLOCK_NUM; - int tail_len = sec_last_dim_shape % FULL_BLOCK_NUM; - int cmd_body_len = 0; - - if (body_len > 0) { - body_args = VectorArgInfo(make_node()); - body_args.GetNode()->body_num_ = 1; - body_args.GetNode()->repeat_ = Expr(body_len / FULL_BLOCK_NUM); - // Here use dst_stride_m1 as dst_stride - body_args.GetNode()->dst_stride_m1_ = Expr(1); - body_args.GetNode()->src_stride_m0_list_ = {Expr(1)}; - body_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)}; - body_args.GetNode()->vec_mask_ = GetVecMask(params.last_dim_shape, FULL_BLOCK_NUM, dst_info->dtype_); - cmd_body_len += GetInt32Const(body_args->repeat_) * FULL_BLOCK_NUM; - } - if (tail_len > 0) { - tail_args = VectorArgInfo(make_node()); - tail_args.GetNode()->dst_head_ = Expr(cmd_body_len); - tail_args.GetNode()->src_head_list_ = {Expr(cmd_body_len / FULL_BLOCK_NUM * params.vec_max_len)}; - tail_args.GetNode()->repeat_ = Expr(1); - tail_args.GetNode()->dst_stride_m1_ = Expr(1); - tail_args.GetNode()->src_stride_m0_list_ = {Expr(1)}; - tail_args.GetNode()->src_stride_m1_list_ = {Expr(0)}; - tail_args.GetNode()->vec_mask_ = GetVecMask(params.last_dim_shape, tail_len, dst_info->dtype_); - } - - params.insn_offset_scale_factor = 1; - return GetRange(params.src_var, -2, 2); -} - -Array ReduceLastAxisPatternGenerator::Get1DPattern() { - int body_len = params.last_dim_shape / params.vec_max_len * params.vec_max_len; - int tail_len = params.last_dim_shape % params.vec_max_len; - int cmd_body_len = 0; - bool is_vadd = intrin_name == "vadd"; - int repeat_stride = FULL_BLOCK_NUM; - if (is_vadd) { - repeat_stride = 1; - } - const int fp16_block_size = 16; - - if (body_len > 0) { - body_args = VectorArgInfo(make_node()); - body_args.GetNode()->body_num_ = 1; - body_args.GetNode()->body_offset_ = params.vec_max_len; - body_args.GetNode()->repeat_ = Expr(body_len / params.vec_max_len); - // Here use dst_stride_m1 as dst_stride - body_args.GetNode()->dst_stride_m1_ = Expr(1); - body_args.GetNode()->src_stride_m0_list_ = {Expr(1)}; - body_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)}; - body_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, 1, dst_info->dtype_); - cmd_body_len += GetInt32Const(body_args->repeat_) * repeat_stride; - } - if (tail_len > 0) { - tail_args = VectorArgInfo(make_node()); - tail_args.GetNode()->body_offset_ = params.vec_max_len; - tail_args.GetNode()->dst_head_ = Expr(cmd_body_len); - tail_args.GetNode()->src_head_list_ = {Expr(body_len)}; - tail_args.GetNode()->repeat_ = Expr(1); - tail_args.GetNode()->dst_stride_m1_ = Expr(1); - tail_args.GetNode()->src_stride_m0_list_ = {Expr(1)}; - tail_args.GetNode()->src_stride_m1_list_ = {Expr(0)}; - tail_args.GetNode()->vec_mask_ = GetVecMask(tail_len, 1, dst_info->dtype_); - if (is_vadd) { - cmd_body_len += 1; - } else { - cmd_body_len += tail_len / fp16_block_size; - if (tail_len % fp16_block_size != 0) { - cmd_body_len += 1; - } - } - } - // cmd_body_len > 1 means vcadd size greater than 128, need to use vcadd again to compute final result - // if cmd_body_len > 128, then need to recursively emit vcadd - while (cmd_body_len > 1) { - int cmd_tail_len = cmd_body_len % params.vec_max_len; - cmd_body_len = cmd_body_len / params.vec_max_len; - if (cmd_body_len > 0) { - VectorArgInfo mix_vec_args = VectorArgInfo(make_node()); - mix_vec_args.GetNode()->repeat_ = Expr(cmd_body_len); - mix_vec_args.GetNode()->dst_head_ = Expr(0); - mix_vec_args.GetNode()->src_head_list_ = {Expr(0)}; - mix_vec_args.GetNode()->dst_stride_m1_ = Expr(1); - mix_vec_args.GetNode()->src_stride_m0_list_ = {Expr(1)}; - mix_vec_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)}; - mix_vec_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, 1, dst_info->dtype_); - mix_vec_arg_list.push_back(mix_vec_args); - if (!is_vadd) { - cmd_body_len *= FULL_BLOCK_NUM; - } - } - if (cmd_tail_len > 0) { - VectorArgInfo mix_vec_args = VectorArgInfo(make_node()); - mix_vec_args.GetNode()->repeat_ = Expr(1); - mix_vec_args.GetNode()->dst_head_ = Expr(cmd_body_len); - if (is_vadd) { - mix_vec_args.GetNode()->src_head_list_ = {Expr(cmd_body_len * params.vec_max_len)}; - } else { - mix_vec_args.GetNode()->src_head_list_ = {Expr(cmd_body_len / FULL_BLOCK_NUM * params.vec_max_len)}; - } - mix_vec_args.GetNode()->dst_stride_m1_ = Expr(1); - mix_vec_args.GetNode()->src_stride_m0_list_ = {Expr(1)}; - mix_vec_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)}; - mix_vec_args.GetNode()->vec_mask_ = GetVecMask(cmd_tail_len, 1, dst_info->dtype_); - if (is_vadd) { - cmd_body_len += 1; - } else { - cmd_body_len += cmd_tail_len / fp16_block_size; - if (cmd_tail_len % fp16_block_size != 0) { - cmd_body_len += 1; - } - } - mix_vec_arg_list.push_back(mix_vec_args); - } - } - - params.insn_offset_scale_factor = Expr(params.block_size); - int max_num = body_len / params.vec_max_len; - if (intrin_name == "vmax" || intrin_name == "vmin") { - max_num *= FULL_BLOCK_NUM; - } - if (max_num >= params.block_size) { - params.insn_offset_scale_factor = max_num + params.block_size - 1; - if (tail_len > 0) { - params.insn_offset_scale_factor += 1; - } - params.insn_offset_scale_factor = truncdiv(params.insn_offset_scale_factor, params.block_size) * params.block_size; - } - - if (!params.src_var.empty()) { - return GetRange(params.src_var, -1, 1); - } - - return {}; -} - -PatternResult ReduceLastAxisPatternGenerator::GenResult(const Array &elim_var) { - dst_info.GetNode()->insn_offset_ = GetInsnOffset(dst_info, elim_var) * params.insn_offset_scale_factor; - src_info.GetNode()->insn_offset_ = GetInsnOffset(src_info, elim_var); - - if (body_args.defined()) { - body_args.GetNode()->insn_offset_scale_factor_ = params.insn_offset_scale_factor; - } - if (tail_args.defined()) { - tail_args.GetNode()->insn_offset_scale_factor_ = params.insn_offset_scale_factor; - } - for (auto &arg : mix_vec_arg_list) { - arg.GetNode()->insn_offset_scale_factor_ = params.insn_offset_scale_factor; - } - - arg_info.GetNode()->body_arg_info_ = body_args; - arg_info.GetNode()->tail_arg_info_ = tail_args; - arg_info.GetNode()->reduction_tail_args_ = mix_vec_arg_list; - - CleanForInfoVars(for_info, elim_var); - arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION_LAST_AXIS; - - PatternResult result; - result.dst_info_list = {dst_info}; - result.src_info_list = {src_info}; - result.for_info = for_info; - result.arg_info = arg_info; - - return result; -} - -void ReduceLastAxisPatternGenerator::CalcParams() { - // check shape len - if (dst_info->shape_.empty() || src_info->shape_.empty()) { - LOG(FATAL) << "CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0."; - } - // check data type - if (dst_info->dtype_ != src_info->dtype_) { - LOG(FATAL) << "CCE Vector Insn Error: dst_buffer and src_buffer can not be different data type."; - } - - params.src_var = src_info->var_; - params.block_size = GetUbBlkSize(dst_info->dtype_); - params.last_dim_shape = GetInt32Const(GetItem(src_info->shape_, -1)); - params.vec_max_len = GetVecMaxLen(dst_info->dtype_); - CHECK_NE(params.block_size, 0); - CHECK_NE(params.vec_max_len, 0); -} - -/// Get CCE Binary Vector instructions args -/// \return -PatternResult BinaryVecPatternGenerator::GetInsnArgs() { - CalcParams(); - if (arg_info->arg_type_ == ARG_VECTOR_BROADCAST_LAST_AXIS) { - PatternResult result; - result.dst_info_list = {dst_info}; - result.src_info_list = src_info_list; - result.for_info = for_info; - result.arg_info = arg_info; - return result; - } - - Array elim_var = {}; - - float rate3d = Compute3DPatternMaskRate(); - float rate2db = Compute2DBlockPatternMaskRate(); - float rate2d = Compute2DPatternMaskRate(); - float rate1d = Compute1DPatternMaskRate(); - - if (rate3d >= rate2db && rate3d > 0) { - elim_var = Get3DPattern(); - arg_info.GetNode()->pattern_ = PATTERN_3D; - } else if (rate2db >= rate2d && rate2db >= rate1d && rate2db > 0) { - elim_var = Get2DBlockPattern(); - arg_info.GetNode()->pattern_ = PATTERN_PARTIAL_3D; - } else if (rate2d > rate1d && rate2d > 0) { - elim_var = Get2DPattern(); - arg_info.GetNode()->pattern_ = PATTERN_2D; - } else if (rate1d > 0) { - elim_var = Get1DPattern(); - arg_info.GetNode()->pattern_ = PATTERN_1D; - } else { - LOG(FATAL) << "Error: Cannot emit Binary-Vector-Insn with any pattern!"; - } - - std::string mask_rate = "rate3d[" + std::to_string(rate3d) + "], rate2db[" + std::to_string(rate2db) + "], rate2d[" + - std::to_string(rate2d) + "], rate1d[" + std::to_string(rate1d) + "]"; - CommentManager::GetInstance().AddComment("Mask_rate", mask_rate); - if (tail_args.defined()) { - CommentManager::GetInstance().AddComment("Contain_tail", "true"); - } else { - CommentManager::GetInstance().AddComment("Contain_tail", "false"); - } - - return GenResult(elim_var); -} - -float BinaryVecPatternGenerator::Compute3DPatternMaskRate() { - if (params.non_zero_shape3 == 1 || params.non_zero_shape2 == 1) { - return not_this_pattern; - } - // in elemwise mode, the var is already checked to be equal, no need to check - if (params.dst_var.size() < 3 || GetIntConst(GetItem(params.dst_shape, -1)) > params.block_size || - GetIntConst(GetItem(params.dst_strides, -2)) % params.block_size != 0 || - GetIntConst(GetItem(params.dst_strides, -3)) % params.block_size != 0 || - (GetIntConst(GetItem(params.dst_strides, -2)) > 0 && GetIntConst(GetItem(params.dst_shape, -1)) > 0 && - GetIntConst(GetItem(params.dst_strides, -2)) < GetIntConst(GetItem(params.dst_shape, -1))) || - (GetIntConst(GetItem(params.dst_strides, -3)) > 0 && GetIntConst(GetItem(params.dst_shape, -2)) > 0 && - GetIntConst(GetItem(params.dst_strides, -3)) < GetIntConst(GetItem(params.dst_shape, -2)))) { - return not_this_pattern; - } - // check dst_stride_m0 - // As described in ISL User Guide t6.3, - // dst_stride_m0 = 0 is treated as 1 - auto JudgeNot3D = [this](const StmtStoreInfo &info) { - auto last_shape1 = GetIntConst(GetItem(info->shape_, -1)); - if (info->var_.size() < 3 || last_shape1 > params.block_size) { - return true; - } - - auto last_shape2 = GetIntConst(GetItem(info->shape_, -2)); - auto last_stride2 = GetIntConst(GetItem(info->strides_, -2)); - auto last_stride3 = GetIntConst(GetItem(info->strides_, -3)); - - return last_stride2 % params.block_size != 0 || last_stride3 % params.block_size != 0 || - (last_stride2 > 0 && last_shape1 > 0 && last_stride2 < last_shape1) || - (last_stride3 > 0 && last_shape2 > 0 && last_stride3 < last_shape2); - }; - if (std::any_of(src_info_list.begin(), src_info_list.end(), JudgeNot3D)) { - return not_this_pattern; - } - - if (mode == "reduction") { - // check same alignment - Array shape_list = {GetItem(params.dst_shape, -1)}; - shape_list.push_back(GetItem(params.src_shape0, -1)); - shape_list.push_back(GetItem(params.src_shape1, -1)); - if (!IsNonZeroShapeEqual(shape_list)) { - return not_this_pattern; - } - } - - // repeat axis is shape [-3], repeat once, has 8 loops - bool is3_d = true; - float rate3d_mode1 = not_this_pattern; - float rate3d_mode2 = not_this_pattern; - int repeat_num; - float repeat_latency; - auto info_list = src_info_list; - Insert(info_list, 0, dst_info); - for (auto info : info_list) { - if (GetInt32Const(GetItem(info->shape_, -2)) > FULL_BLOCK_NUM || - GetInt32Const(GetItem(info->strides_, -2)) / params.block_size >= MAX_STRIDE_M0 || - GetInt32Const(GetItem(info->strides_, -3)) / params.block_size >= MAX_STRIDE_M0) { - is3_d = false; - break; - } - } - if (is3_d) { - if (GetInt32Const(GetItem(params.dst_strides, -2)) == 0) { - return not_this_pattern; - } - repeat_num = params.non_zero_shape3; - repeat_latency = ((repeat_num - 1) / MAX_REPEAT) * repeat_latency_coef; - rate3d_mode1 = static_cast(params.all_points) / params.vec_max_len / (repeat_num + repeat_latency); - } - - is3_d = true; - // repeat axis is shape[-2] - for (auto info : info_list) { - // stride_m0 should be less than 256 - if (GetIntConst(GetItem(info->shape_, -3)) % FULL_BLOCK_NUM != 0 || - GetIntConst(GetItem(info->strides_, -3)) / params.block_size >= MAX_STRIDE_M0) { - is3_d = false; - break; - } - } - if (is3_d) { - if (GetIntConst(GetItem(params.dst_strides, -3)) == 0) { - return not_this_pattern; - } - repeat_num = params.non_zero_shape2 * (params.non_zero_shape3 / FULL_BLOCK_NUM); - repeat_latency = ((repeat_num - 1) / MAX_REPEAT) * repeat_latency_coef; - float offset_latency = - params.non_zero_shape3 / FULL_BLOCK_NUM > 1 ? params.non_zero_shape3 * offset_latency_coef : 0; - rate3d_mode2 = - static_cast(params.all_points) / params.vec_max_len / (repeat_num + repeat_latency + offset_latency); - } - - return rate3d_mode1 > rate3d_mode2 ? rate3d_mode1 : rate3d_mode2; -} - -float BinaryVecPatternGenerator::Compute2DBlockPatternMaskRate() { - if (params.non_zero_shape2 == 1 || GetIntConst(GetItem(params.dst_strides, -1)) != 1) { - return not_this_pattern; - } - if (params.dst_var.size() < 2 || GetIntConst(GetItem(params.dst_shape, -1)) > params.block_size || - GetIntConst(GetItem(params.dst_strides, -2)) % params.block_size != 0 || - GetIntConst(GetItem(params.dst_strides, -2)) / params.block_size >= MAX_STRIDE_M0 || - (GetIntConst(GetItem(params.dst_strides, -2)) > 0 && GetIntConst(GetItem(params.dst_shape, -1)) > 0 && - GetIntConst(GetItem(params.dst_strides, -2)) < GetIntConst(GetItem(params.dst_shape, -1)))) { - return not_this_pattern; - } - - for (auto info : src_info_list) { - if (info->var_.size() < 2 || GetIntConst(GetItem(info->shape_, -1)) > params.block_size || - GetIntConst(GetItem(info->strides_, -2)) % params.block_size != 0 || - GetIntConst(GetItem(info->strides_, -2)) / params.block_size >= MAX_STRIDE_M0 || - (GetIntConst(GetItem(info->strides_, -2)) > 0 && GetIntConst(GetItem(info->shape_, -1)) > 0 && - GetIntConst(GetItem(info->strides_, -2)) < GetIntConst(GetItem(info->shape_, -1)))) { - return not_this_pattern; - } - } - - if (GetIntConst(GetItem(params.dst_strides, -2)) == 0) { - return not_this_pattern; - } - - if (mode == "reduction") { - if (params.dst_var.size() > 2) { - // if not elewise mode, then can not use partial 3D mode - if (GetIntConst(GetItem(params.dst_shape, -3)) == 0) { - return not_this_pattern; - } - - for (auto info : src_info_list) { - if (GetIntConst(GetItem(info->shape_, -3)) == 0) { - return not_this_pattern; - } - } - } - // check same alignment - Array shape_list = {GetItem(params.dst_shape, -1)}; - shape_list.push_back(GetItem(params.src_shape0, -1)); - shape_list.push_back(GetItem(params.src_shape1, -1)); - // check dst_stride_m0 - // As described in ISL User Guide t6.3, - // dst_stride_m0 = 0 is treated as 1 - if (!IsNonZeroShapeEqual(shape_list)) { - return not_this_pattern; - } - } - - int repeat_body_num = params.non_zero_shape2 / FULL_BLOCK_NUM; - int repeat_tail_num = (params.non_zero_shape2 % FULL_BLOCK_NUM + FULL_BLOCK_NUM - 1) / FULL_BLOCK_NUM; - int repeat_num = (repeat_body_num + repeat_tail_num) * params.non_zero_shape3; - float repeat_latency = - (std::max(repeat_body_num - 1, 0) / MAX_REPEAT + std::max(repeat_tail_num - 1, 0) / MAX_REPEAT) * - repeat_latency_coef; - float offset_latency = params.non_zero_shape3 > 1 ? params.non_zero_shape3 * offset_latency_coef : 0; - float split_latency = (repeat_body_num > 0 && repeat_tail_num > 0) ? split_latency_coef : 0; - float rate2db = static_cast(params.all_points) / params.vec_max_len / - (repeat_num + repeat_latency + offset_latency + split_latency); - - return rate2db; -} - -float BinaryVecPatternGenerator::Compute2DPatternMaskRate() { - if (params.non_zero_shape2 == 1) { - return not_this_pattern; - } - if (params.dst_var.size() < 2 || GetIntConst(GetItem(params.dst_strides, -2)) % params.block_size != 0 || - (GetIntConst(GetItem(params.dst_strides, -2)) < GetIntConst(GetItem(params.dst_shape, -1)) && - GetIntConst(GetItem(params.dst_strides, -2) > 0))) { - return not_this_pattern; - } - - for (auto info : src_info_list) { - if (info->var_.size() < 2 || GetIntConst(GetItem(info->strides_, -2)) % params.block_size != 0 || - (GetIntConst(GetItem(info->strides_, -2)) < GetIntConst(GetItem(info->shape_, -1)) && - GetIntConst(GetItem(info->strides_, -2) > 0))) { - return not_this_pattern; - } - } - - // check num of insns, select 1D pattern or 2D pattern - int tail_factor = 0; - if (mode == "reduction") { - Array shape_list = {GetItem(params.dst_shape, -1)}; - shape_list.push_back(GetItem(params.src_shape0, -1)); - shape_list.push_back(GetItem(params.src_shape1, -1)); - if (!IsNonZeroShapeEqual(shape_list)) { - return not_this_pattern; - } - } - - // only cloud allow dst_stride_m1 = 0 - cceconf::CceConf *conf = cceconf::CceConf::getInstance(); - const std::string product_name = conf->getProductName(); - if (GetIntConst(GetItem(params.dst_strides, -2)) == 0 && product_name != "cloud") { - return not_this_pattern; - } - - CHECK_NE(params.vec_max_len, 0); - if (params.non_zero_shape1 / params.vec_max_len > 0 && params.non_zero_shape1 % params.vec_max_len > 0) { - tail_factor = 1; - } - - if (GetIntConst(GetItem(dst_info->strides_, -2)) / params.block_size >= MAX_STRIDE_M0) { - return not_this_pattern; - } - for (auto info : src_info_list) { - if (GetIntConst(GetItem(info->strides_, -2)) / params.block_size >= MAX_STRIDE_M0) { - return not_this_pattern; - } - } - - int shape1 = (params.non_zero_shape1 + params.vec_max_len - 1) / params.vec_max_len; - int repeat_num = shape1 * params.non_zero_shape2 * params.non_zero_shape3; - float repeat_latency = - (std::max(params.non_zero_shape2 - 1, 0) / MAX_REPEAT) * params.non_zero_shape3 * shape1 * repeat_latency_coef; - float offset_latency = - shape1 * params.non_zero_shape3 > 1 ? shape1 * params.non_zero_shape3 * offset_latency_coef : 0; - float split_latency = tail_factor * split_latency_coef; - float rate2d = static_cast(params.all_points) / params.vec_max_len / - (repeat_num + repeat_latency + offset_latency + split_latency); - - return rate2d; -} - -float BinaryVecPatternGenerator::Compute1DPatternMaskRate() { - int tail_factor = 0; - if (params.non_zero_shape1 / params.vec_max_len > 0 && params.non_zero_shape1 % params.vec_max_len > 0) { - tail_factor = 1; - } - - int shape1 = (params.non_zero_shape1 + params.vec_max_len - 1) / params.vec_max_len; - int repeat_num = shape1 * params.non_zero_shape2 * params.non_zero_shape3; - float repeat_latency = - std::max((shape1 - 1) / MAX_REPEAT, 0) * params.non_zero_shape2 * params.non_zero_shape3 * repeat_latency_coef; - float offset_latency = params.non_zero_shape2 * params.non_zero_shape3 > 1 - ? params.non_zero_shape2 * params.non_zero_shape3 * offset_latency_coef - : 0; - float split_latency = tail_factor * split_latency_coef; - float rate1d = static_cast(params.all_points) / params.vec_max_len / - (repeat_num + repeat_latency + offset_latency + split_latency); - - return rate1d; -} - -Array BinaryVecPatternGenerator::Get3DPattern() { - // repeat axis is shape [-2] - int second_last_shape = GetInt32Const( - GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape0, -2), GetItem(params.src_shape1, -2))); - int third_last_shape = GetInt32Const( - GetNonZeroShape(GetItem(params.dst_shape, -3), GetItem(params.src_shape0, -3), GetItem(params.src_shape1, -3))); - if (second_last_shape > 8) { - // split shape[-3] - if (third_last_shape > 8) { - auto info_list = src_info_list; - Insert(info_list, 0, dst_info); - SplitAxis(info_list, for_info, GetItem(params.dst_var, -3), FULL_BLOCK_NUM); - FillEmptyVar(info_list); - - params.dst_var = info_list[0]->var_; - params.dst_shape = info_list[0]->shape_; - params.dst_strides = info_list[0]->strides_; - params.src_var0 = info_list[1]->var_; - params.src_shape0 = info_list[1]->shape_; - params.src_strides0 = info_list[1]->strides_; - params.src_var1 = info_list[2]->var_; - params.src_shape1 = info_list[2]->shape_; - params.src_strides1 = info_list[2]->strides_; - } - // consider original shape[-2] as repeat axis - GetShapeInfoAndSwap(params.dst_var, params.dst_shape, params.dst_strides, -2, -3); - GetShapeInfoAndSwap(params.src_var0, params.src_shape0, params.src_strides0, -2, -3); - GetShapeInfoAndSwap(params.src_var1, params.src_shape1, params.src_strides1, -2, -3); - } - - body_args = VectorArgInfo(make_node()); - CHECK(body_args.GetNode()); - body_args.GetNode()->body_num_ = 1; - body_args.GetNode()->repeat_ = GetItem(params.dst_shape, -3); - - body_args.GetNode()->dst_stride_m0_ = truncdiv(GetItem(params.dst_strides, -2), params.block_size); - body_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -3), params.block_size); - body_args.GetNode()->src_stride_m0_list_ = {truncdiv(GetItem(params.src_strides0, -2), params.block_size), - truncdiv(GetItem(params.src_strides1, -2), params.block_size)}; - body_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides0, -3), params.block_size), - truncdiv(GetItem(params.src_strides1, -3), params.block_size)}; - - int data_num = GetInt32Const(GetItem(params.dst_shape, -2)); - if (mode == "reduction") { - body_args.GetNode()->repeat_ = Expr( - GetNonZeroShape(GetItem(params.dst_shape, -3), GetItem(params.src_shape0, -3), GetItem(params.src_shape1, -3))); - data_num = - GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape0, -2), GetItem(params.src_shape1, -2)); - } - int data_len = expand_mask ? CeilTo(params.last_dim_shape, params.block_size) : params.last_dim_shape; - body_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, dst_info->dtype_); - - return GetRange(params.dst_var, -3, 3); -} - -Array BinaryVecPatternGenerator::Get2DBlockPattern() { - int repeat_len = GetInt32Const(GetItem(params.dst_shape, -2)); - if (mode == "reduction") { - params.last_dim_shape = - GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape0, -1), GetItem(params.src_shape1, -1)); - repeat_len = - GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape0, -2), GetItem(params.src_shape1, -2)); - } - int repeat_body = repeat_len / FULL_BLOCK_NUM; - int repeat_tail = (repeat_len % FULL_BLOCK_NUM + FULL_BLOCK_NUM - 1) / FULL_BLOCK_NUM; - - if (repeat_body > 0) { - body_args = VectorArgInfo(make_node()); - CHECK(body_args.GetNode() != nullptr); - body_args.GetNode()->body_num_ = 1; - body_args.GetNode()->repeat_ = Expr(repeat_body); - body_args.GetNode()->dst_stride_m0_ = truncdiv(GetItem(params.dst_strides, -2), params.block_size); - body_args.GetNode()->dst_stride_m1_ = body_args->dst_stride_m0_ * FULL_BLOCK_NUM; - Expr src0_stride_m0 = truncdiv(GetItem(params.src_strides0, -2), params.block_size); - Expr src1_stride_m0 = truncdiv(GetItem(params.src_strides1, -2), params.block_size); - body_args.GetNode()->src_stride_m0_list_ = {src0_stride_m0, src1_stride_m0}; - body_args.GetNode()->src_stride_m1_list_ = {src0_stride_m0 * FULL_BLOCK_NUM, src1_stride_m0 * FULL_BLOCK_NUM}; - int data_len = expand_mask ? CeilTo(params.last_dim_shape, params.block_size) : params.last_dim_shape; - body_args.GetNode()->vec_mask_ = GetVecMask(data_len, FULL_BLOCK_NUM, dst_info->dtype_); - } - if (repeat_tail > 0) { - tail_args = VectorArgInfo(make_node()); - CHECK(tail_args.GetNode() != nullptr); - tail_args.GetNode()->dst_head_ = GetItem(params.dst_strides, -2) * repeat_body * FULL_BLOCK_NUM; - tail_args.GetNode()->src_head_list_ = {GetItem(params.src_strides0, -2) * repeat_body * FULL_BLOCK_NUM, - GetItem(params.src_strides1, -2) * repeat_body * FULL_BLOCK_NUM}; - tail_args.GetNode()->repeat_ = Expr(1); - tail_args.GetNode()->dst_stride_m0_ = truncdiv(GetItem(params.dst_strides, -2), params.block_size); - tail_args.GetNode()->dst_stride_m1_ = Expr(0); - tail_args.GetNode()->src_stride_m0_list_ = {truncdiv(GetItem(params.src_strides0, -2), params.block_size), - truncdiv(GetItem(params.src_strides1, -2), params.block_size)}; - tail_args.GetNode()->src_stride_m1_list_ = {Expr(0), Expr(0)}; - int data_len = expand_mask ? CeilTo(params.last_dim_shape, params.block_size) : params.last_dim_shape; - tail_args.GetNode()->vec_mask_ = GetVecMask(data_len, repeat_len % FULL_BLOCK_NUM, dst_info->dtype_); - } - return GetRange(params.dst_var, -2, 2); -} - -Array BinaryVecPatternGenerator::Get2DPattern() { - if (mode == "reduction") { - params.last_dim_shape = - GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape0, -1), GetItem(params.src_shape1, -1)); - } - - int body_len = FloorTo(params.last_dim_shape, params.vec_max_len); - int tail_len = params.last_dim_shape % params.vec_max_len; - - if (body_len > 0) { - body_args = VectorArgInfo(make_node()); - CHECK(body_args.GetNode() != nullptr); - body_args.GetNode()->body_num_ = body_len / params.vec_max_len; - body_args.GetNode()->body_offset_ = params.vec_max_len; - body_args.GetNode()->repeat_ = GetItem(params.dst_shape, -2); - if (mode == "reduction") { - body_args.GetNode()->repeat_ = - GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape0, -2), GetItem(params.src_shape1, -2)); - } - body_args.GetNode()->dst_stride_m0_ = Expr(1); - body_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -2), params.block_size); - body_args.GetNode()->src_stride_m0_list_ = {Expr(1), Expr(1)}; - body_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides0, -2), params.block_size), - truncdiv(GetItem(params.src_strides1, -2), params.block_size)}; - body_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, 1, dst_info->dtype_); - } - if (tail_len > 0) { - tail_args = VectorArgInfo(make_node()); - CHECK(tail_args.GetNode() != nullptr); - tail_args.GetNode()->dst_head_ = Expr(body_len); - tail_args.GetNode()->src_head_list_ = {Expr(body_len), Expr(body_len)}; - tail_args.GetNode()->repeat_ = GetItem(params.dst_shape, -2); - if (mode == "reduction") { - tail_args.GetNode()->repeat_ = - GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape0, -2), GetItem(params.src_shape1, -2)); - } - tail_args.GetNode()->dst_stride_m0_ = Expr(1); - tail_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -2), params.block_size); - tail_args.GetNode()->src_stride_m0_list_ = {Expr(1), Expr(1)}; - tail_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides0, -2), params.block_size), - truncdiv(GetItem(params.src_strides1, -2), params.block_size)}; - tail_args.GetNode()->vec_mask_ = GetVecMask(tail_len, 1, dst_info->dtype_); - } - return GetRange(params.dst_var, -2, 2); -} - -Array BinaryVecPatternGenerator::Get1DPattern() { - auto info_list = src_info_list; - Insert(info_list, 0, dst_info); - bool is_scalar_mode = IsScalarMode(info_list); - if (is_scalar_mode) { - params.last_dim_shape = 1; - } - - if (mode == "reduction") { - params.last_dim_shape = - GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape0, -1), GetItem(params.src_shape1, -1)); - } - int body_len = FloorTo(params.last_dim_shape, params.vec_max_len); - int tail_len = params.last_dim_shape % params.vec_max_len; - - int last_axis = -1; - if (mode == "broadcast") { - if (GetIntConst(GetItem(params.src_strides0, -1)) == 0) { - last_axis = 0; - } - if (GetIntConst(GetItem(params.src_strides1, -1)) == 0) { - last_axis = 1; - } - } - - if (body_len > 0) { - body_args = VectorArgInfo(make_node()); - CHECK(body_args.GetNode() != nullptr); - body_args.GetNode()->last_axis_info_.src_index_ = last_axis; - body_args.GetNode()->body_num_ = 1; - body_args.GetNode()->repeat_ = body_len / params.vec_max_len; - body_args.GetNode()->dst_stride_m0_ = Expr(1); - body_args.GetNode()->dst_stride_m1_ = Expr(FULL_BLOCK_NUM); - body_args.GetNode()->src_stride_m0_list_ = {Expr(1), Expr(1)}; - body_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM), Expr(FULL_BLOCK_NUM)}; - body_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, 1, dst_info->dtype_); - } - if (tail_len > 0) { - tail_args = VectorArgInfo(make_node()); - CHECK(tail_args.GetNode() != nullptr); - tail_args.GetNode()->last_axis_info_.src_index_ = last_axis; - tail_args.GetNode()->dst_head_ = Expr(body_len); - tail_args.GetNode()->src_head_list_ = {Expr(body_len), Expr(body_len)}; - tail_args.GetNode()->repeat_ = Expr(1); - tail_args.GetNode()->dst_stride_m0_ = Expr(1); - tail_args.GetNode()->dst_stride_m1_ = Expr(0); - tail_args.GetNode()->src_stride_m0_list_ = {Expr(1), Expr(1)}; - tail_args.GetNode()->src_stride_m1_list_ = {Expr(0), Expr(0)}; - int data_len = expand_mask ? CeilTo(tail_len, params.block_size) : tail_len; - tail_args.GetNode()->vec_mask_ = GetVecMask(data_len, 1, dst_info->dtype_); - } - - // compute offset for cce instructions - Array elim_var = {}; - if (mode == "elewise" && params.dst_var.size() >= 2 && params.dst_strides.size() >= 2 && - params.last_dim_shape <= params.vec_max_len && for_info.ops_.size() >= 2 && - params.last_dim_shape >= params.vec_max_len - params.block_size && - GetIntConst(GetItem(params.dst_strides, -2)) == params.last_dim_shape) { - // in this case we can merge second last for extent to repeat - size_t idx = 0; - bool suc = GetIndexOfElement(for_info.vars_, GetItem(params.dst_var, -2), idx); - CHECK(suc); - auto latest_for = GetItem(for_info.ops_, idx).as(); - // there should not be if_op between for loop and compute stmt - if (latest_for && !latest_for->body->IsInstance()) { - if (!params.dst_var.empty() && !is_scalar_mode) { - if (body_args.defined()) { - // last_dim_shape = vec_max_len - body_args.GetNode()->repeat_ = body_args->repeat_ * latest_for->extent; - } else if (tail_args.defined()) { - // last_dim_shape < vec_max_len - tail_args.GetNode()->repeat_ = tail_args->repeat_ * latest_for->extent; - } - return elim_var = GetRange(params.dst_var, -2, 2); - } - } - } - - if (!params.dst_var.empty() && !is_scalar_mode) { - elim_var = GetRange(params.dst_var, -1, 1); - } - - return elim_var; -} - -PatternResult BinaryVecPatternGenerator::GenResult(const Array &elim_var) { - arg_info.GetNode()->body_arg_info_ = body_args; - arg_info.GetNode()->tail_arg_info_ = tail_args; - - auto real_elim_var = elim_var; - if (!empty_var->name_hint.empty()) { - bool need_elim = true; - for (auto e : elim_var) { - if (e->name_hint == empty_var->name_hint) { - need_elim = false; - break; - } - } - if (need_elim) { - real_elim_var.push_back(empty_var); - } - } - - dst_info.GetNode()->insn_offset_ = GetInsnOffset(dst_info, real_elim_var); - for (auto &info : src_info_list) { - info.GetNode()->insn_offset_ = GetInsnOffset(info, real_elim_var); - } - - CleanForInfoVars(for_info, real_elim_var); - CleanZeroStrides(dst_info); - CleanZeroStrides(src_info_list); - - if (mode == "elewise") { - arg_info.GetNode()->arg_type_ = ARG_VECTOR_ELEWISE; - } else if (mode == "broadcast") { - arg_info.GetNode()->arg_type_ = ARG_VECTOR_BROADCAST; - } else if (mode == "reduction") { - arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION; - } - - PatternResult result; - result.dst_info_list = {dst_info}; - result.src_info_list = src_info_list; - result.for_info = for_info; - result.arg_info = arg_info; - - return result; -} - -void BinaryVecPatternGenerator::CalcParams() { - CHECK_GE(src_info_list.size(), 2); - StmtStoreInfo src_info0 = src_info_list[0]; - StmtStoreInfo src_info1 = src_info_list[1]; - - StmtInfoList info_list = {dst_info, src_info0, src_info1}; - - // check shape len - for (auto info : info_list) { - if (info->shape_.empty()) { - LOG(FATAL) << "CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0."; - } - } - - // check data type - for (auto src_info : src_info_list) { - if (dst_info->dtype_ != src_info->dtype_) { - LOG(FATAL) << "CCE Vector Insn Error: dst_buffer and src_buffer can not be different data type."; - } - } - - params.last_dim_shape = GetInt32Const(GetItem(dst_info->shape_, -1)); - AppendEmptyVar(info_list); - if (arg_info->arg_type_ == ARG_VECTOR_BROADCAST_LAST_AXIS) { - return; - } - - if (mode == "reduction" || mode == "broadcast") { - FillEmptyVar(info_list); - } - CHECK_EQ(info_list.size(), 3); - dst_info = info_list[0]; - src_info0 = info_list[1]; - src_info1 = info_list[2]; - - params.vec_max_len = GetVecMaxLen(dst_info->dtype_); - params.block_size = GetUbBlkSize(dst_info->dtype_); - CHECK_NE(params.vec_max_len, 0); - CHECK_NE(params.block_size, 0); - - params.dst_var = dst_info->var_; - params.dst_shape = dst_info->shape_; - params.dst_strides = dst_info->strides_; - params.src_var0 = src_info0->var_; - params.src_var1 = src_info1->var_; - params.src_shape0 = src_info0->shape_; - params.src_shape1 = src_info1->shape_; - params.src_strides0 = src_info0->strides_; - params.src_strides1 = src_info1->strides_; - - auto GetNonZeroShapeByIdx = [this](int index) -> int { - if (index <= static_cast(params.dst_var.size())) { - if (Equal(GetItem(params.dst_var, -index), GetItem(params.src_var0, -index)) && - Equal(GetItem(params.dst_var, -index), GetItem(params.src_var1, -index))) { - return GetNonZeroShape(GetItem(params.dst_shape, -index), GetItem(params.src_shape0, -index), - GetItem(params.src_shape1, -index)); - } - } - return 1; - }; - - params.non_zero_shape1 = GetNonZeroShapeByIdx(1); - params.non_zero_shape2 = GetNonZeroShapeByIdx(2); - params.non_zero_shape3 = GetNonZeroShapeByIdx(3); - params.all_points = params.non_zero_shape1 * params.non_zero_shape2 * params.non_zero_shape3; -} - -bool BinaryVecPatternGenerator::IsSamePatternComInfo(const StmtStoreInfo &info_a, const StmtStoreInfo &info_b) { - if (IsSame(info_a->var_, info_b->var_)) { - if (info_a->shape_.size() != info_b->shape_.size()) { - return false; - } - for (size_t i = 0; i < info_a->shape_.size(); ++i) { - if (!IsTwoItemEqual(info_a->shape_, info_b->shape_, static_cast(i), true)) { - return false; - } - } - if (info_a->strides_.size() != info_b->strides_.size()) { - return false; - } - for (size_t i = 0; i < info_a->strides_.size(); ++i) { - if (!IsTwoItemEqual(info_a->strides_, info_b->strides_, static_cast(i), true)) { - return false; - } - } - return true; - } - return false; -} - -bool BinaryVecPatternGenerator::IsNonZeroShapeEqual(const Array &shape_list) { - Array non_zero_list; - for (auto shape : shape_list) { - if (GetIntConst(shape) != 0) { - non_zero_list.push_back(shape); - } - } - if (non_zero_list.empty()) { - LOG(FATAL) << "Error: all shapes are equal to 0."; - } - for (auto shape : non_zero_list) { - if (GetIntConst(shape) != GetIntConst(non_zero_list[0])) { - return false; - } - } - return true; -} - -void BinaryVecPatternGenerator::AppendEmptyVar(StmtInfoList &info_list) { - auto FillEmptyVarToLast = [](const StmtStoreInfo com_info, const Var &var) -> void { - com_info.GetNode()->var_.push_back(var); - com_info.GetNode()->shape_.push_back(Expr(1)); - com_info.GetNode()->strides_.push_back(Expr(1)); - com_info.GetNode()->index_ = com_info->index_ + GetItem(com_info->var_, -1); - }; - - auto src_info0 = src_info_list[0]; - auto src_info1 = src_info_list[1]; - - if (mode == "reduction" || mode == "broadcast") { - // ISA 8.1.2, strides of Xd must be equal to Xm, [Xd = dst, Xn = src0, Xm = src1] - if (IsSamePatternComInfo(dst_info, src_info0)) { - auto tmp = src_info0; - src_info0 = src_info1; - src_info1 = tmp; - } - - if (mode == "reduction") { - if (src_info0->data_alignment_ == 1) { - empty_var = Var("empty_cc"); - FillEmptyVarToLast(src_info0, empty_var); - } - } else if (mode == "broadcast") { - // last dim broadcast, should use VS insn, such as vadds and vmuls - bool less_var = - !dst_info->var_.empty() && !src_info0->var_.empty() && !src_info1->var_.empty() && - (!IsTwoItemEqual(dst_info->var_, src_info0->var_, -1) || !IsTwoItemEqual(dst_info->var_, src_info1->var_, -1)); - bool null_var = src_info0->var_.empty() || src_info1->var_.empty(); - if (less_var || null_var) { - arg_info.GetNode()->arg_type_ = ARG_VECTOR_BROADCAST_LAST_AXIS; - return; - } else if (dst_info->data_alignment_ == 1 && src_info0->data_alignment_ == 1) { - empty_var = Var("empty_cc"); - FillEmptyVarToLast(dst_info, empty_var); - FillEmptyVarToLast(src_info0, empty_var); - FillEmptyVarToLast(src_info1, empty_var); - params.last_dim_shape = 1; - } - } - src_info_list = {src_info0, src_info1}; - info_list = {dst_info, src_info0, src_info1}; - } -} - -/// Get CCE Binary Vector Insn Computation Info -/// \param stmt - operand stmt -/// \param intrin_name - vector intrin name -/// \param dst_info_list - dst computation info list -/// \param src_info_list - src computation info list -/// \param if_info - if info list -/// \param for_info - for info list -/// \return intrin args -ArgInfo GetBinaryVecInsnArgs(const Stmt &stmt, std::string intrin_name, StmtInfoList &dst_info_list, - StmtInfoList &src_info_list, StmtInfo &if_info, StmtInfo &for_info, bool enable_bisect) { - // check intrin_name - std::set intrin_name_list = {"vadd", "vmax", "vmin", "vmul", "vdiv", "vsel", "vsub", "vand", - "vor", "vaxpy", "argmax", "argmin", "vmadd", "vmaddrelu", "vmla"}; - if (intrin_name_list.count(intrin_name) == 0) { - LOG(FATAL) << "Error: CCE Binary Vector Insn doesn't support the given intrin_name."; - } - - // get and check dst and src - GetCompactComputationInfo(stmt, dst_info_list, src_info_list, if_info, for_info, true); - // For vmadd/vmaddrelu/vmla we only need first two src - if (dst_info_list.size() != 1 || src_info_list.size() < 2) { - LOG(FATAL) << "CCE Binary Vector Insn only support ONE dst and TWO srcs."; - } - src_info_list = GetRange(src_info_list, 0, 2); - ArgInfo arg_info = ArgInfo(make_node()); - - // detect vector op mode - std::string mode = GetBinaryVecMode(dst_info_list, src_info_list, intrin_name, enable_bisect); - if (mode == "reduce_last_axis") { - size_t src_var_list_size = src_info_list[1]->var_.size(); - if (src_info_list[0]->var_.size() > src_info_list[1]->var_.size()) { - src_var_list_size = src_info_list[0]->var_.size(); - } - - CHECK(src_var_list_size > 0) << "Error: src can not be a scalar."; - if (src_var_list_size - dst_info_list[0]->var_.size() == 1) { - arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION_LAST_AXIS; - } else { - LOG(FATAL) << "Error: cannot support multi-last-axis reduction."; - } - } else if (mode == "reduce_bisection") { - arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION_BISECTION; - } else { - if (mode != "reduction" && mode != "broadcast") { - FillLastDim(dst_info_list, src_info_list, for_info); - } - - // vmax/vmin can't expand mask because it may introduce dirty data - bool can_expand_mask = intrin_name != "vmax" && intrin_name != "vmin"; - BinaryVecPatternGenerator generator = - BinaryVecPatternGenerator(dst_info_list, src_info_list, for_info, mode, can_expand_mask); - auto params = generator.GetInsnArgs(); - arg_info = params.arg_info; - dst_info_list = params.dst_info_list; - src_info_list = params.src_info_list; - for_info = params.for_info; - if (mode == "broadcast") { - bool has_last_axis = false; - if ((arg_info->body_arg_info_.defined() && arg_info->body_arg_info_->last_axis_info_.src_index_ != -1) || - (arg_info->tail_arg_info_.defined() && arg_info->tail_arg_info_->last_axis_info_.src_index_ != -1)) { - has_last_axis = true; - } - - if (has_last_axis && (intrin_name == "vadd" || intrin_name == "vmul")) { - Array stores; - Array loads; - GetStoreAndLoads(stmt, stores, loads); - intrin_name = intrin_name + "s"; - if (arg_info->body_arg_info_.defined()) { - arg_info.GetNode()->body_arg_info_.GetNode()->last_axis_info_.intrin_name_ = intrin_name; - arg_info.GetNode()->body_arg_info_.GetNode()->last_axis_info_.src_op_ = - Downcast(loads[arg_info->body_arg_info_->last_axis_info_.src_index_]); - } - } - } - } - - return arg_info; -} - -/// Replace com_info's var with new for loop's var -/// \param info -/// \param old_for_info -/// \param new_for_info -void ReplaceVarWithNewForInfo(StmtStoreInfo &info, const StmtInfo &old_for_info, const StmtInfo &new_for_info) { - for (size_t i = 0; i < new_for_info.vars_.size(); ++i) { - for (size_t j = 0; j < info->var_.size(); ++j) { - if (info->var_[j]->name_hint == new_for_info.vars_[i]->name_hint) { - SetItem(info.GetNode()->var_, static_cast(j), new_for_info.vars_[i]); - } - } - info.GetNode()->index_ = substitute(old_for_info.vars_[i], new_for_info.vars_[i], info->index_); - } -} - -/// Generete info list for bisection intrin -/// \param dst_info_list -/// \param src_info_list -/// \param for_info -/// \param if_info -/// \param last_axis -/// \param postfix -/// \return -BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_info_list, - const StmtInfoList &src_info_list, const StmtInfo &for_info, - StmtInfo &if_info, bool last_axis, int postfix = 0) { - CHECK_EQ(dst_info_list.size(), 1); - CHECK_EQ(src_info_list.size(), 2); - - BisectionInfoWrapper wrapper; - // Separate com_info and for_info - int compare_idx = 1; - int var_idx = -1; - if (last_axis) { - compare_idx = GetLastAxisReductionIdx(dst_info_list, src_info_list); - } else { - var_idx = GetBisectionReductionIdx(dst_info_list, src_info_list, compare_idx); - } - StmtStoreInfo dst_info = dst_info_list[0]; - CHECK_GE(compare_idx, 0); - StmtStoreInfo src_info1 = src_info_list[compare_idx]; - - Var reduce_var = GetItem(src_info1->var_, var_idx); - size_t for_idx = 0; - bool suc = GetIndexOfElement(for_info.vars_, VarExpr(reduce_var), for_idx); - CHECK(suc); - auto exist_for = GetItem(for_info.ops_, for_idx).as(); - CHECK(exist_for); - int extent = GetInt32Const(exist_for->extent); - - std::string prev_name = src_info1->name_; - Var prev_var = src_info1->data_; - Buffer prev_buffer = src_info1->buffer_; - Var bisec_var; - Buffer bisec_buffer; - std::string bisec_pre_header = last_axis ? "bisec_last_axis" : "bisec"; - std::string bisec_name = bisec_pre_header + "_local_UB"; - if (postfix > 0) { - bisec_name = bisec_name + "_" + std::to_string(postfix); - } - bool first_round = true; - - int vec_max_len = GetVecMaxLen(dst_info->dtype_); - int remain_extent = extent; - int left_extent = 0; - - CHECK_NE(vec_max_len, 0); - std::vector pow2_list = {0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}; - while (extent > 0) { - int for_extent; - if (last_axis) { - left_extent = remain_extent / 2 + remain_extent % 2; - for (int i : pow2_list) { - if (left_extent == i) { - break; - } else if (left_extent < i) { - left_extent = i; - break; - } - } - if (left_extent < vec_max_len) { - // When left_extent < vec_max_len, stop bisect and generate normal reduce intrin - left_extent = remain_extent; - } - extent = remain_extent - left_extent; - remain_extent = left_extent; - for_extent = extent == 0 ? vec_max_len : extent; - } else { - for_extent = extent == 1 ? extent : extent / 2; - extent = extent % 2 == 0 || extent == 1 ? extent / 2 : (extent + 1) / 2; - - for (int i : pow2_list) { - if (extent == i) { - break; - } else if (extent < i) { - int gap = i - extent; - extent = i; - for_extent -= gap; - break; - } - } - } - - StmtStoreInfo dst_tmp_info = dst_info.Copy(); - StmtStoreInfo src_tmp_info0{src_info1.Copy()}; - StmtStoreInfo src_tmp_info1{src_info1.Copy()}; - - if (first_round) { - auto shape = src_tmp_info1->shape_; - if (last_axis) { - int block_size = GetUbBlkSize(dst_info->dtype_); - SetItem(shape, -1, Expr(CeilTo(GetIntConst(GetItem(shape, -1)), block_size))); - } - wrapper.original_shape_ = shape; - bisec_var = Var(bisec_name, Handle()); - bisec_buffer = BufferNode::make(bisec_var, dst_tmp_info->dtype_, shape, Array(), Expr(), bisec_name, - SCOPE_UBUF, 0, 0, BufferType::kDefault); - - if ((last_axis && extent != left_extent) || (!last_axis && extent != for_extent)) { - // Need to copy input to bisect buffer - StmtStoreInfo copy_dst_info{src_info1.Copy()}; - StmtStoreInfo copy_src_info{src_info1.Copy()}; - StmtInfoList src_list = {copy_src_info}; - - auto for_tmp_info = for_info.Copy(); - auto new_for = GetItem(for_tmp_info.ops_, for_idx).as(); - CHECK(new_for); - SetItem(for_tmp_info.ops_, static_cast(for_idx), - For::make(new_for->loop_var, new_for->min, last_axis ? left_extent : extent, new_for->for_type, - new_for->device_api, new_for->body)); - - ReplaceVarWithNewForInfo(copy_dst_info, for_info, for_tmp_info); - ReplaceVarWithNewForInfo(copy_src_info, for_info, for_tmp_info); - - SetItem(copy_dst_info.GetNode()->shape_, var_idx, Expr(last_axis ? left_extent : extent)); - SetItem(copy_src_info.GetNode()->shape_, var_idx, Expr(last_axis ? left_extent : extent)); - - CompactComputationInfoList(copy_dst_info, src_list, if_info, for_tmp_info); - - copy_dst_info.GetNode()->name_ = bisec_name; - copy_dst_info.GetNode()->buffer_ = bisec_buffer; - copy_dst_info.GetNode()->data_ = bisec_var; - // Replace outside for variable in index - auto vars = GetVarsInExpr(copy_dst_info->index_); - for (auto var : vars) { - if (!IsInArray(copy_dst_info->var_, var)) { - copy_dst_info.GetNode()->index_ = substitute(var, Expr(0), copy_dst_info->index_); - } - } - wrapper.bisec_info_list_.emplace_back(StmtInfoList{copy_dst_info, copy_src_info}); - wrapper.for_info_list_.push_back(for_tmp_info); - } - } - - auto for_tmp_info = for_info.Copy(); - auto new_for = GetItem(for_tmp_info.ops_, for_idx).as(); - CHECK(new_for); - SetItem( - for_tmp_info.ops_, static_cast(for_idx), - For::make(new_for->loop_var, new_for->min, for_extent, new_for->for_type, new_for->device_api, new_for->body)); - - ReplaceVarWithNewForInfo(dst_tmp_info, for_info, for_tmp_info); - ReplaceVarWithNewForInfo(src_tmp_info0, for_info, for_tmp_info); - ReplaceVarWithNewForInfo(src_tmp_info1, for_info, for_tmp_info); - - SetItem(src_tmp_info0.GetNode()->shape_, var_idx, Expr(for_extent)); - SetItem(src_tmp_info1.GetNode()->shape_, var_idx, Expr(for_extent)); - - if (extent > 0) { - dst_tmp_info.GetNode()->shape_ = src_tmp_info1->shape_; - dst_tmp_info.GetNode()->strides_ = src_tmp_info1->strides_; - dst_tmp_info.GetNode()->var_ = src_tmp_info1->var_; - dst_tmp_info.GetNode()->index_ = src_tmp_info1->index_; - dst_tmp_info.GetNode()->data_alignment_ = src_tmp_info1->data_alignment_; - dst_tmp_info.GetNode()->name_ = bisec_name; - dst_tmp_info.GetNode()->buffer_ = bisec_buffer; - dst_tmp_info.GetNode()->data_ = bisec_var; - auto src_extent = Expr(left_extent); - if (!last_axis) { - src_extent = GetItem(src_tmp_info1->strides_, var_idx) * extent; - } - src_tmp_info1.GetNode()->index_ = src_tmp_info1->index_ + src_extent; - } - - src_tmp_info0.GetNode()->name_ = prev_name; - src_tmp_info1.GetNode()->name_ = prev_name; - src_tmp_info0.GetNode()->buffer_ = prev_buffer; - src_tmp_info1.GetNode()->buffer_ = prev_buffer; - src_tmp_info0.GetNode()->data_ = prev_var; - src_tmp_info1.GetNode()->data_ = prev_var; - - // Replace outside for variable in index - for (auto &info : {dst_tmp_info, src_tmp_info0, src_tmp_info1}) { - if (info->name_.find(bisec_pre_header) == std::string::npos) { - continue; - } - auto vars = GetVarsInExpr(info->index_); - for (auto var : vars) { - if (!IsInArray(info->var_, var)) { - info.GetNode()->index_ = substitute(var, Expr(0), info->index_); - } - } - } - prev_name = bisec_name; - prev_var = bisec_var; - prev_buffer = bisec_buffer; - - StmtInfoList src_list = {src_tmp_info0, src_tmp_info1}; - CompactComputationInfoList(dst_tmp_info, src_list, if_info, for_tmp_info); - wrapper.for_info_list_.emplace_back(for_tmp_info); - - if (extent == 0) { - // last round should be dst = dst + src_tmp1 - wrapper.bisec_info_list_.emplace_back(StmtInfoList{dst_tmp_info, dst_tmp_info, src_tmp_info1}); - } else { - // normally is dst_tmp = src_tmp0 + src_tmp1 - wrapper.bisec_info_list_.emplace_back(StmtInfoList{dst_tmp_info, src_tmp_info0, src_tmp_info1}); - } - - first_round = false; - } - - // Generate arg_info - for (size_t i = 0; i < wrapper.bisec_info_list_.size(); ++i) { - auto info_list = wrapper.bisec_info_list_[i]; - auto new_for_info = wrapper.for_info_list_[i]; - - ArgInfo arg_info; - auto dst_list = GetRange(info_list, 0, 1); - auto src_list = GetRange(info_list, 1, info_list.size() - 1); - if (info_list.size() == 2) { - std::string dma_intrin = INTRIN_NAME_COPY_UB_TO_UB; - wrapper.dma_arg_info_map_ = GetDmaCopyInsnArgs(dma_intrin, dst_list, src_list, new_for_info); - } else if (last_axis && i == wrapper.bisec_info_list_.size() - 1) { - auto dst_tmp_info = dst_list[0]; - auto src_tmp_info = src_list[1]; - ReduceLastAxisPatternGenerator generator = - ReduceLastAxisPatternGenerator(dst_tmp_info, src_tmp_info, new_for_info, "vadd"); - auto result = generator.GetInsnArgs(); - arg_info = result.arg_info; - dst_tmp_info = result.dst_info_list[0]; - src_tmp_info = result.src_info_list[0]; - new_for_info = result.for_info; - wrapper.bisec_info_list_[i] = {dst_tmp_info, dst_tmp_info, src_tmp_info}; - } else { - // Bisect can't expand mask because it has inplace operation - if (i != wrapper.bisec_info_list_.size() - 1) { - // Last round dont need to add - FillLastDim(dst_list, src_list, new_for_info); - } - std::string mode = GetBinaryVecMode(dst_list, src_list, "vadd", false); - BinaryVecPatternGenerator generator = BinaryVecPatternGenerator(dst_list, src_list, new_for_info, mode, false); - auto params = generator.GetInsnArgs(); - arg_info = params.arg_info; - dst_list = params.dst_info_list; - src_list = params.src_info_list; - new_for_info = params.for_info; - wrapper.bisec_info_list_[i] = {dst_list[0], src_list[0], src_list[1]}; - } - wrapper.arg_info_list_.push_back(arg_info); - wrapper.for_info_list_[i] = new_for_info; - } - - return wrapper; -} -} // namespace akg diff --git a/src/emit_insn/insn_emitter.cc b/src/emit_insn/insn_emitter.cc index 7c0067bedddaf9c64a9a3cea6c71fe0faa1e208b..4cfebced51fe136e76b3dc4ab35d83aeb7ae2446 100644 --- a/src/emit_insn/insn_emitter.cc +++ b/src/emit_insn/insn_emitter.cc @@ -35,7 +35,7 @@ #include "insn_info.h" #include "insn_pattern.h" #include "insn_emitter_multimask.h" - +#include "insn_args_calculator.h" namespace akg { namespace ir { /// Sort indexes @@ -71,8 +71,7 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) { Array call_args; int call_cnt = 0; - if (intrin_name == "vector_dup" || intrin_name == "vadds" || - intrin_name == "vmuls" || intrin_name == "vaxpy") { + if (intrin_name == "vector_dup" || intrin_name == "vadds" || intrin_name == "vmuls" || intrin_name == "vaxpy") { auto GetCallInfo = [&intrin_name, &call_args, &call_cnt](const NodeRef &op) { if (op.as() && op.as()->name == intrin_name) { call_args = op.as()->args; @@ -82,8 +81,8 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) { PostOrderVisit(op, GetCallInfo); CHECK_EQ(call_cnt, 1); } - SingleType insn_type {SingleType::SIMD}; - Expr scalar_src {}; + SingleType insn_type{SingleType::SIMD}; + Expr scalar_src{}; if (intrin_name == "vector_dup") { insn_type = SingleType::Vector_Dump; src_info_list = {}; @@ -93,10 +92,11 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) { src_info_list = {src_info_list[0]}; scalar_src = call_args[1]; } - // check is single vector broadcast reduce mode exist - SingleVecPatternGenerator generator = SingleVecPatternGenerator(dst_info_list, src_info_list, for_info); - auto params = generator.GetInsnArgs(); + + SingleVecInsnArgsCalculator args_calculator = SingleVecInsnArgsCalculator(dst_info_list, src_info_list, for_info, intrin_name); + PatternResult params = args_calculator.GetInsnArgs(); + dst_info_list = params.dst_info_list; src_info_list = params.src_info_list; for_info = params.for_info; @@ -141,23 +141,16 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec if (src_info_list[0]->var_.size() > src_info_list[1]->var_.size()) { src_info = src_info_list[0]; } - const int vec_max_len = GetVecMaxLen(dst_info->dtype_); - if (enable_bisect && GetIntConst(GetItem(src_info->shape_, -1)) > vec_max_len) { - CommentManager::GetInstance().AddComment("Bisect_optimize", "enabled"); - auto wrapper = - SeparateComInfoToBisectionInfoList(dst_info_list, src_info_list, for_info, if_info, true, postfix); - return EmitCceBinaryVectorToBisectionReduction(wrapper, if_info, intrin_name); - } else { - CommentManager::GetInstance().AddComment("Pattern", arg_info.GetPattern()); - ReduceLastAxisPatternGenerator generator = - ReduceLastAxisPatternGenerator(dst_info, src_info, for_info, intrin_name); - auto result = generator.GetInsnArgs(); - arg_info = result.arg_info; - dst_info = result.dst_info_list[0]; - src_info = result.src_info_list[0]; - for_info = result.for_info; - return EmitCceBinaryVectorToReduceLastAxis(dst_info, src_info, if_info, for_info, arg_info, intrin_name); - } + CommentManager::GetInstance().AddComment("Pattern", arg_info.GetPattern()); + + LastAxisReduceInsnArgsCalculator args_calculator = LastAxisReduceInsnArgsCalculator(dst_info, src_info, for_info, intrin_name); + PatternResult result = args_calculator.GetInsnArgs(); + + arg_info = result.arg_info; + dst_info = result.dst_info_list[0]; + src_info = result.src_info_list[0]; + for_info = result.for_info; + return EmitCceBinaryVectorToReduceLastAxis(dst_info, src_info, if_info, for_info, arg_info, intrin_name); } case ARG_VECTOR_REDUCTION_BISECTION: { CommentManager::GetInstance().AddComment("Compute_type", "reduction"); @@ -192,7 +185,7 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec return FoldInsnWithForInfo(insn_list, if_info, for_info, stmt); } } -} +} // namespace ir /// Function to emit scalar intrin /// \param op - The input stmt to be emitted as intrin @@ -984,8 +977,9 @@ Stmt BinaryDropoutEmitter(const Stmt &op) { src1.GetNode()->data_ = mask->buffer_var; src1.GetNode()->data_alignment_ = GetInt32Const(mask->predicate); - SingleVecPatternGenerator generator = SingleVecPatternGenerator(dst_info_list, src_info_list, for_info, "elewise"); - auto params = generator.GetInsnArgs(); + SingleVecInsnArgsCalculator args_calculator = SingleVecInsnArgsCalculator(dst_info_list, src_info_list, for_info); + PatternResult params = args_calculator.GetInsnArgs(); + dst_info_list = params.dst_info_list; src_info_list = params.src_info_list; for_info = params.for_info; @@ -1484,8 +1478,10 @@ Stmt BinaryArgOpEmitter(const Stmt &op, const std::string &intrin_name) { if (src_info_list[0]->var_.size() > src_info_list[1]->var_.size()) { src_info = src_info_list[0]; } - ReduceLastAxisPatternGenerator generator = ReduceLastAxisPatternGenerator(dst_info, src_info, for_info, intrin_name); - auto result = generator.GetInsnArgs(); + + LastAxisReduceInsnArgsCalculator args_calculator = LastAxisReduceInsnArgsCalculator(dst_info, src_info, for_info, intrin_name); + PatternResult result = args_calculator.GetInsnArgs(); + arg_info = result.arg_info; dst_info = result.dst_info_list[0]; src_info = result.src_info_list[0]; diff --git a/src/emit_insn/insn_info.cc b/src/emit_insn/insn_info.cc index c11470fad6a41f309f449b1ac47ca5a64d97e350..a6b203f69d453df83efdabdcd3835e892ad2ec51 100644 --- a/src/emit_insn/insn_info.cc +++ b/src/emit_insn/insn_info.cc @@ -104,10 +104,7 @@ StmtStoreInfo StmtStoreInfo::Copy() const { StmtInfo StmtInfo::Copy() const { auto stmt_info = StmtInfo(); stmt_info.ops_ = ops_; - for (auto var : vars_) { - auto new_var = Variable::make(var->type, var->name_hint); - stmt_info.vars_.push_back(new_var); - } + stmt_info.vars_ = vars_; for (size_t i = 0; i < vars_.size(); ++i) { for (size_t j = 0; j < stmt_info.ops_.size(); ++j) { diff --git a/src/emit_insn/insn_info.h b/src/emit_insn/insn_info.h index 5eca1814edecec281003899bf1998d8611490a74..6ac5cfa1cfa364d6642fe17df30fe4e2be1719e9 100644 --- a/src/emit_insn/insn_info.h +++ b/src/emit_insn/insn_info.h @@ -276,15 +276,7 @@ struct BisectionInfoWrapper { Map dma_arg_info_map_; }; -struct InsnAxis { - int min{0}; - int extent{0}; - Var var; - int dst_stride{0}; - int src_stride{0}; - std::list src_stride_list; - std::list stride_list; -}; + IterVar GetCceAxis(); diff --git a/src/emit_insn/insn_pattern.cc b/src/emit_insn/insn_pattern.cc index 3c3c7d0e4f8fcb4889a981c60144a9cc4bee5400..46f2ffc643ea378bbef2dd8b8fb9c9c0672b1623 100644 --- a/src/emit_insn/insn_pattern.cc +++ b/src/emit_insn/insn_pattern.cc @@ -15,7 +15,6 @@ */ #include "insn_pattern.h" - #include #include #include @@ -200,28 +199,6 @@ ArgInfo GetMultiVecInsnArgs(StmtInfoList &dst_info_list, StmtInfoList &src_info_ return arg_info; } -/// Get first non zero shape from input shapes -/// \param dst_shape -/// \param src0_shape -/// \param src1_shape -/// \return -int PatternGenerator::GetNonZeroShape(const Expr &dst_shape, const Expr &src0_shape, const Expr &src1_shape) { - int shape = 0; - for (int val : - {GetInt32Const(dst_shape), GetInt32Const(src0_shape), src1_shape.defined() ? GetInt32Const(src1_shape) : 0}) { - if (val == 0) { - continue; - } - if (shape != 0 && val != shape) { - LOG(FATAL) << "Error: same var has different shapes. " << GetIntConst(dst_shape) << " " - << GetIntConst(src0_shape); - } - shape = val; - } - CHECK(shape != 0) << "Error: all shapes are equal to 0."; - return shape; -} - /// In case /// for (cc3) { /// A[(cc3*16)] = (B[(cc3*16)] - C[(cc3*16)]) @@ -432,25 +409,6 @@ void CleanZeroStrides(Array &info_list) { } } -/// Swap axis in Array -/// \param var -/// \param shape -/// \param strides -/// \param idx1 -/// \param idx2 -void PatternGenerator::GetShapeInfoAndSwap(Array &var, Array &shape, Array &strides, int idx1, - int idx2) { - auto tmp_var = GetItem(var, idx1); - SetItem(var, idx1, GetItem(var, idx2)); - SetItem(var, idx2, tmp_var); - auto tmp_shape = GetItem(shape, idx1); - SetItem(shape, idx1, GetItem(shape, idx2)); - SetItem(shape, idx2, tmp_shape); - auto tmp_stride = GetItem(strides, idx1); - SetItem(strides, idx1, GetItem(strides, idx2)); - SetItem(strides, idx2, tmp_stride); -} - /// Get insn args of load 2D intrin /// \param intrin_name /// \param dst_info_list @@ -856,6 +814,38 @@ Map GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn return arg_info_map; } +/// Replace com_info's var with new for loop's var +/// \param info +/// \param old_for_info +/// \param new_for_info +void ReplaceVarWithNewForInfo(StmtStoreInfo &info, const StmtInfo &old_for_info, const StmtInfo &new_for_info) { + for (size_t i = 0; i < new_for_info.vars_.size(); ++i) { + for (size_t j = 0; j < info->var_.size(); ++j) { + if (info->var_[j]->name_hint == new_for_info.vars_[i]->name_hint) { + SetItem(info.GetNode()->var_, static_cast(j), new_for_info.vars_[i]); + } + } + info.GetNode()->index_ = substitute(old_for_info.vars_[i], new_for_info.vars_[i], info->index_); + } +} + +std::string GetBinaryVecMode(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, + const std::string &intrin_name, bool enable_bisect) { + std::set reduce_bisect_list = {"vadd", "vsub", "vmul", "vmax"}; + std::string mode = "reduction"; + if (IsElementwise(dst_info_list, src_info_list)) { + mode = "elewise"; + } else if (IsBroadcast(dst_info_list, src_info_list)) { + mode = "broadcast"; + } else if (IsLastAxisReduction(dst_info_list, src_info_list)) { + mode = "reduce_last_axis"; + } else if (enable_bisect && reduce_bisect_list.count(intrin_name) != 0 && + IsBisectionReduction(dst_info_list, src_info_list)) { + mode = "reduce_bisection"; + } + + return mode; +} const char *const DummyLastVar = "cc_last"; TVM_REGISTER_API("cce_util.GetVecMask").set_body([](const TVMArgs args, TVMRetValue *ret) { diff --git a/src/emit_insn/insn_pattern.h b/src/emit_insn/insn_pattern.h index 66b6c9ffe3c004c4fe5f6123cc5787f0f8f5edec..36bf634a493e3bf051903e5448b8960fe0ad5620 100644 --- a/src/emit_insn/insn_pattern.h +++ b/src/emit_insn/insn_pattern.h @@ -37,220 +37,12 @@ struct PatternResult { StmtInfo for_info; }; -class PatternGenerator { - public: - PatternGenerator(const StmtInfoList &dst_info_list, const StmtInfo &for_info) - : for_info(for_info), - not_this_pattern(-1.0f), - split_latency_coef(10.0f), - repeat_latency_coef(3.0f), - offset_latency_coef(0.1f) { - CHECK(!dst_info_list.empty()); - dst_info = dst_info_list[0]; - } - virtual ~PatternGenerator() = default; - virtual PatternResult GetInsnArgs() = 0; - - protected: - int GetNonZeroShape(const Expr &dst_shape, const Expr &src0_shape, const Expr &src1_shape = Expr()); - void GetShapeInfoAndSwap(Array &var, Array &shape, Array &strides, int idx1, int idx2); - - virtual float Compute3DPatternMaskRate() { return not_this_pattern; } - virtual float Compute2DBlockPatternMaskRate() { return not_this_pattern; } - virtual float Compute2DPatternMaskRate() { return not_this_pattern; } - virtual float Compute1DPatternMaskRate() { return not_this_pattern; } - virtual Array Get3DPattern() { return {}; } - virtual Array Get2DBlockPattern() { return {}; } - virtual Array Get2DPattern() { return {}; } - virtual Array Get1DPattern() { return {}; } - virtual PatternResult GenResult(const Array &elim_var) = 0; - - StmtStoreInfo dst_info; - StmtInfo for_info; - - const float not_this_pattern; - const float split_latency_coef; - const float repeat_latency_coef; - const float offset_latency_coef; -}; - -class SingleVecPatternGenerator : public PatternGenerator { - public: - SingleVecPatternGenerator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, - const StmtInfo &for_info, const std::string &mode = "elewise") - : PatternGenerator(dst_info_list, for_info), - arg_info(ArgInfo(make_node())), - body_args(VectorArgInfo()), - tail_args(VectorArgInfo()), - mode(mode) { - if (src_info_list.empty()) { - src_info = dst_info.Copy(); - } else { - CHECK(!src_info_list.empty()); - src_info = src_info_list[0]; - } - } - ~SingleVecPatternGenerator() override = default; - PatternResult GetInsnArgs() final; - - protected: - float Compute3DPatternMaskRate() final; - float Compute2DBlockPatternMaskRate() final; - float Compute2DPatternMaskRate() final; - float Compute1DPatternMaskRate() final; - float Compute3DsPatternMaskRate(); - float Compute2DRepeatPatternMaskRate(); - Array Get3DPattern() final; - Array Get2DBlockPattern() final; - Array Get2DPattern() final; - Array Get1DPattern() final; - Array Get3DsPattern(); - Array Get2DRepeatPattern(); - PatternResult GenResult(const Array &elim_var) final; - - private: - void CalcParams(); - int GetLastDimShape(const Expr &dst_shape, const Expr &src_shape); - - struct Params { - Array dst_var; - Array src_var; - Array dst_shape; - Array src_shape; - Array dst_strides; - Array src_strides; - int non_zero_shape1 = 0; - int non_zero_shape2 = 0; - int non_zero_shape3 = 0; - int all_points = 0; - int dst_block_size = 0; - int src_block_size = 0; - int mask_block_size = 0; - int dst_bits = 0; - int src_bits = 0; - int max_bits = 0; - int dst_vec_max_len = 0; - int vec_max_len = 0; - int block_offset = 0; - }; - - StmtStoreInfo src_info; - Params params; - ArgInfo arg_info; - VectorArgInfo body_args; - VectorArgInfo tail_args; - std::string mode; - Type data_type; -}; - -class BinaryVecPatternGenerator : public PatternGenerator { - public: - BinaryVecPatternGenerator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, - const StmtInfo &for_info, const std::string &mode, bool expand_mask = true) - : PatternGenerator(dst_info_list, for_info), - src_info_list(src_info_list), - arg_info(ArgInfo(make_node())), - body_args(VectorArgInfo()), - tail_args(VectorArgInfo()), - empty_var(Var("")), - mode(mode), - expand_mask(expand_mask) {} - ~BinaryVecPatternGenerator() override = default; - - PatternResult GetInsnArgs() final; - - protected: - float Compute3DPatternMaskRate() final; - float Compute2DBlockPatternMaskRate() final; - float Compute2DPatternMaskRate() final; - float Compute1DPatternMaskRate() final; - Array Get3DPattern() final; - Array Get2DBlockPattern() final; - Array Get2DPattern() final; - Array Get1DPattern() final; - PatternResult GenResult(const Array &elim_var) final; - - private: - void CalcParams(); - bool IsSamePatternComInfo(const StmtStoreInfo &info_a, const StmtStoreInfo &info_b); - bool IsNonZeroShapeEqual(const Array &shape_list); - void AppendEmptyVar(StmtInfoList &info_list); - - struct Params { - Array dst_var; - Array dst_shape; - Array dst_strides; - Array src_var0; - Array src_shape0; - Array src_strides0; - Array src_var1; - Array src_shape1; - Array src_strides1; - int non_zero_shape1 = 0; - int non_zero_shape2 = 0; - int non_zero_shape3 = 0; - int all_points = 0; - int block_size = 0; - int last_dim_shape = 0; - int vec_max_len = 0; - }; - - StmtInfoList src_info_list; - ArgInfo arg_info; - VectorArgInfo body_args; - VectorArgInfo tail_args; - Params params; - Var empty_var; - std::string mode; - bool expand_mask; -}; - -class ReduceLastAxisPatternGenerator : public PatternGenerator { - public: - ReduceLastAxisPatternGenerator(const StmtStoreInfo &dst_info, const StmtStoreInfo &src_info, const StmtInfo &for_info, - const std::string &intrin_name) - : PatternGenerator({dst_info}, for_info), - src_info(src_info), - arg_info(ArgInfo(make_node())), - body_args(VectorArgInfo()), - tail_args(VectorArgInfo()), - intrin_name(intrin_name) {} - PatternResult GetInsnArgs() final; - ~ReduceLastAxisPatternGenerator() override = default; - - protected: - float Compute2DBlockPatternMaskRate() final; - Array Get2DBlockPattern() final; - Array Get1DPattern() final; - PatternResult GenResult(const Array &elim_var) final; - - private: - void CalcParams(); - - struct Params { - Array src_var; - int block_size = 0; - int vec_max_len = 0; - int last_dim_shape = 0; - Expr insn_offset_scale_factor; - }; - - StmtStoreInfo src_info; - ArgInfo arg_info; - VectorArgInfo body_args; - VectorArgInfo tail_args; - Array mix_vec_arg_list; - std::string intrin_name; - Params params; -}; - std::string GetSingleVecComputationInfo(const Stmt &stmt, const std::string &intrin_name, Array &dst_info_list, Array &src_info_list, StmtInfo &if_info, StmtInfo &for_info, bool need_compact = true); - -ArgInfo GetBinaryVecInsnArgs(const Stmt &stmt, std::string intrin_name, StmtInfoList &dst_info_list, - StmtInfoList &src_info_list, StmtInfo &if_info, StmtInfo &for_info, - bool enable_bisect = true); + +std::string GetBinaryVecMode(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, + const std::string &intrin_name, bool enable_bisect = true); ArgInfo GetMultiVecInsnArgs(StmtInfoList &dst_info_list, StmtInfoList &src_info_list, StmtInfo &for_info); @@ -277,10 +69,7 @@ Map GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn const StmtInfoList &src_info_list, StmtInfo &for_info, Map &ub_copy_pre, Map &ub_copy_post); -BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_info_list, - const StmtInfoList &src_info_list, const StmtInfo &for_info, - StmtInfo &if_info, bool last_axis, int postfix); - +void ReplaceVarWithNewForInfo(StmtStoreInfo &info, const StmtInfo &old_for_info, const StmtInfo &new_for_info); extern const char *const DummyLastVar; } // namespace akg #endif // EMIT_INSN_INSN_PATTERN_H_ diff --git a/src/emit_insn/insn_single_vec_pattern.cc b/src/emit_insn/insn_single_vec_pattern.cc deleted file mode 100644 index 6e55bf519028a34af1863f1896447913ec79cf95..0000000000000000000000000000000000000000 --- a/src/emit_insn/insn_single_vec_pattern.cc +++ /dev/null @@ -1,802 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include - -#include -#include - -#include "insn_builder.h" -#include "insn_pattern.h" -#include "common/array_api.h" -#include "pass/expr_alg_simplify.h" - -namespace akg { -/// Get CCE Single Vector Insn mode -/// \param dst_info_list -/// \param src_info_list -/// \return -std::string GetSingleVecMode(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list) { - CHECK(!dst_info_list.empty()); - auto dst_var_list = dst_info_list[0]->var_; - Array src_var_list; - if (!src_info_list.empty()) { - src_var_list = src_info_list[0]->var_; - } - - if (IsSame(dst_var_list, src_var_list)) { - return "elewise"; - } else if (dst_var_list.size() >= src_var_list.size()) { - return "broadcast"; - } - - return "reduction"; -} - -/// Get Single Vector Computation Info -/// \param stmt -/// \param intrin_name -/// \param dst_info_list -/// \param src_info_list -/// \param if_info -/// \param for_info -/// \param need_compact -/// \return -std::string GetSingleVecComputationInfo(const Stmt &stmt, const std::string &intrin_name, StmtInfoList &dst_info_list, - StmtInfoList &src_info_list, StmtInfo &if_info, StmtInfo &for_info, - bool need_compact) { - std::set intrin_name_list = {"vadds", "vmuls", "vrelu", "vabs", "vln", "vexp", - "vrec", "vector_dup", "vnot", "vsqrt", "vrsqrt"}; - if (intrin_name_list.count(intrin_name) == 0 && intrin_name.find("vconv_") == std::string::npos) { - LOG(FATAL) << "Error: CCE Single Vector Insn unsupported the given intrin_name. " << intrin_name; - return ""; - } - - bool same_dtype = intrin_name.find("vconv_") == std::string::npos; - GetCompactComputationInfo(stmt, dst_info_list, src_info_list, if_info, for_info, same_dtype, need_compact); - std::string mode = GetSingleVecMode(dst_info_list, src_info_list); - - CHECK(dst_info_list.size() == 1) << "CCE Single Vector only support ONE dst."; - - return mode; -} - -/// Get CCE Single vector instructions args. -/// \param dst_info_list -/// \param src_info_list -/// \param for_info -/// \param mode -/// \return -PatternResult SingleVecPatternGenerator::GetInsnArgs() { - CalcParams(); - Array elim_var = {}; - float rate3d = Compute3DPatternMaskRate(); - float rate2db = Compute2DBlockPatternMaskRate(); - float rate2d = Compute2DPatternMaskRate(); - float rate1d = Compute1DPatternMaskRate(); - float rate3ds = Compute3DsPatternMaskRate(); - float rate2ds = Compute2DRepeatPatternMaskRate(); - if (mode == "broadcast_last_axis") { - elim_var = Get1DPattern(); - } else if (rate2ds > 0) { - elim_var = Get2DRepeatPattern(); - } else if (rate3ds > 0) { - elim_var = Get3DsPattern(); - arg_info.GetNode()->pattern_ = PATTERN_2D; - } else if (rate3d >= rate2db && rate3d > 0) { - elim_var = Get3DPattern(); - arg_info.GetNode()->pattern_ = PATTERN_3D; - } else if (rate2db >= rate2d && rate2db >= rate1d && rate2db > 0) { - elim_var = Get2DBlockPattern(); - arg_info.GetNode()->pattern_ = PATTERN_PARTIAL_3D; - } else if (rate2d > rate1d && rate2d > 0) { - elim_var = Get2DPattern(); - arg_info.GetNode()->pattern_ = PATTERN_2D; - } else if (rate1d > 0) { - elim_var = Get1DPattern(); - arg_info.GetNode()->pattern_ = PATTERN_1D; - } else { - LOG(FATAL) << "Error: Cannot emit Single-Vector-Insn with any pattern!"; - } - - std::string mask_rate = "rate3d[" + std::to_string(rate3d) + "], rate2db[" + std::to_string(rate2db) + "], rate2d[" + - std::to_string(rate2d) + "], rate1d[" + std::to_string(rate1d) + "]"; - CommentManager::GetInstance().AddComment("Mask_rate", mask_rate); - if (arg_info->tail_arg_info_.defined()) { - CommentManager::GetInstance().AddComment("Contain_tail", "true"); - } else { - CommentManager::GetInstance().AddComment("Contain_tail", "false"); - } - - return GenResult(elim_var); -} - -/// Calc params for pattern match -void SingleVecPatternGenerator::CalcParams() { - Array info_list = {dst_info, src_info}; - // check shape len - for (auto info : info_list) { - CHECK(!info->shape_.empty()) - << "CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0."; - } - - FillEmptyVar(info_list); - dst_info = info_list[0]; - src_info = info_list[1]; - - int dst_bits = dst_info->dtype_.bits(); - int src_bits = src_info->dtype_.bits(); - CHECK_NE(dst_bits, 0); - CHECK_NE(src_bits, 0); - int dst_block_size = GetUbBlkSize(dst_info->dtype_); - int src_block_size = GetUbBlkSize(src_info->dtype_); - CHECK_NE(dst_block_size, 0); - CHECK_NE(src_block_size, 0); - - data_type = src_bits > dst_bits ? src_info->dtype_ : dst_info->dtype_; - - params.dst_var = dst_info->var_; - params.src_var = src_info->var_; - params.dst_shape = dst_info->shape_; - params.src_shape = src_info->shape_; - params.dst_strides = dst_info->strides_; - params.src_strides = src_info->strides_; - params.dst_block_size = dst_block_size; - params.src_block_size = src_block_size; - params.mask_block_size = src_bits > dst_bits ? src_block_size : dst_block_size; - params.dst_bits = dst_bits; - params.src_bits = src_bits; - params.max_bits = FULL_BLOCK_NUM * std::min(dst_bits, src_bits); - params.dst_vec_max_len = GetVecMaxLen(dst_info->dtype_); - params.vec_max_len = src_bits > dst_bits ? GetVecMaxLen(src_info->dtype_) : GetVecMaxLen(dst_info->dtype_); - CHECK_NE(params.dst_vec_max_len, 0); - CHECK_NE(params.vec_max_len, 0); - - auto GetNonZeroShapeByIdx = [this](int index) -> int { - if (index <= static_cast(params.dst_var.size())) { - if (Equal(GetItem(params.dst_var, -index), GetItem(params.src_var, -index))) { - return GetNonZeroShape(GetItem(params.dst_shape, -index), GetItem(params.src_shape, -index)); - } - } - return 1; - }; - - params.non_zero_shape1 = GetNonZeroShapeByIdx(1); - params.non_zero_shape2 = GetNonZeroShapeByIdx(2); - params.non_zero_shape3 = GetNonZeroShapeByIdx(3); - params.all_points = params.non_zero_shape1 * params.non_zero_shape2 * params.non_zero_shape3; - - auto elem_offset_mod = ir::ExprSimplifier().Simplify(Mod::make(dst_info->elem_offset_, dst_block_size)); - if (elem_offset_mod.as()) { - params.block_offset = elem_offset_mod.as()->value; - } -} - -int SingleVecPatternGenerator::GetLastDimShape(const Expr &dst_shape, const Expr &src_shape) { - int dst_last_dim = GetInt32Const(dst_shape); - int src_last_dim = GetInt32Const(src_shape); - - CHECK(dst_last_dim != 0 || src_last_dim != 0); - if (dst_last_dim == 0) { - return src_last_dim; - } - if (src_last_dim == 0) { - return dst_last_dim; - } - return std::min(dst_last_dim, src_last_dim); -} - -bool FindInShape(Array &shape, const Expr &target) { - for (int i = -1; i >= -3; --i) { - if (Equal(GetItem(shape, i), target)) { - return true; - } - } - return false; -} - -float SingleVecPatternGenerator::Compute2DRepeatPatternMaskRate() { - if (params.dst_var.size() < 3) { - return not_this_pattern; - } - - for (int i = -1; i >= -3; --i) { - if (!FindInShape(params.src_shape, GetItem(params.dst_shape, i))) { - return not_this_pattern; - } - } - - if (GetInt32Const(GetItem(params.dst_strides, -2)) % params.dst_block_size != 0 || - GetInt32Const(GetItem(params.src_strides, -2)) % params.src_block_size != 0) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -3)) % params.dst_block_size != 0 || - GetInt32Const(GetItem(params.src_strides, -3)) % params.src_block_size != 0) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -2)) == 0 && GetInt32Const(GetItem(params.src_strides, -2)) == 0) { - return not_this_pattern; - } - - if (!Equal(GetItem(params.dst_shape, -3), GetItem(params.src_shape, -2)) || - !Equal(GetItem(params.dst_shape, -2), GetItem(params.src_shape, -3))) { - return not_this_pattern; - } - if (GetIntConst(GetItem(params.dst_shape, -2)) > FULL_BLOCK_NUM && - GetIntConst(GetItem(params.dst_shape, -2)) % FULL_BLOCK_NUM != 0) { - return not_this_pattern; - } - if (params.dst_block_size == params.src_block_size) { - return not_this_pattern; - } - if (GetInt32Const(GetItem(params.dst_shape, -1)) <= params.dst_block_size && - GetInt32Const(GetItem(params.src_shape, -1)) <= params.src_block_size) { - return not_this_pattern; - } - return 1.0; -} - -float SingleVecPatternGenerator::Compute3DsPatternMaskRate() { - if (params.dst_var.size() < 3) { - return not_this_pattern; - } - if (params.dst_block_size != params.src_block_size) { - return not_this_pattern; - } - for (int i = -1; i >= -3; --i) { - if (!FindInShape(params.src_shape, GetItem(params.dst_shape, i))) { - return not_this_pattern; - } - } - if (GetInt32Const(GetItem(params.dst_shape, -1)) > params.dst_block_size || - GetInt32Const(GetItem(params.src_shape, -1)) > params.src_block_size) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -2)) % params.dst_block_size != 0 || - GetInt32Const(GetItem(params.src_strides, -2)) % params.src_block_size != 0) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -3)) % params.dst_block_size != 0 || - GetInt32Const(GetItem(params.src_strides, -3)) % params.src_block_size != 0) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -2)) == 0 && GetInt32Const(GetItem(params.src_strides, -2)) == 0) { - return not_this_pattern; - } - - if (!Equal(GetItem(params.dst_shape, -3), GetItem(params.src_shape, -2)) || - !Equal(GetItem(params.dst_shape, -2), GetItem(params.src_shape, -3))) { - return not_this_pattern; - } - if (GetIntConst(GetItem(params.dst_shape, -2)) > FULL_BLOCK_NUM && - GetIntConst(GetItem(params.dst_shape, -2)) % FULL_BLOCK_NUM != 0) { - return not_this_pattern; - } - return 1.0; -} - -float SingleVecPatternGenerator::Compute3DPatternMaskRate() { - // in elemwise mode, the var is already checked to be equal, no need to check - if (params.dst_var.size() < 3) { - return not_this_pattern; - } - - // do not support cast op in 3D pattern - if (params.dst_block_size != params.src_block_size) { - return not_this_pattern; - } - - for (int i = -1; i >= -3; --i) { - if (!IsTwoItemEqual(params.dst_var, params.src_var, i)) { - return not_this_pattern; - } - } - - if (GetInt32Const(GetItem(params.dst_shape, -1)) > params.dst_block_size || - GetInt32Const(GetItem(params.src_shape, -1)) > params.src_block_size) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -2)) % params.dst_block_size != 0 || - GetInt32Const(GetItem(params.src_strides, -2)) % params.src_block_size != 0) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -3)) % params.dst_block_size != 0 || - GetInt32Const(GetItem(params.src_strides, -3)) % params.src_block_size != 0) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -2)) == 0 && GetInt32Const(GetItem(params.src_strides, -2)) == 0) { - return not_this_pattern; - } - - // repeat axis is shape [-3], repeat once, has 8 loops - bool is3_d = true; - float rate3d_mode1 = not_this_pattern; - float rate3d_mode2 = not_this_pattern; - int repeat_num; - float repeat_latency; - StmtInfoList info_list = {dst_info, src_info}; - for (auto info : info_list) { - if (GetInt32Const(GetItem(info->shape_, -2)) > FULL_BLOCK_NUM || - GetInt32Const(GetItem(info->strides_, -2)) / params.dst_block_size >= MAX_STRIDE_M0_SINGLE || - GetInt32Const(GetItem(info->strides_, -3)) / params.dst_block_size >= MAX_STRIDE_M1) { - is3_d = false; - break; - } - } - if (is3_d) { - if (GetIntConst(GetItem(params.dst_strides, -2)) == 0) { - return not_this_pattern; - } - repeat_num = params.non_zero_shape3; - repeat_latency = ((repeat_num - 1) / MAX_REPEAT) * repeat_latency_coef; - rate3d_mode1 = static_cast(params.all_points) / params.dst_vec_max_len / (repeat_num + repeat_latency); - } - - is3_d = true; - // repeat axis is shape[-2] - for (auto info : info_list) { - // stride_m0 should less than 65536 - if (GetInt32Const(GetItem(info->shape_, -3)) % FULL_BLOCK_NUM != 0 || - GetInt32Const(GetItem(info->strides_, -3)) / params.dst_block_size >= MAX_STRIDE_M0_SINGLE) { - is3_d = false; - break; - } - } - if (is3_d) { - if (GetIntConst(GetItem(params.dst_strides, -3)) == 0) { - return not_this_pattern; - } - repeat_num = params.non_zero_shape2 * (params.non_zero_shape3 / FULL_BLOCK_NUM); - repeat_latency = ((repeat_num - 1) / MAX_REPEAT) * repeat_latency_coef; - float offset_latency = - params.non_zero_shape3 / FULL_BLOCK_NUM > 1 ? params.non_zero_shape3 * offset_latency_coef : 0; - rate3d_mode2 = - static_cast(params.all_points) / params.dst_vec_max_len / (repeat_num + repeat_latency + offset_latency); - } - - return rate3d_mode1 > rate3d_mode2 ? rate3d_mode1 : rate3d_mode2; -} - -// Partial 3D Pattern -float SingleVecPatternGenerator::Compute2DBlockPatternMaskRate() { - // in elemwise mode, the var is already checked to be equal, no need to check - if (params.dst_var.size() < 2 || params.src_var.size() < 2 || GetInt32Const(GetItem(params.dst_strides, -1)) != 1) { - return not_this_pattern; - } - - // do not support cast op in Partial3D pattern - if (params.dst_block_size != params.src_block_size) { - return not_this_pattern; - } - - for (int i = -1; i >= -2; --i) { - if (!Equal(GetItem(params.dst_var, i), GetItem(params.src_var, i))) { - return not_this_pattern; - } - } - - if (GetInt32Const(GetItem(params.dst_shape, -1)) > params.dst_block_size || - GetInt32Const(GetItem(params.src_shape, -1)) > params.src_block_size) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -2)) % params.dst_block_size != 0 || - GetInt32Const(GetItem(params.src_strides, -2)) % params.src_block_size != 0) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -2)) == 0 && GetInt32Const(GetItem(params.src_strides, -2)) == 0) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -2)) / params.dst_block_size >= MAX_STRIDE_M0_SINGLE || - GetInt32Const(GetItem(params.src_strides, -2)) / params.src_block_size >= MAX_STRIDE_M0_SINGLE) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -2)) == 0) { - return not_this_pattern; - } - - int repeat_body_num = params.non_zero_shape2 / FULL_BLOCK_NUM; - int repeat_tail_num = (params.non_zero_shape2 % FULL_BLOCK_NUM + FULL_BLOCK_NUM - 1) / FULL_BLOCK_NUM; - int repeat_num = (repeat_body_num + repeat_tail_num) * params.non_zero_shape3; - float repeat_latency = - (std::max(repeat_body_num - 1, 0) / MAX_REPEAT + std::max(repeat_tail_num - 1, 0) / MAX_REPEAT) * - repeat_latency_coef; - float offset_latency = params.non_zero_shape3 > 1 ? params.non_zero_shape3 * offset_latency_coef : 0; - float split_latency = (repeat_body_num > 0 && repeat_tail_num > 0) ? split_latency_coef : 0; - float rate2db = static_cast(params.all_points) / params.dst_vec_max_len / - (repeat_num + repeat_latency + offset_latency + split_latency); - - return rate2db; -} - -float SingleVecPatternGenerator::Compute2DPatternMaskRate() { - // in elemwise mode, the var is already checked to be equal, no need to check - if (params.dst_var.size() < 2 || params.src_var.size() < 2) { - return not_this_pattern; - } - - if (src_info->data_alignment_ == 1 && GetInt32Const(GetItem(src_info->strides_, -1)) != params.dst_block_size) { - return not_this_pattern; - } - - for (int i = -1; i >= -2; --i) { - if (!Equal(GetItem(params.dst_var, i), GetItem(params.src_var, i))) { - return not_this_pattern; - } - } - - if (GetInt32Const(GetItem(params.dst_strides, -2)) % params.dst_block_size != 0 || - GetInt32Const(GetItem(params.src_strides, -2)) % params.src_block_size != 0) { - return not_this_pattern; - } - - if (GetInt32Const(GetItem(params.dst_strides, -2)) / params.dst_block_size >= MAX_STRIDE_M1 || - GetInt32Const(GetItem(params.src_strides, -2)) / params.src_block_size >= MAX_STRIDE_M1) { - return not_this_pattern; - } - - // check num of insns, select 1D pattern or 2D pattern - int tail_factor = 0; - if (params.non_zero_shape1 / params.dst_vec_max_len > 0 && params.non_zero_shape1 % params.dst_vec_max_len > 0) { - tail_factor = 1; - } - - int offset_num = - (params.non_zero_shape1 + params.dst_vec_max_len - 1) / params.dst_vec_max_len * params.non_zero_shape3; - int repeat_num = offset_num * params.non_zero_shape2; - float repeat_latency = (std::max(params.non_zero_shape2 - 1, 0) / MAX_REPEAT) * offset_num * repeat_latency_coef; - float offset_latency = offset_num > 1 ? offset_num * offset_latency_coef : 0; - float split_latency = tail_factor * split_latency_coef; - float rate2d = static_cast(params.all_points) / params.dst_vec_max_len / - (repeat_num + repeat_latency + offset_latency + split_latency); - - return rate2d; -} - -float SingleVecPatternGenerator::Compute1DPatternMaskRate() { - int tail_factor = 0; - if (params.non_zero_shape1 / params.dst_vec_max_len > 0 && params.non_zero_shape1 % params.dst_vec_max_len > 0) { - tail_factor = 1; - } - - int shape1 = (params.non_zero_shape1 + params.dst_vec_max_len - 1) / params.dst_vec_max_len; - int repeat_num = shape1 * params.non_zero_shape2 * params.non_zero_shape3; - float repeat_latency = - std::max((shape1 - 1) / MAX_REPEAT, 0) * params.non_zero_shape2 * params.non_zero_shape3 * repeat_latency_coef; - float offset_latency = params.non_zero_shape2 * params.non_zero_shape3 > 1 - ? params.non_zero_shape2 * params.non_zero_shape3 * offset_latency_coef - : 0; - float split_latency = tail_factor * split_latency_coef; - float rate1d = static_cast(params.all_points) / params.dst_vec_max_len / - (repeat_num + repeat_latency + offset_latency + split_latency); - - return rate1d; -} - -Array SingleVecPatternGenerator::Get2DRepeatPattern() { - GetShapeInfoAndSwap(params.src_var, params.src_shape, params.src_strides, -2, -3); - int last_dim_shape = GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape, -1)); - body_args = VectorArgInfo(make_node()); - body_args.GetNode()->body_num_ = 1; - body_args.GetNode()->dst_stride_m0_ = 1; - body_args.GetNode()->src_stride_m0_list_ = {1}; - body_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size); - body_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides, -2), params.src_block_size)}; - body_args.GetNode()->repeat_ = GetItem(params.dst_shape, -2); - int data_len = CeilTo(last_dim_shape, params.dst_block_size); - body_args.GetNode()->vec_mask_ = GetVecMask(data_len, 1, dst_info->dtype_); - return GetRange(params.dst_var, -2, 2); -} - -Array SingleVecPatternGenerator::Get3DsPattern() { - GetShapeInfoAndSwap(params.src_var, params.src_shape, params.src_strides, -2, -3); - int last_dim_shape = GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape, -1)); - body_args = VectorArgInfo(make_node()); - body_args.GetNode()->body_num_ = 1; - - Expr dst_stride_m0 = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size); - Expr src_stride_m0 = truncdiv(GetItem(params.src_strides, -2), params.src_block_size); - body_args.GetNode()->dst_stride_m0_ = dst_stride_m0; - body_args.GetNode()->src_stride_m0_list_ = {src_stride_m0}; - - int block_num = 0; - int data_len = CeilTo(last_dim_shape, params.mask_block_size); - - if (GetIntConst(GetItem(params.dst_shape, -2)) <= FULL_BLOCK_NUM) { - block_num = GetIntConst(GetItem(params.dst_shape, -2)); - body_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -3), params.dst_block_size); - body_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides, -3), params.src_block_size)}; - body_args.GetNode()->vec_mask_ = GetVecMask(data_len, block_num, dst_info->dtype_); - auto repeat = GetItem(params.dst_shape, -3); - if (GetIntConst(repeat) < MAX_STRIDE_M1) { - body_args.GetNode()->repeat_ = repeat; - return GetRange(params.dst_var, -3, 3); - } else { - body_args.GetNode()->repeat_ = 1; - return GetRange(params.dst_var, -2, 2); - } - } else { - block_num = FULL_BLOCK_NUM; - body_args.GetNode()->dst_stride_m1_ = dst_stride_m0 * block_num; - body_args.GetNode()->src_stride_m1_list_ = {src_stride_m0 * block_num}; - auto repeat = truncdiv(GetItem(params.dst_shape, -2), FULL_BLOCK_NUM); - body_args.GetNode()->vec_mask_ = GetVecMask(data_len, block_num, dst_info->dtype_); - if (GetIntConst(repeat) < MAX_STRIDE_M1) { - body_args.GetNode()->repeat_ = repeat; - return GetRange(params.dst_var, -2, 2); - } else { - return Get1DPattern(); - } - } -} - -Array SingleVecPatternGenerator::Get3DPattern() { - if (GetIntConst(GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape, -2))) > FULL_BLOCK_NUM) { - // split shape[-3] - if (GetIntConst(GetNonZeroShape(GetItem(params.dst_shape, -3), GetItem(params.src_shape, -3))) > FULL_BLOCK_NUM) { - StmtInfoList info_list = {dst_info, src_info}; - SplitAxis(info_list, for_info, GetItem(params.dst_var, -3), FULL_BLOCK_NUM); - FillEmptyVar(info_list); - dst_info = info_list[0]; - src_info = info_list[1]; - - params.dst_var = dst_info->var_; - params.dst_shape = dst_info->shape_; - params.dst_strides = dst_info->strides_; - params.src_var = src_info->var_; - params.src_shape = src_info->shape_; - params.src_strides = src_info->strides_; - } - // consider original shape[-2] as repeat axis - GetShapeInfoAndSwap(params.dst_var, params.dst_shape, params.dst_strides, -2, -3); - GetShapeInfoAndSwap(params.src_var, params.src_shape, params.src_strides, -2, -3); - } - - int last_dim_shape = GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape, -1)); - body_args = VectorArgInfo(make_node()); - body_args.GetNode()->body_num_ = 1; - body_args.GetNode()->repeat_ = - make_const(Int(32), GetNonZeroShape(GetItem(params.dst_shape, -3), GetItem(params.src_shape, -3))); - body_args.GetNode()->dst_stride_m0_ = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size); - body_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -3), params.dst_block_size); - body_args.GetNode()->src_stride_m0_list_ = {truncdiv(GetItem(params.src_strides, -2), params.src_block_size)}; - body_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides, -3), params.src_block_size)}; - body_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset); - - int data_len = CeilTo(last_dim_shape, params.mask_block_size); - int data_num = GetInt32Const(GetItem(params.dst_shape, -2)); - body_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, dst_info->dtype_, params.block_offset); - - return GetRange(params.dst_var, -3, 3); -} - -Array SingleVecPatternGenerator::Get2DBlockPattern() { - int last_dim_shape = GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape, -1)); - int repeat_len = GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape, -2)); - int repeat_body = repeat_len / FULL_BLOCK_NUM; - int repeat_tail = (repeat_len % FULL_BLOCK_NUM + FULL_BLOCK_NUM - 1) / FULL_BLOCK_NUM; - int data_len = CeilTo(last_dim_shape, params.dst_block_size); - - if (repeat_body > 0) { - body_args = VectorArgInfo(make_node()); - body_args.GetNode()->body_num_ = 1; - body_args.GetNode()->repeat_ = make_const(Int(32), repeat_body); - auto dst_stride_m0 = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size); - body_args.GetNode()->dst_stride_m0_ = dst_stride_m0; - body_args.GetNode()->dst_stride_m1_ = dst_stride_m0 * (params.max_bits / params.src_bits); - auto src_stride_m0 = truncdiv(GetItem(params.src_strides, -2), params.src_block_size); - body_args.GetNode()->src_stride_m0_list_ = {src_stride_m0}; - body_args.GetNode()->src_stride_m1_list_ = {src_stride_m0 * (params.max_bits / params.dst_bits)}; - body_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset); - - int data_num = FULL_BLOCK_NUM; - body_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, dst_info->dtype_, params.block_offset); - } - - if (repeat_tail > 0) { - tail_args = VectorArgInfo(make_node()); - tail_args.GetNode()->dst_head_ = GetItem(params.dst_strides, -2) * repeat_body * FULL_BLOCK_NUM; - tail_args.GetNode()->src_head_list_ = {GetItem(params.src_strides, -2) * repeat_body * FULL_BLOCK_NUM}; - tail_args.GetNode()->repeat_ = Expr(1); - tail_args.GetNode()->dst_stride_m0_ = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size); - tail_args.GetNode()->dst_stride_m1_ = Expr(0); - tail_args.GetNode()->src_stride_m0_list_ = {truncdiv(GetItem(params.src_strides, -2), params.src_block_size)}; - tail_args.GetNode()->src_stride_m1_list_ = {Expr(0)}; - tail_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset); - - int data_num = repeat_len % FULL_BLOCK_NUM; - tail_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, dst_info->dtype_, params.block_offset); - } - - return GetRange(params.dst_var, -2, 2); -} - -Array SingleVecPatternGenerator::Get2DPattern() { - const int data_num = 1; - int last_dim_shape = GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape, -1)); - if (GetInt32Const(GetItem(dst_info->strides_, -1)) == params.dst_block_size && - IsTwoItemEqual(dst_info->strides_, src_info->strides_, -1, true)) { - last_dim_shape *= params.dst_block_size; - } - int body_len = FloorTo(last_dim_shape, params.vec_max_len); - int tail_len = last_dim_shape % params.vec_max_len; - - if (body_len > 0) { - body_args = VectorArgInfo(make_node()); - body_args.GetNode()->body_num_ = body_len / params.vec_max_len; - body_args.GetNode()->body_offset_ = params.vec_max_len; - body_args.GetNode()->repeat_ = GetItem(params.dst_shape, -2); - body_args.GetNode()->dst_stride_m0_ = Expr(1); - body_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size); - body_args.GetNode()->src_stride_m0_list_ = {Expr(1)}; - body_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides, -2), params.src_block_size)}; - body_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset); - body_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, data_num, data_type, params.block_offset); - } - - // get tail params - if (tail_len > 0) { - tail_args = VectorArgInfo(make_node()); - tail_args.GetNode()->dst_head_ = Expr(body_len); - tail_args.GetNode()->src_head_list_ = {Expr(body_len)}; - tail_args.GetNode()->repeat_ = GetItem(params.dst_shape, -2); - tail_args.GetNode()->dst_stride_m0_ = Expr(1); - tail_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size); - tail_args.GetNode()->src_stride_m0_list_ = {Expr(1)}; - tail_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides, -2), params.src_block_size)}; - tail_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset); - - int data_len = CeilTo(tail_len, params.mask_block_size); - tail_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, data_type, params.block_offset); - } - - return GetRange(params.dst_var, -2, 2); -} - -Array SingleVecPatternGenerator::Get1DPattern() { - int last_dim_shape; - bool linear_mode = false; - if ((params.dst_shape.empty() && params.src_shape.empty()) || GetIntConst(GetItem(params.dst_shape, -1)) == 0) { - last_dim_shape = 1; - } else if (!IsTwoItemEqual(params.dst_var, params.src_var, -1)) { - last_dim_shape = 1; - } else { - last_dim_shape = GetLastDimShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape, -1)); - linear_mode = params.dst_bits == params.src_bits; - } - - bool is_scalar_mode = IsScalarMode({dst_info, src_info}); - if (is_scalar_mode && params.dst_bits != params.src_bits) { - last_dim_shape = 1; - } - int vec_max_len = is_scalar_mode ? FULL_BLOCK_NUM : params.vec_max_len; - int body_len = FloorTo(last_dim_shape, vec_max_len); - int tail_len = last_dim_shape % vec_max_len; - - auto dst_stride_m0 = - is_scalar_mode && linear_mode ? truncdiv(GetItem(params.dst_strides, -1), params.dst_block_size) : Expr(1); - auto src_stride_m0 = - is_scalar_mode && linear_mode ? truncdiv(GetItem(params.src_strides, -1), params.src_block_size) : Expr(1); - if (body_len > 0) { - body_args = VectorArgInfo(make_node()); - body_args.GetNode()->body_num_ = 1; - body_args.GetNode()->body_offset_ = vec_max_len; - body_args.GetNode()->repeat_ = Expr(body_len / vec_max_len); - body_args.GetNode()->dst_stride_m0_ = dst_stride_m0; - auto dst_block_num = is_scalar_mode ? FULL_BLOCK_NUM : (params.max_bits / params.src_bits); - body_args.GetNode()->dst_stride_m1_ = dst_stride_m0 * dst_block_num; - body_args.GetNode()->src_stride_m0_list_ = {src_stride_m0}; - auto src_block_num = is_scalar_mode ? FULL_BLOCK_NUM : (params.max_bits / params.dst_bits); - body_args.GetNode()->src_stride_m1_list_ = {src_stride_m0 * src_block_num}; - body_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset); - - // in cast case, data_num should be 1 because dst and src bit is not equal - int data_len = is_scalar_mode ? 1 : vec_max_len; - int data_num = is_scalar_mode ? FULL_BLOCK_NUM : 1; - body_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, data_type, params.block_offset); - } - - // get tail params - if (tail_len > 0) { - tail_args = VectorArgInfo(make_node()); - tail_args.GetNode()->body_offset_ = vec_max_len; - tail_args.GetNode()->body_num_ = 1; - tail_args.GetNode()->dst_head_ = - Expr(body_len * (is_scalar_mode ? dst_stride_m0 * params.dst_block_size : Expr(1))); - tail_args.GetNode()->src_head_list_ = { - Expr(body_len * (is_scalar_mode ? src_stride_m0 * params.src_block_size : Expr(1)))}; - tail_args.GetNode()->repeat_ = Expr(1); - tail_args.GetNode()->dst_stride_m0_ = dst_stride_m0; - tail_args.GetNode()->dst_stride_m1_ = Expr(0); - tail_args.GetNode()->src_stride_m0_list_ = {src_stride_m0}; - tail_args.GetNode()->src_stride_m1_list_ = {Expr(0)}; - tail_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset); - - int data_len = is_scalar_mode && linear_mode ? 1 : CeilTo(tail_len, params.mask_block_size); - int data_num = is_scalar_mode && linear_mode ? tail_len : 1; - data_num = data_num == 0 ? 1 : data_num; - tail_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, data_type, params.block_offset); - } - - // compute offset for cce instructions - Array elim_var = {}; - if (mode == "elewise" && params.dst_var.size() >= 2 && params.dst_strides.size() >= 2 && for_info.ops_.size() >= 2 && - last_dim_shape <= vec_max_len && last_dim_shape >= vec_max_len - params.dst_block_size && - GetIntConst(GetItem(params.dst_strides, -2)) == last_dim_shape) { - // in this case we can merge second last for extent to repeat - size_t index = 0; - bool suc = GetIndexOfElement(for_info.vars_, GetItem(params.dst_var, -2), index); - CHECK(suc); - auto latest_for = GetItem(for_info.ops_, index).as(); - // there should not be if_op between for loop and compute stmt - if (latest_for && !latest_for->body->IsInstance()) { - if (!params.dst_var.empty() && (!is_scalar_mode || last_dim_shape != 1)) { - if (body_args.defined()) { - // last_dim_shape = vec_max_len - body_args.GetNode()->repeat_ = body_args->repeat_ * latest_for->extent; - } else if (tail_args.defined()) { - // last_dim_shape < vec_max_len - tail_args.GetNode()->repeat_ = tail_args->repeat_ * latest_for->extent; - } - - return GetRange(params.dst_var, -2, 2); - } - } - } - - if (!params.dst_var.empty() && (!is_scalar_mode || last_dim_shape != 1 || linear_mode) && - GetIntConst(GetItem(params.dst_strides, -1)) > 0 && - (params.src_var.empty() || IsTwoItemEqual(params.dst_var, params.src_var, -1))) { - elim_var = GetRange(params.dst_var, -1, 1); - } - - return elim_var; -} - -PatternResult SingleVecPatternGenerator::GenResult(const Array &elim_var) { - arg_info.GetNode()->body_arg_info_ = body_args; - arg_info.GetNode()->tail_arg_info_ = tail_args; - - dst_info.GetNode()->insn_offset_ = GetInsnOffset(dst_info, elim_var); - src_info.GetNode()->insn_offset_ = GetInsnOffset(src_info, elim_var); - - CleanForInfoVars(for_info, elim_var); - - StmtInfoList info_list = {dst_info, src_info}; - CleanZeroStrides(info_list); - dst_info = info_list[0]; - src_info = info_list[1]; - - PatternResult result; - result.dst_info_list = {dst_info}; - result.src_info_list = {src_info}; - result.for_info = for_info; - result.arg_info = arg_info; - - return result; -} -} // namespace akg diff --git a/src/pass/emit_insn.cc b/src/pass/emit_insn.cc index b2a7620e39e3a8157a5348eb6f82d8aa67ef1cda..38b90a85e12739fd2f459deca7c43225b63119b5 100644 --- a/src/pass/emit_insn.cc +++ b/src/pass/emit_insn.cc @@ -21,6 +21,7 @@ #include "pass/ir_util.h" #include "poly/poly_util.h" #include "emit_insn/insn_emitter.h" +#include "emit_insn/ir_transform.h" namespace akg { namespace ir { @@ -475,6 +476,7 @@ Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Ma } stmt = UnalignedMad().Mutate(stmt); stmt = RegCondition().Mutate(stmt); + stmt = ForVarUnique().Mutate(stmt); return stmt; } } // namespace ir diff --git a/src/pass/multi_last_axis_reduction.cc b/src/pass/multi_last_axis_reduction.cc index 1f2137df151f084ff1caee98073108439dcc96c4..4e882b20b43f9412ee0a226777dca097e664c955 100644 --- a/src/pass/multi_last_axis_reduction.cc +++ b/src/pass/multi_last_axis_reduction.cc @@ -343,8 +343,12 @@ class BroadcastCalculate : public IRMutator { }; Stmt MultiLastAxisReductions(Stmt stmt, bool is_dynamic = false) { + auto ori_stmt = stmt; stmt = MultiLastAxisReduction().Mutate(stmt); stmt = BroadcastCalculate(is_dynamic).Mutate(stmt); + if (!is_dynamic && !Equal(ori_stmt, stmt)) { + stmt = MergeLoops(stmt); + } return stmt; } } // namespace ir diff --git a/src/pass/split_tail_block.cc b/src/pass/split_tail_block.cc index 52a0beb7856642c4a2b8bb899c2168175c2d41bc..e486308f8fb4fc3c5270a8368d3bf6c92626c5a0 100644 --- a/src/pass/split_tail_block.cc +++ b/src/pass/split_tail_block.cc @@ -21,7 +21,7 @@ #include #include "emit_insn/insn_info.h" #include "emit_insn/insn_pattern.h" - +#include "emit_insn/insn_args_calculator.h" namespace akg { namespace ir { @@ -48,85 +48,63 @@ class TailSpliter : public IRMutator { if (src_info_list.empty()) { src_info_list = {dst_info.Copy()}; } - auto get_info_list = [](const StmtStoreInfo &dst_info, const Array &src_info_list) { - Array res; - res.push_back(dst_info.Copy()); - for (auto it : src_info_list) { - res.push_back(it.Copy()); - } - return res; - }; - auto info_list = get_info_list(dst_info, src_info_list); - FillEmptyVar(info_list); - auto axis_list = GetAixsList(for_info, info_list); - auto get_last_axis_it = [](const std::list &axis_list) { - for (auto it = axis_list.begin(); it != axis_list.end(); it++) { - auto stride_list = it->stride_list; - if (!(std::any_of(stride_list.begin(), stride_list.end(), [](int stride) { return stride > 1; }) || - std::all_of(stride_list.begin(), stride_list.end(), [](int stride) { return stride == 0; }))) { - return it; - } - } - return axis_list.end(); - }; - auto last_axis_it = get_last_axis_it(axis_list); - if (last_axis_it == axis_list.end()) { - return s; - } - auto last_axis = *last_axis_it; - auto last_axis_shape = last_axis.extent; + auto info_list = GetInfoList(dst_info, src_info_list); + FillEmptyVar(info_list); int dst_block_size = GetUbBlkSize(dst_info->dtype_); int src_block_size = GetUbBlkSize(src_info_list[0]->dtype_); - int block_size = dst_block_size > src_block_size ? dst_block_size : src_block_size; + int block_size = dst_block_size < src_block_size ? dst_block_size : src_block_size; + int cast_block_size = dst_block_size > src_block_size ? dst_block_size : src_block_size; int vec_max_len = block_size * FULL_BLOCK_NUM; - - if (last_axis_shape > vec_max_len && last_axis_shape % vec_max_len != 0) { - return Block::make(TailMake(s, last_axis, vec_max_len, false), TailMake(s, last_axis, vec_max_len, true)); - } - if (last_axis_shape < vec_max_len * tail_rate_ && last_axis_shape > block_size && - last_axis_shape % block_size != 0 && axis_list.size() > 1) { - return Block::make(TailMake(s, last_axis, block_size, false), TailMake(s, last_axis, block_size, true)); - } - } - return IRMutator::Mutate_(op, s); - } - - std::list GetAixsList(const StmtInfo &for_info, const Array &info_list) { - std::list axis_list; - auto GetStrideByAxis = [](const Array &vars, const Array &strides, Var obj_var) { - int index = 0; - for (auto var_it : vars) { - if (Equal(var_it, obj_var)) { - return strides[index]; + auto args_calculator = InsnArgsCalculator(dst_info_list, src_info_list, for_info, ""); + auto vec_axis_it = args_calculator.GetVecAxisIt(); + bool cast = dst_block_size != src_block_size; + if (args_calculator.IsValid(vec_axis_it)) { + auto vec_axis = *vec_axis_it; + auto vec_axis_shape = vec_axis.extent; + if (vec_axis_shape >= vec_max_len) { + if (vec_axis_shape % vec_max_len != 0) { + return TailBlock(s, vec_axis, vec_max_len); + } + } else { + if (vec_axis_shape < vec_max_len * tail_rate_ && vec_axis_shape > cast_block_size && + vec_axis_shape % cast_block_size != 0 && args_calculator.axis_list_.size() > 1) { + return TailBlock(s, vec_axis, cast_block_size); + } } - index++; } - return Expr(0); - }; - for (auto it : for_info.ops_) { - InsnAxis axis; - auto for_stmt = it.as(); - CHECK(for_stmt); - axis.var = for_stmt->loop_var; - axis.extent = GetInt32Const(for_stmt->extent); - axis.min = GetInt32Const(for_stmt->min); - int index = 0; - for (auto it : info_list) { - auto stride = GetInt32Const(GetStrideByAxis(it->var_, it->strides_, axis.var)); - axis.stride_list.push_back(stride); - if (index == 0) { - axis.dst_stride = stride; - } else { - axis.src_stride_list.push_back(stride); + if (!cast && (!args_calculator.IsValid(vec_axis_it) || vec_axis_it->extent <= cast_block_size * tail_rate_)) { + auto get_block_axis = [&](std::list &axis_list) { + InsnAxis block_axis; + block_axis.is_valid = false; + std::vector temp_axis_set; + auto block_stride_lambda = [&](int stride) { return stride % block_size == 0 && stride / block_size <= 4; }; + for (auto axis : axis_list) { + if (std::all_of(axis.stride_list.begin(), axis.stride_list.end(), block_stride_lambda) && + axis.dst_stride != 0 && axis.extent != 0 && axis.extent > FULL_BLOCK_NUM && + axis.extent % FULL_BLOCK_NUM != 0) { + temp_axis_set.push_back(axis); + } + } + if (!temp_axis_set.empty()) { + return temp_axis_set[0]; + } else { + return block_axis; + } + }; + auto block_axis = get_block_axis(args_calculator.axis_list_); + if (block_axis.IsValid()) { + return TailBlock(s, block_axis, FULL_BLOCK_NUM); } - index++; } - axis_list.push_back(axis); + return s; } - return axis_list; + return IRMutator::Mutate_(op, s); } + Stmt TailBlock(const Stmt &s, const InsnAxis &tail_axis, int body_size) { + return Block::make(TailMake(s, tail_axis, body_size, false), TailMake(s, tail_axis, body_size, true)); + } Stmt TailMake(const Stmt &s, const InsnAxis &tail_axis, int body_size, bool is_tail) { if (auto attr_stmt = s.as()) { return AttrStmt::make(attr_stmt->node, attr_stmt->attr_key, attr_stmt->value, @@ -145,8 +123,7 @@ class TailSpliter : public IRMutator { } return For::make(for_stmt->loop_var, for_stmt->min, for_stmt->extent, for_stmt->for_type, for_stmt->device_api, TailMake(for_stmt->body, tail_axis, body_size, is_tail)); - - } + } if (s.as() && is_tail) { return substitute(tail_axis.var, Add::make(Expr(tail_axis.extent / body_size * body_size), tail_axis.var), s); } @@ -156,6 +133,20 @@ class TailSpliter : public IRMutator { private: const float tail_rate_{0.6}; const std::set include_intrin_list_ = { + // binary vec + "vec_binary_add", + "vec_binary_sub", + "vec_binary_mul", + "vec_binary_min", + "vec_binary_max", + "vec_binary_div", + "vec_binary_and", + "vec_binary_or", + "vec_binary_vmadd", + "vec_binary_vmaddrelu", + "vec_binary_vmla", + + // single vec "vec_single_fabs", "vec_single_log", "vec_single_exp", @@ -165,20 +156,28 @@ class TailSpliter : public IRMutator { "vec_single_rsqrt", "vec_single_relu", "vec_single_not", - // vector_scalar - "vec_single_muls", - "vec_single_adds", // Mov "broadcast", + "mask_broadcast", // vector_cast "vec_single_cast", "vec_single_floor", "vec_single_round", "vec_single_ceil", "vec_single_trunc", + // scalar case + "vector_dup", + "vmuls", + "vadds", + "vaxpy", }; }; -Stmt SplitTail(Stmt stmt) { return TailSpliter().Mutate(stmt); } +Stmt SplitTail(Stmt stmt) { + auto tail_spliter = TailSpliter(); + auto first_round = tail_spliter.Mutate(stmt); + auto second_round = tail_spliter.Mutate(stmt); + return second_round; +} } // namespace ir } // namespace akg \ No newline at end of file