提交 aec205ff 编写于 作者: C cy 提交者: wYann

rewrite insn pattern generator in EmitInsn

上级 11ed37cc
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/base.h>
#include <tvm/ir_pass.h>
#include "common/array_api.h"
#include "pass/expr_alg_simplify.h"
#include "insn_pattern.h"
#include "insn_args_calculator.h"
namespace akg {
InsnAxis::InsnAxis(const For *for_stmt, const Array<StmtStoreInfo> &info_list) {
this->var = for_stmt->loop_var;
this->extent = GetInt32Const(for_stmt->extent);
this->min = GetInt32Const(for_stmt->min);
int index = 0;
for (auto it : info_list) {
auto stride = GetInt32Const(GetStrideByAxis(it->var_, it->strides_, this->var));
this->stride_list.push_back(stride);
if (index == 0) {
this->dst_stride = stride;
} else {
this->src_stride_list.push_back(stride);
}
index++;
}
}
Expr InsnAxis::GetStrideByAxis(const Array<Var> &vars, const Array<Expr> &strides, Var obj_var) {
int index = 0;
for (auto var_it : vars) {
if (Equal(var_it, obj_var)) {
return strides[index];
}
index++;
}
return Expr(0);
};
bool InsnAxis::IsValid() { return this->is_valid; }
void InsnAxis::Print(const std::string &name) {
if (!name.empty()) {
LOG(DEBUG) << "********** " << name << " ************";
}
auto r_stride = this->src_stride_list.size() > 1 ? src_stride_list[1] : 99999;
LOG(DEBUG) << "var:" << this->var << " extent:" << this->extent << " min:" << this->min
<< " dst_stride:" << this->dst_stride << " src_stride_l:" << this->src_stride_list.front()
<< "src_stride_r:" << r_stride;
}
Array<StmtStoreInfo> GetInfoList(const StmtStoreInfo &dst_info, const Array<StmtStoreInfo> &src_info_list) {
Array<StmtStoreInfo> res;
res.push_back(dst_info.Copy());
for (auto it : src_info_list) {
res.push_back(it.Copy());
}
return res;
};
std::list<InsnAxis> GetAxisList(const StmtInfo &for_info, const Array<StmtStoreInfo> &info_list) {
std::list<InsnAxis> axis_list;
for (auto it : for_info.ops_) {
auto for_stmt = it.as<For>();
CHECK(for_stmt);
auto axis = InsnAxis(for_stmt, info_list);
axis_list.push_back(axis);
}
return axis_list;
}
void Print(std::list<InsnAxis> &axis_list) {
LOG(DEBUG) << "+++++++++++++++++++ AXIS_LIST +++++++++++++++++++";
int index = 0;
for (auto it : axis_list) {
LOG(DEBUG) << "================== INDEX " << index << " =================";
it.Print();
index++;
}
LOG(DEBUG) << "------------------ END ---------------------";
}
InsnArgsCalculator::InsnArgsCalculator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list,
const StmtInfo &for_info, const std::string &intrin_name)
: dst_info_list_(dst_info_list), src_info_list_(src_info_list), for_info_(for_info), intrin_name_(intrin_name) {
InitArg();
CalAxis();
}
void InsnArgsCalculator::CalAxis() {
CHECK(!dst_info_list_.empty());
dst_info_ = dst_info_list_[0];
if (src_info_list_.empty()) {
src_info_list_ = {dst_info_.Copy()};
}
auto src_info = src_info_list_[0];
dst_info_.Print();
for (auto src_info_it : src_info_list_) {
src_info_it.Print();
if (src_info_it->name_ == dst_info_->name_) {
meta_.same_dst_src = true;
}
}
meta_.dst_block_size = GetUbBlkSize(dst_info_->dtype_);
meta_.src_block_size = GetUbBlkSize(src_info->dtype_);
meta_.cast = meta_.dst_block_size != meta_.src_block_size;
meta_.block_size = meta_.dst_block_size <= meta_.src_block_size ? meta_.dst_block_size : meta_.src_block_size;
meta_.src_dtype = src_info->dtype_;
meta_.dst_dtype = dst_info_->dtype_;
meta_.dtype = meta_.dst_dtype.bits() >= meta_.src_dtype.bits() ? meta_.dst_dtype : meta_.src_dtype;
auto elem_offset_mod = ir::ExprSimplifier().Simplify(Mod::make(dst_info_->elem_offset_, meta_.block_size));
if (elem_offset_mod.as<IntImm>()) {
meta_.block_offset = elem_offset_mod.as<IntImm>()->value;
}
axis_list_ = GetAxisList(for_info_, GetInfoList(dst_info_, src_info_list_));
} // namespace akg
void InsnArgsCalculator::InitArg() {
arg_.src_m1_list = {0, 0};
arg_.src_m0_list = {1, 1};
}
std::function<bool(const InsnAxis &)> InsnArgsCalculator::GetStrideLambda() {
return [&](const InsnAxis &axis) {
auto is_stride = [&](int stride) { return stride % meta_.block_size == 0; };
auto zero_stride = [&](int stride) { return stride == 0; };
return std::all_of(axis.stride_list.begin(), axis.stride_list.end(), is_stride) &&
!std::all_of(axis.stride_list.begin(), axis.stride_list.end(), zero_stride);
};
}
std::function<bool(const InsnAxis &)> InsnArgsCalculator::GetM0LimitLambda() {
return [&](const InsnAxis &axis) {
auto is_limit = [&](int stride) { return stride / meta_.block_size < MAX_STRIDE_M0_SINGLE; };
return std::all_of(axis.stride_list.begin(), axis.stride_list.end(), is_limit);
};
}
std::function<bool(const InsnAxis &)> InsnArgsCalculator::GetBlockStrideLimitLambda() {
return [&](const InsnAxis &axis) {
auto is_limit = [&](int stride) { return stride / meta_.block_size <= max_block_stride_; };
return std::all_of(axis.stride_list.begin(), axis.stride_list.end(), is_limit);
};
}
std::function<bool(const InsnAxis &)> InsnArgsCalculator::GetM1LimitLambda() {
return [&](const InsnAxis &axis) {
auto is_limit = [&](int stride) { return stride / meta_.src_block_size < MAX_STRIDE_M1; };
return axis.dst_stride / meta_.dst_block_size < MAX_STRIDE_M1 &&
std::all_of(axis.src_stride_list.begin(), axis.src_stride_list.end(), is_limit);
};
}
std::function<bool(const InsnAxis &)> And(const std::list<std::function<bool(const InsnAxis &)>> &lambda_list) {
return [&lambda_list](const InsnAxis &axis) {
bool res = true;
for (auto lambda_it : lambda_list) {
res = res && lambda_it(axis);
}
return res;
};
}
AxisIt InsnArgsCalculator::GetAxisByLambda(const std::function<bool(const InsnAxis &)> &lambda) {
for (auto axis_it = axis_list_.begin(); axis_it != axis_list_.end(); axis_it++) {
if (lambda(*axis_it)) {
return axis_it;
}
}
return axis_list_.end();
}
InsnAxis InsnArgsCalculator::ExtractAxis(AxisIt &it) {
InsnAxis res = *it;
axis_list_.erase(it);
return res;
}
bool InsnArgsCalculator::IsValid(AxisIt &it) { return it != axis_list_.end(); }
void AxisSort(std::list<InsnAxis> &axis_arr, bool order = true) {
auto up_compare = [&](InsnAxis &a, InsnAxis &b) { return a.extent < b.extent; };
auto down_compare = [&](InsnAxis &a, InsnAxis &b) { return a.extent > b.extent; };
if (order) {
axis_arr.sort(up_compare);
} else {
axis_arr.sort(down_compare);
}
}
AxisIt InsnArgsCalculator::GetVecAxisIt() {
axis_list_.reverse();
auto IsVecAxis = [&](const InsnAxis &axis) {
return !(std::any_of(axis.stride_list.begin(), axis.stride_list.end(), [](int stride) { return stride > 1; }) ||
std::all_of(axis.stride_list.begin(), axis.stride_list.end(), [](int stride) { return stride == 0; }));
};
return GetAxisByLambda(IsVecAxis);
}
SplitStat InsnArgsCalculator::SplitAxis(int extent, InsnAxis &axis) {
if (axis.extent <= extent) {
return NO_SPLIT;
}
if (axis.extent % extent != 0) {
return TAIL;
}
InsnAxis new_axis;
new_axis.extent = axis.extent / extent;
for (auto stride : axis.stride_list) {
new_axis.stride_list.push_back(stride * extent);
}
auto temp_stride_list = new_axis.stride_list;
CHECK(!temp_stride_list.empty());
new_axis.dst_stride = temp_stride_list.front();
temp_stride_list.erase(temp_stride_list.begin());
new_axis.src_stride_list = temp_stride_list;
new_axis.var = Var(axis.var->name_hint);
axis_list_.push_back(new_axis);
axis.extent = extent;
return SUCCESS;
}
AxisIt InsnArgsCalculator::GetBlockAxis() {
AxisSort(axis_list_);
auto stride_lambda = GetStrideLambda();
auto m0_limit_lambda = GetM0LimitLambda();
auto block_stride_limit_lambda = GetBlockStrideLimitLambda();
auto axis_it =
GetAxisByLambda(And({stride_lambda, m0_limit_lambda, block_stride_limit_lambda, [&](const InsnAxis &axis) {
return axis.extent >= FULL_BLOCK_NUM && axis.extent % FULL_BLOCK_NUM == 0;
}}));
if (IsValid(axis_it)) {
return axis_it;
}
axis_it = GetAxisByLambda(And({stride_lambda, m0_limit_lambda, block_stride_limit_lambda,
[&](const InsnAxis &axis) { return axis.extent >= FULL_BLOCK_NUM; }}));
if (IsValid(axis_it) && axis_list_.size() == 1) {
return axis_it;
}
axis_list_.reverse();
axis_it = GetAxisByLambda(And({stride_lambda, m0_limit_lambda, block_stride_limit_lambda,
[&](const InsnAxis &axis) { return axis.extent < FULL_BLOCK_NUM; }}));
if (IsValid(axis_it)) {
return axis_it;
}
return GetAxisByLambda(And({stride_lambda, m0_limit_lambda, [&](const InsnAxis &axis) {
return axis.extent <= FULL_BLOCK_NUM || axis.extent % FULL_BLOCK_NUM == 0;
}}));
}
AxisIt InsnArgsCalculator::GetRepeatAxisIt() {
AxisSort(axis_list_);
auto stride_lambda = GetStrideLambda();
auto m1_limit_lambda = GetM1LimitLambda();
auto axis_it = GetAxisByLambda(
And({stride_lambda, m1_limit_lambda, [&](const InsnAxis &axis) { return axis.extent >= MAX_REPEAT - 1; }}));
if (IsValid(axis_it)) {
return axis_it;
}
axis_list_.reverse();
return GetAxisByLambda(And({stride_lambda, m1_limit_lambda}));
}
void InsnArgsCalculator::SetArgMask(int len) {
SetArgBlockNum(1);
SetArgBlockLen(len);
}
void InsnArgsCalculator::SetArgBlockNum(int data_num) { arg_.block_num = data_num; }
void InsnArgsCalculator::SetArgBlockLen(int data_len) { arg_.block_len = data_len; }
void InsnArgsCalculator::SetArgM0(int dst_m0, int lsrc_m0, int rsrc_m0 = 0) {
arg_.dst_m0 = dst_m0;
arg_.src_m0_list = {lsrc_m0, rsrc_m0};
}
void InsnArgsCalculator::SetArgM1(int dst_m1, int lsrc_m1, int rsrc_m1 = 0) {
arg_.dst_m1 = dst_m1;
arg_.src_m1_list = {lsrc_m1, rsrc_m1};
}
void InsnArgsCalculator::SetArgRepeat(int repeat) { arg_.repeat = repeat; }
void InsnArgsCalculator::BlockAxisReduction() {
Print(axis_list_);
auto block_axis_it = GetBlockAxis();
if (IsValid(block_axis_it)) {
auto origin_block_axis = *block_axis_it;
InsnAxis block_axis = ExtractAxis(block_axis_it);
if (block_axis.extent % FULL_BLOCK_NUM != 0 && block_axis.extent > FULL_BLOCK_NUM) {
arg_.tail_len = block_axis.extent % FULL_BLOCK_NUM;
block_axis.extent = FloorTo(block_axis.extent, FULL_BLOCK_NUM);
arg_.dst_tail_offset = block_axis.dst_stride * block_axis.extent;
for (auto stride : block_axis.src_stride_list) {
arg_.src_tail_offset_list.push_back(stride * block_axis.extent);
}
SplitAxis(FULL_BLOCK_NUM, block_axis);
auto repeat_axis_it = GetRepeatAxisIt();
if (!IsValid(repeat_axis_it) && axis_list_.size() > 0) {
for (auto it = axis_list_.begin(); it != axis_list_.end(); it++) {
if (it->var->name_hint == block_axis.var->name_hint) {
axis_list_.erase(it);
break;
}
}
axis_list_.push_back(origin_block_axis);
return;
}
} else {
SplitAxis(FULL_BLOCK_NUM, block_axis);
}
block_axis.Print("BLOCK_AXIS");
SetArgM0(block_axis.dst_stride / meta_.block_size, block_axis.src_stride_list.front() / meta_.block_size,
block_axis.src_stride_list.back() / meta_.block_size);
SetArgBlockNum(block_axis.extent);
}
}
void InsnArgsCalculator::RepeatAxisReduction() {
Print(axis_list_);
auto repeat_axis = GetRepeatAxis();
if (repeat_axis.IsValid()) {
repeat_axis.Print("REPEAT_AXIS");
SetArgM1(repeat_axis.dst_stride / meta_.dst_block_size, repeat_axis.src_stride_list.front() / meta_.src_block_size,
repeat_axis.src_stride_list.back() / meta_.src_block_size);
SetArgRepeat(repeat_axis.extent);
}
}
InsnAxis InsnArgsCalculator::GetInvalidAxis() {
InsnAxis res;
res.is_valid = false;
return res;
}
InsnAxis InsnArgsCalculator::GetRepeatAxis() {
auto repeat_axis_it = GetRepeatAxisIt();
if (IsValid(repeat_axis_it)) {
InsnAxis repeat_axis = ExtractAxis(repeat_axis_it);
SplitAxis(MAX_REPEAT - 1, repeat_axis);
return repeat_axis;
}
return GetInvalidAxis();
}
void InsnArgsCalculator::CastCaseReduction() {
if (axis_list_.empty()) {
return;
}
Print(axis_list_);
int cast_block_size = meta_.dst_block_size < meta_.src_block_size ? meta_.dst_block_size : meta_.src_block_size;
auto vec_axis_it = GetVecAxisIt();
if (IsValid(vec_axis_it)) {
InsnAxis vec_axis = ExtractAxis(vec_axis_it);
int max_vec_len = cast_block_size * FULL_BLOCK_NUM;
if (vec_axis.extent > cast_block_size && vec_axis.extent < max_vec_len) {
SetArgMask(DivFloor(vec_axis.extent, cast_block_size) * cast_block_size);
SetArgM0(1, 1, 1);
} else if (vec_axis.extent >= max_vec_len) {
SplitAxis(max_vec_len, vec_axis);
SetArgMask(DivFloor(vec_axis.extent, cast_block_size) * cast_block_size);
SetArgM0(1, 1, 1);
} else {
SetArgBlockLen(cast_block_size);
}
}
RepeatAxisReduction();
}
int DivFloor(int a, int b) {
if (a % b == 0) {
return a / b;
} else {
return a / b + 1;
}
}
void InsnArgsCalculator::InsnReduction() {
if (axis_list_.empty()) {
return;
}
Print(axis_list_);
auto vec_axis_it = GetVecAxisIt();
meta_.scalar = !IsValid(vec_axis_it);
if (!meta_.scalar) {
InsnAxis vec_axis = ExtractAxis(vec_axis_it);
int max_vec_len = meta_.block_size * FULL_BLOCK_NUM;
if (vec_axis.extent > meta_.block_size && vec_axis.extent < max_vec_len &&
(vec_axis.extent % meta_.block_size != 0 || vec_axis.extent > max_vec_len * meta_.vec_rate)) {
vec_axis.Print("VEC_BLOCK_AXIS");
SetArgMask(DivFloor(vec_axis.extent, meta_.block_size) * meta_.block_size);
SetArgM0(1, 1, 1);
} else {
SplitAxis(meta_.block_size, vec_axis);
vec_axis.Print("VEC_AXIS");
SetArgBlockLen(meta_.block_size);
BlockAxisReduction();
}
RepeatAxisReduction();
} else {
BlockAxisReduction();
RepeatAxisReduction();
}
Print(axis_list_);
}
Expr InsnArgsCalculator::GetOffset(int stride_index) {
Expr res = Expr(0);
for (auto axis_it : axis_list_) {
auto stride = axis_it.stride_list[stride_index];
auto mul_expr = Mul::make(stride, axis_it.var);
res = Add::make(mul_expr, res);
}
return Simplify(res);
}
StmtInfo InsnArgsCalculator::ExportForInfo() {
if (for_info_.ops_.empty()) {
return for_info_;
}
int last_index = for_info_.ops_.size() - 1;
auto last_for = for_info_.ops_[last_index].as<For>();
auto store_stmt = last_for->body;
Stmt for_stmt = store_stmt;
StmtInfo result;
for (auto axis_it : axis_list_) {
for_stmt = For::make(axis_it.var, axis_it.min, axis_it.extent, last_for->for_type, last_for->device_api, for_stmt);
result.ops_.push_back(for_stmt);
result.vars_.push_back(axis_it.var);
}
return result;
}
PatternResult InsnArgsCalculator::ExportResult() {
PatternResult res;
auto arg_info = ArgInfo(make_node<ArgInfoNode>());
auto body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
body_args.GetNode()->body_num_ = arg_.body_num;
body_args.GetNode()->body_offset_ = meta_.block_size * FULL_BLOCK_NUM;
body_args.GetNode()->repeat_ = Expr(arg_.repeat);
body_args.GetNode()->dst_stride_m0_ = Expr(arg_.dst_m0);
body_args.GetNode()->dst_stride_m1_ = Expr(arg_.dst_m1);
body_args.GetNode()->src_stride_m0_list_ = arg_.src_m0_list;
body_args.GetNode()->src_stride_m1_list_ = arg_.src_m1_list;
body_args.GetNode()->vec_mask_ = GetVecMask(arg_.block_len, arg_.block_num, meta_.dtype, meta_.block_offset);
body_args.GetNode()->block_offset_ = make_const(Int(32), meta_.block_offset);
arg_info.GetNode()->body_arg_info_ = body_args;
if (arg_.tail_len > 0) {
auto tail_args = VectorArgInfo(make_node<VectorArgInfoNode>());
tail_args.GetNode()->dst_head_ = Expr(arg_.dst_tail_offset);
tail_args.GetNode()->dst_stride_m1_ = Expr(arg_.dst_m1);
tail_args.GetNode()->src_stride_m1_list_ = arg_.src_m1_list;
tail_args.GetNode()->repeat_ = Expr(1);
tail_args.GetNode()->src_head_list_ = arg_.src_tail_offset_list;
tail_args.GetNode()->body_offset_ = meta_.block_size * FULL_BLOCK_NUM;
tail_args.GetNode()->dst_stride_m0_ = Expr(arg_.dst_m0);
tail_args.GetNode()->src_stride_m0_list_ = arg_.src_m0_list;
tail_args.GetNode()->vec_mask_ = GetVecMask(arg_.block_len, arg_.tail_len, meta_.dtype, meta_.block_offset);
tail_args.GetNode()->block_offset_ = make_const(Int(32), meta_.block_offset);
arg_info.GetNode()->tail_arg_info_ = tail_args;
}
StmtInfoList info_list = GetInfoList(dst_info_, src_info_list_);
CleanZeroStrides(info_list);
for (size_t i = 0; i < info_list.size(); i++) {
info_list[i].GetNode()->insn_offset_ = GetOffset(i);
}
info_list[1].Print();
res.for_info = ExportForInfo();
res.arg_info = arg_info;
res.dst_info_list = {info_list[0]};
if (info_list.size() > 2) {
res.src_info_list = {info_list[1], info_list[2]};
} else {
res.src_info_list = {info_list[1]};
}
body_args.Print();
if (arg_info->tail_arg_info_.defined()) {
arg_info->tail_arg_info_.Print();
}
return res;
}
SingleVecInsnArgsCalculator::SingleVecInsnArgsCalculator(const StmtInfoList &dst_info_list,
const StmtInfoList &src_info_list, const StmtInfo &for_info,
const std::string &intrin_name)
: InsnArgsCalculator(dst_info_list, src_info_list, for_info, intrin_name) {}
PatternResult SingleVecInsnArgsCalculator::GetInsnArgs() {
if (meta_.cast) {
CastCaseReduction();
} else {
InsnReduction();
}
return ExportResult();
}
BinaryVecInsnArgsCalculator::BinaryVecInsnArgsCalculator(const StmtInfoList &dst_info_list,
const StmtInfoList &src_info_list, const StmtInfo &for_info,
const std::string &mode, const std::string &intrin_name,
bool expand_mask)
: InsnArgsCalculator(dst_info_list, src_info_list, for_info, intrin_name), mode_{mode}, expand_mask_{expand_mask} {
if (mode_ == "reduction" && src_info_list_.size() == 2 && src_info_list_[0]->name_ == dst_info_list[0]->name_) {
auto temp = src_info_list_[0].Copy();
src_info_list_.Set(0, src_info_list_[1].Copy());
src_info_list_.Set(1, temp);
CalAxis();
}
}
PatternResult BinaryVecInsnArgsCalculator::GetInsnArgs() {
LOG(DEBUG) << "Binary vec Insn reduction";
InsnReduction();
return ExportResult();
}
std::function<bool(const InsnAxis &)> BinaryVecInsnArgsCalculator::GetM0LimitLambda() {
return [&](const InsnAxis &axis) {
auto is_limit = [&](int stride) { return stride / meta_.block_size < MAX_STRIDE_M0; };
return std::all_of(axis.stride_list.begin(), axis.stride_list.end(), is_limit) && axis.dst_stride != 0;
};
}
std::function<bool(const InsnAxis &)> BinaryVecInsnArgsCalculator::GetM1LimitLambda() {
return [&](const InsnAxis &axis) {
auto is_limit = [&](int stride) { return stride / meta_.src_block_size < MAX_STRIDE_M1; };
return axis.dst_stride / meta_.dst_block_size < MAX_STRIDE_M1 &&
std::all_of(axis.src_stride_list.begin(), axis.src_stride_list.end(), is_limit);
};
}
void BinaryVecInsnArgsCalculator::InsnReduction() {
if (axis_list_.empty()) {
return;
}
Print(axis_list_);
auto vec_axis_it = GetVecAxisIt();
meta_.scalar = !IsValid(vec_axis_it);
if (!meta_.scalar) {
vec_axis_ = *vec_axis_it;
InsnAxis vec_axis = ExtractAxis(vec_axis_it);
auto bad_axis_lambda = [&](const InsnAxis &axis) {
int min_stride = vec_axis_it->extent;
auto dst_name = dst_info_list_[0]->name_;
if (meta_.same_dst_src && axis.dst_stride < min_stride && axis.dst_stride != 0) {
return true;
}
return false;
};
auto bad_axis_it = GetAxisByLambda(bad_axis_lambda);
InsnAxis bad_axis;
bad_axis.is_valid = false;
if (IsValid(bad_axis_it)) {
bad_axis = ExtractAxis(bad_axis_it);
}
int max_vec_len = meta_.block_size * FULL_BLOCK_NUM;
if (vec_axis.extent > meta_.block_size && vec_axis.extent < max_vec_len &&
(vec_axis.extent % meta_.block_size != 0 || vec_axis.extent > max_vec_len * meta_.vec_rate)) {
vec_axis.Print("VEC_BLOCK_AXIS");
if (expand_mask_) {
SetArgMask(DivFloor(vec_axis.extent, meta_.block_size) * meta_.block_size);
} else {
SetArgMask(vec_axis.extent);
}
SetArgM0(1, 1, 1);
} else {
SplitAxis(meta_.block_size, vec_axis);
vec_axis.Print("VEC_AXIS");
if (expand_mask_ && mode_ != "reduction") {
SetArgBlockLen(meta_.block_size);
} else {
SetArgBlockLen(vec_axis.extent);
}
BlockAxisReduction();
}
RepeatAxisReduction();
if (bad_axis.IsValid()) {
axis_list_.push_back(bad_axis);
}
} else {
BlockAxisReduction();
RepeatAxisReduction();
}
Print(axis_list_);
}
PatternResult LastAxisReduceInsnArgsCalculator::GetInsnArgs() {
CalcParams();
Array<Var> elim_var;
elim_var = GetPattern();
arg_info.GetNode()->pattern_ = PATTERN_1D;
return GenResult(elim_var);
}
Array<Var> LastAxisReduceInsnArgsCalculator::GetPattern() {
int body_len = params.last_dim_shape / params.vec_max_len * params.vec_max_len;
int tail_len = params.last_dim_shape % params.vec_max_len;
int cmd_body_len = 0;
bool is_vadd = intrin_name == "vadd";
int repeat_stride = FULL_BLOCK_NUM;
if (is_vadd) {
repeat_stride = 1;
}
const int fp16_block_size = 16;
if (body_len > 0) {
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
body_args.GetNode()->body_num_ = 1;
body_args.GetNode()->body_offset_ = params.vec_max_len;
body_args.GetNode()->repeat_ = Expr(body_len / params.vec_max_len);
// Here use dst_stride_m1 as dst_stride
body_args.GetNode()->dst_stride_m1_ = Expr(1);
body_args.GetNode()->src_stride_m0_list_ = {Expr(1)};
body_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)};
body_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, 1, dst_info->dtype_);
cmd_body_len += GetInt32Const(body_args->repeat_) * repeat_stride;
}
if (tail_len > 0) {
tail_args = VectorArgInfo(make_node<VectorArgInfoNode>());
tail_args.GetNode()->body_offset_ = params.vec_max_len;
tail_args.GetNode()->dst_head_ = Expr(cmd_body_len);
tail_args.GetNode()->src_head_list_ = {Expr(body_len)};
tail_args.GetNode()->repeat_ = Expr(1);
tail_args.GetNode()->dst_stride_m1_ = Expr(1);
tail_args.GetNode()->src_stride_m0_list_ = {Expr(1)};
tail_args.GetNode()->src_stride_m1_list_ = {Expr(0)};
tail_args.GetNode()->vec_mask_ = GetVecMask(tail_len, 1, dst_info->dtype_);
if (is_vadd) {
cmd_body_len += 1;
} else {
cmd_body_len += tail_len / fp16_block_size;
if (tail_len % fp16_block_size != 0) {
cmd_body_len += 1;
}
}
}
// cmd_body_len > 1 means vcadd size greater than 128, need to use vcadd again to compute final result
// if cmd_body_len > 128, then need to recursively emit vcadd
while (cmd_body_len > 1) {
int cmd_tail_len = cmd_body_len % params.vec_max_len;
cmd_body_len = cmd_body_len / params.vec_max_len;
if (cmd_body_len > 0) {
VectorArgInfo mix_vec_args = VectorArgInfo(make_node<VectorArgInfoNode>());
mix_vec_args.GetNode()->repeat_ = Expr(cmd_body_len);
mix_vec_args.GetNode()->dst_head_ = Expr(0);
mix_vec_args.GetNode()->src_head_list_ = {Expr(0)};
mix_vec_args.GetNode()->dst_stride_m1_ = Expr(1);
mix_vec_args.GetNode()->src_stride_m0_list_ = {Expr(1)};
mix_vec_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)};
mix_vec_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, 1, dst_info->dtype_);
mix_vec_arg_list.push_back(mix_vec_args);
if (!is_vadd) {
cmd_body_len *= FULL_BLOCK_NUM;
}
}
if (cmd_tail_len > 0) {
VectorArgInfo mix_vec_args = VectorArgInfo(make_node<VectorArgInfoNode>());
mix_vec_args.GetNode()->repeat_ = Expr(1);
mix_vec_args.GetNode()->dst_head_ = Expr(cmd_body_len);
if (is_vadd) {
mix_vec_args.GetNode()->src_head_list_ = {Expr(cmd_body_len * params.vec_max_len)};
} else {
mix_vec_args.GetNode()->src_head_list_ = {Expr(cmd_body_len / FULL_BLOCK_NUM * params.vec_max_len)};
}
mix_vec_args.GetNode()->dst_stride_m1_ = Expr(1);
mix_vec_args.GetNode()->src_stride_m0_list_ = {Expr(1)};
mix_vec_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)};
mix_vec_args.GetNode()->vec_mask_ = GetVecMask(cmd_tail_len, 1, dst_info->dtype_);
if (is_vadd) {
cmd_body_len += 1;
} else {
cmd_body_len += cmd_tail_len / fp16_block_size;
if (cmd_tail_len % fp16_block_size != 0) {
cmd_body_len += 1;
}
}
mix_vec_arg_list.push_back(mix_vec_args);
}
}
params.insn_offset_scale_factor = Expr(params.block_size);
int max_num = body_len / params.vec_max_len;
if (intrin_name == "vmax" || intrin_name == "vmin") {
max_num *= FULL_BLOCK_NUM;
}
if (max_num >= params.block_size) {
params.insn_offset_scale_factor = max_num + params.block_size - 1;
if (tail_len > 0) {
params.insn_offset_scale_factor += 1;
}
params.insn_offset_scale_factor = truncdiv(params.insn_offset_scale_factor, params.block_size) * params.block_size;
}
if (!params.src_var.empty()) {
return GetRange(params.src_var, -1, 1);
}
return {};
}
PatternResult LastAxisReduceInsnArgsCalculator::GenResult(const Array<Var> &elim_var) {
dst_info.GetNode()->insn_offset_ = GetInsnOffset(dst_info, elim_var) * params.insn_offset_scale_factor;
src_info.GetNode()->insn_offset_ = GetInsnOffset(src_info, elim_var);
if (body_args.defined()) {
body_args.GetNode()->insn_offset_scale_factor_ = params.insn_offset_scale_factor;
}
if (tail_args.defined()) {
tail_args.GetNode()->insn_offset_scale_factor_ = params.insn_offset_scale_factor;
}
for (auto &arg : mix_vec_arg_list) {
arg.GetNode()->insn_offset_scale_factor_ = params.insn_offset_scale_factor;
}
arg_info.GetNode()->body_arg_info_ = body_args;
arg_info.GetNode()->tail_arg_info_ = tail_args;
arg_info.GetNode()->reduction_tail_args_ = mix_vec_arg_list;
CleanForInfoVars(for_info, elim_var);
arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION_LAST_AXIS;
PatternResult result;
result.dst_info_list = {dst_info};
result.src_info_list = {src_info};
result.for_info = for_info;
result.arg_info = arg_info;
return result;
}
void LastAxisReduceInsnArgsCalculator::CalcParams() {
// check shape len
if (dst_info->shape_.empty() || src_info->shape_.empty()) {
LOG(FATAL) << "CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0.";
}
// check data type
if (dst_info->dtype_ != src_info->dtype_) {
LOG(FATAL) << "CCE Vector Insn Error: dst_buffer and src_buffer can not be different data type.";
}
params.src_var = src_info->var_;
params.block_size = GetUbBlkSize(dst_info->dtype_);
params.last_dim_shape = GetInt32Const(GetItem(src_info->shape_, -1));
params.vec_max_len = GetVecMaxLen(dst_info->dtype_);
CHECK_NE(params.block_size, 0);
CHECK_NE(params.vec_max_len, 0);
}
/// Generete info list for bisection intrin
/// \param dst_info_list
/// \param src_info_list
/// \param for_info
/// \param if_info
/// \param last_axis
/// \param postfix
/// \return
BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_info_list,
const StmtInfoList &src_info_list, const StmtInfo &for_info,
StmtInfo &if_info, bool last_axis, int postfix = 0) {
CHECK_EQ(dst_info_list.size(), 1);
CHECK_EQ(src_info_list.size(), 2);
BisectionInfoWrapper wrapper;
// Separate com_info and for_info
int compare_idx = 1;
int var_idx = -1;
var_idx = GetBisectionReductionIdx(dst_info_list, src_info_list, compare_idx);
StmtStoreInfo dst_info = dst_info_list[0];
CHECK_GE(compare_idx, 0);
StmtStoreInfo src_info1 = src_info_list[compare_idx];
Var reduce_var = GetItem(src_info1->var_, var_idx);
int stride_len = GetInt32Const(GetItem(src_info1->strides_, var_idx));
size_t for_idx = 0;
bool suc = GetIndexOfElement(for_info.vars_, VarExpr(reduce_var), for_idx);
CHECK(suc);
auto exist_for = GetItem(for_info.ops_, for_idx).as<For>();
CHECK(exist_for);
int extent = GetInt32Const(exist_for->extent);
int simd_len = 1;
const std::string un_def_var = "un_def_var";
Var simd_var = Var("un_def_var");
CHECK_GT(src_info1->strides_.size(), 0);
CHECK_EQ(src_info1->var_.size(), src_info1->strides_.size());
for (size_t i = 0; i <= src_info1->strides_.size() - 1; i++) {
if (GetInt32Const(src_info1->strides_[i]) == 1) {
simd_var = src_info1->var_[i];
size_t simd_for_idx = 0;
bool suc = GetIndexOfElement(for_info.vars_, VarExpr(simd_var), simd_for_idx);
CHECK(suc);
auto simd_for = GetItem(for_info.ops_, simd_for_idx).as<For>();
CHECK(simd_for);
simd_len = GetInt32Const(simd_for->extent);
}
}
int block_unit = GetUbBlkSize(src_info1->dtype_);
int last_dim_len = ((simd_len - 1) / block_unit + 1) * block_unit;
Var bisec_var;
Buffer bisec_buffer;
std::string bisec_pre_header = "bisec";
std::string bisec_name = bisec_pre_header + "_local_UB";
if (postfix > 0) {
bisec_name = bisec_name + "_" + std::to_string(postfix);
}
int vec_max_len = GetVecMaxLen(dst_info->dtype_);
CHECK_NE(vec_max_len, 0);
std::vector<int> pow2_list = {0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536};
int origin_len = extent;
for (int i : pow2_list) {
if (extent <= i) {
extent = i / 2;
break;
}
}
int prolog_len = origin_len - extent;
src_info1.Print();
auto src_vars = src_info1->var_;
auto src_strides = src_info1->strides_;
auto src_dims = src_info1->shape_;
auto new_vars = src_info1->var_;
auto new_strides = src_info1->strides_;
auto new_dims = src_info1->shape_;
LOG(DEBUG) << "\nvar_idx:" << var_idx << "\n";
var_idx = var_idx + src_info1->var_.size();
if (var_idx != static_cast<int>(src_info1->var_.size()) - 1) {
new_dims.Set(var_idx, extent);
new_dims.Set(new_dims.size() - 1, last_dim_len);
CHECK_GT(new_dims.size(), 1);
new_strides.Set(new_strides.size() - 1, 1);
for (int i = static_cast<int>(new_dims.size()) - 2; i >= 0; i--) {
new_strides.Set(i, new_strides[i + 1] * new_dims[i + 1]);
}
new_dims.Set(new_dims.size() - 1, simd_len);
} else {
new_dims = {extent};
}
// copy data from origin buffer to new temp buffer
Array<Expr> shape = new_dims;
wrapper.original_shape_ = new_dims;
bisec_var = Var(bisec_name, Handle());
bisec_buffer = BufferNode::make(bisec_var, dst_info->dtype_, shape, Array<Expr>(), Expr(), bisec_name, SCOPE_UBUF, 0,
0, BufferType::kDefault);
// Need to copy input to bisect buffer
StmtStoreInfo copy_dst_info{src_info1.Copy()};
StmtStoreInfo copy_src_info{src_info1.Copy()};
StmtInfoList src_list = {copy_src_info};
auto for_tmp_info = for_info.Copy();
auto new_for = GetItem(for_tmp_info.ops_, for_idx).as<For>();
CHECK(new_for);
SetItem(for_tmp_info.ops_, static_cast<int>(for_idx),
For::make(new_for->loop_var, new_for->min, extent, new_for->for_type, new_for->device_api, new_for->body));
ReplaceVarWithNewForInfo(copy_dst_info, for_info, for_tmp_info);
ReplaceVarWithNewForInfo(copy_src_info, for_info, for_tmp_info);
SetItem(copy_src_info.GetNode()->shape_, var_idx, Expr(extent));
SetItem(copy_dst_info.GetNode()->shape_, var_idx, Expr(extent));
SetItem(copy_dst_info.GetNode()->strides_, var_idx, Expr(last_dim_len));
if (simd_var->name_hint != un_def_var) {
copy_dst_info.GetNode()->index_ = 0;
for (size_t i = 0; i <= new_vars.size() - 1; i++) {
copy_dst_info.GetNode()->index_ += new_vars[i] * new_strides[i];
}
} else {
copy_dst_info.GetNode()->index_ = last_dim_len * reduce_var;
}
copy_dst_info.GetNode()->elem_offset_ = 0;
copy_dst_info.GetNode()->name_ = bisec_name;
copy_dst_info.GetNode()->buffer_ = bisec_buffer;
copy_dst_info.GetNode()->data_ = bisec_var;
copy_dst_info.GetNode()->strides_ = new_strides;
CompactComputationInfoList(copy_dst_info, src_list, if_info, for_tmp_info);
wrapper.bisec_info_list_.emplace_back(StmtInfoList{copy_dst_info, copy_src_info});
wrapper.for_info_list_.push_back(for_tmp_info);
// Generate the vadd wrapper
while (extent >= 0) {
StmtStoreInfo dst_tmp_info = dst_info.Copy();
StmtStoreInfo src_tmp_info0{src_info1.Copy()};
StmtStoreInfo src_tmp_info1{src_info1.Copy()};
auto for_tmp_info = for_info.Copy();
int vadd_length = (prolog_len != 0) ? prolog_len : extent;
if (extent > 0) {
dst_tmp_info = src_info1.Copy();
dst_tmp_info.GetNode()->data_alignment_ = simd_len;
dst_tmp_info.GetNode()->name_ = bisec_name;
dst_tmp_info.GetNode()->buffer_ = bisec_buffer;
dst_tmp_info.GetNode()->data_ = bisec_var;
dst_tmp_info.GetNode()->shape_ = new_dims;
SetItem(dst_tmp_info.GetNode()->shape_, var_idx, Expr(vadd_length));
dst_tmp_info.GetNode()->strides_ = new_strides;
dst_tmp_info.GetNode()->var_ = new_vars;
dst_tmp_info.GetNode()->index_ = 0;
for (size_t i = 0; i <= new_vars.size() - 1; i++) {
dst_tmp_info.GetNode()->index_ += new_vars[i] * new_strides[i];
}
if (prolog_len == 0) {
src_tmp_info1 = dst_tmp_info.Copy();
src_tmp_info1.GetNode()->index_ = dst_tmp_info.GetNode()->index_ + extent * last_dim_len;
} else {
SetItem(src_tmp_info1.GetNode()->shape_, var_idx, Expr(vadd_length));
src_tmp_info1.GetNode()->index_ += extent * stride_len;
}
}
src_tmp_info0 = dst_tmp_info.Copy();
auto new_for = GetItem(for_tmp_info.ops_, for_idx).as<For>();
CHECK(new_for);
int temp_for_len = (vadd_length != 0) ? vadd_length : 1;
SetItem(
for_tmp_info.ops_, static_cast<int>(for_idx),
For::make(new_for->loop_var, new_for->min, temp_for_len, new_for->for_type, new_for->device_api, new_for->body));
if (extent == 0) {
src_tmp_info1.GetNode()->name_ = bisec_name;
src_tmp_info1.GetNode()->buffer_ = bisec_buffer;
src_tmp_info1.GetNode()->data_ = bisec_var;
if (simd_var->name_hint != un_def_var) {
src_tmp_info1.GetNode()->shape_ = RemoveItemAtIndex(new_dims, var_idx);
src_tmp_info1.GetNode()->strides_ = RemoveItemAtIndex(new_strides, var_idx);
src_tmp_info1.GetNode()->var_ = RemoveItemAtIndex(new_vars, var_idx);
src_tmp_info1.GetNode()->index_ = 0;
for (size_t i = 0; i <= src_tmp_info1->var_.size() - 1; i++) {
src_tmp_info1.GetNode()->index_ += src_tmp_info1->var_[i] * src_tmp_info1->strides_[i];
}
} else {
src_tmp_info1.GetNode()->shape_ = dst_tmp_info->shape_;
src_tmp_info1.GetNode()->strides_ = dst_tmp_info->strides_;
src_tmp_info1.GetNode()->var_ = dst_tmp_info->var_;
src_tmp_info1.GetNode()->index_ = dst_tmp_info->index_;
}
}
ReplaceVarWithNewForInfo(dst_tmp_info, for_info, for_tmp_info);
ReplaceVarWithNewForInfo(src_tmp_info0, for_info, for_tmp_info);
ReplaceVarWithNewForInfo(src_tmp_info1, for_info, for_tmp_info);
StmtInfoList src_list = {src_tmp_info0, src_tmp_info1};
CompactComputationInfoList(dst_tmp_info, src_list, if_info, for_tmp_info);
wrapper.for_info_list_.emplace_back(for_tmp_info);
if (extent == 0) {
// normally is bisect_tmp = bisect_tmp + bisect_tmp/src_tmp
wrapper.bisec_info_list_.emplace_back(StmtInfoList{dst_tmp_info, dst_tmp_info, src_tmp_info1});
} else {
// normally is dst_tmp = dst_tmp + bisect_tmp
wrapper.bisec_info_list_.emplace_back(StmtInfoList{dst_tmp_info, src_tmp_info0, src_tmp_info1});
}
if (extent == 0) {
break;
} else {
extent = extent / 2;
}
prolog_len = 0;
}
// Generate arg_info
for (size_t i = 0; i < wrapper.bisec_info_list_.size(); ++i) {
auto info_list = wrapper.bisec_info_list_[i];
auto new_for_info = wrapper.for_info_list_[i];
ArgInfo arg_info;
auto dst_list = GetRange(info_list, 0, 1);
auto src_list = GetRange(info_list, 1, info_list.size() - 1);
if (info_list.size() == 2) {
std::string dma_intrin = INTRIN_NAME_COPY_UB_TO_UB;
wrapper.dma_arg_info_map_ = GetDmaCopyInsnArgs(dma_intrin, dst_list, src_list, new_for_info);
} else {
// Bisect can't expand mask because it has inplace operation
if (i != wrapper.bisec_info_list_.size() - 1) {
// Last round dont need to add
FillLastDim(dst_list, src_list, new_for_info);
}
std::string mode = GetBinaryVecMode(dst_list, src_list, "vadd", false);
BinaryVecInsnArgsCalculator args_calculator =
BinaryVecInsnArgsCalculator(dst_list, src_list, new_for_info, mode, "", false);
PatternResult params = args_calculator.GetInsnArgs();
arg_info = params.arg_info;
dst_list = params.dst_info_list;
src_list = params.src_info_list;
new_for_info = params.for_info;
wrapper.bisec_info_list_[i] = {dst_list[0], src_list[0], src_list[1]};
}
wrapper.arg_info_list_.push_back(arg_info);
wrapper.for_info_list_[i] = new_for_info;
}
return wrapper;
}
/// Get CCE Binary Vector Insn Computation Info
/// \param stmt - operand stmt
/// \param intrin_name - vector intrin name
/// \param dst_info_list - dst computation info list
/// \param src_info_list - src computation info list
/// \param if_info - if info list
/// \param for_info - for info list
/// \return intrin args
ArgInfo GetBinaryVecInsnArgs(const Stmt &stmt, std::string intrin_name, StmtInfoList &dst_info_list,
StmtInfoList &src_info_list, StmtInfo &if_info, StmtInfo &for_info, bool enable_bisect) {
// check intrin_name
std::set<std::string> intrin_name_list = {"vadd", "vmax", "vmin", "vmul", "vdiv", "vsel", "vsub", "vand",
"vor", "vaxpy", "argmax", "argmin", "vmadd", "vmaddrelu", "vmla"};
if (intrin_name_list.count(intrin_name) == 0) {
LOG(FATAL) << "Error: CCE Binary Vector Insn doesn't support the given intrin_name.";
}
// get and check dst and src
GetCompactComputationInfo(stmt, dst_info_list, src_info_list, if_info, for_info, true);
// For vmadd/vmaddrelu/vmla we only need first two src
if (dst_info_list.size() != 1 || src_info_list.size() < 2) {
LOG(FATAL) << "CCE Binary Vector Insn only support ONE dst and TWO srcs.";
}
src_info_list = GetRange(src_info_list, 0, 2);
ArgInfo arg_info = ArgInfo(make_node<ArgInfoNode>());
// detect vector op mode
std::string mode = GetBinaryVecMode(dst_info_list, src_info_list, intrin_name, enable_bisect);
if (mode == "reduce_last_axis") {
size_t src_var_list_size = src_info_list[1]->var_.size();
if (src_info_list[0]->var_.size() > src_info_list[1]->var_.size()) {
src_var_list_size = src_info_list[0]->var_.size();
}
CHECK(src_var_list_size > 0) << "Error: src can not be a scalar.";
if (src_var_list_size - dst_info_list[0]->var_.size() == 1) {
arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION_LAST_AXIS;
} else {
LOG(FATAL) << "Error: cannot support multi-last-axis reduction.";
}
} else if (mode == "reduce_bisection") {
arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION_BISECTION;
} else {
if (mode != "reduction" && mode != "broadcast") {
FillLastDim(dst_info_list, src_info_list, for_info);
}
// vmax/vmin can't expand mask because it may introduce dirty data
bool can_expand_mask = intrin_name != "vmax" && intrin_name != "vmin";
BinaryVecInsnArgsCalculator args_calculator =
BinaryVecInsnArgsCalculator(dst_info_list, src_info_list, for_info, mode, intrin_name, can_expand_mask);
PatternResult params = args_calculator.GetInsnArgs();
arg_info = params.arg_info;
dst_info_list = params.dst_info_list;
src_info_list = params.src_info_list;
for_info = params.for_info;
if (mode == "broadcast") {
bool has_last_axis = false;
if ((arg_info->body_arg_info_.defined() && arg_info->body_arg_info_->last_axis_info_.src_index_ != -1) ||
(arg_info->tail_arg_info_.defined() && arg_info->tail_arg_info_->last_axis_info_.src_index_ != -1)) {
has_last_axis = true;
}
if (has_last_axis && (intrin_name == "vadd" || intrin_name == "vmul")) {
Array<NodeRef> stores;
Array<NodeRef> loads;
GetStoreAndLoads(stmt, stores, loads);
intrin_name = intrin_name + "s";
if (arg_info->body_arg_info_.defined()) {
arg_info.GetNode()->body_arg_info_.GetNode()->last_axis_info_.intrin_name_ = intrin_name;
arg_info.GetNode()->body_arg_info_.GetNode()->last_axis_info_.src_op_ =
Downcast<Expr>(loads[arg_info->body_arg_info_->last_axis_info_.src_index_]);
}
}
}
}
return arg_info;
}
} // namespace akg
\ No newline at end of file
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef EMIT_INSN_ARGS_CALCULATOR_H_
#define EMIT_INSN_ARGS_CALCULATOR_H_
namespace akg {
struct InsnArg {
int dst_m0{1};
int dst_m1{0};
std::vector<Expr> src_m0_list;
std::vector<Expr> src_m1_list;
int repeat{1};
int block_len{1};
int block_num{1};
int body_num{1};
int tail_len{0};
int dst_tail_offset{0};
std::vector<Expr> src_tail_offset_list;
};
struct Meta {
int block_size{0};
int src_block_size{0};
int dst_block_size{0};
int block_offset{0};
const float vec_rate{0.6};
Type src_dtype;
Type dst_dtype;
Type dtype;
bool cast{false};
bool tail{false};
bool scalar{false};
bool liner{false};
bool same_dst_src{false};
};
enum SplitStat { SUCCESS, NO_SPLIT, TAIL };
class InsnAxis {
public:
InsnAxis() = default;
InsnAxis(const For *for_stmt, const Array<StmtStoreInfo> &info_list);
virtual ~InsnAxis() = default;
bool IsValid();
void Print(const std::string &name = "");
int min{0};
int extent{0};
Var var;
int dst_stride{0};
int src_stride{0};
std::vector<int> src_stride_list;
std::vector<int> stride_list;
bool is_valid{true};
private:
Expr GetStrideByAxis(const Array<Var> &vars, const Array<Expr> &strides, Var obj_var);
};
using AxisIt = std::list<InsnAxis>::iterator;
std::list<InsnAxis> GetAxisList(const StmtInfo &for_info, const Array<StmtStoreInfo> &info_list);
Array<StmtStoreInfo> GetInfoList(const StmtStoreInfo &dst_info, const Array<StmtStoreInfo> &src_info_list);
int DivFloor(int a, int b);
void Print(std::list<InsnAxis> &axis_list);
class InsnArgsCalculator {
public:
InsnArgsCalculator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, const StmtInfo &for_info,
const std::string &intrin_name);
virtual ~InsnArgsCalculator() = default;
PatternResult ExportResult();
void CalAxis();
void InitArg();
virtual std::function<bool(const InsnAxis &)> GetStrideLambda();
virtual std::function<bool(const InsnAxis &)> GetM0LimitLambda();
virtual std::function<bool(const InsnAxis &)> GetM1LimitLambda();
std::function<bool(const InsnAxis &)> GetBlockStrideLimitLambda();
AxisIt GetAxisByLambda(const std::function<bool(const InsnAxis &)> &lambda);
InsnAxis ExtractAxis(AxisIt &it);
bool IsValid(AxisIt &it);
AxisIt GetVecAxisIt();
AxisIt GetBlockAxis();
AxisIt GetRepeatAxisIt();
InsnAxis GetRepeatAxis();
void SetArgMask(int len);
void SetArgBlockNum(int data_num);
void SetArgBlockLen(int data_len);
void SetArgM0(int dst_m0, int lsrc_m0, int rsrc_m0);
void SetArgM1(int dst_m1, int lsrc_m1, int rsrc_m1);
void SetArgRepeat(int repeat);
void BlockAxisReduction();
void RepeatAxisReduction();
void CastCaseReduction();
virtual void InsnReduction();
StmtInfo ExportForInfo();
Expr GetOffset(int stride_index);
InsnAxis GetInvalidAxis();
SplitStat SplitAxis(int extent, InsnAxis &axis);
std::list<InsnAxis> axis_list_;
protected:
InsnArg arg_;
Meta meta_;
StmtInfoList dst_info_list_;
StmtInfoList src_info_list_;
StmtStoreInfo dst_info_;
StmtInfo for_info_;
const std::string intrin_name_;
const int max_block_stride_{4};
};
class SingleVecInsnArgsCalculator : public InsnArgsCalculator {
public:
SingleVecInsnArgsCalculator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, const StmtInfo &for_info,
const std::string &intrin_name = "");
virtual ~SingleVecInsnArgsCalculator() override = default;
PatternResult GetInsnArgs();
};
class BinaryVecInsnArgsCalculator : public InsnArgsCalculator {
public:
BinaryVecInsnArgsCalculator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, const StmtInfo &for_info,
const std::string &mode, const std::string &intrin_name = "", bool expand_mask = true);
virtual ~BinaryVecInsnArgsCalculator() override = default;
PatternResult GetInsnArgs();
std::function<bool(const InsnAxis &)> GetM0LimitLambda();
std::function<bool(const InsnAxis &)> GetM1LimitLambda();
void InsnReduction();
private:
std::string mode_;
bool expand_mask_;
InsnAxis vec_axis_;
};
class LastAxisReduceInsnArgsCalculator : InsnArgsCalculator{
public:
LastAxisReduceInsnArgsCalculator(const StmtStoreInfo &dst_info, const StmtStoreInfo &src_info, const StmtInfo &for_info,
const std::string &intrin_name)
: InsnArgsCalculator({dst_info}, {src_info}, for_info, intrin_name),
dst_info(dst_info),
src_info(src_info),
for_info(for_info),
arg_info(ArgInfo(make_node<ArgInfoNode>())),
body_args(VectorArgInfo()),
tail_args(VectorArgInfo()),
intrin_name(intrin_name) {}
PatternResult GetInsnArgs();
~LastAxisReduceInsnArgsCalculator() = default;
protected:
Array<Var> GetPattern();
PatternResult GenResult(const Array<Var> &elim_var);
private:
void CalcParams();
struct Params {
Array<Var> src_var;
int block_size = 0;
int vec_max_len = 0;
int last_dim_shape = 0;
Expr insn_offset_scale_factor;
};
StmtStoreInfo dst_info;
StmtStoreInfo src_info;
StmtInfo for_info;
ArgInfo arg_info;
VectorArgInfo body_args;
VectorArgInfo tail_args;
Array<VectorArgInfo> mix_vec_arg_list;
std::string intrin_name;
Params params;
};
BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_info_list,
const StmtInfoList &src_info_list, const StmtInfo &for_info,
StmtInfo &if_info, bool last_axis, int postfix);
ArgInfo GetBinaryVecInsnArgs(const Stmt &stmt, std::string intrin_name, StmtInfoList &dst_info_list,
StmtInfoList &src_info_list, StmtInfo &if_info, StmtInfo &for_info,
bool enable_bisect = true);
} // namespace akg
#endif
\ No newline at end of file
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/base.h>
#include <tvm/ir_pass.h>
#include <set>
#include "ir_pass.h"
#include "contrib/cce_parm/cceconf.h"
#include "tvm.h"
#include "common/array_api.h"
#include "insn_pattern.h"
#include "insn_builder.h"
namespace akg {
std::string GetBinaryVecMode(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list,
const std::string &intrin_name, bool enable_bisect = true) {
std::set<std::string> reduce_bisect_list = {"vadd", "vsub", "vmul", "vmax"};
std::string mode = "reduction";
if (IsElementwise(dst_info_list, src_info_list)) {
mode = "elewise";
} else if (IsBroadcast(dst_info_list, src_info_list)) {
mode = "broadcast";
} else if (IsLastAxisReduction(dst_info_list, src_info_list)) {
mode = "reduce_last_axis";
} else if (enable_bisect && reduce_bisect_list.count(intrin_name) != 0 &&
IsBisectionReduction(dst_info_list, src_info_list)) {
mode = "reduce_bisection";
}
return mode;
}
PatternResult ReduceLastAxisPatternGenerator::GetInsnArgs() {
CalcParams();
Array<Var> elim_var;
float rate2d = Compute2DBlockPatternMaskRate();
if (rate2d > 1.0f) {
elim_var = Get2DBlockPattern();
arg_info.GetNode()->pattern_ = PATTERN_2D_BLOCK;
} else {
elim_var = Get1DPattern();
arg_info.GetNode()->pattern_ = PATTERN_1D;
}
return GenResult(elim_var);
}
float ReduceLastAxisPatternGenerator::Compute2DBlockPatternMaskRate() {
const float is2_dpattern = 1.0f;
if (intrin_name == "vadd" || intrin_name == "argmax" || intrin_name == "argmin") {
return not_this_pattern;
}
// src var size must larger than 2
if (params.src_var.size() < 2) {
return not_this_pattern;
}
int body_len = params.last_dim_shape / params.vec_max_len * params.vec_max_len;
int tail_len = params.last_dim_shape % params.vec_max_len;
// there is no body in this mode
if (body_len > 0 || tail_len > params.block_size) {
return not_this_pattern;
}
return is2_dpattern;
}
Array<Var> ReduceLastAxisPatternGenerator::Get2DBlockPattern() {
int sec_last_dim_shape = GetInt32Const(GetItem(src_info->shape_, -2));
int body_len = sec_last_dim_shape / FULL_BLOCK_NUM * FULL_BLOCK_NUM;
int tail_len = sec_last_dim_shape % FULL_BLOCK_NUM;
int cmd_body_len = 0;
if (body_len > 0) {
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
body_args.GetNode()->body_num_ = 1;
body_args.GetNode()->repeat_ = Expr(body_len / FULL_BLOCK_NUM);
// Here use dst_stride_m1 as dst_stride
body_args.GetNode()->dst_stride_m1_ = Expr(1);
body_args.GetNode()->src_stride_m0_list_ = {Expr(1)};
body_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)};
body_args.GetNode()->vec_mask_ = GetVecMask(params.last_dim_shape, FULL_BLOCK_NUM, dst_info->dtype_);
cmd_body_len += GetInt32Const(body_args->repeat_) * FULL_BLOCK_NUM;
}
if (tail_len > 0) {
tail_args = VectorArgInfo(make_node<VectorArgInfoNode>());
tail_args.GetNode()->dst_head_ = Expr(cmd_body_len);
tail_args.GetNode()->src_head_list_ = {Expr(cmd_body_len / FULL_BLOCK_NUM * params.vec_max_len)};
tail_args.GetNode()->repeat_ = Expr(1);
tail_args.GetNode()->dst_stride_m1_ = Expr(1);
tail_args.GetNode()->src_stride_m0_list_ = {Expr(1)};
tail_args.GetNode()->src_stride_m1_list_ = {Expr(0)};
tail_args.GetNode()->vec_mask_ = GetVecMask(params.last_dim_shape, tail_len, dst_info->dtype_);
}
params.insn_offset_scale_factor = 1;
return GetRange(params.src_var, -2, 2);
}
Array<Var> ReduceLastAxisPatternGenerator::Get1DPattern() {
int body_len = params.last_dim_shape / params.vec_max_len * params.vec_max_len;
int tail_len = params.last_dim_shape % params.vec_max_len;
int cmd_body_len = 0;
bool is_vadd = intrin_name == "vadd";
int repeat_stride = FULL_BLOCK_NUM;
if (is_vadd) {
repeat_stride = 1;
}
const int fp16_block_size = 16;
if (body_len > 0) {
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
body_args.GetNode()->body_num_ = 1;
body_args.GetNode()->body_offset_ = params.vec_max_len;
body_args.GetNode()->repeat_ = Expr(body_len / params.vec_max_len);
// Here use dst_stride_m1 as dst_stride
body_args.GetNode()->dst_stride_m1_ = Expr(1);
body_args.GetNode()->src_stride_m0_list_ = {Expr(1)};
body_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)};
body_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, 1, dst_info->dtype_);
cmd_body_len += GetInt32Const(body_args->repeat_) * repeat_stride;
}
if (tail_len > 0) {
tail_args = VectorArgInfo(make_node<VectorArgInfoNode>());
tail_args.GetNode()->body_offset_ = params.vec_max_len;
tail_args.GetNode()->dst_head_ = Expr(cmd_body_len);
tail_args.GetNode()->src_head_list_ = {Expr(body_len)};
tail_args.GetNode()->repeat_ = Expr(1);
tail_args.GetNode()->dst_stride_m1_ = Expr(1);
tail_args.GetNode()->src_stride_m0_list_ = {Expr(1)};
tail_args.GetNode()->src_stride_m1_list_ = {Expr(0)};
tail_args.GetNode()->vec_mask_ = GetVecMask(tail_len, 1, dst_info->dtype_);
if (is_vadd) {
cmd_body_len += 1;
} else {
cmd_body_len += tail_len / fp16_block_size;
if (tail_len % fp16_block_size != 0) {
cmd_body_len += 1;
}
}
}
// cmd_body_len > 1 means vcadd size greater than 128, need to use vcadd again to compute final result
// if cmd_body_len > 128, then need to recursively emit vcadd
while (cmd_body_len > 1) {
int cmd_tail_len = cmd_body_len % params.vec_max_len;
cmd_body_len = cmd_body_len / params.vec_max_len;
if (cmd_body_len > 0) {
VectorArgInfo mix_vec_args = VectorArgInfo(make_node<VectorArgInfoNode>());
mix_vec_args.GetNode()->repeat_ = Expr(cmd_body_len);
mix_vec_args.GetNode()->dst_head_ = Expr(0);
mix_vec_args.GetNode()->src_head_list_ = {Expr(0)};
mix_vec_args.GetNode()->dst_stride_m1_ = Expr(1);
mix_vec_args.GetNode()->src_stride_m0_list_ = {Expr(1)};
mix_vec_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)};
mix_vec_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, 1, dst_info->dtype_);
mix_vec_arg_list.push_back(mix_vec_args);
if (!is_vadd) {
cmd_body_len *= FULL_BLOCK_NUM;
}
}
if (cmd_tail_len > 0) {
VectorArgInfo mix_vec_args = VectorArgInfo(make_node<VectorArgInfoNode>());
mix_vec_args.GetNode()->repeat_ = Expr(1);
mix_vec_args.GetNode()->dst_head_ = Expr(cmd_body_len);
if (is_vadd) {
mix_vec_args.GetNode()->src_head_list_ = {Expr(cmd_body_len * params.vec_max_len)};
} else {
mix_vec_args.GetNode()->src_head_list_ = {Expr(cmd_body_len / FULL_BLOCK_NUM * params.vec_max_len)};
}
mix_vec_args.GetNode()->dst_stride_m1_ = Expr(1);
mix_vec_args.GetNode()->src_stride_m0_list_ = {Expr(1)};
mix_vec_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM)};
mix_vec_args.GetNode()->vec_mask_ = GetVecMask(cmd_tail_len, 1, dst_info->dtype_);
if (is_vadd) {
cmd_body_len += 1;
} else {
cmd_body_len += cmd_tail_len / fp16_block_size;
if (cmd_tail_len % fp16_block_size != 0) {
cmd_body_len += 1;
}
}
mix_vec_arg_list.push_back(mix_vec_args);
}
}
params.insn_offset_scale_factor = Expr(params.block_size);
int max_num = body_len / params.vec_max_len;
if (intrin_name == "vmax" || intrin_name == "vmin") {
max_num *= FULL_BLOCK_NUM;
}
if (max_num >= params.block_size) {
params.insn_offset_scale_factor = max_num + params.block_size - 1;
if (tail_len > 0) {
params.insn_offset_scale_factor += 1;
}
params.insn_offset_scale_factor = truncdiv(params.insn_offset_scale_factor, params.block_size) * params.block_size;
}
if (!params.src_var.empty()) {
return GetRange(params.src_var, -1, 1);
}
return {};
}
PatternResult ReduceLastAxisPatternGenerator::GenResult(const Array<Var> &elim_var) {
dst_info.GetNode()->insn_offset_ = GetInsnOffset(dst_info, elim_var) * params.insn_offset_scale_factor;
src_info.GetNode()->insn_offset_ = GetInsnOffset(src_info, elim_var);
if (body_args.defined()) {
body_args.GetNode()->insn_offset_scale_factor_ = params.insn_offset_scale_factor;
}
if (tail_args.defined()) {
tail_args.GetNode()->insn_offset_scale_factor_ = params.insn_offset_scale_factor;
}
for (auto &arg : mix_vec_arg_list) {
arg.GetNode()->insn_offset_scale_factor_ = params.insn_offset_scale_factor;
}
arg_info.GetNode()->body_arg_info_ = body_args;
arg_info.GetNode()->tail_arg_info_ = tail_args;
arg_info.GetNode()->reduction_tail_args_ = mix_vec_arg_list;
CleanForInfoVars(for_info, elim_var);
arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION_LAST_AXIS;
PatternResult result;
result.dst_info_list = {dst_info};
result.src_info_list = {src_info};
result.for_info = for_info;
result.arg_info = arg_info;
return result;
}
void ReduceLastAxisPatternGenerator::CalcParams() {
// check shape len
if (dst_info->shape_.empty() || src_info->shape_.empty()) {
LOG(FATAL) << "CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0.";
}
// check data type
if (dst_info->dtype_ != src_info->dtype_) {
LOG(FATAL) << "CCE Vector Insn Error: dst_buffer and src_buffer can not be different data type.";
}
params.src_var = src_info->var_;
params.block_size = GetUbBlkSize(dst_info->dtype_);
params.last_dim_shape = GetInt32Const(GetItem(src_info->shape_, -1));
params.vec_max_len = GetVecMaxLen(dst_info->dtype_);
CHECK_NE(params.block_size, 0);
CHECK_NE(params.vec_max_len, 0);
}
/// Get CCE Binary Vector instructions args
/// \return
PatternResult BinaryVecPatternGenerator::GetInsnArgs() {
CalcParams();
if (arg_info->arg_type_ == ARG_VECTOR_BROADCAST_LAST_AXIS) {
PatternResult result;
result.dst_info_list = {dst_info};
result.src_info_list = src_info_list;
result.for_info = for_info;
result.arg_info = arg_info;
return result;
}
Array<Var> elim_var = {};
float rate3d = Compute3DPatternMaskRate();
float rate2db = Compute2DBlockPatternMaskRate();
float rate2d = Compute2DPatternMaskRate();
float rate1d = Compute1DPatternMaskRate();
if (rate3d >= rate2db && rate3d > 0) {
elim_var = Get3DPattern();
arg_info.GetNode()->pattern_ = PATTERN_3D;
} else if (rate2db >= rate2d && rate2db >= rate1d && rate2db > 0) {
elim_var = Get2DBlockPattern();
arg_info.GetNode()->pattern_ = PATTERN_PARTIAL_3D;
} else if (rate2d > rate1d && rate2d > 0) {
elim_var = Get2DPattern();
arg_info.GetNode()->pattern_ = PATTERN_2D;
} else if (rate1d > 0) {
elim_var = Get1DPattern();
arg_info.GetNode()->pattern_ = PATTERN_1D;
} else {
LOG(FATAL) << "Error: Cannot emit Binary-Vector-Insn with any pattern!";
}
std::string mask_rate = "rate3d[" + std::to_string(rate3d) + "], rate2db[" + std::to_string(rate2db) + "], rate2d[" +
std::to_string(rate2d) + "], rate1d[" + std::to_string(rate1d) + "]";
CommentManager::GetInstance().AddComment("Mask_rate", mask_rate);
if (tail_args.defined()) {
CommentManager::GetInstance().AddComment("Contain_tail", "true");
} else {
CommentManager::GetInstance().AddComment("Contain_tail", "false");
}
return GenResult(elim_var);
}
float BinaryVecPatternGenerator::Compute3DPatternMaskRate() {
if (params.non_zero_shape3 == 1 || params.non_zero_shape2 == 1) {
return not_this_pattern;
}
// in elemwise mode, the var is already checked to be equal, no need to check
if (params.dst_var.size() < 3 || GetIntConst(GetItem(params.dst_shape, -1)) > params.block_size ||
GetIntConst(GetItem(params.dst_strides, -2)) % params.block_size != 0 ||
GetIntConst(GetItem(params.dst_strides, -3)) % params.block_size != 0 ||
(GetIntConst(GetItem(params.dst_strides, -2)) > 0 && GetIntConst(GetItem(params.dst_shape, -1)) > 0 &&
GetIntConst(GetItem(params.dst_strides, -2)) < GetIntConst(GetItem(params.dst_shape, -1))) ||
(GetIntConst(GetItem(params.dst_strides, -3)) > 0 && GetIntConst(GetItem(params.dst_shape, -2)) > 0 &&
GetIntConst(GetItem(params.dst_strides, -3)) < GetIntConst(GetItem(params.dst_shape, -2)))) {
return not_this_pattern;
}
// check dst_stride_m0
// As described in ISL User Guide t6.3,
// dst_stride_m0 = 0 is treated as 1
auto JudgeNot3D = [this](const StmtStoreInfo &info) {
auto last_shape1 = GetIntConst(GetItem(info->shape_, -1));
if (info->var_.size() < 3 || last_shape1 > params.block_size) {
return true;
}
auto last_shape2 = GetIntConst(GetItem(info->shape_, -2));
auto last_stride2 = GetIntConst(GetItem(info->strides_, -2));
auto last_stride3 = GetIntConst(GetItem(info->strides_, -3));
return last_stride2 % params.block_size != 0 || last_stride3 % params.block_size != 0 ||
(last_stride2 > 0 && last_shape1 > 0 && last_stride2 < last_shape1) ||
(last_stride3 > 0 && last_shape2 > 0 && last_stride3 < last_shape2);
};
if (std::any_of(src_info_list.begin(), src_info_list.end(), JudgeNot3D)) {
return not_this_pattern;
}
if (mode == "reduction") {
// check same alignment
Array<Expr> shape_list = {GetItem(params.dst_shape, -1)};
shape_list.push_back(GetItem(params.src_shape0, -1));
shape_list.push_back(GetItem(params.src_shape1, -1));
if (!IsNonZeroShapeEqual(shape_list)) {
return not_this_pattern;
}
}
// repeat axis is shape [-3], repeat once, has 8 loops
bool is3_d = true;
float rate3d_mode1 = not_this_pattern;
float rate3d_mode2 = not_this_pattern;
int repeat_num;
float repeat_latency;
auto info_list = src_info_list;
Insert(info_list, 0, dst_info);
for (auto info : info_list) {
if (GetInt32Const(GetItem(info->shape_, -2)) > FULL_BLOCK_NUM ||
GetInt32Const(GetItem(info->strides_, -2)) / params.block_size >= MAX_STRIDE_M0 ||
GetInt32Const(GetItem(info->strides_, -3)) / params.block_size >= MAX_STRIDE_M0) {
is3_d = false;
break;
}
}
if (is3_d) {
if (GetInt32Const(GetItem(params.dst_strides, -2)) == 0) {
return not_this_pattern;
}
repeat_num = params.non_zero_shape3;
repeat_latency = ((repeat_num - 1) / MAX_REPEAT) * repeat_latency_coef;
rate3d_mode1 = static_cast<float>(params.all_points) / params.vec_max_len / (repeat_num + repeat_latency);
}
is3_d = true;
// repeat axis is shape[-2]
for (auto info : info_list) {
// stride_m0 should be less than 256
if (GetIntConst(GetItem(info->shape_, -3)) % FULL_BLOCK_NUM != 0 ||
GetIntConst(GetItem(info->strides_, -3)) / params.block_size >= MAX_STRIDE_M0) {
is3_d = false;
break;
}
}
if (is3_d) {
if (GetIntConst(GetItem(params.dst_strides, -3)) == 0) {
return not_this_pattern;
}
repeat_num = params.non_zero_shape2 * (params.non_zero_shape3 / FULL_BLOCK_NUM);
repeat_latency = ((repeat_num - 1) / MAX_REPEAT) * repeat_latency_coef;
float offset_latency =
params.non_zero_shape3 / FULL_BLOCK_NUM > 1 ? params.non_zero_shape3 * offset_latency_coef : 0;
rate3d_mode2 =
static_cast<float>(params.all_points) / params.vec_max_len / (repeat_num + repeat_latency + offset_latency);
}
return rate3d_mode1 > rate3d_mode2 ? rate3d_mode1 : rate3d_mode2;
}
float BinaryVecPatternGenerator::Compute2DBlockPatternMaskRate() {
if (params.non_zero_shape2 == 1 || GetIntConst(GetItem(params.dst_strides, -1)) != 1) {
return not_this_pattern;
}
if (params.dst_var.size() < 2 || GetIntConst(GetItem(params.dst_shape, -1)) > params.block_size ||
GetIntConst(GetItem(params.dst_strides, -2)) % params.block_size != 0 ||
GetIntConst(GetItem(params.dst_strides, -2)) / params.block_size >= MAX_STRIDE_M0 ||
(GetIntConst(GetItem(params.dst_strides, -2)) > 0 && GetIntConst(GetItem(params.dst_shape, -1)) > 0 &&
GetIntConst(GetItem(params.dst_strides, -2)) < GetIntConst(GetItem(params.dst_shape, -1)))) {
return not_this_pattern;
}
for (auto info : src_info_list) {
if (info->var_.size() < 2 || GetIntConst(GetItem(info->shape_, -1)) > params.block_size ||
GetIntConst(GetItem(info->strides_, -2)) % params.block_size != 0 ||
GetIntConst(GetItem(info->strides_, -2)) / params.block_size >= MAX_STRIDE_M0 ||
(GetIntConst(GetItem(info->strides_, -2)) > 0 && GetIntConst(GetItem(info->shape_, -1)) > 0 &&
GetIntConst(GetItem(info->strides_, -2)) < GetIntConst(GetItem(info->shape_, -1)))) {
return not_this_pattern;
}
}
if (GetIntConst(GetItem(params.dst_strides, -2)) == 0) {
return not_this_pattern;
}
if (mode == "reduction") {
if (params.dst_var.size() > 2) {
// if not elewise mode, then can not use partial 3D mode
if (GetIntConst(GetItem(params.dst_shape, -3)) == 0) {
return not_this_pattern;
}
for (auto info : src_info_list) {
if (GetIntConst(GetItem(info->shape_, -3)) == 0) {
return not_this_pattern;
}
}
}
// check same alignment
Array<Expr> shape_list = {GetItem(params.dst_shape, -1)};
shape_list.push_back(GetItem(params.src_shape0, -1));
shape_list.push_back(GetItem(params.src_shape1, -1));
// check dst_stride_m0
// As described in ISL User Guide t6.3,
// dst_stride_m0 = 0 is treated as 1
if (!IsNonZeroShapeEqual(shape_list)) {
return not_this_pattern;
}
}
int repeat_body_num = params.non_zero_shape2 / FULL_BLOCK_NUM;
int repeat_tail_num = (params.non_zero_shape2 % FULL_BLOCK_NUM + FULL_BLOCK_NUM - 1) / FULL_BLOCK_NUM;
int repeat_num = (repeat_body_num + repeat_tail_num) * params.non_zero_shape3;
float repeat_latency =
(std::max(repeat_body_num - 1, 0) / MAX_REPEAT + std::max(repeat_tail_num - 1, 0) / MAX_REPEAT) *
repeat_latency_coef;
float offset_latency = params.non_zero_shape3 > 1 ? params.non_zero_shape3 * offset_latency_coef : 0;
float split_latency = (repeat_body_num > 0 && repeat_tail_num > 0) ? split_latency_coef : 0;
float rate2db = static_cast<float>(params.all_points) / params.vec_max_len /
(repeat_num + repeat_latency + offset_latency + split_latency);
return rate2db;
}
float BinaryVecPatternGenerator::Compute2DPatternMaskRate() {
if (params.non_zero_shape2 == 1) {
return not_this_pattern;
}
if (params.dst_var.size() < 2 || GetIntConst(GetItem(params.dst_strides, -2)) % params.block_size != 0 ||
(GetIntConst(GetItem(params.dst_strides, -2)) < GetIntConst(GetItem(params.dst_shape, -1)) &&
GetIntConst(GetItem(params.dst_strides, -2) > 0))) {
return not_this_pattern;
}
for (auto info : src_info_list) {
if (info->var_.size() < 2 || GetIntConst(GetItem(info->strides_, -2)) % params.block_size != 0 ||
(GetIntConst(GetItem(info->strides_, -2)) < GetIntConst(GetItem(info->shape_, -1)) &&
GetIntConst(GetItem(info->strides_, -2) > 0))) {
return not_this_pattern;
}
}
// check num of insns, select 1D pattern or 2D pattern
int tail_factor = 0;
if (mode == "reduction") {
Array<Expr> shape_list = {GetItem(params.dst_shape, -1)};
shape_list.push_back(GetItem(params.src_shape0, -1));
shape_list.push_back(GetItem(params.src_shape1, -1));
if (!IsNonZeroShapeEqual(shape_list)) {
return not_this_pattern;
}
}
// only cloud allow dst_stride_m1 = 0
cceconf::CceConf *conf = cceconf::CceConf::getInstance();
const std::string product_name = conf->getProductName();
if (GetIntConst(GetItem(params.dst_strides, -2)) == 0 && product_name != "cloud") {
return not_this_pattern;
}
CHECK_NE(params.vec_max_len, 0);
if (params.non_zero_shape1 / params.vec_max_len > 0 && params.non_zero_shape1 % params.vec_max_len > 0) {
tail_factor = 1;
}
if (GetIntConst(GetItem(dst_info->strides_, -2)) / params.block_size >= MAX_STRIDE_M0) {
return not_this_pattern;
}
for (auto info : src_info_list) {
if (GetIntConst(GetItem(info->strides_, -2)) / params.block_size >= MAX_STRIDE_M0) {
return not_this_pattern;
}
}
int shape1 = (params.non_zero_shape1 + params.vec_max_len - 1) / params.vec_max_len;
int repeat_num = shape1 * params.non_zero_shape2 * params.non_zero_shape3;
float repeat_latency =
(std::max(params.non_zero_shape2 - 1, 0) / MAX_REPEAT) * params.non_zero_shape3 * shape1 * repeat_latency_coef;
float offset_latency =
shape1 * params.non_zero_shape3 > 1 ? shape1 * params.non_zero_shape3 * offset_latency_coef : 0;
float split_latency = tail_factor * split_latency_coef;
float rate2d = static_cast<float>(params.all_points) / params.vec_max_len /
(repeat_num + repeat_latency + offset_latency + split_latency);
return rate2d;
}
float BinaryVecPatternGenerator::Compute1DPatternMaskRate() {
int tail_factor = 0;
if (params.non_zero_shape1 / params.vec_max_len > 0 && params.non_zero_shape1 % params.vec_max_len > 0) {
tail_factor = 1;
}
int shape1 = (params.non_zero_shape1 + params.vec_max_len - 1) / params.vec_max_len;
int repeat_num = shape1 * params.non_zero_shape2 * params.non_zero_shape3;
float repeat_latency =
std::max((shape1 - 1) / MAX_REPEAT, 0) * params.non_zero_shape2 * params.non_zero_shape3 * repeat_latency_coef;
float offset_latency = params.non_zero_shape2 * params.non_zero_shape3 > 1
? params.non_zero_shape2 * params.non_zero_shape3 * offset_latency_coef
: 0;
float split_latency = tail_factor * split_latency_coef;
float rate1d = static_cast<float>(params.all_points) / params.vec_max_len /
(repeat_num + repeat_latency + offset_latency + split_latency);
return rate1d;
}
Array<Var> BinaryVecPatternGenerator::Get3DPattern() {
// repeat axis is shape [-2]
int second_last_shape = GetInt32Const(
GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape0, -2), GetItem(params.src_shape1, -2)));
int third_last_shape = GetInt32Const(
GetNonZeroShape(GetItem(params.dst_shape, -3), GetItem(params.src_shape0, -3), GetItem(params.src_shape1, -3)));
if (second_last_shape > 8) {
// split shape[-3]
if (third_last_shape > 8) {
auto info_list = src_info_list;
Insert(info_list, 0, dst_info);
SplitAxis(info_list, for_info, GetItem(params.dst_var, -3), FULL_BLOCK_NUM);
FillEmptyVar(info_list);
params.dst_var = info_list[0]->var_;
params.dst_shape = info_list[0]->shape_;
params.dst_strides = info_list[0]->strides_;
params.src_var0 = info_list[1]->var_;
params.src_shape0 = info_list[1]->shape_;
params.src_strides0 = info_list[1]->strides_;
params.src_var1 = info_list[2]->var_;
params.src_shape1 = info_list[2]->shape_;
params.src_strides1 = info_list[2]->strides_;
}
// consider original shape[-2] as repeat axis
GetShapeInfoAndSwap(params.dst_var, params.dst_shape, params.dst_strides, -2, -3);
GetShapeInfoAndSwap(params.src_var0, params.src_shape0, params.src_strides0, -2, -3);
GetShapeInfoAndSwap(params.src_var1, params.src_shape1, params.src_strides1, -2, -3);
}
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
CHECK(body_args.GetNode());
body_args.GetNode()->body_num_ = 1;
body_args.GetNode()->repeat_ = GetItem(params.dst_shape, -3);
body_args.GetNode()->dst_stride_m0_ = truncdiv(GetItem(params.dst_strides, -2), params.block_size);
body_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -3), params.block_size);
body_args.GetNode()->src_stride_m0_list_ = {truncdiv(GetItem(params.src_strides0, -2), params.block_size),
truncdiv(GetItem(params.src_strides1, -2), params.block_size)};
body_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides0, -3), params.block_size),
truncdiv(GetItem(params.src_strides1, -3), params.block_size)};
int data_num = GetInt32Const(GetItem(params.dst_shape, -2));
if (mode == "reduction") {
body_args.GetNode()->repeat_ = Expr(
GetNonZeroShape(GetItem(params.dst_shape, -3), GetItem(params.src_shape0, -3), GetItem(params.src_shape1, -3)));
data_num =
GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape0, -2), GetItem(params.src_shape1, -2));
}
int data_len = expand_mask ? CeilTo(params.last_dim_shape, params.block_size) : params.last_dim_shape;
body_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, dst_info->dtype_);
return GetRange(params.dst_var, -3, 3);
}
Array<Var> BinaryVecPatternGenerator::Get2DBlockPattern() {
int repeat_len = GetInt32Const(GetItem(params.dst_shape, -2));
if (mode == "reduction") {
params.last_dim_shape =
GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape0, -1), GetItem(params.src_shape1, -1));
repeat_len =
GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape0, -2), GetItem(params.src_shape1, -2));
}
int repeat_body = repeat_len / FULL_BLOCK_NUM;
int repeat_tail = (repeat_len % FULL_BLOCK_NUM + FULL_BLOCK_NUM - 1) / FULL_BLOCK_NUM;
if (repeat_body > 0) {
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
CHECK(body_args.GetNode() != nullptr);
body_args.GetNode()->body_num_ = 1;
body_args.GetNode()->repeat_ = Expr(repeat_body);
body_args.GetNode()->dst_stride_m0_ = truncdiv(GetItem(params.dst_strides, -2), params.block_size);
body_args.GetNode()->dst_stride_m1_ = body_args->dst_stride_m0_ * FULL_BLOCK_NUM;
Expr src0_stride_m0 = truncdiv(GetItem(params.src_strides0, -2), params.block_size);
Expr src1_stride_m0 = truncdiv(GetItem(params.src_strides1, -2), params.block_size);
body_args.GetNode()->src_stride_m0_list_ = {src0_stride_m0, src1_stride_m0};
body_args.GetNode()->src_stride_m1_list_ = {src0_stride_m0 * FULL_BLOCK_NUM, src1_stride_m0 * FULL_BLOCK_NUM};
int data_len = expand_mask ? CeilTo(params.last_dim_shape, params.block_size) : params.last_dim_shape;
body_args.GetNode()->vec_mask_ = GetVecMask(data_len, FULL_BLOCK_NUM, dst_info->dtype_);
}
if (repeat_tail > 0) {
tail_args = VectorArgInfo(make_node<VectorArgInfoNode>());
CHECK(tail_args.GetNode() != nullptr);
tail_args.GetNode()->dst_head_ = GetItem(params.dst_strides, -2) * repeat_body * FULL_BLOCK_NUM;
tail_args.GetNode()->src_head_list_ = {GetItem(params.src_strides0, -2) * repeat_body * FULL_BLOCK_NUM,
GetItem(params.src_strides1, -2) * repeat_body * FULL_BLOCK_NUM};
tail_args.GetNode()->repeat_ = Expr(1);
tail_args.GetNode()->dst_stride_m0_ = truncdiv(GetItem(params.dst_strides, -2), params.block_size);
tail_args.GetNode()->dst_stride_m1_ = Expr(0);
tail_args.GetNode()->src_stride_m0_list_ = {truncdiv(GetItem(params.src_strides0, -2), params.block_size),
truncdiv(GetItem(params.src_strides1, -2), params.block_size)};
tail_args.GetNode()->src_stride_m1_list_ = {Expr(0), Expr(0)};
int data_len = expand_mask ? CeilTo(params.last_dim_shape, params.block_size) : params.last_dim_shape;
tail_args.GetNode()->vec_mask_ = GetVecMask(data_len, repeat_len % FULL_BLOCK_NUM, dst_info->dtype_);
}
return GetRange(params.dst_var, -2, 2);
}
Array<Var> BinaryVecPatternGenerator::Get2DPattern() {
if (mode == "reduction") {
params.last_dim_shape =
GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape0, -1), GetItem(params.src_shape1, -1));
}
int body_len = FloorTo(params.last_dim_shape, params.vec_max_len);
int tail_len = params.last_dim_shape % params.vec_max_len;
if (body_len > 0) {
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
CHECK(body_args.GetNode() != nullptr);
body_args.GetNode()->body_num_ = body_len / params.vec_max_len;
body_args.GetNode()->body_offset_ = params.vec_max_len;
body_args.GetNode()->repeat_ = GetItem(params.dst_shape, -2);
if (mode == "reduction") {
body_args.GetNode()->repeat_ =
GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape0, -2), GetItem(params.src_shape1, -2));
}
body_args.GetNode()->dst_stride_m0_ = Expr(1);
body_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -2), params.block_size);
body_args.GetNode()->src_stride_m0_list_ = {Expr(1), Expr(1)};
body_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides0, -2), params.block_size),
truncdiv(GetItem(params.src_strides1, -2), params.block_size)};
body_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, 1, dst_info->dtype_);
}
if (tail_len > 0) {
tail_args = VectorArgInfo(make_node<VectorArgInfoNode>());
CHECK(tail_args.GetNode() != nullptr);
tail_args.GetNode()->dst_head_ = Expr(body_len);
tail_args.GetNode()->src_head_list_ = {Expr(body_len), Expr(body_len)};
tail_args.GetNode()->repeat_ = GetItem(params.dst_shape, -2);
if (mode == "reduction") {
tail_args.GetNode()->repeat_ =
GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape0, -2), GetItem(params.src_shape1, -2));
}
tail_args.GetNode()->dst_stride_m0_ = Expr(1);
tail_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -2), params.block_size);
tail_args.GetNode()->src_stride_m0_list_ = {Expr(1), Expr(1)};
tail_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides0, -2), params.block_size),
truncdiv(GetItem(params.src_strides1, -2), params.block_size)};
tail_args.GetNode()->vec_mask_ = GetVecMask(tail_len, 1, dst_info->dtype_);
}
return GetRange(params.dst_var, -2, 2);
}
Array<Var> BinaryVecPatternGenerator::Get1DPattern() {
auto info_list = src_info_list;
Insert(info_list, 0, dst_info);
bool is_scalar_mode = IsScalarMode(info_list);
if (is_scalar_mode) {
params.last_dim_shape = 1;
}
if (mode == "reduction") {
params.last_dim_shape =
GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape0, -1), GetItem(params.src_shape1, -1));
}
int body_len = FloorTo(params.last_dim_shape, params.vec_max_len);
int tail_len = params.last_dim_shape % params.vec_max_len;
int last_axis = -1;
if (mode == "broadcast") {
if (GetIntConst(GetItem(params.src_strides0, -1)) == 0) {
last_axis = 0;
}
if (GetIntConst(GetItem(params.src_strides1, -1)) == 0) {
last_axis = 1;
}
}
if (body_len > 0) {
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
CHECK(body_args.GetNode() != nullptr);
body_args.GetNode()->last_axis_info_.src_index_ = last_axis;
body_args.GetNode()->body_num_ = 1;
body_args.GetNode()->repeat_ = body_len / params.vec_max_len;
body_args.GetNode()->dst_stride_m0_ = Expr(1);
body_args.GetNode()->dst_stride_m1_ = Expr(FULL_BLOCK_NUM);
body_args.GetNode()->src_stride_m0_list_ = {Expr(1), Expr(1)};
body_args.GetNode()->src_stride_m1_list_ = {Expr(FULL_BLOCK_NUM), Expr(FULL_BLOCK_NUM)};
body_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, 1, dst_info->dtype_);
}
if (tail_len > 0) {
tail_args = VectorArgInfo(make_node<VectorArgInfoNode>());
CHECK(tail_args.GetNode() != nullptr);
tail_args.GetNode()->last_axis_info_.src_index_ = last_axis;
tail_args.GetNode()->dst_head_ = Expr(body_len);
tail_args.GetNode()->src_head_list_ = {Expr(body_len), Expr(body_len)};
tail_args.GetNode()->repeat_ = Expr(1);
tail_args.GetNode()->dst_stride_m0_ = Expr(1);
tail_args.GetNode()->dst_stride_m1_ = Expr(0);
tail_args.GetNode()->src_stride_m0_list_ = {Expr(1), Expr(1)};
tail_args.GetNode()->src_stride_m1_list_ = {Expr(0), Expr(0)};
int data_len = expand_mask ? CeilTo(tail_len, params.block_size) : tail_len;
tail_args.GetNode()->vec_mask_ = GetVecMask(data_len, 1, dst_info->dtype_);
}
// compute offset for cce instructions
Array<Var> elim_var = {};
if (mode == "elewise" && params.dst_var.size() >= 2 && params.dst_strides.size() >= 2 &&
params.last_dim_shape <= params.vec_max_len && for_info.ops_.size() >= 2 &&
params.last_dim_shape >= params.vec_max_len - params.block_size &&
GetIntConst(GetItem(params.dst_strides, -2)) == params.last_dim_shape) {
// in this case we can merge second last for extent to repeat
size_t idx = 0;
bool suc = GetIndexOfElement(for_info.vars_, GetItem(params.dst_var, -2), idx);
CHECK(suc);
auto latest_for = GetItem(for_info.ops_, idx).as<For>();
// there should not be if_op between for loop and compute stmt
if (latest_for && !latest_for->body->IsInstance<IfThenElse>()) {
if (!params.dst_var.empty() && !is_scalar_mode) {
if (body_args.defined()) {
// last_dim_shape = vec_max_len
body_args.GetNode()->repeat_ = body_args->repeat_ * latest_for->extent;
} else if (tail_args.defined()) {
// last_dim_shape < vec_max_len
tail_args.GetNode()->repeat_ = tail_args->repeat_ * latest_for->extent;
}
return elim_var = GetRange(params.dst_var, -2, 2);
}
}
}
if (!params.dst_var.empty() && !is_scalar_mode) {
elim_var = GetRange(params.dst_var, -1, 1);
}
return elim_var;
}
PatternResult BinaryVecPatternGenerator::GenResult(const Array<Var> &elim_var) {
arg_info.GetNode()->body_arg_info_ = body_args;
arg_info.GetNode()->tail_arg_info_ = tail_args;
auto real_elim_var = elim_var;
if (!empty_var->name_hint.empty()) {
bool need_elim = true;
for (auto e : elim_var) {
if (e->name_hint == empty_var->name_hint) {
need_elim = false;
break;
}
}
if (need_elim) {
real_elim_var.push_back(empty_var);
}
}
dst_info.GetNode()->insn_offset_ = GetInsnOffset(dst_info, real_elim_var);
for (auto &info : src_info_list) {
info.GetNode()->insn_offset_ = GetInsnOffset(info, real_elim_var);
}
CleanForInfoVars(for_info, real_elim_var);
CleanZeroStrides(dst_info);
CleanZeroStrides(src_info_list);
if (mode == "elewise") {
arg_info.GetNode()->arg_type_ = ARG_VECTOR_ELEWISE;
} else if (mode == "broadcast") {
arg_info.GetNode()->arg_type_ = ARG_VECTOR_BROADCAST;
} else if (mode == "reduction") {
arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION;
}
PatternResult result;
result.dst_info_list = {dst_info};
result.src_info_list = src_info_list;
result.for_info = for_info;
result.arg_info = arg_info;
return result;
}
void BinaryVecPatternGenerator::CalcParams() {
CHECK_GE(src_info_list.size(), 2);
StmtStoreInfo src_info0 = src_info_list[0];
StmtStoreInfo src_info1 = src_info_list[1];
StmtInfoList info_list = {dst_info, src_info0, src_info1};
// check shape len
for (auto info : info_list) {
if (info->shape_.empty()) {
LOG(FATAL) << "CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0.";
}
}
// check data type
for (auto src_info : src_info_list) {
if (dst_info->dtype_ != src_info->dtype_) {
LOG(FATAL) << "CCE Vector Insn Error: dst_buffer and src_buffer can not be different data type.";
}
}
params.last_dim_shape = GetInt32Const(GetItem(dst_info->shape_, -1));
AppendEmptyVar(info_list);
if (arg_info->arg_type_ == ARG_VECTOR_BROADCAST_LAST_AXIS) {
return;
}
if (mode == "reduction" || mode == "broadcast") {
FillEmptyVar(info_list);
}
CHECK_EQ(info_list.size(), 3);
dst_info = info_list[0];
src_info0 = info_list[1];
src_info1 = info_list[2];
params.vec_max_len = GetVecMaxLen(dst_info->dtype_);
params.block_size = GetUbBlkSize(dst_info->dtype_);
CHECK_NE(params.vec_max_len, 0);
CHECK_NE(params.block_size, 0);
params.dst_var = dst_info->var_;
params.dst_shape = dst_info->shape_;
params.dst_strides = dst_info->strides_;
params.src_var0 = src_info0->var_;
params.src_var1 = src_info1->var_;
params.src_shape0 = src_info0->shape_;
params.src_shape1 = src_info1->shape_;
params.src_strides0 = src_info0->strides_;
params.src_strides1 = src_info1->strides_;
auto GetNonZeroShapeByIdx = [this](int index) -> int {
if (index <= static_cast<int>(params.dst_var.size())) {
if (Equal(GetItem(params.dst_var, -index), GetItem(params.src_var0, -index)) &&
Equal(GetItem(params.dst_var, -index), GetItem(params.src_var1, -index))) {
return GetNonZeroShape(GetItem(params.dst_shape, -index), GetItem(params.src_shape0, -index),
GetItem(params.src_shape1, -index));
}
}
return 1;
};
params.non_zero_shape1 = GetNonZeroShapeByIdx(1);
params.non_zero_shape2 = GetNonZeroShapeByIdx(2);
params.non_zero_shape3 = GetNonZeroShapeByIdx(3);
params.all_points = params.non_zero_shape1 * params.non_zero_shape2 * params.non_zero_shape3;
}
bool BinaryVecPatternGenerator::IsSamePatternComInfo(const StmtStoreInfo &info_a, const StmtStoreInfo &info_b) {
if (IsSame(info_a->var_, info_b->var_)) {
if (info_a->shape_.size() != info_b->shape_.size()) {
return false;
}
for (size_t i = 0; i < info_a->shape_.size(); ++i) {
if (!IsTwoItemEqual(info_a->shape_, info_b->shape_, static_cast<int>(i), true)) {
return false;
}
}
if (info_a->strides_.size() != info_b->strides_.size()) {
return false;
}
for (size_t i = 0; i < info_a->strides_.size(); ++i) {
if (!IsTwoItemEqual(info_a->strides_, info_b->strides_, static_cast<int>(i), true)) {
return false;
}
}
return true;
}
return false;
}
bool BinaryVecPatternGenerator::IsNonZeroShapeEqual(const Array<Expr> &shape_list) {
Array<Expr> non_zero_list;
for (auto shape : shape_list) {
if (GetIntConst(shape) != 0) {
non_zero_list.push_back(shape);
}
}
if (non_zero_list.empty()) {
LOG(FATAL) << "Error: all shapes are equal to 0.";
}
for (auto shape : non_zero_list) {
if (GetIntConst(shape) != GetIntConst(non_zero_list[0])) {
return false;
}
}
return true;
}
void BinaryVecPatternGenerator::AppendEmptyVar(StmtInfoList &info_list) {
auto FillEmptyVarToLast = [](const StmtStoreInfo com_info, const Var &var) -> void {
com_info.GetNode()->var_.push_back(var);
com_info.GetNode()->shape_.push_back(Expr(1));
com_info.GetNode()->strides_.push_back(Expr(1));
com_info.GetNode()->index_ = com_info->index_ + GetItem(com_info->var_, -1);
};
auto src_info0 = src_info_list[0];
auto src_info1 = src_info_list[1];
if (mode == "reduction" || mode == "broadcast") {
// ISA 8.1.2, strides of Xd must be equal to Xm, [Xd = dst, Xn = src0, Xm = src1]
if (IsSamePatternComInfo(dst_info, src_info0)) {
auto tmp = src_info0;
src_info0 = src_info1;
src_info1 = tmp;
}
if (mode == "reduction") {
if (src_info0->data_alignment_ == 1) {
empty_var = Var("empty_cc");
FillEmptyVarToLast(src_info0, empty_var);
}
} else if (mode == "broadcast") {
// last dim broadcast, should use VS insn, such as vadds and vmuls
bool less_var =
!dst_info->var_.empty() && !src_info0->var_.empty() && !src_info1->var_.empty() &&
(!IsTwoItemEqual(dst_info->var_, src_info0->var_, -1) || !IsTwoItemEqual(dst_info->var_, src_info1->var_, -1));
bool null_var = src_info0->var_.empty() || src_info1->var_.empty();
if (less_var || null_var) {
arg_info.GetNode()->arg_type_ = ARG_VECTOR_BROADCAST_LAST_AXIS;
return;
} else if (dst_info->data_alignment_ == 1 && src_info0->data_alignment_ == 1) {
empty_var = Var("empty_cc");
FillEmptyVarToLast(dst_info, empty_var);
FillEmptyVarToLast(src_info0, empty_var);
FillEmptyVarToLast(src_info1, empty_var);
params.last_dim_shape = 1;
}
}
src_info_list = {src_info0, src_info1};
info_list = {dst_info, src_info0, src_info1};
}
}
/// Get CCE Binary Vector Insn Computation Info
/// \param stmt - operand stmt
/// \param intrin_name - vector intrin name
/// \param dst_info_list - dst computation info list
/// \param src_info_list - src computation info list
/// \param if_info - if info list
/// \param for_info - for info list
/// \return intrin args
ArgInfo GetBinaryVecInsnArgs(const Stmt &stmt, std::string intrin_name, StmtInfoList &dst_info_list,
StmtInfoList &src_info_list, StmtInfo &if_info, StmtInfo &for_info, bool enable_bisect) {
// check intrin_name
std::set<std::string> intrin_name_list = {"vadd", "vmax", "vmin", "vmul", "vdiv", "vsel", "vsub", "vand",
"vor", "vaxpy", "argmax", "argmin", "vmadd", "vmaddrelu", "vmla"};
if (intrin_name_list.count(intrin_name) == 0) {
LOG(FATAL) << "Error: CCE Binary Vector Insn doesn't support the given intrin_name.";
}
// get and check dst and src
GetCompactComputationInfo(stmt, dst_info_list, src_info_list, if_info, for_info, true);
// For vmadd/vmaddrelu/vmla we only need first two src
if (dst_info_list.size() != 1 || src_info_list.size() < 2) {
LOG(FATAL) << "CCE Binary Vector Insn only support ONE dst and TWO srcs.";
}
src_info_list = GetRange(src_info_list, 0, 2);
ArgInfo arg_info = ArgInfo(make_node<ArgInfoNode>());
// detect vector op mode
std::string mode = GetBinaryVecMode(dst_info_list, src_info_list, intrin_name, enable_bisect);
if (mode == "reduce_last_axis") {
size_t src_var_list_size = src_info_list[1]->var_.size();
if (src_info_list[0]->var_.size() > src_info_list[1]->var_.size()) {
src_var_list_size = src_info_list[0]->var_.size();
}
CHECK(src_var_list_size > 0) << "Error: src can not be a scalar.";
if (src_var_list_size - dst_info_list[0]->var_.size() == 1) {
arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION_LAST_AXIS;
} else {
LOG(FATAL) << "Error: cannot support multi-last-axis reduction.";
}
} else if (mode == "reduce_bisection") {
arg_info.GetNode()->arg_type_ = ARG_VECTOR_REDUCTION_BISECTION;
} else {
if (mode != "reduction" && mode != "broadcast") {
FillLastDim(dst_info_list, src_info_list, for_info);
}
// vmax/vmin can't expand mask because it may introduce dirty data
bool can_expand_mask = intrin_name != "vmax" && intrin_name != "vmin";
BinaryVecPatternGenerator generator =
BinaryVecPatternGenerator(dst_info_list, src_info_list, for_info, mode, can_expand_mask);
auto params = generator.GetInsnArgs();
arg_info = params.arg_info;
dst_info_list = params.dst_info_list;
src_info_list = params.src_info_list;
for_info = params.for_info;
if (mode == "broadcast") {
bool has_last_axis = false;
if ((arg_info->body_arg_info_.defined() && arg_info->body_arg_info_->last_axis_info_.src_index_ != -1) ||
(arg_info->tail_arg_info_.defined() && arg_info->tail_arg_info_->last_axis_info_.src_index_ != -1)) {
has_last_axis = true;
}
if (has_last_axis && (intrin_name == "vadd" || intrin_name == "vmul")) {
Array<NodeRef> stores;
Array<NodeRef> loads;
GetStoreAndLoads(stmt, stores, loads);
intrin_name = intrin_name + "s";
if (arg_info->body_arg_info_.defined()) {
arg_info.GetNode()->body_arg_info_.GetNode()->last_axis_info_.intrin_name_ = intrin_name;
arg_info.GetNode()->body_arg_info_.GetNode()->last_axis_info_.src_op_ =
Downcast<Expr>(loads[arg_info->body_arg_info_->last_axis_info_.src_index_]);
}
}
}
}
return arg_info;
}
/// Replace com_info's var with new for loop's var
/// \param info
/// \param old_for_info
/// \param new_for_info
void ReplaceVarWithNewForInfo(StmtStoreInfo &info, const StmtInfo &old_for_info, const StmtInfo &new_for_info) {
for (size_t i = 0; i < new_for_info.vars_.size(); ++i) {
for (size_t j = 0; j < info->var_.size(); ++j) {
if (info->var_[j]->name_hint == new_for_info.vars_[i]->name_hint) {
SetItem(info.GetNode()->var_, static_cast<int>(j), new_for_info.vars_[i]);
}
}
info.GetNode()->index_ = substitute(old_for_info.vars_[i], new_for_info.vars_[i], info->index_);
}
}
/// Generete info list for bisection intrin
/// \param dst_info_list
/// \param src_info_list
/// \param for_info
/// \param if_info
/// \param last_axis
/// \param postfix
/// \return
BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_info_list,
const StmtInfoList &src_info_list, const StmtInfo &for_info,
StmtInfo &if_info, bool last_axis, int postfix = 0) {
CHECK_EQ(dst_info_list.size(), 1);
CHECK_EQ(src_info_list.size(), 2);
BisectionInfoWrapper wrapper;
// Separate com_info and for_info
int compare_idx = 1;
int var_idx = -1;
if (last_axis) {
compare_idx = GetLastAxisReductionIdx(dst_info_list, src_info_list);
} else {
var_idx = GetBisectionReductionIdx(dst_info_list, src_info_list, compare_idx);
}
StmtStoreInfo dst_info = dst_info_list[0];
CHECK_GE(compare_idx, 0);
StmtStoreInfo src_info1 = src_info_list[compare_idx];
Var reduce_var = GetItem(src_info1->var_, var_idx);
size_t for_idx = 0;
bool suc = GetIndexOfElement(for_info.vars_, VarExpr(reduce_var), for_idx);
CHECK(suc);
auto exist_for = GetItem(for_info.ops_, for_idx).as<For>();
CHECK(exist_for);
int extent = GetInt32Const(exist_for->extent);
std::string prev_name = src_info1->name_;
Var prev_var = src_info1->data_;
Buffer prev_buffer = src_info1->buffer_;
Var bisec_var;
Buffer bisec_buffer;
std::string bisec_pre_header = last_axis ? "bisec_last_axis" : "bisec";
std::string bisec_name = bisec_pre_header + "_local_UB";
if (postfix > 0) {
bisec_name = bisec_name + "_" + std::to_string(postfix);
}
bool first_round = true;
int vec_max_len = GetVecMaxLen(dst_info->dtype_);
int remain_extent = extent;
int left_extent = 0;
CHECK_NE(vec_max_len, 0);
std::vector<int> pow2_list = {0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536};
while (extent > 0) {
int for_extent;
if (last_axis) {
left_extent = remain_extent / 2 + remain_extent % 2;
for (int i : pow2_list) {
if (left_extent == i) {
break;
} else if (left_extent < i) {
left_extent = i;
break;
}
}
if (left_extent < vec_max_len) {
// When left_extent < vec_max_len, stop bisect and generate normal reduce intrin
left_extent = remain_extent;
}
extent = remain_extent - left_extent;
remain_extent = left_extent;
for_extent = extent == 0 ? vec_max_len : extent;
} else {
for_extent = extent == 1 ? extent : extent / 2;
extent = extent % 2 == 0 || extent == 1 ? extent / 2 : (extent + 1) / 2;
for (int i : pow2_list) {
if (extent == i) {
break;
} else if (extent < i) {
int gap = i - extent;
extent = i;
for_extent -= gap;
break;
}
}
}
StmtStoreInfo dst_tmp_info = dst_info.Copy();
StmtStoreInfo src_tmp_info0{src_info1.Copy()};
StmtStoreInfo src_tmp_info1{src_info1.Copy()};
if (first_round) {
auto shape = src_tmp_info1->shape_;
if (last_axis) {
int block_size = GetUbBlkSize(dst_info->dtype_);
SetItem(shape, -1, Expr(CeilTo(GetIntConst(GetItem(shape, -1)), block_size)));
}
wrapper.original_shape_ = shape;
bisec_var = Var(bisec_name, Handle());
bisec_buffer = BufferNode::make(bisec_var, dst_tmp_info->dtype_, shape, Array<Expr>(), Expr(), bisec_name,
SCOPE_UBUF, 0, 0, BufferType::kDefault);
if ((last_axis && extent != left_extent) || (!last_axis && extent != for_extent)) {
// Need to copy input to bisect buffer
StmtStoreInfo copy_dst_info{src_info1.Copy()};
StmtStoreInfo copy_src_info{src_info1.Copy()};
StmtInfoList src_list = {copy_src_info};
auto for_tmp_info = for_info.Copy();
auto new_for = GetItem(for_tmp_info.ops_, for_idx).as<For>();
CHECK(new_for);
SetItem(for_tmp_info.ops_, static_cast<int>(for_idx),
For::make(new_for->loop_var, new_for->min, last_axis ? left_extent : extent, new_for->for_type,
new_for->device_api, new_for->body));
ReplaceVarWithNewForInfo(copy_dst_info, for_info, for_tmp_info);
ReplaceVarWithNewForInfo(copy_src_info, for_info, for_tmp_info);
SetItem(copy_dst_info.GetNode()->shape_, var_idx, Expr(last_axis ? left_extent : extent));
SetItem(copy_src_info.GetNode()->shape_, var_idx, Expr(last_axis ? left_extent : extent));
CompactComputationInfoList(copy_dst_info, src_list, if_info, for_tmp_info);
copy_dst_info.GetNode()->name_ = bisec_name;
copy_dst_info.GetNode()->buffer_ = bisec_buffer;
copy_dst_info.GetNode()->data_ = bisec_var;
// Replace outside for variable in index
auto vars = GetVarsInExpr(copy_dst_info->index_);
for (auto var : vars) {
if (!IsInArray(copy_dst_info->var_, var)) {
copy_dst_info.GetNode()->index_ = substitute(var, Expr(0), copy_dst_info->index_);
}
}
wrapper.bisec_info_list_.emplace_back(StmtInfoList{copy_dst_info, copy_src_info});
wrapper.for_info_list_.push_back(for_tmp_info);
}
}
auto for_tmp_info = for_info.Copy();
auto new_for = GetItem(for_tmp_info.ops_, for_idx).as<For>();
CHECK(new_for);
SetItem(
for_tmp_info.ops_, static_cast<int>(for_idx),
For::make(new_for->loop_var, new_for->min, for_extent, new_for->for_type, new_for->device_api, new_for->body));
ReplaceVarWithNewForInfo(dst_tmp_info, for_info, for_tmp_info);
ReplaceVarWithNewForInfo(src_tmp_info0, for_info, for_tmp_info);
ReplaceVarWithNewForInfo(src_tmp_info1, for_info, for_tmp_info);
SetItem(src_tmp_info0.GetNode()->shape_, var_idx, Expr(for_extent));
SetItem(src_tmp_info1.GetNode()->shape_, var_idx, Expr(for_extent));
if (extent > 0) {
dst_tmp_info.GetNode()->shape_ = src_tmp_info1->shape_;
dst_tmp_info.GetNode()->strides_ = src_tmp_info1->strides_;
dst_tmp_info.GetNode()->var_ = src_tmp_info1->var_;
dst_tmp_info.GetNode()->index_ = src_tmp_info1->index_;
dst_tmp_info.GetNode()->data_alignment_ = src_tmp_info1->data_alignment_;
dst_tmp_info.GetNode()->name_ = bisec_name;
dst_tmp_info.GetNode()->buffer_ = bisec_buffer;
dst_tmp_info.GetNode()->data_ = bisec_var;
auto src_extent = Expr(left_extent);
if (!last_axis) {
src_extent = GetItem(src_tmp_info1->strides_, var_idx) * extent;
}
src_tmp_info1.GetNode()->index_ = src_tmp_info1->index_ + src_extent;
}
src_tmp_info0.GetNode()->name_ = prev_name;
src_tmp_info1.GetNode()->name_ = prev_name;
src_tmp_info0.GetNode()->buffer_ = prev_buffer;
src_tmp_info1.GetNode()->buffer_ = prev_buffer;
src_tmp_info0.GetNode()->data_ = prev_var;
src_tmp_info1.GetNode()->data_ = prev_var;
// Replace outside for variable in index
for (auto &info : {dst_tmp_info, src_tmp_info0, src_tmp_info1}) {
if (info->name_.find(bisec_pre_header) == std::string::npos) {
continue;
}
auto vars = GetVarsInExpr(info->index_);
for (auto var : vars) {
if (!IsInArray(info->var_, var)) {
info.GetNode()->index_ = substitute(var, Expr(0), info->index_);
}
}
}
prev_name = bisec_name;
prev_var = bisec_var;
prev_buffer = bisec_buffer;
StmtInfoList src_list = {src_tmp_info0, src_tmp_info1};
CompactComputationInfoList(dst_tmp_info, src_list, if_info, for_tmp_info);
wrapper.for_info_list_.emplace_back(for_tmp_info);
if (extent == 0) {
// last round should be dst = dst + src_tmp1
wrapper.bisec_info_list_.emplace_back(StmtInfoList{dst_tmp_info, dst_tmp_info, src_tmp_info1});
} else {
// normally is dst_tmp = src_tmp0 + src_tmp1
wrapper.bisec_info_list_.emplace_back(StmtInfoList{dst_tmp_info, src_tmp_info0, src_tmp_info1});
}
first_round = false;
}
// Generate arg_info
for (size_t i = 0; i < wrapper.bisec_info_list_.size(); ++i) {
auto info_list = wrapper.bisec_info_list_[i];
auto new_for_info = wrapper.for_info_list_[i];
ArgInfo arg_info;
auto dst_list = GetRange(info_list, 0, 1);
auto src_list = GetRange(info_list, 1, info_list.size() - 1);
if (info_list.size() == 2) {
std::string dma_intrin = INTRIN_NAME_COPY_UB_TO_UB;
wrapper.dma_arg_info_map_ = GetDmaCopyInsnArgs(dma_intrin, dst_list, src_list, new_for_info);
} else if (last_axis && i == wrapper.bisec_info_list_.size() - 1) {
auto dst_tmp_info = dst_list[0];
auto src_tmp_info = src_list[1];
ReduceLastAxisPatternGenerator generator =
ReduceLastAxisPatternGenerator(dst_tmp_info, src_tmp_info, new_for_info, "vadd");
auto result = generator.GetInsnArgs();
arg_info = result.arg_info;
dst_tmp_info = result.dst_info_list[0];
src_tmp_info = result.src_info_list[0];
new_for_info = result.for_info;
wrapper.bisec_info_list_[i] = {dst_tmp_info, dst_tmp_info, src_tmp_info};
} else {
// Bisect can't expand mask because it has inplace operation
if (i != wrapper.bisec_info_list_.size() - 1) {
// Last round dont need to add
FillLastDim(dst_list, src_list, new_for_info);
}
std::string mode = GetBinaryVecMode(dst_list, src_list, "vadd", false);
BinaryVecPatternGenerator generator = BinaryVecPatternGenerator(dst_list, src_list, new_for_info, mode, false);
auto params = generator.GetInsnArgs();
arg_info = params.arg_info;
dst_list = params.dst_info_list;
src_list = params.src_info_list;
new_for_info = params.for_info;
wrapper.bisec_info_list_[i] = {dst_list[0], src_list[0], src_list[1]};
}
wrapper.arg_info_list_.push_back(arg_info);
wrapper.for_info_list_[i] = new_for_info;
}
return wrapper;
}
} // namespace akg
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
#include "insn_info.h" #include "insn_info.h"
#include "insn_pattern.h" #include "insn_pattern.h"
#include "insn_emitter_multimask.h" #include "insn_emitter_multimask.h"
#include "insn_args_calculator.h"
namespace akg { namespace akg {
namespace ir { namespace ir {
/// Sort indexes /// Sort indexes
...@@ -71,8 +71,7 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) { ...@@ -71,8 +71,7 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
Array<Expr> call_args; Array<Expr> call_args;
int call_cnt = 0; int call_cnt = 0;
if (intrin_name == "vector_dup" || intrin_name == "vadds" || if (intrin_name == "vector_dup" || intrin_name == "vadds" || intrin_name == "vmuls" || intrin_name == "vaxpy") {
intrin_name == "vmuls" || intrin_name == "vaxpy") {
auto GetCallInfo = [&intrin_name, &call_args, &call_cnt](const NodeRef &op) { auto GetCallInfo = [&intrin_name, &call_args, &call_cnt](const NodeRef &op) {
if (op.as<Call>() && op.as<Call>()->name == intrin_name) { if (op.as<Call>() && op.as<Call>()->name == intrin_name) {
call_args = op.as<Call>()->args; call_args = op.as<Call>()->args;
...@@ -82,8 +81,8 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) { ...@@ -82,8 +81,8 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
PostOrderVisit(op, GetCallInfo); PostOrderVisit(op, GetCallInfo);
CHECK_EQ(call_cnt, 1); CHECK_EQ(call_cnt, 1);
} }
SingleType insn_type {SingleType::SIMD}; SingleType insn_type{SingleType::SIMD};
Expr scalar_src {}; Expr scalar_src{};
if (intrin_name == "vector_dup") { if (intrin_name == "vector_dup") {
insn_type = SingleType::Vector_Dump; insn_type = SingleType::Vector_Dump;
src_info_list = {}; src_info_list = {};
...@@ -93,10 +92,11 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) { ...@@ -93,10 +92,11 @@ Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
src_info_list = {src_info_list[0]}; src_info_list = {src_info_list[0]};
scalar_src = call_args[1]; scalar_src = call_args[1];
} }
// check is single vector broadcast reduce mode exist // check is single vector broadcast reduce mode exist
SingleVecPatternGenerator generator = SingleVecPatternGenerator(dst_info_list, src_info_list, for_info);
auto params = generator.GetInsnArgs(); SingleVecInsnArgsCalculator args_calculator = SingleVecInsnArgsCalculator(dst_info_list, src_info_list, for_info, intrin_name);
PatternResult params = args_calculator.GetInsnArgs();
dst_info_list = params.dst_info_list; dst_info_list = params.dst_info_list;
src_info_list = params.src_info_list; src_info_list = params.src_info_list;
for_info = params.for_info; for_info = params.for_info;
...@@ -141,24 +141,17 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec ...@@ -141,24 +141,17 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
if (src_info_list[0]->var_.size() > src_info_list[1]->var_.size()) { if (src_info_list[0]->var_.size() > src_info_list[1]->var_.size()) {
src_info = src_info_list[0]; src_info = src_info_list[0];
} }
const int vec_max_len = GetVecMaxLen(dst_info->dtype_);
if (enable_bisect && GetIntConst(GetItem(src_info->shape_, -1)) > vec_max_len) {
CommentManager::GetInstance().AddComment("Bisect_optimize", "enabled");
auto wrapper =
SeparateComInfoToBisectionInfoList(dst_info_list, src_info_list, for_info, if_info, true, postfix);
return EmitCceBinaryVectorToBisectionReduction(wrapper, if_info, intrin_name);
} else {
CommentManager::GetInstance().AddComment("Pattern", arg_info.GetPattern()); CommentManager::GetInstance().AddComment("Pattern", arg_info.GetPattern());
ReduceLastAxisPatternGenerator generator =
ReduceLastAxisPatternGenerator(dst_info, src_info, for_info, intrin_name); LastAxisReduceInsnArgsCalculator args_calculator = LastAxisReduceInsnArgsCalculator(dst_info, src_info, for_info, intrin_name);
auto result = generator.GetInsnArgs(); PatternResult result = args_calculator.GetInsnArgs();
arg_info = result.arg_info; arg_info = result.arg_info;
dst_info = result.dst_info_list[0]; dst_info = result.dst_info_list[0];
src_info = result.src_info_list[0]; src_info = result.src_info_list[0];
for_info = result.for_info; for_info = result.for_info;
return EmitCceBinaryVectorToReduceLastAxis(dst_info, src_info, if_info, for_info, arg_info, intrin_name); return EmitCceBinaryVectorToReduceLastAxis(dst_info, src_info, if_info, for_info, arg_info, intrin_name);
} }
}
case ARG_VECTOR_REDUCTION_BISECTION: { case ARG_VECTOR_REDUCTION_BISECTION: {
CommentManager::GetInstance().AddComment("Compute_type", "reduction"); CommentManager::GetInstance().AddComment("Compute_type", "reduction");
CommentManager::GetInstance().AddComment("Bisect_optimize", "enabled"); CommentManager::GetInstance().AddComment("Bisect_optimize", "enabled");
...@@ -192,7 +185,7 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec ...@@ -192,7 +185,7 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
return FoldInsnWithForInfo(insn_list, if_info, for_info, stmt); return FoldInsnWithForInfo(insn_list, if_info, for_info, stmt);
} }
} }
} } // namespace ir
/// Function to emit scalar intrin /// Function to emit scalar intrin
/// \param op - The input stmt to be emitted as intrin /// \param op - The input stmt to be emitted as intrin
...@@ -984,8 +977,9 @@ Stmt BinaryDropoutEmitter(const Stmt &op) { ...@@ -984,8 +977,9 @@ Stmt BinaryDropoutEmitter(const Stmt &op) {
src1.GetNode()->data_ = mask->buffer_var; src1.GetNode()->data_ = mask->buffer_var;
src1.GetNode()->data_alignment_ = GetInt32Const(mask->predicate); src1.GetNode()->data_alignment_ = GetInt32Const(mask->predicate);
SingleVecPatternGenerator generator = SingleVecPatternGenerator(dst_info_list, src_info_list, for_info, "elewise"); SingleVecInsnArgsCalculator args_calculator = SingleVecInsnArgsCalculator(dst_info_list, src_info_list, for_info);
auto params = generator.GetInsnArgs(); PatternResult params = args_calculator.GetInsnArgs();
dst_info_list = params.dst_info_list; dst_info_list = params.dst_info_list;
src_info_list = params.src_info_list; src_info_list = params.src_info_list;
for_info = params.for_info; for_info = params.for_info;
...@@ -1484,8 +1478,10 @@ Stmt BinaryArgOpEmitter(const Stmt &op, const std::string &intrin_name) { ...@@ -1484,8 +1478,10 @@ Stmt BinaryArgOpEmitter(const Stmt &op, const std::string &intrin_name) {
if (src_info_list[0]->var_.size() > src_info_list[1]->var_.size()) { if (src_info_list[0]->var_.size() > src_info_list[1]->var_.size()) {
src_info = src_info_list[0]; src_info = src_info_list[0];
} }
ReduceLastAxisPatternGenerator generator = ReduceLastAxisPatternGenerator(dst_info, src_info, for_info, intrin_name);
auto result = generator.GetInsnArgs(); LastAxisReduceInsnArgsCalculator args_calculator = LastAxisReduceInsnArgsCalculator(dst_info, src_info, for_info, intrin_name);
PatternResult result = args_calculator.GetInsnArgs();
arg_info = result.arg_info; arg_info = result.arg_info;
dst_info = result.dst_info_list[0]; dst_info = result.dst_info_list[0];
src_info = result.src_info_list[0]; src_info = result.src_info_list[0];
......
...@@ -104,10 +104,7 @@ StmtStoreInfo StmtStoreInfo::Copy() const { ...@@ -104,10 +104,7 @@ StmtStoreInfo StmtStoreInfo::Copy() const {
StmtInfo StmtInfo::Copy() const { StmtInfo StmtInfo::Copy() const {
auto stmt_info = StmtInfo(); auto stmt_info = StmtInfo();
stmt_info.ops_ = ops_; stmt_info.ops_ = ops_;
for (auto var : vars_) { stmt_info.vars_ = vars_;
auto new_var = Variable::make(var->type, var->name_hint);
stmt_info.vars_.push_back(new_var);
}
for (size_t i = 0; i < vars_.size(); ++i) { for (size_t i = 0; i < vars_.size(); ++i) {
for (size_t j = 0; j < stmt_info.ops_.size(); ++j) { for (size_t j = 0; j < stmt_info.ops_.size(); ++j) {
......
...@@ -276,15 +276,7 @@ struct BisectionInfoWrapper { ...@@ -276,15 +276,7 @@ struct BisectionInfoWrapper {
Map<std::string, Expr> dma_arg_info_map_; Map<std::string, Expr> dma_arg_info_map_;
}; };
struct InsnAxis {
int min{0};
int extent{0};
Var var;
int dst_stride{0};
int src_stride{0};
std::list<int> src_stride_list;
std::list<int> stride_list;
};
IterVar GetCceAxis(); IterVar GetCceAxis();
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
*/ */
#include "insn_pattern.h" #include "insn_pattern.h"
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
...@@ -200,28 +199,6 @@ ArgInfo GetMultiVecInsnArgs(StmtInfoList &dst_info_list, StmtInfoList &src_info_ ...@@ -200,28 +199,6 @@ ArgInfo GetMultiVecInsnArgs(StmtInfoList &dst_info_list, StmtInfoList &src_info_
return arg_info; return arg_info;
} }
/// Get first non zero shape from input shapes
/// \param dst_shape
/// \param src0_shape
/// \param src1_shape
/// \return
int PatternGenerator::GetNonZeroShape(const Expr &dst_shape, const Expr &src0_shape, const Expr &src1_shape) {
int shape = 0;
for (int val :
{GetInt32Const(dst_shape), GetInt32Const(src0_shape), src1_shape.defined() ? GetInt32Const(src1_shape) : 0}) {
if (val == 0) {
continue;
}
if (shape != 0 && val != shape) {
LOG(FATAL) << "Error: same var has different shapes. " << GetIntConst(dst_shape) << " "
<< GetIntConst(src0_shape);
}
shape = val;
}
CHECK(shape != 0) << "Error: all shapes are equal to 0.";
return shape;
}
/// In case /// In case
/// for (cc3) { /// for (cc3) {
/// A[(cc3*16)] = (B[(cc3*16)] - C[(cc3*16)]) /// A[(cc3*16)] = (B[(cc3*16)] - C[(cc3*16)])
...@@ -432,25 +409,6 @@ void CleanZeroStrides(Array<StmtStoreInfo> &info_list) { ...@@ -432,25 +409,6 @@ void CleanZeroStrides(Array<StmtStoreInfo> &info_list) {
} }
} }
/// Swap axis in Array
/// \param var
/// \param shape
/// \param strides
/// \param idx1
/// \param idx2
void PatternGenerator::GetShapeInfoAndSwap(Array<Var> &var, Array<Expr> &shape, Array<Expr> &strides, int idx1,
int idx2) {
auto tmp_var = GetItem(var, idx1);
SetItem(var, idx1, GetItem(var, idx2));
SetItem(var, idx2, tmp_var);
auto tmp_shape = GetItem(shape, idx1);
SetItem(shape, idx1, GetItem(shape, idx2));
SetItem(shape, idx2, tmp_shape);
auto tmp_stride = GetItem(strides, idx1);
SetItem(strides, idx1, GetItem(strides, idx2));
SetItem(strides, idx2, tmp_stride);
}
/// Get insn args of load 2D intrin /// Get insn args of load 2D intrin
/// \param intrin_name /// \param intrin_name
/// \param dst_info_list /// \param dst_info_list
...@@ -856,6 +814,38 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn ...@@ -856,6 +814,38 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn
return arg_info_map; return arg_info_map;
} }
/// Replace com_info's var with new for loop's var
/// \param info
/// \param old_for_info
/// \param new_for_info
void ReplaceVarWithNewForInfo(StmtStoreInfo &info, const StmtInfo &old_for_info, const StmtInfo &new_for_info) {
for (size_t i = 0; i < new_for_info.vars_.size(); ++i) {
for (size_t j = 0; j < info->var_.size(); ++j) {
if (info->var_[j]->name_hint == new_for_info.vars_[i]->name_hint) {
SetItem(info.GetNode()->var_, static_cast<int>(j), new_for_info.vars_[i]);
}
}
info.GetNode()->index_ = substitute(old_for_info.vars_[i], new_for_info.vars_[i], info->index_);
}
}
std::string GetBinaryVecMode(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list,
const std::string &intrin_name, bool enable_bisect) {
std::set<std::string> reduce_bisect_list = {"vadd", "vsub", "vmul", "vmax"};
std::string mode = "reduction";
if (IsElementwise(dst_info_list, src_info_list)) {
mode = "elewise";
} else if (IsBroadcast(dst_info_list, src_info_list)) {
mode = "broadcast";
} else if (IsLastAxisReduction(dst_info_list, src_info_list)) {
mode = "reduce_last_axis";
} else if (enable_bisect && reduce_bisect_list.count(intrin_name) != 0 &&
IsBisectionReduction(dst_info_list, src_info_list)) {
mode = "reduce_bisection";
}
return mode;
}
const char *const DummyLastVar = "cc_last"; const char *const DummyLastVar = "cc_last";
TVM_REGISTER_API("cce_util.GetVecMask").set_body([](const TVMArgs args, TVMRetValue *ret) { TVM_REGISTER_API("cce_util.GetVecMask").set_body([](const TVMArgs args, TVMRetValue *ret) {
......
...@@ -37,220 +37,12 @@ struct PatternResult { ...@@ -37,220 +37,12 @@ struct PatternResult {
StmtInfo for_info; StmtInfo for_info;
}; };
class PatternGenerator {
public:
PatternGenerator(const StmtInfoList &dst_info_list, const StmtInfo &for_info)
: for_info(for_info),
not_this_pattern(-1.0f),
split_latency_coef(10.0f),
repeat_latency_coef(3.0f),
offset_latency_coef(0.1f) {
CHECK(!dst_info_list.empty());
dst_info = dst_info_list[0];
}
virtual ~PatternGenerator() = default;
virtual PatternResult GetInsnArgs() = 0;
protected:
int GetNonZeroShape(const Expr &dst_shape, const Expr &src0_shape, const Expr &src1_shape = Expr());
void GetShapeInfoAndSwap(Array<Var> &var, Array<Expr> &shape, Array<Expr> &strides, int idx1, int idx2);
virtual float Compute3DPatternMaskRate() { return not_this_pattern; }
virtual float Compute2DBlockPatternMaskRate() { return not_this_pattern; }
virtual float Compute2DPatternMaskRate() { return not_this_pattern; }
virtual float Compute1DPatternMaskRate() { return not_this_pattern; }
virtual Array<Var> Get3DPattern() { return {}; }
virtual Array<Var> Get2DBlockPattern() { return {}; }
virtual Array<Var> Get2DPattern() { return {}; }
virtual Array<Var> Get1DPattern() { return {}; }
virtual PatternResult GenResult(const Array<Var> &elim_var) = 0;
StmtStoreInfo dst_info;
StmtInfo for_info;
const float not_this_pattern;
const float split_latency_coef;
const float repeat_latency_coef;
const float offset_latency_coef;
};
class SingleVecPatternGenerator : public PatternGenerator {
public:
SingleVecPatternGenerator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list,
const StmtInfo &for_info, const std::string &mode = "elewise")
: PatternGenerator(dst_info_list, for_info),
arg_info(ArgInfo(make_node<ArgInfoNode>())),
body_args(VectorArgInfo()),
tail_args(VectorArgInfo()),
mode(mode) {
if (src_info_list.empty()) {
src_info = dst_info.Copy();
} else {
CHECK(!src_info_list.empty());
src_info = src_info_list[0];
}
}
~SingleVecPatternGenerator() override = default;
PatternResult GetInsnArgs() final;
protected:
float Compute3DPatternMaskRate() final;
float Compute2DBlockPatternMaskRate() final;
float Compute2DPatternMaskRate() final;
float Compute1DPatternMaskRate() final;
float Compute3DsPatternMaskRate();
float Compute2DRepeatPatternMaskRate();
Array<Var> Get3DPattern() final;
Array<Var> Get2DBlockPattern() final;
Array<Var> Get2DPattern() final;
Array<Var> Get1DPattern() final;
Array<Var> Get3DsPattern();
Array<Var> Get2DRepeatPattern();
PatternResult GenResult(const Array<Var> &elim_var) final;
private:
void CalcParams();
int GetLastDimShape(const Expr &dst_shape, const Expr &src_shape);
struct Params {
Array<Var> dst_var;
Array<Var> src_var;
Array<Expr> dst_shape;
Array<Expr> src_shape;
Array<Expr> dst_strides;
Array<Expr> src_strides;
int non_zero_shape1 = 0;
int non_zero_shape2 = 0;
int non_zero_shape3 = 0;
int all_points = 0;
int dst_block_size = 0;
int src_block_size = 0;
int mask_block_size = 0;
int dst_bits = 0;
int src_bits = 0;
int max_bits = 0;
int dst_vec_max_len = 0;
int vec_max_len = 0;
int block_offset = 0;
};
StmtStoreInfo src_info;
Params params;
ArgInfo arg_info;
VectorArgInfo body_args;
VectorArgInfo tail_args;
std::string mode;
Type data_type;
};
class BinaryVecPatternGenerator : public PatternGenerator {
public:
BinaryVecPatternGenerator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list,
const StmtInfo &for_info, const std::string &mode, bool expand_mask = true)
: PatternGenerator(dst_info_list, for_info),
src_info_list(src_info_list),
arg_info(ArgInfo(make_node<ArgInfoNode>())),
body_args(VectorArgInfo()),
tail_args(VectorArgInfo()),
empty_var(Var("")),
mode(mode),
expand_mask(expand_mask) {}
~BinaryVecPatternGenerator() override = default;
PatternResult GetInsnArgs() final;
protected:
float Compute3DPatternMaskRate() final;
float Compute2DBlockPatternMaskRate() final;
float Compute2DPatternMaskRate() final;
float Compute1DPatternMaskRate() final;
Array<Var> Get3DPattern() final;
Array<Var> Get2DBlockPattern() final;
Array<Var> Get2DPattern() final;
Array<Var> Get1DPattern() final;
PatternResult GenResult(const Array<Var> &elim_var) final;
private:
void CalcParams();
bool IsSamePatternComInfo(const StmtStoreInfo &info_a, const StmtStoreInfo &info_b);
bool IsNonZeroShapeEqual(const Array<Expr> &shape_list);
void AppendEmptyVar(StmtInfoList &info_list);
struct Params {
Array<Var> dst_var;
Array<Expr> dst_shape;
Array<Expr> dst_strides;
Array<Var> src_var0;
Array<Expr> src_shape0;
Array<Expr> src_strides0;
Array<Var> src_var1;
Array<Expr> src_shape1;
Array<Expr> src_strides1;
int non_zero_shape1 = 0;
int non_zero_shape2 = 0;
int non_zero_shape3 = 0;
int all_points = 0;
int block_size = 0;
int last_dim_shape = 0;
int vec_max_len = 0;
};
StmtInfoList src_info_list;
ArgInfo arg_info;
VectorArgInfo body_args;
VectorArgInfo tail_args;
Params params;
Var empty_var;
std::string mode;
bool expand_mask;
};
class ReduceLastAxisPatternGenerator : public PatternGenerator {
public:
ReduceLastAxisPatternGenerator(const StmtStoreInfo &dst_info, const StmtStoreInfo &src_info, const StmtInfo &for_info,
const std::string &intrin_name)
: PatternGenerator({dst_info}, for_info),
src_info(src_info),
arg_info(ArgInfo(make_node<ArgInfoNode>())),
body_args(VectorArgInfo()),
tail_args(VectorArgInfo()),
intrin_name(intrin_name) {}
PatternResult GetInsnArgs() final;
~ReduceLastAxisPatternGenerator() override = default;
protected:
float Compute2DBlockPatternMaskRate() final;
Array<Var> Get2DBlockPattern() final;
Array<Var> Get1DPattern() final;
PatternResult GenResult(const Array<Var> &elim_var) final;
private:
void CalcParams();
struct Params {
Array<Var> src_var;
int block_size = 0;
int vec_max_len = 0;
int last_dim_shape = 0;
Expr insn_offset_scale_factor;
};
StmtStoreInfo src_info;
ArgInfo arg_info;
VectorArgInfo body_args;
VectorArgInfo tail_args;
Array<VectorArgInfo> mix_vec_arg_list;
std::string intrin_name;
Params params;
};
std::string GetSingleVecComputationInfo(const Stmt &stmt, const std::string &intrin_name, std::string GetSingleVecComputationInfo(const Stmt &stmt, const std::string &intrin_name,
Array<StmtStoreInfo> &dst_info_list, Array<StmtStoreInfo> &src_info_list, Array<StmtStoreInfo> &dst_info_list, Array<StmtStoreInfo> &src_info_list,
StmtInfo &if_info, StmtInfo &for_info, bool need_compact = true); StmtInfo &if_info, StmtInfo &for_info, bool need_compact = true);
ArgInfo GetBinaryVecInsnArgs(const Stmt &stmt, std::string intrin_name, StmtInfoList &dst_info_list, std::string GetBinaryVecMode(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list,
StmtInfoList &src_info_list, StmtInfo &if_info, StmtInfo &for_info, const std::string &intrin_name, bool enable_bisect = true);
bool enable_bisect = true);
ArgInfo GetMultiVecInsnArgs(StmtInfoList &dst_info_list, StmtInfoList &src_info_list, StmtInfo &for_info); ArgInfo GetMultiVecInsnArgs(StmtInfoList &dst_info_list, StmtInfoList &src_info_list, StmtInfo &for_info);
...@@ -277,10 +69,7 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn ...@@ -277,10 +69,7 @@ Map<std::string, Expr> GetDmaCopyInsnArgs(std::string &intrin_name, const StmtIn
const StmtInfoList &src_info_list, StmtInfo &for_info, const StmtInfoList &src_info_list, StmtInfo &for_info,
Map<std::string, Expr> &ub_copy_pre, Map<std::string, Expr> &ub_copy_post); Map<std::string, Expr> &ub_copy_pre, Map<std::string, Expr> &ub_copy_post);
BisectionInfoWrapper SeparateComInfoToBisectionInfoList(const StmtInfoList &dst_info_list, void ReplaceVarWithNewForInfo(StmtStoreInfo &info, const StmtInfo &old_for_info, const StmtInfo &new_for_info);
const StmtInfoList &src_info_list, const StmtInfo &for_info,
StmtInfo &if_info, bool last_axis, int postfix);
extern const char *const DummyLastVar; extern const char *const DummyLastVar;
} // namespace akg } // namespace akg
#endif // EMIT_INSN_INSN_PATTERN_H_ #endif // EMIT_INSN_INSN_PATTERN_H_
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/base.h>
#include <tvm/ir_pass.h>
#include <cmath>
#include <set>
#include "insn_builder.h"
#include "insn_pattern.h"
#include "common/array_api.h"
#include "pass/expr_alg_simplify.h"
namespace akg {
/// Get CCE Single Vector Insn mode
/// \param dst_info_list
/// \param src_info_list
/// \return
std::string GetSingleVecMode(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list) {
CHECK(!dst_info_list.empty());
auto dst_var_list = dst_info_list[0]->var_;
Array<Var> src_var_list;
if (!src_info_list.empty()) {
src_var_list = src_info_list[0]->var_;
}
if (IsSame(dst_var_list, src_var_list)) {
return "elewise";
} else if (dst_var_list.size() >= src_var_list.size()) {
return "broadcast";
}
return "reduction";
}
/// Get Single Vector Computation Info
/// \param stmt
/// \param intrin_name
/// \param dst_info_list
/// \param src_info_list
/// \param if_info
/// \param for_info
/// \param need_compact
/// \return
std::string GetSingleVecComputationInfo(const Stmt &stmt, const std::string &intrin_name, StmtInfoList &dst_info_list,
StmtInfoList &src_info_list, StmtInfo &if_info, StmtInfo &for_info,
bool need_compact) {
std::set<std::string> intrin_name_list = {"vadds", "vmuls", "vrelu", "vabs", "vln", "vexp",
"vrec", "vector_dup", "vnot", "vsqrt", "vrsqrt"};
if (intrin_name_list.count(intrin_name) == 0 && intrin_name.find("vconv_") == std::string::npos) {
LOG(FATAL) << "Error: CCE Single Vector Insn unsupported the given intrin_name. " << intrin_name;
return "";
}
bool same_dtype = intrin_name.find("vconv_") == std::string::npos;
GetCompactComputationInfo(stmt, dst_info_list, src_info_list, if_info, for_info, same_dtype, need_compact);
std::string mode = GetSingleVecMode(dst_info_list, src_info_list);
CHECK(dst_info_list.size() == 1) << "CCE Single Vector only support ONE dst.";
return mode;
}
/// Get CCE Single vector instructions args.
/// \param dst_info_list
/// \param src_info_list
/// \param for_info
/// \param mode
/// \return
PatternResult SingleVecPatternGenerator::GetInsnArgs() {
CalcParams();
Array<Var> elim_var = {};
float rate3d = Compute3DPatternMaskRate();
float rate2db = Compute2DBlockPatternMaskRate();
float rate2d = Compute2DPatternMaskRate();
float rate1d = Compute1DPatternMaskRate();
float rate3ds = Compute3DsPatternMaskRate();
float rate2ds = Compute2DRepeatPatternMaskRate();
if (mode == "broadcast_last_axis") {
elim_var = Get1DPattern();
} else if (rate2ds > 0) {
elim_var = Get2DRepeatPattern();
} else if (rate3ds > 0) {
elim_var = Get3DsPattern();
arg_info.GetNode()->pattern_ = PATTERN_2D;
} else if (rate3d >= rate2db && rate3d > 0) {
elim_var = Get3DPattern();
arg_info.GetNode()->pattern_ = PATTERN_3D;
} else if (rate2db >= rate2d && rate2db >= rate1d && rate2db > 0) {
elim_var = Get2DBlockPattern();
arg_info.GetNode()->pattern_ = PATTERN_PARTIAL_3D;
} else if (rate2d > rate1d && rate2d > 0) {
elim_var = Get2DPattern();
arg_info.GetNode()->pattern_ = PATTERN_2D;
} else if (rate1d > 0) {
elim_var = Get1DPattern();
arg_info.GetNode()->pattern_ = PATTERN_1D;
} else {
LOG(FATAL) << "Error: Cannot emit Single-Vector-Insn with any pattern!";
}
std::string mask_rate = "rate3d[" + std::to_string(rate3d) + "], rate2db[" + std::to_string(rate2db) + "], rate2d[" +
std::to_string(rate2d) + "], rate1d[" + std::to_string(rate1d) + "]";
CommentManager::GetInstance().AddComment("Mask_rate", mask_rate);
if (arg_info->tail_arg_info_.defined()) {
CommentManager::GetInstance().AddComment("Contain_tail", "true");
} else {
CommentManager::GetInstance().AddComment("Contain_tail", "false");
}
return GenResult(elim_var);
}
/// Calc params for pattern match
void SingleVecPatternGenerator::CalcParams() {
Array<StmtStoreInfo> info_list = {dst_info, src_info};
// check shape len
for (auto info : info_list) {
CHECK(!info->shape_.empty())
<< "CCE Vector Insn Error: dst_buffer and src_buffer can not be scalar, should keep len(shape) > 0.";
}
FillEmptyVar(info_list);
dst_info = info_list[0];
src_info = info_list[1];
int dst_bits = dst_info->dtype_.bits();
int src_bits = src_info->dtype_.bits();
CHECK_NE(dst_bits, 0);
CHECK_NE(src_bits, 0);
int dst_block_size = GetUbBlkSize(dst_info->dtype_);
int src_block_size = GetUbBlkSize(src_info->dtype_);
CHECK_NE(dst_block_size, 0);
CHECK_NE(src_block_size, 0);
data_type = src_bits > dst_bits ? src_info->dtype_ : dst_info->dtype_;
params.dst_var = dst_info->var_;
params.src_var = src_info->var_;
params.dst_shape = dst_info->shape_;
params.src_shape = src_info->shape_;
params.dst_strides = dst_info->strides_;
params.src_strides = src_info->strides_;
params.dst_block_size = dst_block_size;
params.src_block_size = src_block_size;
params.mask_block_size = src_bits > dst_bits ? src_block_size : dst_block_size;
params.dst_bits = dst_bits;
params.src_bits = src_bits;
params.max_bits = FULL_BLOCK_NUM * std::min(dst_bits, src_bits);
params.dst_vec_max_len = GetVecMaxLen(dst_info->dtype_);
params.vec_max_len = src_bits > dst_bits ? GetVecMaxLen(src_info->dtype_) : GetVecMaxLen(dst_info->dtype_);
CHECK_NE(params.dst_vec_max_len, 0);
CHECK_NE(params.vec_max_len, 0);
auto GetNonZeroShapeByIdx = [this](int index) -> int {
if (index <= static_cast<int>(params.dst_var.size())) {
if (Equal(GetItem(params.dst_var, -index), GetItem(params.src_var, -index))) {
return GetNonZeroShape(GetItem(params.dst_shape, -index), GetItem(params.src_shape, -index));
}
}
return 1;
};
params.non_zero_shape1 = GetNonZeroShapeByIdx(1);
params.non_zero_shape2 = GetNonZeroShapeByIdx(2);
params.non_zero_shape3 = GetNonZeroShapeByIdx(3);
params.all_points = params.non_zero_shape1 * params.non_zero_shape2 * params.non_zero_shape3;
auto elem_offset_mod = ir::ExprSimplifier().Simplify(Mod::make(dst_info->elem_offset_, dst_block_size));
if (elem_offset_mod.as<IntImm>()) {
params.block_offset = elem_offset_mod.as<IntImm>()->value;
}
}
int SingleVecPatternGenerator::GetLastDimShape(const Expr &dst_shape, const Expr &src_shape) {
int dst_last_dim = GetInt32Const(dst_shape);
int src_last_dim = GetInt32Const(src_shape);
CHECK(dst_last_dim != 0 || src_last_dim != 0);
if (dst_last_dim == 0) {
return src_last_dim;
}
if (src_last_dim == 0) {
return dst_last_dim;
}
return std::min(dst_last_dim, src_last_dim);
}
bool FindInShape(Array<Expr> &shape, const Expr &target) {
for (int i = -1; i >= -3; --i) {
if (Equal(GetItem(shape, i), target)) {
return true;
}
}
return false;
}
float SingleVecPatternGenerator::Compute2DRepeatPatternMaskRate() {
if (params.dst_var.size() < 3) {
return not_this_pattern;
}
for (int i = -1; i >= -3; --i) {
if (!FindInShape(params.src_shape, GetItem(params.dst_shape, i))) {
return not_this_pattern;
}
}
if (GetInt32Const(GetItem(params.dst_strides, -2)) % params.dst_block_size != 0 ||
GetInt32Const(GetItem(params.src_strides, -2)) % params.src_block_size != 0) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -3)) % params.dst_block_size != 0 ||
GetInt32Const(GetItem(params.src_strides, -3)) % params.src_block_size != 0) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -2)) == 0 && GetInt32Const(GetItem(params.src_strides, -2)) == 0) {
return not_this_pattern;
}
if (!Equal(GetItem(params.dst_shape, -3), GetItem(params.src_shape, -2)) ||
!Equal(GetItem(params.dst_shape, -2), GetItem(params.src_shape, -3))) {
return not_this_pattern;
}
if (GetIntConst(GetItem(params.dst_shape, -2)) > FULL_BLOCK_NUM &&
GetIntConst(GetItem(params.dst_shape, -2)) % FULL_BLOCK_NUM != 0) {
return not_this_pattern;
}
if (params.dst_block_size == params.src_block_size) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_shape, -1)) <= params.dst_block_size &&
GetInt32Const(GetItem(params.src_shape, -1)) <= params.src_block_size) {
return not_this_pattern;
}
return 1.0;
}
float SingleVecPatternGenerator::Compute3DsPatternMaskRate() {
if (params.dst_var.size() < 3) {
return not_this_pattern;
}
if (params.dst_block_size != params.src_block_size) {
return not_this_pattern;
}
for (int i = -1; i >= -3; --i) {
if (!FindInShape(params.src_shape, GetItem(params.dst_shape, i))) {
return not_this_pattern;
}
}
if (GetInt32Const(GetItem(params.dst_shape, -1)) > params.dst_block_size ||
GetInt32Const(GetItem(params.src_shape, -1)) > params.src_block_size) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -2)) % params.dst_block_size != 0 ||
GetInt32Const(GetItem(params.src_strides, -2)) % params.src_block_size != 0) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -3)) % params.dst_block_size != 0 ||
GetInt32Const(GetItem(params.src_strides, -3)) % params.src_block_size != 0) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -2)) == 0 && GetInt32Const(GetItem(params.src_strides, -2)) == 0) {
return not_this_pattern;
}
if (!Equal(GetItem(params.dst_shape, -3), GetItem(params.src_shape, -2)) ||
!Equal(GetItem(params.dst_shape, -2), GetItem(params.src_shape, -3))) {
return not_this_pattern;
}
if (GetIntConst(GetItem(params.dst_shape, -2)) > FULL_BLOCK_NUM &&
GetIntConst(GetItem(params.dst_shape, -2)) % FULL_BLOCK_NUM != 0) {
return not_this_pattern;
}
return 1.0;
}
float SingleVecPatternGenerator::Compute3DPatternMaskRate() {
// in elemwise mode, the var is already checked to be equal, no need to check
if (params.dst_var.size() < 3) {
return not_this_pattern;
}
// do not support cast op in 3D pattern
if (params.dst_block_size != params.src_block_size) {
return not_this_pattern;
}
for (int i = -1; i >= -3; --i) {
if (!IsTwoItemEqual(params.dst_var, params.src_var, i)) {
return not_this_pattern;
}
}
if (GetInt32Const(GetItem(params.dst_shape, -1)) > params.dst_block_size ||
GetInt32Const(GetItem(params.src_shape, -1)) > params.src_block_size) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -2)) % params.dst_block_size != 0 ||
GetInt32Const(GetItem(params.src_strides, -2)) % params.src_block_size != 0) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -3)) % params.dst_block_size != 0 ||
GetInt32Const(GetItem(params.src_strides, -3)) % params.src_block_size != 0) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -2)) == 0 && GetInt32Const(GetItem(params.src_strides, -2)) == 0) {
return not_this_pattern;
}
// repeat axis is shape [-3], repeat once, has 8 loops
bool is3_d = true;
float rate3d_mode1 = not_this_pattern;
float rate3d_mode2 = not_this_pattern;
int repeat_num;
float repeat_latency;
StmtInfoList info_list = {dst_info, src_info};
for (auto info : info_list) {
if (GetInt32Const(GetItem(info->shape_, -2)) > FULL_BLOCK_NUM ||
GetInt32Const(GetItem(info->strides_, -2)) / params.dst_block_size >= MAX_STRIDE_M0_SINGLE ||
GetInt32Const(GetItem(info->strides_, -3)) / params.dst_block_size >= MAX_STRIDE_M1) {
is3_d = false;
break;
}
}
if (is3_d) {
if (GetIntConst(GetItem(params.dst_strides, -2)) == 0) {
return not_this_pattern;
}
repeat_num = params.non_zero_shape3;
repeat_latency = ((repeat_num - 1) / MAX_REPEAT) * repeat_latency_coef;
rate3d_mode1 = static_cast<float>(params.all_points) / params.dst_vec_max_len / (repeat_num + repeat_latency);
}
is3_d = true;
// repeat axis is shape[-2]
for (auto info : info_list) {
// stride_m0 should less than 65536
if (GetInt32Const(GetItem(info->shape_, -3)) % FULL_BLOCK_NUM != 0 ||
GetInt32Const(GetItem(info->strides_, -3)) / params.dst_block_size >= MAX_STRIDE_M0_SINGLE) {
is3_d = false;
break;
}
}
if (is3_d) {
if (GetIntConst(GetItem(params.dst_strides, -3)) == 0) {
return not_this_pattern;
}
repeat_num = params.non_zero_shape2 * (params.non_zero_shape3 / FULL_BLOCK_NUM);
repeat_latency = ((repeat_num - 1) / MAX_REPEAT) * repeat_latency_coef;
float offset_latency =
params.non_zero_shape3 / FULL_BLOCK_NUM > 1 ? params.non_zero_shape3 * offset_latency_coef : 0;
rate3d_mode2 =
static_cast<float>(params.all_points) / params.dst_vec_max_len / (repeat_num + repeat_latency + offset_latency);
}
return rate3d_mode1 > rate3d_mode2 ? rate3d_mode1 : rate3d_mode2;
}
// Partial 3D Pattern
float SingleVecPatternGenerator::Compute2DBlockPatternMaskRate() {
// in elemwise mode, the var is already checked to be equal, no need to check
if (params.dst_var.size() < 2 || params.src_var.size() < 2 || GetInt32Const(GetItem(params.dst_strides, -1)) != 1) {
return not_this_pattern;
}
// do not support cast op in Partial3D pattern
if (params.dst_block_size != params.src_block_size) {
return not_this_pattern;
}
for (int i = -1; i >= -2; --i) {
if (!Equal(GetItem(params.dst_var, i), GetItem(params.src_var, i))) {
return not_this_pattern;
}
}
if (GetInt32Const(GetItem(params.dst_shape, -1)) > params.dst_block_size ||
GetInt32Const(GetItem(params.src_shape, -1)) > params.src_block_size) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -2)) % params.dst_block_size != 0 ||
GetInt32Const(GetItem(params.src_strides, -2)) % params.src_block_size != 0) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -2)) == 0 && GetInt32Const(GetItem(params.src_strides, -2)) == 0) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -2)) / params.dst_block_size >= MAX_STRIDE_M0_SINGLE ||
GetInt32Const(GetItem(params.src_strides, -2)) / params.src_block_size >= MAX_STRIDE_M0_SINGLE) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -2)) == 0) {
return not_this_pattern;
}
int repeat_body_num = params.non_zero_shape2 / FULL_BLOCK_NUM;
int repeat_tail_num = (params.non_zero_shape2 % FULL_BLOCK_NUM + FULL_BLOCK_NUM - 1) / FULL_BLOCK_NUM;
int repeat_num = (repeat_body_num + repeat_tail_num) * params.non_zero_shape3;
float repeat_latency =
(std::max(repeat_body_num - 1, 0) / MAX_REPEAT + std::max(repeat_tail_num - 1, 0) / MAX_REPEAT) *
repeat_latency_coef;
float offset_latency = params.non_zero_shape3 > 1 ? params.non_zero_shape3 * offset_latency_coef : 0;
float split_latency = (repeat_body_num > 0 && repeat_tail_num > 0) ? split_latency_coef : 0;
float rate2db = static_cast<float>(params.all_points) / params.dst_vec_max_len /
(repeat_num + repeat_latency + offset_latency + split_latency);
return rate2db;
}
float SingleVecPatternGenerator::Compute2DPatternMaskRate() {
// in elemwise mode, the var is already checked to be equal, no need to check
if (params.dst_var.size() < 2 || params.src_var.size() < 2) {
return not_this_pattern;
}
if (src_info->data_alignment_ == 1 && GetInt32Const(GetItem(src_info->strides_, -1)) != params.dst_block_size) {
return not_this_pattern;
}
for (int i = -1; i >= -2; --i) {
if (!Equal(GetItem(params.dst_var, i), GetItem(params.src_var, i))) {
return not_this_pattern;
}
}
if (GetInt32Const(GetItem(params.dst_strides, -2)) % params.dst_block_size != 0 ||
GetInt32Const(GetItem(params.src_strides, -2)) % params.src_block_size != 0) {
return not_this_pattern;
}
if (GetInt32Const(GetItem(params.dst_strides, -2)) / params.dst_block_size >= MAX_STRIDE_M1 ||
GetInt32Const(GetItem(params.src_strides, -2)) / params.src_block_size >= MAX_STRIDE_M1) {
return not_this_pattern;
}
// check num of insns, select 1D pattern or 2D pattern
int tail_factor = 0;
if (params.non_zero_shape1 / params.dst_vec_max_len > 0 && params.non_zero_shape1 % params.dst_vec_max_len > 0) {
tail_factor = 1;
}
int offset_num =
(params.non_zero_shape1 + params.dst_vec_max_len - 1) / params.dst_vec_max_len * params.non_zero_shape3;
int repeat_num = offset_num * params.non_zero_shape2;
float repeat_latency = (std::max(params.non_zero_shape2 - 1, 0) / MAX_REPEAT) * offset_num * repeat_latency_coef;
float offset_latency = offset_num > 1 ? offset_num * offset_latency_coef : 0;
float split_latency = tail_factor * split_latency_coef;
float rate2d = static_cast<float>(params.all_points) / params.dst_vec_max_len /
(repeat_num + repeat_latency + offset_latency + split_latency);
return rate2d;
}
float SingleVecPatternGenerator::Compute1DPatternMaskRate() {
int tail_factor = 0;
if (params.non_zero_shape1 / params.dst_vec_max_len > 0 && params.non_zero_shape1 % params.dst_vec_max_len > 0) {
tail_factor = 1;
}
int shape1 = (params.non_zero_shape1 + params.dst_vec_max_len - 1) / params.dst_vec_max_len;
int repeat_num = shape1 * params.non_zero_shape2 * params.non_zero_shape3;
float repeat_latency =
std::max((shape1 - 1) / MAX_REPEAT, 0) * params.non_zero_shape2 * params.non_zero_shape3 * repeat_latency_coef;
float offset_latency = params.non_zero_shape2 * params.non_zero_shape3 > 1
? params.non_zero_shape2 * params.non_zero_shape3 * offset_latency_coef
: 0;
float split_latency = tail_factor * split_latency_coef;
float rate1d = static_cast<float>(params.all_points) / params.dst_vec_max_len /
(repeat_num + repeat_latency + offset_latency + split_latency);
return rate1d;
}
Array<Var> SingleVecPatternGenerator::Get2DRepeatPattern() {
GetShapeInfoAndSwap(params.src_var, params.src_shape, params.src_strides, -2, -3);
int last_dim_shape = GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape, -1));
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
body_args.GetNode()->body_num_ = 1;
body_args.GetNode()->dst_stride_m0_ = 1;
body_args.GetNode()->src_stride_m0_list_ = {1};
body_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size);
body_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides, -2), params.src_block_size)};
body_args.GetNode()->repeat_ = GetItem(params.dst_shape, -2);
int data_len = CeilTo(last_dim_shape, params.dst_block_size);
body_args.GetNode()->vec_mask_ = GetVecMask(data_len, 1, dst_info->dtype_);
return GetRange(params.dst_var, -2, 2);
}
Array<Var> SingleVecPatternGenerator::Get3DsPattern() {
GetShapeInfoAndSwap(params.src_var, params.src_shape, params.src_strides, -2, -3);
int last_dim_shape = GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape, -1));
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
body_args.GetNode()->body_num_ = 1;
Expr dst_stride_m0 = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size);
Expr src_stride_m0 = truncdiv(GetItem(params.src_strides, -2), params.src_block_size);
body_args.GetNode()->dst_stride_m0_ = dst_stride_m0;
body_args.GetNode()->src_stride_m0_list_ = {src_stride_m0};
int block_num = 0;
int data_len = CeilTo(last_dim_shape, params.mask_block_size);
if (GetIntConst(GetItem(params.dst_shape, -2)) <= FULL_BLOCK_NUM) {
block_num = GetIntConst(GetItem(params.dst_shape, -2));
body_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -3), params.dst_block_size);
body_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides, -3), params.src_block_size)};
body_args.GetNode()->vec_mask_ = GetVecMask(data_len, block_num, dst_info->dtype_);
auto repeat = GetItem(params.dst_shape, -3);
if (GetIntConst(repeat) < MAX_STRIDE_M1) {
body_args.GetNode()->repeat_ = repeat;
return GetRange(params.dst_var, -3, 3);
} else {
body_args.GetNode()->repeat_ = 1;
return GetRange(params.dst_var, -2, 2);
}
} else {
block_num = FULL_BLOCK_NUM;
body_args.GetNode()->dst_stride_m1_ = dst_stride_m0 * block_num;
body_args.GetNode()->src_stride_m1_list_ = {src_stride_m0 * block_num};
auto repeat = truncdiv(GetItem(params.dst_shape, -2), FULL_BLOCK_NUM);
body_args.GetNode()->vec_mask_ = GetVecMask(data_len, block_num, dst_info->dtype_);
if (GetIntConst(repeat) < MAX_STRIDE_M1) {
body_args.GetNode()->repeat_ = repeat;
return GetRange(params.dst_var, -2, 2);
} else {
return Get1DPattern();
}
}
}
Array<Var> SingleVecPatternGenerator::Get3DPattern() {
if (GetIntConst(GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape, -2))) > FULL_BLOCK_NUM) {
// split shape[-3]
if (GetIntConst(GetNonZeroShape(GetItem(params.dst_shape, -3), GetItem(params.src_shape, -3))) > FULL_BLOCK_NUM) {
StmtInfoList info_list = {dst_info, src_info};
SplitAxis(info_list, for_info, GetItem(params.dst_var, -3), FULL_BLOCK_NUM);
FillEmptyVar(info_list);
dst_info = info_list[0];
src_info = info_list[1];
params.dst_var = dst_info->var_;
params.dst_shape = dst_info->shape_;
params.dst_strides = dst_info->strides_;
params.src_var = src_info->var_;
params.src_shape = src_info->shape_;
params.src_strides = src_info->strides_;
}
// consider original shape[-2] as repeat axis
GetShapeInfoAndSwap(params.dst_var, params.dst_shape, params.dst_strides, -2, -3);
GetShapeInfoAndSwap(params.src_var, params.src_shape, params.src_strides, -2, -3);
}
int last_dim_shape = GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape, -1));
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
body_args.GetNode()->body_num_ = 1;
body_args.GetNode()->repeat_ =
make_const(Int(32), GetNonZeroShape(GetItem(params.dst_shape, -3), GetItem(params.src_shape, -3)));
body_args.GetNode()->dst_stride_m0_ = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size);
body_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -3), params.dst_block_size);
body_args.GetNode()->src_stride_m0_list_ = {truncdiv(GetItem(params.src_strides, -2), params.src_block_size)};
body_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides, -3), params.src_block_size)};
body_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset);
int data_len = CeilTo(last_dim_shape, params.mask_block_size);
int data_num = GetInt32Const(GetItem(params.dst_shape, -2));
body_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, dst_info->dtype_, params.block_offset);
return GetRange(params.dst_var, -3, 3);
}
Array<Var> SingleVecPatternGenerator::Get2DBlockPattern() {
int last_dim_shape = GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape, -1));
int repeat_len = GetNonZeroShape(GetItem(params.dst_shape, -2), GetItem(params.src_shape, -2));
int repeat_body = repeat_len / FULL_BLOCK_NUM;
int repeat_tail = (repeat_len % FULL_BLOCK_NUM + FULL_BLOCK_NUM - 1) / FULL_BLOCK_NUM;
int data_len = CeilTo(last_dim_shape, params.dst_block_size);
if (repeat_body > 0) {
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
body_args.GetNode()->body_num_ = 1;
body_args.GetNode()->repeat_ = make_const(Int(32), repeat_body);
auto dst_stride_m0 = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size);
body_args.GetNode()->dst_stride_m0_ = dst_stride_m0;
body_args.GetNode()->dst_stride_m1_ = dst_stride_m0 * (params.max_bits / params.src_bits);
auto src_stride_m0 = truncdiv(GetItem(params.src_strides, -2), params.src_block_size);
body_args.GetNode()->src_stride_m0_list_ = {src_stride_m0};
body_args.GetNode()->src_stride_m1_list_ = {src_stride_m0 * (params.max_bits / params.dst_bits)};
body_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset);
int data_num = FULL_BLOCK_NUM;
body_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, dst_info->dtype_, params.block_offset);
}
if (repeat_tail > 0) {
tail_args = VectorArgInfo(make_node<VectorArgInfoNode>());
tail_args.GetNode()->dst_head_ = GetItem(params.dst_strides, -2) * repeat_body * FULL_BLOCK_NUM;
tail_args.GetNode()->src_head_list_ = {GetItem(params.src_strides, -2) * repeat_body * FULL_BLOCK_NUM};
tail_args.GetNode()->repeat_ = Expr(1);
tail_args.GetNode()->dst_stride_m0_ = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size);
tail_args.GetNode()->dst_stride_m1_ = Expr(0);
tail_args.GetNode()->src_stride_m0_list_ = {truncdiv(GetItem(params.src_strides, -2), params.src_block_size)};
tail_args.GetNode()->src_stride_m1_list_ = {Expr(0)};
tail_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset);
int data_num = repeat_len % FULL_BLOCK_NUM;
tail_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, dst_info->dtype_, params.block_offset);
}
return GetRange(params.dst_var, -2, 2);
}
Array<Var> SingleVecPatternGenerator::Get2DPattern() {
const int data_num = 1;
int last_dim_shape = GetNonZeroShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape, -1));
if (GetInt32Const(GetItem(dst_info->strides_, -1)) == params.dst_block_size &&
IsTwoItemEqual(dst_info->strides_, src_info->strides_, -1, true)) {
last_dim_shape *= params.dst_block_size;
}
int body_len = FloorTo(last_dim_shape, params.vec_max_len);
int tail_len = last_dim_shape % params.vec_max_len;
if (body_len > 0) {
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
body_args.GetNode()->body_num_ = body_len / params.vec_max_len;
body_args.GetNode()->body_offset_ = params.vec_max_len;
body_args.GetNode()->repeat_ = GetItem(params.dst_shape, -2);
body_args.GetNode()->dst_stride_m0_ = Expr(1);
body_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size);
body_args.GetNode()->src_stride_m0_list_ = {Expr(1)};
body_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides, -2), params.src_block_size)};
body_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset);
body_args.GetNode()->vec_mask_ = GetVecMask(params.vec_max_len, data_num, data_type, params.block_offset);
}
// get tail params
if (tail_len > 0) {
tail_args = VectorArgInfo(make_node<VectorArgInfoNode>());
tail_args.GetNode()->dst_head_ = Expr(body_len);
tail_args.GetNode()->src_head_list_ = {Expr(body_len)};
tail_args.GetNode()->repeat_ = GetItem(params.dst_shape, -2);
tail_args.GetNode()->dst_stride_m0_ = Expr(1);
tail_args.GetNode()->dst_stride_m1_ = truncdiv(GetItem(params.dst_strides, -2), params.dst_block_size);
tail_args.GetNode()->src_stride_m0_list_ = {Expr(1)};
tail_args.GetNode()->src_stride_m1_list_ = {truncdiv(GetItem(params.src_strides, -2), params.src_block_size)};
tail_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset);
int data_len = CeilTo(tail_len, params.mask_block_size);
tail_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, data_type, params.block_offset);
}
return GetRange(params.dst_var, -2, 2);
}
Array<Var> SingleVecPatternGenerator::Get1DPattern() {
int last_dim_shape;
bool linear_mode = false;
if ((params.dst_shape.empty() && params.src_shape.empty()) || GetIntConst(GetItem(params.dst_shape, -1)) == 0) {
last_dim_shape = 1;
} else if (!IsTwoItemEqual(params.dst_var, params.src_var, -1)) {
last_dim_shape = 1;
} else {
last_dim_shape = GetLastDimShape(GetItem(params.dst_shape, -1), GetItem(params.src_shape, -1));
linear_mode = params.dst_bits == params.src_bits;
}
bool is_scalar_mode = IsScalarMode({dst_info, src_info});
if (is_scalar_mode && params.dst_bits != params.src_bits) {
last_dim_shape = 1;
}
int vec_max_len = is_scalar_mode ? FULL_BLOCK_NUM : params.vec_max_len;
int body_len = FloorTo(last_dim_shape, vec_max_len);
int tail_len = last_dim_shape % vec_max_len;
auto dst_stride_m0 =
is_scalar_mode && linear_mode ? truncdiv(GetItem(params.dst_strides, -1), params.dst_block_size) : Expr(1);
auto src_stride_m0 =
is_scalar_mode && linear_mode ? truncdiv(GetItem(params.src_strides, -1), params.src_block_size) : Expr(1);
if (body_len > 0) {
body_args = VectorArgInfo(make_node<VectorArgInfoNode>());
body_args.GetNode()->body_num_ = 1;
body_args.GetNode()->body_offset_ = vec_max_len;
body_args.GetNode()->repeat_ = Expr(body_len / vec_max_len);
body_args.GetNode()->dst_stride_m0_ = dst_stride_m0;
auto dst_block_num = is_scalar_mode ? FULL_BLOCK_NUM : (params.max_bits / params.src_bits);
body_args.GetNode()->dst_stride_m1_ = dst_stride_m0 * dst_block_num;
body_args.GetNode()->src_stride_m0_list_ = {src_stride_m0};
auto src_block_num = is_scalar_mode ? FULL_BLOCK_NUM : (params.max_bits / params.dst_bits);
body_args.GetNode()->src_stride_m1_list_ = {src_stride_m0 * src_block_num};
body_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset);
// in cast case, data_num should be 1 because dst and src bit is not equal
int data_len = is_scalar_mode ? 1 : vec_max_len;
int data_num = is_scalar_mode ? FULL_BLOCK_NUM : 1;
body_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, data_type, params.block_offset);
}
// get tail params
if (tail_len > 0) {
tail_args = VectorArgInfo(make_node<VectorArgInfoNode>());
tail_args.GetNode()->body_offset_ = vec_max_len;
tail_args.GetNode()->body_num_ = 1;
tail_args.GetNode()->dst_head_ =
Expr(body_len * (is_scalar_mode ? dst_stride_m0 * params.dst_block_size : Expr(1)));
tail_args.GetNode()->src_head_list_ = {
Expr(body_len * (is_scalar_mode ? src_stride_m0 * params.src_block_size : Expr(1)))};
tail_args.GetNode()->repeat_ = Expr(1);
tail_args.GetNode()->dst_stride_m0_ = dst_stride_m0;
tail_args.GetNode()->dst_stride_m1_ = Expr(0);
tail_args.GetNode()->src_stride_m0_list_ = {src_stride_m0};
tail_args.GetNode()->src_stride_m1_list_ = {Expr(0)};
tail_args.GetNode()->block_offset_ = make_const(Int(32), params.block_offset);
int data_len = is_scalar_mode && linear_mode ? 1 : CeilTo(tail_len, params.mask_block_size);
int data_num = is_scalar_mode && linear_mode ? tail_len : 1;
data_num = data_num == 0 ? 1 : data_num;
tail_args.GetNode()->vec_mask_ = GetVecMask(data_len, data_num, data_type, params.block_offset);
}
// compute offset for cce instructions
Array<Var> elim_var = {};
if (mode == "elewise" && params.dst_var.size() >= 2 && params.dst_strides.size() >= 2 && for_info.ops_.size() >= 2 &&
last_dim_shape <= vec_max_len && last_dim_shape >= vec_max_len - params.dst_block_size &&
GetIntConst(GetItem(params.dst_strides, -2)) == last_dim_shape) {
// in this case we can merge second last for extent to repeat
size_t index = 0;
bool suc = GetIndexOfElement(for_info.vars_, GetItem(params.dst_var, -2), index);
CHECK(suc);
auto latest_for = GetItem(for_info.ops_, index).as<For>();
// there should not be if_op between for loop and compute stmt
if (latest_for && !latest_for->body->IsInstance<IfThenElse>()) {
if (!params.dst_var.empty() && (!is_scalar_mode || last_dim_shape != 1)) {
if (body_args.defined()) {
// last_dim_shape = vec_max_len
body_args.GetNode()->repeat_ = body_args->repeat_ * latest_for->extent;
} else if (tail_args.defined()) {
// last_dim_shape < vec_max_len
tail_args.GetNode()->repeat_ = tail_args->repeat_ * latest_for->extent;
}
return GetRange(params.dst_var, -2, 2);
}
}
}
if (!params.dst_var.empty() && (!is_scalar_mode || last_dim_shape != 1 || linear_mode) &&
GetIntConst(GetItem(params.dst_strides, -1)) > 0 &&
(params.src_var.empty() || IsTwoItemEqual(params.dst_var, params.src_var, -1))) {
elim_var = GetRange(params.dst_var, -1, 1);
}
return elim_var;
}
PatternResult SingleVecPatternGenerator::GenResult(const Array<Var> &elim_var) {
arg_info.GetNode()->body_arg_info_ = body_args;
arg_info.GetNode()->tail_arg_info_ = tail_args;
dst_info.GetNode()->insn_offset_ = GetInsnOffset(dst_info, elim_var);
src_info.GetNode()->insn_offset_ = GetInsnOffset(src_info, elim_var);
CleanForInfoVars(for_info, elim_var);
StmtInfoList info_list = {dst_info, src_info};
CleanZeroStrides(info_list);
dst_info = info_list[0];
src_info = info_list[1];
PatternResult result;
result.dst_info_list = {dst_info};
result.src_info_list = {src_info};
result.for_info = for_info;
result.arg_info = arg_info;
return result;
}
} // namespace akg
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "pass/ir_util.h" #include "pass/ir_util.h"
#include "poly/poly_util.h" #include "poly/poly_util.h"
#include "emit_insn/insn_emitter.h" #include "emit_insn/insn_emitter.h"
#include "emit_insn/ir_transform.h"
namespace akg { namespace akg {
namespace ir { namespace ir {
...@@ -475,6 +476,7 @@ Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Ma ...@@ -475,6 +476,7 @@ Stmt EmitInsn(Stmt stmt, bool enable_bisect, bool enable_cover_protect, const Ma
} }
stmt = UnalignedMad().Mutate(stmt); stmt = UnalignedMad().Mutate(stmt);
stmt = RegCondition().Mutate(stmt); stmt = RegCondition().Mutate(stmt);
stmt = ForVarUnique().Mutate(stmt);
return stmt; return stmt;
} }
} // namespace ir } // namespace ir
......
...@@ -343,8 +343,12 @@ class BroadcastCalculate : public IRMutator { ...@@ -343,8 +343,12 @@ class BroadcastCalculate : public IRMutator {
}; };
Stmt MultiLastAxisReductions(Stmt stmt, bool is_dynamic = false) { Stmt MultiLastAxisReductions(Stmt stmt, bool is_dynamic = false) {
auto ori_stmt = stmt;
stmt = MultiLastAxisReduction().Mutate(stmt); stmt = MultiLastAxisReduction().Mutate(stmt);
stmt = BroadcastCalculate(is_dynamic).Mutate(stmt); stmt = BroadcastCalculate(is_dynamic).Mutate(stmt);
if (!is_dynamic && !Equal(ori_stmt, stmt)) {
stmt = MergeLoops(stmt);
}
return stmt; return stmt;
} }
} // namespace ir } // namespace ir
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include <algorithm> #include <algorithm>
#include "emit_insn/insn_info.h" #include "emit_insn/insn_info.h"
#include "emit_insn/insn_pattern.h" #include "emit_insn/insn_pattern.h"
#include "emit_insn/insn_args_calculator.h"
namespace akg { namespace akg {
namespace ir { namespace ir {
...@@ -48,85 +48,63 @@ class TailSpliter : public IRMutator { ...@@ -48,85 +48,63 @@ class TailSpliter : public IRMutator {
if (src_info_list.empty()) { if (src_info_list.empty()) {
src_info_list = {dst_info.Copy()}; src_info_list = {dst_info.Copy()};
} }
auto get_info_list = [](const StmtStoreInfo &dst_info, const Array<StmtStoreInfo> &src_info_list) {
Array<StmtStoreInfo> res;
res.push_back(dst_info.Copy());
for (auto it : src_info_list) {
res.push_back(it.Copy());
}
return res;
};
auto info_list = get_info_list(dst_info, src_info_list);
FillEmptyVar(info_list);
auto axis_list = GetAixsList(for_info, info_list);
auto get_last_axis_it = [](const std::list<InsnAxis> &axis_list) {
for (auto it = axis_list.begin(); it != axis_list.end(); it++) {
auto stride_list = it->stride_list;
if (!(std::any_of(stride_list.begin(), stride_list.end(), [](int stride) { return stride > 1; }) ||
std::all_of(stride_list.begin(), stride_list.end(), [](int stride) { return stride == 0; }))) {
return it;
}
}
return axis_list.end();
};
auto last_axis_it = get_last_axis_it(axis_list); auto info_list = GetInfoList(dst_info, src_info_list);
if (last_axis_it == axis_list.end()) { FillEmptyVar(info_list);
return s;
}
auto last_axis = *last_axis_it;
auto last_axis_shape = last_axis.extent;
int dst_block_size = GetUbBlkSize(dst_info->dtype_); int dst_block_size = GetUbBlkSize(dst_info->dtype_);
int src_block_size = GetUbBlkSize(src_info_list[0]->dtype_); int src_block_size = GetUbBlkSize(src_info_list[0]->dtype_);
int block_size = dst_block_size > src_block_size ? dst_block_size : src_block_size; int block_size = dst_block_size < src_block_size ? dst_block_size : src_block_size;
int cast_block_size = dst_block_size > src_block_size ? dst_block_size : src_block_size;
int vec_max_len = block_size * FULL_BLOCK_NUM; int vec_max_len = block_size * FULL_BLOCK_NUM;
auto args_calculator = InsnArgsCalculator(dst_info_list, src_info_list, for_info, "");
if (last_axis_shape > vec_max_len && last_axis_shape % vec_max_len != 0) { auto vec_axis_it = args_calculator.GetVecAxisIt();
return Block::make(TailMake(s, last_axis, vec_max_len, false), TailMake(s, last_axis, vec_max_len, true)); bool cast = dst_block_size != src_block_size;
if (args_calculator.IsValid(vec_axis_it)) {
auto vec_axis = *vec_axis_it;
auto vec_axis_shape = vec_axis.extent;
if (vec_axis_shape >= vec_max_len) {
if (vec_axis_shape % vec_max_len != 0) {
return TailBlock(s, vec_axis, vec_max_len);
} }
if (last_axis_shape < vec_max_len * tail_rate_ && last_axis_shape > block_size && } else {
last_axis_shape % block_size != 0 && axis_list.size() > 1) { if (vec_axis_shape < vec_max_len * tail_rate_ && vec_axis_shape > cast_block_size &&
return Block::make(TailMake(s, last_axis, block_size, false), TailMake(s, last_axis, block_size, true)); vec_axis_shape % cast_block_size != 0 && args_calculator.axis_list_.size() > 1) {
return TailBlock(s, vec_axis, cast_block_size);
} }
} }
return IRMutator::Mutate_(op, s);
} }
if (!cast && (!args_calculator.IsValid(vec_axis_it) || vec_axis_it->extent <= cast_block_size * tail_rate_)) {
std::list<InsnAxis> GetAixsList(const StmtInfo &for_info, const Array<StmtStoreInfo> &info_list) { auto get_block_axis = [&](std::list<InsnAxis> &axis_list) {
std::list<InsnAxis> axis_list; InsnAxis block_axis;
auto GetStrideByAxis = [](const Array<Var> &vars, const Array<Expr> &strides, Var obj_var) { block_axis.is_valid = false;
int index = 0; std::vector<InsnAxis> temp_axis_set;
for (auto var_it : vars) { auto block_stride_lambda = [&](int stride) { return stride % block_size == 0 && stride / block_size <= 4; };
if (Equal(var_it, obj_var)) { for (auto axis : axis_list) {
return strides[index]; if (std::all_of(axis.stride_list.begin(), axis.stride_list.end(), block_stride_lambda) &&
axis.dst_stride != 0 && axis.extent != 0 && axis.extent > FULL_BLOCK_NUM &&
axis.extent % FULL_BLOCK_NUM != 0) {
temp_axis_set.push_back(axis);
} }
index++;
} }
return Expr(0); if (!temp_axis_set.empty()) {
}; return temp_axis_set[0];
for (auto it : for_info.ops_) {
InsnAxis axis;
auto for_stmt = it.as<For>();
CHECK(for_stmt);
axis.var = for_stmt->loop_var;
axis.extent = GetInt32Const(for_stmt->extent);
axis.min = GetInt32Const(for_stmt->min);
int index = 0;
for (auto it : info_list) {
auto stride = GetInt32Const(GetStrideByAxis(it->var_, it->strides_, axis.var));
axis.stride_list.push_back(stride);
if (index == 0) {
axis.dst_stride = stride;
} else { } else {
axis.src_stride_list.push_back(stride); return block_axis;
} }
index++; };
auto block_axis = get_block_axis(args_calculator.axis_list_);
if (block_axis.IsValid()) {
return TailBlock(s, block_axis, FULL_BLOCK_NUM);
} }
axis_list.push_back(axis);
} }
return axis_list; return s;
}
return IRMutator::Mutate_(op, s);
} }
Stmt TailBlock(const Stmt &s, const InsnAxis &tail_axis, int body_size) {
return Block::make(TailMake(s, tail_axis, body_size, false), TailMake(s, tail_axis, body_size, true));
}
Stmt TailMake(const Stmt &s, const InsnAxis &tail_axis, int body_size, bool is_tail) { Stmt TailMake(const Stmt &s, const InsnAxis &tail_axis, int body_size, bool is_tail) {
if (auto attr_stmt = s.as<AttrStmt>()) { if (auto attr_stmt = s.as<AttrStmt>()) {
return AttrStmt::make(attr_stmt->node, attr_stmt->attr_key, attr_stmt->value, return AttrStmt::make(attr_stmt->node, attr_stmt->attr_key, attr_stmt->value,
...@@ -145,7 +123,6 @@ class TailSpliter : public IRMutator { ...@@ -145,7 +123,6 @@ class TailSpliter : public IRMutator {
} }
return For::make(for_stmt->loop_var, for_stmt->min, for_stmt->extent, for_stmt->for_type, for_stmt->device_api, return For::make(for_stmt->loop_var, for_stmt->min, for_stmt->extent, for_stmt->for_type, for_stmt->device_api,
TailMake(for_stmt->body, tail_axis, body_size, is_tail)); TailMake(for_stmt->body, tail_axis, body_size, is_tail));
} }
if (s.as<Store>() && is_tail) { if (s.as<Store>() && is_tail) {
return substitute(tail_axis.var, Add::make(Expr(tail_axis.extent / body_size * body_size), tail_axis.var), s); return substitute(tail_axis.var, Add::make(Expr(tail_axis.extent / body_size * body_size), tail_axis.var), s);
...@@ -156,6 +133,20 @@ class TailSpliter : public IRMutator { ...@@ -156,6 +133,20 @@ class TailSpliter : public IRMutator {
private: private:
const float tail_rate_{0.6}; const float tail_rate_{0.6};
const std::set<std::string> include_intrin_list_ = { const std::set<std::string> include_intrin_list_ = {
// binary vec
"vec_binary_add",
"vec_binary_sub",
"vec_binary_mul",
"vec_binary_min",
"vec_binary_max",
"vec_binary_div",
"vec_binary_and",
"vec_binary_or",
"vec_binary_vmadd",
"vec_binary_vmaddrelu",
"vec_binary_vmla",
// single vec
"vec_single_fabs", "vec_single_fabs",
"vec_single_log", "vec_single_log",
"vec_single_exp", "vec_single_exp",
...@@ -165,20 +156,28 @@ class TailSpliter : public IRMutator { ...@@ -165,20 +156,28 @@ class TailSpliter : public IRMutator {
"vec_single_rsqrt", "vec_single_rsqrt",
"vec_single_relu", "vec_single_relu",
"vec_single_not", "vec_single_not",
// vector_scalar
"vec_single_muls",
"vec_single_adds",
// Mov // Mov
"broadcast", "broadcast",
"mask_broadcast",
// vector_cast // vector_cast
"vec_single_cast", "vec_single_cast",
"vec_single_floor", "vec_single_floor",
"vec_single_round", "vec_single_round",
"vec_single_ceil", "vec_single_ceil",
"vec_single_trunc", "vec_single_trunc",
// scalar case
"vector_dup",
"vmuls",
"vadds",
"vaxpy",
}; };
}; };
Stmt SplitTail(Stmt stmt) { return TailSpliter().Mutate(stmt); } Stmt SplitTail(Stmt stmt) {
auto tail_spliter = TailSpliter();
auto first_round = tail_spliter.Mutate(stmt);
auto second_round = tail_spliter.Mutate(stmt);
return second_round;
}
} // namespace ir } // namespace ir
} // namespace akg } // namespace akg
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册