提交 f5776fee 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!52 finetune pragma, refactor emitter and alignment

Merge pull request !52 from cyun/710_pragma_insn_emit_fix
...@@ -115,6 +115,9 @@ REGISTER_PASS(AnalyzeMinAlignStatic); ...@@ -115,6 +115,9 @@ REGISTER_PASS(AnalyzeMinAlignStatic);
REGISTER_PASS(AnalyzeMinAlignDynamic); REGISTER_PASS(AnalyzeMinAlignDynamic);
REGISTER_PASS(RewriteBroadcastVector); REGISTER_PASS(RewriteBroadcastVector);
REGISTER_PASS(OptimizePragma); REGISTER_PASS(OptimizePragma);
REGISTER_PASS(PackStore);
REGISTER_PASS(RecoverStore);
REGISTER_PASS(MergeLoops);
REGISTER_PASS(ExpandC0); REGISTER_PASS(ExpandC0);
REGISTER_PASS(ForEliminate); REGISTER_PASS(ForEliminate);
REGISTER_PASS(FixLoopExtent); REGISTER_PASS(FixLoopExtent);
......
...@@ -738,16 +738,19 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef> ...@@ -738,16 +738,19 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
if (global_attrs.GetBoolAttr(kDeadCodeElim, false)) { if (global_attrs.GetBoolAttr(kDeadCodeElim, false)) {
stmt = NEXT_PASS(DeadCodeElim, stmt); stmt = NEXT_PASS(DeadCodeElim, stmt);
} }
if (!is_dynamic) {
stmt = NEXT_PASS(RewriteBroadcastVector, stmt);
stmt = NEXT_PASS(OptimizePragma, stmt);
}
if (is_dynamic) { if (is_dynamic) {
stmt = NEXT_PASS(AnalyzeMinAlignDynamic, stmt, global_attrs.GetIntAttr(kEnableConvAnalyzeAlign, true), stmt = NEXT_PASS(AnalyzeMinAlignDynamic, stmt, global_attrs.GetIntAttr(kEnableConvAnalyzeAlign, true),
global_attrs.GetIntAttr(kEnableScalarAlign, false)); global_attrs.GetIntAttr(kEnableScalarAlign, false));
} else { } else {
stmt = NEXT_PASS(RewriteBroadcastVector, stmt);
stmt = NEXT_PASS(OptimizePragma, stmt);
stmt = NEXT_PASS(MergeLoops, stmt, false);
stmt = NEXT_PASS(PackStore, stmt);
stmt = NEXT_PASS(AnalyzeMinAlignStatic, stmt); stmt = NEXT_PASS(AnalyzeMinAlignStatic, stmt);
stmt = NEXT_PASS(RecoverStore, stmt);
} }
stmt = NEXT_PASS(MultiLastAxisReductions, stmt, is_dynamic); stmt = NEXT_PASS(MultiLastAxisReductions, stmt, is_dynamic);
stmt = NEXT_PASS(AutoReorder, stmt); stmt = NEXT_PASS(AutoReorder, stmt);
if (enable_multicore != 0) { if (enable_multicore != 0) {
......
...@@ -25,6 +25,9 @@ ...@@ -25,6 +25,9 @@
#include "insn_info.h" #include "insn_info.h"
#include "cce_params.h" #include "cce_params.h"
namespace akg { namespace akg {
enum SingleType {SIMD, Tensor_Scalar, Vector_Dump};
struct MutableMaskParams { struct MutableMaskParams {
Var mask_var_; Var mask_var_;
Expr loop_var_; Expr loop_var_;
...@@ -239,8 +242,11 @@ class VectorInsnBuilder : public InsnBuilder { ...@@ -239,8 +242,11 @@ class VectorInsnBuilder : public InsnBuilder {
class SingleVecInsnBuilder : public VectorInsnBuilder { class SingleVecInsnBuilder : public VectorInsnBuilder {
public: public:
SingleVecInsnBuilder(const StmtStoreInfo &dst, const StmtStoreInfo &src, const ArgInfo &args, SingleVecInsnBuilder(const StmtStoreInfo &dst, const StmtStoreInfo &src, const ArgInfo &args,
const std::string &intrin_name, const Buffer &tmp_buf = Buffer()) const std::string &intrin_name, const Expr &scalar_src = Expr(),
: VectorInsnBuilder(dst, {src}, args, intrin_name), src_info_(src_info_list_[0]), tmp_buffer_(tmp_buf) { const SingleType insn_type = SingleType::SIMD)
: VectorInsnBuilder(dst, {src}, args, intrin_name),
src_info_(src_info_list_[0]),
scalar_src_(scalar_src), insn_type_(insn_type) {
CHECK(src_info_.defined()); CHECK(src_info_.defined());
} }
~SingleVecInsnBuilder() override = default; ~SingleVecInsnBuilder() override = default;
...@@ -254,8 +260,10 @@ class SingleVecInsnBuilder : public VectorInsnBuilder { ...@@ -254,8 +260,10 @@ class SingleVecInsnBuilder : public VectorInsnBuilder {
Stmt CreateBroadcast(const VectorArgInfo &arg_info, const Var &local_var, Stmt stmt); Stmt CreateBroadcast(const VectorArgInfo &arg_info, const Var &local_var, Stmt stmt);
StmtStoreInfo src_info_; StmtStoreInfo src_info_;
Buffer tmp_buffer_;
Buffer broadcast_buffer_; Buffer broadcast_buffer_;
Expr scalar_src_;
SingleType insn_type_; // 0 simd : 1 vector_scalar : 2 vector_dup
}; };
class MultiVecInsnBuilder : public VectorInsnBuilder { class MultiVecInsnBuilder : public VectorInsnBuilder {
......
...@@ -92,9 +92,6 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) { ...@@ -92,9 +92,6 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) {
Expr dst_offset = dst_info_->insn_offset_; Expr dst_offset = dst_info_->insn_offset_;
Expr src_offset = src_info_->insn_offset_; Expr src_offset = src_info_->insn_offset_;
Var local_var = Var("broadcast_for_vec_local_UB", Handle());
stmt = CreateBroadcast(arg_info, local_var, stmt);
// Handle stride_m1 loop of single vector intrin, if stride_m1 > 255, it will be separated // Handle stride_m1 loop of single vector intrin, if stride_m1 > 255, it will be separated
if (dst_stride_m1 >= MAX_STRIDE_M1 || src_stride_m1 >= MAX_STRIDE_M1) { if (dst_stride_m1 >= MAX_STRIDE_M1 || src_stride_m1 >= MAX_STRIDE_M1) {
auto var = Var("repeatStrideM1Idx"); auto var = Var("repeatStrideM1Idx");
...@@ -112,14 +109,6 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) { ...@@ -112,14 +109,6 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) {
} }
} }
if (!dst_info_->var_.empty() && src_info_->var_.empty() && intrin_name_ != INTRIN_NAME_VECTOR_DUP) {
// need to broadcast src first
stmt = Allocate::make(local_var, src_info_->dtype_, {Expr(src_block_size * FULL_BLOCK_NUM)}, const_true(), stmt);
if (!src_info_->scope_.empty()) {
stmt = AttrStmt::make(local_var, STORAGE_SCOPE, StringImm::make(src_info_->scope_), stmt);
}
}
CHECK(stmt.defined()) << "Error: Stmt is undefined!"; CHECK(stmt.defined()) << "Error: Stmt is undefined!";
return stmt; return stmt;
...@@ -131,70 +120,36 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) { ...@@ -131,70 +120,36 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) {
/// \return /// \return
Stmt SingleVecInsnBuilder::EmitIntrinBody(const VectorArgInfo &arg_info, const Map<std::string, Expr> &args) { Stmt SingleVecInsnBuilder::EmitIntrinBody(const VectorArgInfo &arg_info, const Map<std::string, Expr> &args) {
Stmt body; Stmt body;
CHECK(!arg_info->src_stride_m0_list_.empty()); CHECK(!arg_info->src_stride_m0_list_.empty());
CHECK(!arg_info->src_stride_m1_list_.empty()); CHECK(!arg_info->src_stride_m1_list_.empty());
auto dst_buffer_id = GenBufferId(dst_info_);
auto src_buffer_id = GenBufferId(src_info_);
Expr repeat = args["repeat"]; Expr repeat = args["repeat"];
auto dst_buffer_id = GenBufferId(dst_info_);
Expr dst_offset = Sub::make(args["dstOffset"], arg_info->block_offset_); Expr dst_offset = Sub::make(args["dstOffset"], arg_info->block_offset_);
Expr src_offset = args["srcOffset"];
Expr src_stride_m1 = arg_info->src_stride_m1_list_[0];
auto dst = GetAccessPtr(dst_buffer_id, "w", dst_offset); auto dst = GetAccessPtr(dst_buffer_id, "w", dst_offset);
auto src = GetAccessPtr(src_buffer_id, "r", src_offset);
if (broadcast_buffer_.defined()) { Array<Expr> insn_args {};
src_stride_m1 = 0; if (insn_type_ == SingleType::Vector_Dump) {
src = GetAccessPtr(broadcast_buffer_, "r", Expr(0)); insn_args = {dst, scalar_src_, repeat};
} else {
auto src_buffer_id = GenBufferId(src_info_);
Expr src_offset = args["srcOffset"];
auto src = GetAccessPtr(src_buffer_id, "r", src_offset);
if (insn_type_ == SingleType::SIMD) {
insn_args = {dst, src, repeat};
} else if (insn_type_ == SingleType::Tensor_Scalar) {
insn_args = {dst, src, scalar_src_, repeat};
} else {
CHECK(0) << "\nUnknown insn_type_\n";
}
} }
Array<Expr> stride_args = {arg_info->dst_stride_m0_, arg_info->src_stride_m0_list_[0], arg_info->dst_stride_m1_, Array<Expr> stride_args = {arg_info->dst_stride_m0_, arg_info->src_stride_m0_list_[0], arg_info->dst_stride_m1_,
src_stride_m1}; arg_info->src_stride_m1_list_[0]};
Array<Expr> insn_args = {dst, src, repeat};
if (arg_info->scalar_.defined()) {
auto scalar = arg_info->scalar_;
if (tmp_buffer_.defined()) {
dst = GetAccessPtr(tmp_buffer_, "w", dst_offset);
}
insn_args = {dst, scalar, repeat};
if (intrin_name_ != INTRIN_NAME_VECTOR_DUP) {
Insert(insn_args, 1, src);
}
}
insn_args = MergeTwo(insn_args, stride_args); insn_args = MergeTwo(insn_args, stride_args);
body = EmitCceIntrinTemplate(Stmt(), dst.type(), insn_args, intrin_name_); body = EmitCceIntrinTemplate(Stmt(), dst.type(), insn_args, intrin_name_);
return body; return body;
} }
/// Create broadcast intrin if src is scalar
/// \param arg_info
/// \param local_var
/// \param stmt
/// \return
Stmt SingleVecInsnBuilder::CreateBroadcast(const VectorArgInfo &arg_info, const Var &local_var, Stmt stmt) {
if (!dst_info_->var_.empty() && src_info_->var_.empty() && intrin_name_ != INTRIN_NAME_VECTOR_DUP) {
// need to broadcast src first
auto src_block_size = GetUbBlkSize(src_info_->dtype_);
broadcast_buffer_ = BufferNode::make(local_var, src_info_->dtype_, {Expr(src_block_size * FULL_BLOCK_NUM)}, {},
src_info_->elem_offset_, "broadcast_for_vec_local_UB", src_info_->scope_,
src_info_->data_alignment_, 1, BufferType::kDefault);
auto broad_dst = GetAccessPtr(broadcast_buffer_, "w", 0);
Array<Expr> args = {
broad_dst, GenBufferId(src_info_).vload({Expr(0)}, src_info_->dtype_), Expr(1), Expr(1), Expr(1), Expr(0),
Expr(0)};
stmt = EmitSetVecMaskIntrin(stmt, src_info_->dtype_, GetAllMask(src_info_->dtype_));
stmt = InsertBody(stmt, EmitCceIntrinTemplate(Stmt(), src_info_->dtype_, args, INTRIN_NAME_VECTOR_DUP));
stmt = EmitSetVecMaskIntrin(stmt, dst_info_->dtype_, arg_info->vec_mask_);
}
return stmt;
}
/// if repeat-size > cce_max_repeat, then split it into loop as "Davinci ISA User Guide t6.3 (8.2.2)" mentioned /// if repeat-size > cce_max_repeat, then split it into loop as "Davinci ISA User Guide t6.3 (8.2.2)" mentioned
/// max_cce_repeat = 255, considering params are about 2 cycles, set it to be 255 // 2 = 127 /// max_cce_repeat = 255, considering params are about 2 cycles, set it to be 255 // 2 = 127
...@@ -1250,8 +1205,10 @@ Stmt EmitCceBinaryVectorToReduceLastAxis(const StmtStoreInfo &dst_info, const St ...@@ -1250,8 +1205,10 @@ Stmt EmitCceBinaryVectorToReduceLastAxis(const StmtStoreInfo &dst_info, const St
auto vec_dup_arg_info = GenReduceHelperArgInfo(vec_dup_dst_info, for_extent, scalar, "VecDup"); auto vec_dup_arg_info = GenReduceHelperArgInfo(vec_dup_dst_info, for_extent, scalar, "VecDup");
vec_dup_dst_info.GetNode()->data_ = final_var;
vec_dup_dst_info.GetNode()->name_ = final_var->name_hint;
SingleVecInsnBuilder single_vec_builder = SingleVecInsnBuilder(vec_dup_dst_info, vec_dup_dst_info, vec_dup_arg_info, SingleVecInsnBuilder single_vec_builder = SingleVecInsnBuilder(vec_dup_dst_info, vec_dup_dst_info, vec_dup_arg_info,
INTRIN_NAME_VECTOR_DUP, final_dst_buffer); INTRIN_NAME_VECTOR_DUP, scalar, SingleType::Vector_Dump);
auto insn_list = single_vec_builder.EmitIntrin(); auto insn_list = single_vec_builder.EmitIntrin();
auto stmt = std::accumulate(insn_list.begin(), insn_list.end(), Stmt(), auto stmt = std::accumulate(insn_list.begin(), insn_list.end(), Stmt(),
[](const Stmt &s0, const Stmt &s1) { return InsertBody(s0, s1); }); [](const Stmt &s0, const Stmt &s1) { return InsertBody(s0, s1); });
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "emit_insn/insn_emitter.h" #include "insn_emitter.h"
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
...@@ -53,145 +53,68 @@ std::vector<size_t> SortIndexes(const std::vector<int> &v) { ...@@ -53,145 +53,68 @@ std::vector<size_t> SortIndexes(const std::vector<int> &v) {
/// \param intrin_name - The CCE intrin name /// \param intrin_name - The CCE intrin name
/// \param broadcast_last_axis - Tag of broadcast_last_axis mode /// \param broadcast_last_axis - Tag of broadcast_last_axis mode
/// \return Stmt of emitted CCE intrin /// \return Stmt of emitted CCE intrin
Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name, bool broadcast_last_axis = false) { Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
CHECK(op); CHECK(op);
Stmt result; Stmt result;
// optimization of copy_ubuf_to_ubuf
bool is_dma_opt = false; CommentManager::GetInstance().AddComment("Insn_type", "single_vector");
if (intrin_name == INTRIN_NAME_COPY_UB_TO_UB) { CommentManager::GetInstance().AddComment("Insn_name", intrin_name);
CommentManager::GetInstance().AddComment("Insn_type", "dma_copy");
CommentManager::GetInstance().AddComment("Insn_name", INTRIN_NAME_COPY_UB_TO_UB);
CommentManager::GetInstance().AddComment("Vadds_replace_copy", "enable");
intrin_name = "vadds";
is_dma_opt = true;
} else {
CommentManager::GetInstance().AddComment("Insn_type", "single_vector");
CommentManager::GetInstance().AddComment("Insn_name", intrin_name);
}
StmtInfoList dst_info_list; StmtInfoList dst_info_list;
StmtInfoList src_info_list; StmtInfoList src_info_list;
StmtStoreInfo scalar_info;
StmtInfo for_info; StmtInfo for_info;
StmtInfo if_info; StmtInfo if_info;
std::string mode = GetSingleVecComputationInfo(op, intrin_name, dst_info_list, src_info_list, if_info, for_info);
bool same_dtype = intrin_name.find("vconv_") == std::string::npos;
GetCompactComputationInfo(op, dst_info_list, src_info_list, if_info, for_info, same_dtype, true);
CHECK(!dst_info_list.empty()); CHECK(!dst_info_list.empty());
if (broadcast_last_axis) { Array<Expr> call_args;
mode = "broadcast_last_axis"; int call_cnt = 0;
// In this case, must come from binary vec, so must have two src if (intrin_name == "vector_dup" || intrin_name == "vadds" ||
CHECK(src_info_list.size() >= 2) << "Broadcast last axis mode must have at least two srcs."; intrin_name == "vmuls" || intrin_name == "vaxpy") {
if (!IsTwoItemEqual(src_info_list[0]->var_, dst_info_list[0]->var_, -1)) { auto GetCallInfo = [&intrin_name, &call_args, &call_cnt](const NodeRef &op) {
scalar_info = src_info_list[0]; if (op.as<Call>() && op.as<Call>()->name == intrin_name) {
src_info_list.Set(0, src_info_list[1]); call_args = op.as<Call>()->args;
} else if (!IsTwoItemEqual(src_info_list[1]->var_, dst_info_list[0]->var_, -1)) { call_cnt = call_cnt + 1;
scalar_info = src_info_list[1];
}
} else {
if (mode == "broadcast" && !src_info_list.empty() && dst_info_list.size() == 1) {
if (!IsTwoItemEqual(src_info_list[0]->var_, dst_info_list[0]->var_, -1)) {
mode = "broadcast_last_axis";
} }
if (src_info_list.size() > 1) { };
if (!IsTwoItemEqual(src_info_list[1]->var_, dst_info_list[0]->var_, -1)) { PostOrderVisit(op, GetCallInfo);
mode = "broadcast_last_axis"; CHECK_EQ(call_cnt, 1);
} else {
scalar_info = src_info_list[0];
src_info_list.Set(0, src_info_list[1]);
}
}
}
}
if (broadcast_last_axis) {
mode = "broadcast_last_axis";
} }
SingleType insn_type {SingleType::SIMD};
if (intrin_name == INTRIN_NAME_VECTOR_DUP) { Expr scalar_src {};
auto dst_info = dst_info_list[0]; if (intrin_name == "vector_dup") {
if (dst_info->var_.size() > 1 && insn_type = SingleType::Vector_Dump;
GetIntConst(GetItem(dst_info->strides_, -1)) == GetIntConst(GetItem(dst_info->shape_, -1)) + 1) { src_info_list = {};
// diagnoal broadcast case scalar_src = call_args[0];
return op; } else if (intrin_name == "vadds" || intrin_name == "vmuls" || intrin_name == "vaxpy") {
} insn_type = SingleType::Tensor_Scalar;
dst_info.CleanFlexVar(); src_info_list = {src_info_list[0]};
scalar_src = call_args[1];
} }
// check is single vector broadcast reduce mode exist // check is single vector broadcast reduce mode exist
SingleVecPatternGenerator generator = SingleVecPatternGenerator(dst_info_list, src_info_list, for_info, mode); SingleVecPatternGenerator generator = SingleVecPatternGenerator(dst_info_list, src_info_list, for_info);
auto params = generator.GetInsnArgs(); auto params = generator.GetInsnArgs();
dst_info_list = params.dst_info_list; dst_info_list = params.dst_info_list;
src_info_list = params.src_info_list; src_info_list = params.src_info_list;
for_info = params.for_info; for_info = params.for_info;
ArgInfo arg_info = params.arg_info; ArgInfo arg_info = params.arg_info;
CommentManager::GetInstance().AddComment("Compute_type", mode); CommentManager::GetInstance().AddComment("Compute_type", intrin_name);
CommentManager::GetInstance().AddComment("Pattern", arg_info.GetPattern()); CommentManager::GetInstance().AddComment("Pattern", arg_info.GetPattern());
if (intrin_name == "vadds" || intrin_name == "vmuls" || intrin_name == INTRIN_NAME_VECTOR_DUP) {
auto stores = GetStores(op);
auto store = stores[0].as<Store>();
auto scalar = Expr(0);
if (intrin_name == "vadds" || intrin_name == "vmuls") {
if (!dst_info_list.empty()) {
scalar = FloatImm::make(dst_info_list[0]->dtype_, 0.000000);
}
if (!dst_info_list[0]->dtype_.is_float()) {
return op;
}
if (!is_dma_opt) {
if (!scalar_info.defined()) {
auto children = GetBinaryOpExprChildren(store->value);
if (children.empty()) {
LOG(FATAL) << store->value << " is not binary op.";
}
scalar = children[1];
} else {
scalar = Load::make(scalar_info->dtype_, scalar_info->data_, scalar_info->index_, Expr(1));
}
}
} else if (intrin_name == INTRIN_NAME_VECTOR_DUP) {
if (store->value->IsInstance<Load>()) {
// scale is load
scalar =
Load::make(src_info_list[0]->dtype_, store->value.as<Load>()->buffer_var, src_info_list[0]->index_, Expr(1));
} else {
// scale is imm
scalar = store->value;
}
}
if (arg_info->body_arg_info_.defined()) {
arg_info->body_arg_info_.GetNode()->scalar_ = scalar;
}
if (arg_info->tail_arg_info_.defined()) {
arg_info->tail_arg_info_.GetNode()->scalar_ = scalar;
}
}
if (intrin_name == "vconv_deq") { if (intrin_name == "vconv_deq") {
result = InsertBody( result = InsertBody(
result, Evaluate::make(Call::make(Float(16), "set_deqscale", {FloatImm::make(Float(16), 1.0)}, Call::Extern))); result, Evaluate::make(Call::make(Float(16), "set_deqscale", {FloatImm::make(Float(16), 1.0)}, Call::Extern)));
} }
SingleVecInsnBuilder single_vec_builder = SingleVecInsnBuilder single_vec_builder =
SingleVecInsnBuilder(dst_info_list[0], src_info_list[0], arg_info, intrin_name); SingleVecInsnBuilder(dst_info_list[0], src_info_list[0], arg_info, intrin_name, scalar_src, insn_type);
auto insn_list = single_vec_builder.EmitIntrin(); auto insn_list = single_vec_builder.EmitIntrin();
auto ret = FoldInsnWithForInfo(insn_list, if_info, for_info, result);
if (intrin_name == INTRIN_NAME_VECTOR_DUP && dst_info_list[0]->var_.empty()) { return ret;
Stmt store;
auto ScanStore = [&store](const NodeRef &op) {
const auto e = op.as<Store>();
if (e != nullptr) {
store = Store::make(e->buffer_var, e->value, e->index, e->predicate);
}
};
air::ir::PostOrderVisit(op, ScanStore);
store = EmitSetVecMaskIntrin(store, dst_info_list[0]->dtype_);
insn_list = {store};
}
return FoldInsnWithForInfo(insn_list, if_info, for_info, result);
} }
/// Function to emit binary vector intrin /// Function to emit binary vector intrin
...@@ -211,11 +134,6 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec ...@@ -211,11 +134,6 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
CommentManager::GetInstance().AddComment("Insn_name", intrin_name); CommentManager::GetInstance().AddComment("Insn_name", intrin_name);
switch (arg_info->arg_type_) { switch (arg_info->arg_type_) {
case ARG_VECTOR_BROADCAST_LAST_AXIS: {
CommentManager::GetInstance().CleanComments();
intrin_name += "s";
return SingleVecEmitter(op, intrin_name, true);
}
case ARG_VECTOR_REDUCTION_LAST_AXIS: { case ARG_VECTOR_REDUCTION_LAST_AXIS: {
CommentManager::GetInstance().AddComment("Compute_type", "reduce_last_axis"); CommentManager::GetInstance().AddComment("Compute_type", "reduce_last_axis");
auto dst_info = dst_info_list[0]; auto dst_info = dst_info_list[0];
...@@ -928,83 +846,8 @@ Stmt DmaMovEmitter(const Stmt &op, bool enable_cover_protect) { ...@@ -928,83 +846,8 @@ Stmt DmaMovEmitter(const Stmt &op, bool enable_cover_protect) {
StmtInfo for_info; StmtInfo for_info;
GetDmaComputationInfo(op, dst_info_list, src_info_list, if_info, for_info, dma_mode, intrin_name); GetDmaComputationInfo(op, dst_info_list, src_info_list, if_info, for_info, dma_mode, intrin_name);
auto check_alignment = [](const Expr &align, const Array<Expr> &shape) {
if (GetIntConst(align) == 1 || shape.size() == 1u) {
return true;
}
if (shape.empty()) {
return false;
}
Expr sz = 1;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
sz = sz * shape[i];
if (GetIntConst(align) == GetIntConst(sz)) {
return true;
}
}
return false;
};
const auto &dst_info = dst_info_list[0]; const auto &dst_info = dst_info_list[0];
const auto &src_info = src_info_list[0]; const auto &src_info = src_info_list[0];
int block_size = GetUbBlkSize(dst_info->dtype_);
// check scalar to scalar
// check if dst is considered as scalar
// check if src is considered as scalar
bool is_broadcast =
(dst_info->var_.empty() || (!dst_info->strides_.empty() && GetIntConst(GetItem(dst_info->strides_, -1)) != 1)) &&
(src_info->var_.empty() || (!src_info->strides_.empty() && GetIntConst(GetItem(src_info->strides_, -1)) != 1));
// check vector to vector, but in scalar dma mode
bool last_dim_equal = !dst_info->var_.empty() && !src_info->var_.empty() && !dst_info->strides_.empty() &&
!src_info->strides_.empty() &&
GetItem(dst_info->var_, -1).get() == GetItem(src_info->var_, -1).get() &&
GetIntConst(GetItem(dst_info->strides_, -1)) != GetIntConst(GetItem(src_info->strides_, -1));
bool broadcast_scalar = intrin_name == "broadcast" && is_broadcast;
bool ubuf_scalar = intrin_name == INTRIN_NAME_COPY_UB_TO_UB && (is_broadcast || last_dim_equal);
if (broadcast_scalar || ubuf_scalar) {
int shape1 = GetInt32Const(GetItem(dst_info->shape_, -1));
int stride1 = GetInt32Const(GetItem(dst_info->strides_, -1));
if (ubuf_scalar && shape1 < block_size && stride1 == block_size &&
IsTwoItemEqual(dst_info->strides_, src_info->strides_, -1, true) && src_info->dtype_.bits() != 64) {
// if last dim small than blocksize, then use vadds
return SingleVecEmitter(op, intrin_name);
}
CommentManager::GetInstance().AddComment("Insn_type", "dma_copy");
CommentManager::GetInstance().AddComment("Insn_name", "scalar");
if (src_info->var_.empty() && dst_info->var_.empty()) {
return op;
} else {
// check align
if (!check_alignment(dst_info->data_alignment_, dst_info->shape_)) {
return op;
}
Stmt base_stmt = EmitScalarDmaIntrinTemplate(op, src_info, dst_info);
return GenIfAndFor(base_stmt, if_info, for_info, false);
}
}
if (intrin_name == "broadcast") {
return SingleVecEmitter(op, INTRIN_NAME_VECTOR_DUP);
} else if (intrin_name == INTRIN_NAME_COPY_UB_TO_UB) {
// Use vadds to optimize dma copy
if (if_info.vars_.empty() && dst_info->dtype_.is_float() && src_info->dtype_.is_float()) {
if ((dst_info->dtype_.bits() == 32 && src_info->dtype_.bits() == 32) ||
(dst_info->dtype_.bits() == 16 && src_info->dtype_.bits() == 16)) {
int repeat_len = block_size * FULL_BLOCK_NUM;
CHECK_NE(block_size, 0);
int shape1 = GetInt32Const(GetItem(dst_info->shape_, -1));
if ((shape1 >= repeat_len / 2 && shape1 <= repeat_len) ||
(dst_info->shape_.size() >= 3 && shape1 <= block_size) ||
(dst_info->shape_.size() >= 2 && shape1 % block_size == 0)) {
// if last dim shape is too small, there is no need to opt
return SingleVecEmitter(op, intrin_name);
}
}
}
}
CommentManager::GetInstance().AddComment("Insn_type", "dma_copy"); CommentManager::GetInstance().AddComment("Insn_type", "dma_copy");
...@@ -1014,31 +857,10 @@ Stmt DmaMovEmitter(const Stmt &op, bool enable_cover_protect) { ...@@ -1014,31 +857,10 @@ Stmt DmaMovEmitter(const Stmt &op, bool enable_cover_protect) {
Map<std::string, Expr> ub_copy_post; Map<std::string, Expr> ub_copy_post;
auto arg_info_map = auto arg_info_map =
GetDmaCopyInsnArgs(intrin_name, dst_info_list, src_info_list, for_info, ub_copy_pre, ub_copy_post); GetDmaCopyInsnArgs(intrin_name, dst_info_list, src_info_list, for_info, ub_copy_pre, ub_copy_post);
if (intrin_name == "vtranspose_scalar") { DmaInsnBuilder dma_builder =
base_stmt = EmitScalarDmaIntrinTemplate(op, src_info, dst_info); DmaInsnBuilder(dst_info, src_info, intrin_name, arg_info_map, false, false, enable_cover_protect);
CommentManager::GetInstance().AddComment("Insn_name", "scalar"); base_stmt = dma_builder.EmitSingleIntrin();
} else if (intrin_name == "vtranspose") { CommentManager::GetInstance().AddComment("Insn_name", intrin_name);
Array<Expr> args = {arg_info_map["loop_width"], arg_info_map["loop_height"], arg_info_map["shape_width"]};
Array<Expr> pre_ub_copy_args;
if (!ub_copy_pre.empty()) {
pre_ub_copy_args = Array<Expr>(
{ub_copy_pre["nBurst"], ub_copy_pre["lenBurst"], ub_copy_pre["srcStride"], ub_copy_pre["dstStride"]});
}
Array<Expr> post_ub_copy_args;
if (!ub_copy_post.empty()) {
post_ub_copy_args = Array<Expr>(
{ub_copy_post["nBurst"], ub_copy_post["lenBurst"], ub_copy_post["srcStride"], ub_copy_post["dstStride"]});
}
TransposeInsnBuilder builder =
TransposeInsnBuilder(dst_info, src_info, args, pre_ub_copy_args, post_ub_copy_args);
base_stmt = builder.EmitSingleIntrin();
CommentManager::GetInstance().AddComment("Insn_name", intrin_name);
} else {
DmaInsnBuilder dma_builder =
DmaInsnBuilder(dst_info, src_info, intrin_name, arg_info_map, false, false, enable_cover_protect);
base_stmt = dma_builder.EmitSingleIntrin();
CommentManager::GetInstance().AddComment("Insn_name", intrin_name);
}
} else if (dma_mode == "cce_load") { } else if (dma_mode == "cce_load") {
auto arg_info_map = GetDmaLoad2DInsnArgs(intrin_name, dst_info_list, src_info_list, for_info); auto arg_info_map = GetDmaLoad2DInsnArgs(intrin_name, dst_info_list, src_info_list, for_info);
DmaInsnBuilder builder = DmaInsnBuilder(dst_info, src_info, intrin_name, arg_info_map, true); DmaInsnBuilder builder = DmaInsnBuilder(dst_info, src_info, intrin_name, arg_info_map, true);
...@@ -1104,6 +926,19 @@ Stmt DmaAtomicAddEmitter(const Stmt &op) { ...@@ -1104,6 +926,19 @@ Stmt DmaAtomicAddEmitter(const Stmt &op) {
return stmt; return stmt;
} }
Stmt VTransposeEmitter(const Stmt &op) {
StmtInfoList dst_info_list;
StmtInfoList src_info_list;
StmtInfo for_info;
StmtInfo if_info;
GetCompactComputationInfo(op, dst_info_list, src_info_list, if_info, for_info, true, true);
auto dst_buffer_id = GenBufferId(dst_info_list[0]);
auto src_buffer_id = GenBufferId(src_info_list[0]);
auto dst = GetAccessPtr(dst_buffer_id, "w", 0);
auto src = GetAccessPtr(src_buffer_id, "r", 0);
return Evaluate::make(Call::make(Float(16), "vtranspose", {dst, src}, Call::Extern));
}
/// Function to emit dropout intrin /// Function to emit dropout intrin
/// \param op - The input stmt to be emitted as intrin /// \param op - The input stmt to be emitted as intrin
/// \return Stmt of emitted CCE intrin /// \return Stmt of emitted CCE intrin
...@@ -1913,97 +1748,6 @@ Stmt ReduceCombineEmitter(const Stmt &op, bool enable_bisect) { ...@@ -1913,97 +1748,6 @@ Stmt ReduceCombineEmitter(const Stmt &op, bool enable_bisect) {
Stmt InsnEmit(std::string insn_name, const Stmt &op, bool enable_bisect, bool enable_cover_protect, int comment_level) { Stmt InsnEmit(std::string insn_name, const Stmt &op, bool enable_bisect, bool enable_cover_protect, int comment_level) {
CHECK(op.defined()); CHECK(op.defined());
static const std::map<std::string, std::string> ReplaceAttrPragmaMap = {
// vector binary
{"binary_vcadd", "vec_binary_add"},
{"vaxpy", "vec_binary_axpy"},
// vector single
{"vec_single_fabs", "vec_single_abs"},
{"broadcast", "vec_broadcast"},
// cube
{"mad", "cube_mad"},
{"ub2gm", "cube_ub2gm"},
{"im2col", "cube_img2col"},
// special attrs
{"vec_binary_proposal_sort", "vec_proposal_sort"},
{"vec_binary_topk_sort", "vec_topk_sort"},
{"vec_binary_dropout", "vec_dropout"},
{"vec_binary_fargmax", "vec_argmax"},
{"vec_binary_fargmin", "vec_argmin"},
{"vec_binary_iou", "vec_iou"},
{"vec_binary_nms", "vec_nms"},
{"mask_broadcast", "vec_broadcast"},
};
static const std::map<std::string, std::string> BinaryVecInsnMap = {
// vadd.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadd.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadd.f32 support target:mini_v100 cloud_v100
// vadd contains two situations:
// 1. normal elewise vector add
// - all src[i].shape = dst.shape
// 2. reductive vector add
// - exist src[i].shape != dst.shape
{"vec_binary_add", "vadd"},
// vsub.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vsub.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vsub.f32 support target:mini_v100 cloud_v100
{"vec_binary_sub", "vsub"},
// vmul.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmul.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmul.f32 support target:mini_v100 cloud_v100
{"vec_binary_mul", "vmul"},
// vmin.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmin.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmin.f32 support target:mini_v100 cloud_v100
{"vec_binary_min", "vmin"},
// vmax.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmax.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmax.f32 support target:mini_v100 cloud_v100
{"vec_binary_max", "vmax"},
{"vec_binary_div", "vdiv"},
{"vec_binary_and", "vand"},
{"vec_binary_bitwise_and", "vand"},
{"vec_binary_or", "vor"},
{"vec_binary_bitwise_or", "vor"},
{"vec_binary_vmadd", "vmadd"},
{"vec_binary_vmaddrelu", "vmaddrelu"},
{"vec_binary_vmla", "vmla"}};
static const std::map<std::string, std::string> SingleVecInsnMap = {
// vmuls.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmuls.f32 supporttarget:mini_v100 cloud_v100
{"vec_single_muls", "vmuls"},
// vadds.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadds.f32 support target:mini_v100 cloud_v100
{"vec_single_adds", "vadds"},
// vrelu.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
{"vec_single_relu", "vrelu"},
// vabs.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vabs.f32 support target:mini_v100 cloud_v100
{"vec_single_abs", "vabs"},
// vln.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vln.f32 support target:cloud_v100
{"vec_single_log", "vln"},
// vexp.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vexp.f32 support target:cloud_v100
{"vec_single_exp", "vexp"},
// vrec.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vrec.f32 support target:mini_v100 cloud_v100
{"vec_single_rec", "vrec"},
// vnot support target:mini_v100 tiny_v100 lite_v100 cloud_v100
{"vec_single_not", "vnot"},
{"vec_single_bitwise_not", "vnot"},
// vsqrt support target:cloud_v100
{"vec_single_sqrt", "vsqrt"},
{"vec_single_rsqrt", "vrsqrt"},
{"vec_broadcast", "vector_dup"}};
static const std::map<std::string, std::string> SingleCastInsnMap = {
{"vec_single_floor", "f"}, {"vec_single_round", "r"}, {"vec_single_ceil", "c"}, {"vec_single_trunc", "z"}};
static const std::set<std::string> ReturnOpInsnSet = {"scalar_dma", "scatter", "vec_binary_select_loop_var"};
static const std::map<std::string, std::function<Stmt(const Stmt &)>> InsnFunctorMap = { static const std::map<std::string, std::function<Stmt(const Stmt &)>> InsnFunctorMap = {
{"dma_atomic_add", DmaAtomicAddEmitter}, {"dma_atomic_add", DmaAtomicAddEmitter},
{"vec_single_cast", SingleCastEmitter}, {"vec_single_cast", SingleCastEmitter},
...@@ -2017,9 +1761,9 @@ Stmt InsnEmit(std::string insn_name, const Stmt &op, bool enable_bisect, bool en ...@@ -2017,9 +1761,9 @@ Stmt InsnEmit(std::string insn_name, const Stmt &op, bool enable_bisect, bool en
{"vec_dropout", BinaryDropoutEmitter}, {"vec_dropout", BinaryDropoutEmitter},
{"cube_mad", MadEmitter}, {"cube_mad", MadEmitter},
{"vec_select_scalar", SelectWithScalarEmitter}, {"vec_select_scalar", SelectWithScalarEmitter},
{"vec_binary_axpy", VaxpyEmitter},
{"opt_broadcast", MultiMaskEmitter}, {"opt_broadcast", MultiMaskEmitter},
{"vec_single_four2five_nchw", VnchwconvEmitter}}; {"vec_single_four2five_nchw", VnchwconvEmitter},
{"vtranspose", VTransposeEmitter}};
if (ReplaceAttrPragmaMap.count(insn_name) != 0) { if (ReplaceAttrPragmaMap.count(insn_name) != 0) {
insn_name = ReplaceAttrPragmaMap.find(insn_name)->second; insn_name = ReplaceAttrPragmaMap.find(insn_name)->second;
......
...@@ -30,6 +30,100 @@ ...@@ -30,6 +30,100 @@
namespace akg { namespace akg {
namespace ir { namespace ir {
static const std::map<std::string, std::string> ReplaceAttrPragmaMap = {
// vector binary
{"binary_vcadd", "vec_binary_add"},
// vector single
{"vec_single_fabs", "vec_single_abs"},
{"broadcast", "vec_broadcast"},
// cube
{"mad", "cube_mad"},
{"ub2gm", "cube_ub2gm"},
{"im2col", "cube_img2col"},
// special attrs
{"vec_binary_proposal_sort", "vec_proposal_sort"},
{"vec_binary_topk_sort", "vec_topk_sort"},
{"vec_binary_dropout", "vec_dropout"},
{"vec_binary_fargmax", "vec_argmax"},
{"vec_binary_fargmin", "vec_argmin"},
{"vec_binary_iou", "vec_iou"},
{"vec_binary_nms", "vec_nms"},
{"mask_broadcast", "vec_broadcast"},
};
static const std::map<std::string, std::string> BinaryVecInsnMap = {
// vadd.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadd.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadd.f32 support target:mini_v100 cloud_v100
// vadd contains two situations:
// 1. normal elewise vector add
// - all src[i].shape = dst.shape
// 2. reductive vector add
// - exist src[i].shape != dst.shape
{"vec_binary_add", "vadd"},
// vsub.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vsub.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vsub.f32 support target:mini_v100 cloud_v100
{"vec_binary_sub", "vsub"},
// vmul.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmul.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmul.f32 support target:mini_v100 cloud_v100
{"vec_binary_mul", "vmul"},
// vmin.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmin.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmin.f32 support target:mini_v100 cloud_v100
{"vec_binary_min", "vmin"},
// vmax.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmax.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmax.f32 support target:mini_v100 cloud_v100
{"vec_binary_max", "vmax"},
{"vec_binary_div", "vdiv"},
{"vec_binary_and", "vand"},
{"vec_binary_bitwise_and", "vand"},
{"vec_binary_or", "vor"},
{"vec_binary_bitwise_or", "vor"},
{"vec_binary_vmadd", "vmadd"},
{"vec_binary_vmaddrelu", "vmaddrelu"},
{"vec_binary_vmla", "vmla"}};
static const std::map<std::string, std::string> SingleVecInsnMap = {
// vmuls.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmuls.f32 supporttarget:mini_v100 cloud_v100
{"vec_single_muls", "vmuls"},
// vadds.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadds.f32 support target:mini_v100 cloud_v100
{"vec_single_adds", "vadds"},
// vrelu.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
{"vec_single_relu", "vrelu"},
// vabs.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vabs.f32 support target:mini_v100 cloud_v100
{"vec_single_abs", "vabs"},
// vln.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vln.f32 support target:cloud_v100
{"vec_single_log", "vln"},
// vexp.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vexp.f32 support target:cloud_v100
{"vec_single_exp", "vexp"},
// vrec.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vrec.f32 support target:mini_v100 cloud_v100
{"vec_single_rec", "vrec"},
// vnot support target:mini_v100 tiny_v100 lite_v100 cloud_v100
{"vec_single_not", "vnot"},
{"vec_single_bitwise_not", "vnot"},
// vsqrt support target:cloud_v100
{"vec_single_sqrt", "vsqrt"},
{"vec_single_rsqrt", "vrsqrt"},
{"vaxpy", "vaxpy"},
{"vec_broadcast", "vector_dup"},
{"vadds", "vadds"},
{"vmuls", "vmuls"},
{"vector_dup", "vector_dup"},
};
static const std::map<std::string, std::string> SingleCastInsnMap = {
{"vec_single_floor", "f"}, {"vec_single_round", "r"}, {"vec_single_ceil", "c"}, {"vec_single_trunc", "z"}};
static const std::set<std::string> ReturnOpInsnSet = {"scalar_calc", "scalar_dma", "scatter", "vec_binary_select_loop_var"};
Stmt EmitInsnWithDynamicShapes(const Stmt &s, const Map<Tensor, Buffer> &extern_buffer); Stmt EmitInsnWithDynamicShapes(const Stmt &s, const Map<Tensor, Buffer> &extern_buffer);
......
...@@ -935,7 +935,7 @@ void GetCompactComputationInfo(const Stmt &stmt, StmtInfoList &dst_info_list, St ...@@ -935,7 +935,7 @@ void GetCompactComputationInfo(const Stmt &stmt, StmtInfoList &dst_info_list, St
/// \param if_info - The if-condition as input /// \param if_info - The if-condition as input
/// \param for_info - The for-loop info to be modified /// \param for_info - The for-loop info to be modified
void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_info_list, const StmtInfo &if_info, void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_info_list, const StmtInfo &if_info,
StmtInfo &for_info) { StmtInfo &for_info) {
auto MergeTwoVar = [](const Var &keep_var, const Var &delete_var, StmtInfoList &dst_info_list, auto MergeTwoVar = [](const Var &keep_var, const Var &delete_var, StmtInfoList &dst_info_list,
StmtInfoList &src_info_list, StmtInfo &for_info) { StmtInfoList &src_info_list, StmtInfo &for_info) {
for (auto info : dst_info_list) { for (auto info : dst_info_list) {
...@@ -1059,8 +1059,7 @@ void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_i ...@@ -1059,8 +1059,7 @@ void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_i
bool find_merge = false; bool find_merge = false;
for (size_t i = 0; (i < var_cnt - 1) && (!find_merge); i++) { for (size_t i = 0; (i < var_cnt - 1) && (!find_merge); i++) {
for (size_t j = i + 1; j < var_cnt; j++) { for (size_t j = i + 1; j < var_cnt; j++) {
if (CanMergeTwoVar(for_info.vars_[i], for_info.vars_[j], dst_info_list, src_info_list, if (CanMergeTwoVar(for_info.vars_[i], for_info.vars_[j], dst_info_list, src_info_list, for_info)) {
for_info)) {
find_merge = true; find_merge = true;
break; break;
} }
...@@ -1075,7 +1074,6 @@ void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_i ...@@ -1075,7 +1074,6 @@ void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_i
} }
} }
/// A helper function for single dst_info's compact /// A helper function for single dst_info's compact
/// \param dst_info /// \param dst_info
/// \param src_info_list /// \param src_info_list
...@@ -1357,6 +1355,43 @@ int GetVectorizedVarPosition(const Expr &index, Array<Var> &loop_vars) { ...@@ -1357,6 +1355,43 @@ int GetVectorizedVarPosition(const Expr &index, Array<Var> &loop_vars) {
return pos; return pos;
} }
std::string GetOpType(const Expr &value) {
if (value.as<Add>()) {
return value.as<Add>()->_type_key;
}
if (value.as<Sub>()) {
return value.as<Sub>()->_type_key;
}
if (value.as<Mul>()) {
return value.as<Mul>()->_type_key;
}
if (value.as<Div>()) {
return value.as<Div>()->_type_key;
}
if (value.as<Mod>()) {
return value.as<Mod>()->_type_key;
}
if (value.as<FloorDiv>()) {
return value.as<FloorDiv>()->_type_key;
}
if (value.as<FloorMod>()) {
return value.as<FloorMod>()->_type_key;
}
if (value.as<Min>()) {
return value.as<Min>()->_type_key;
}
if (value.as<Max>()) {
return value.as<Max>()->_type_key;
}
if (value.as<Call>()) {
return value.as<Call>()->name;
}
if (value.as<Load>() || value.as<IntImm>() || value.as<FloatImm>()) {
return "DMACopy";
}
return "undefined";
}
/// TVM Function Register, enable python code to call these cpp function. /// TVM Function Register, enable python code to call these cpp function.
TVM_REGISTER_API("cce_util.GetCceAxis").set_body([](TVMArgs args, TVMRetValue *ret) { *ret = GetCceAxis(); }); TVM_REGISTER_API("cce_util.GetCceAxis").set_body([](TVMArgs args, TVMRetValue *ret) { *ret = GetCceAxis(); });
......
...@@ -49,13 +49,7 @@ enum ArgType { ...@@ -49,13 +49,7 @@ enum ArgType {
ARG_NOT_DEFINE ARG_NOT_DEFINE
}; };
enum PatternType { enum PatternType { PATTERN_3D = 1, PATTERN_PARTIAL_3D, PATTERN_2D, PATTERN_2D_BLOCK, PATTERN_1D };
PATTERN_3D = 1,
PATTERN_PARTIAL_3D,
PATTERN_2D,
PATTERN_2D_BLOCK,
PATTERN_1D
};
class StmtStoreInfoNode : public Node { class StmtStoreInfoNode : public Node {
public: public:
...@@ -98,13 +92,9 @@ class StmtStoreInfo : public NodeRef { ...@@ -98,13 +92,9 @@ class StmtStoreInfo : public NodeRef {
explicit StmtStoreInfo(const ObjectPtr<Object> &n) : NodeRef(n), node_(n) {} explicit StmtStoreInfo(const ObjectPtr<Object> &n) : NodeRef(n), node_(n) {}
~StmtStoreInfo() = default; ~StmtStoreInfo() = default;
inline StmtStoreInfoNode *GetNode() const { inline StmtStoreInfoNode *GetNode() const { return static_cast<StmtStoreInfoNode *>(node_.get()); }
return static_cast<StmtStoreInfoNode *>(node_.get());
}
inline const StmtStoreInfoNode *operator->() const { inline const StmtStoreInfoNode *operator->() const { return static_cast<const StmtStoreInfoNode *>(node_.get()); }
return static_cast<const StmtStoreInfoNode *>(node_.get());
}
void CleanFlexVar(); void CleanFlexVar();
...@@ -188,13 +178,9 @@ class VectorArgInfo : public NodeRef { ...@@ -188,13 +178,9 @@ class VectorArgInfo : public NodeRef {
explicit VectorArgInfo(const ObjectPtr<Object> &n) : NodeRef(n), node_(n) {} explicit VectorArgInfo(const ObjectPtr<Object> &n) : NodeRef(n), node_(n) {}
~VectorArgInfo() = default; ~VectorArgInfo() = default;
inline VectorArgInfoNode *GetNode() const { inline VectorArgInfoNode *GetNode() const { return static_cast<VectorArgInfoNode *>(node_.get()); }
return static_cast<VectorArgInfoNode *>(node_.get());
}
inline const VectorArgInfoNode *operator->() const { inline const VectorArgInfoNode *operator->() const { return static_cast<const VectorArgInfoNode *>(node_.get()); }
return static_cast<const VectorArgInfoNode *>(node_.get());
}
void Print() const { void Print() const {
LOG(DEBUG) << "[ body_num: " << GetNode()->body_num_ << ", body_offset: " << GetNode()->body_offset_ LOG(DEBUG) << "[ body_num: " << GetNode()->body_num_ << ", body_offset: " << GetNode()->body_offset_
...@@ -235,13 +221,9 @@ class ArgInfo : public NodeRef { ...@@ -235,13 +221,9 @@ class ArgInfo : public NodeRef {
explicit ArgInfo(const ObjectPtr<Object> &n) : NodeRef(n), node_(n) {} explicit ArgInfo(const ObjectPtr<Object> &n) : NodeRef(n), node_(n) {}
~ArgInfo() = default; ~ArgInfo() = default;
inline ArgInfoNode *GetNode() const { inline ArgInfoNode *GetNode() const { return static_cast<ArgInfoNode *>(node_.get()); }
return static_cast<ArgInfoNode *>(node_.get());
}
inline const ArgInfoNode *operator->() const { inline const ArgInfoNode *operator->() const { return static_cast<const ArgInfoNode *>(node_.get()); }
return static_cast<const ArgInfoNode *>(node_.get());
}
inline std::string GetPattern() const { inline std::string GetPattern() const {
switch (GetNode()->pattern_) { switch (GetNode()->pattern_) {
...@@ -373,6 +355,8 @@ bool IsBisectionReduction(const StmtInfoList &dst_info_list, const StmtInfoList ...@@ -373,6 +355,8 @@ bool IsBisectionReduction(const StmtInfoList &dst_info_list, const StmtInfoList
bool HasVars(const Expr &index, const Var &vec_var); bool HasVars(const Expr &index, const Var &vec_var);
int GetVectorizedVarPosition(const Expr &index, Array<Var> &loop_vars); int GetVectorizedVarPosition(const Expr &index, Array<Var> &loop_vars);
std::string GetOpType(const Expr &value);
} // namespace akg } // namespace akg
namespace air { namespace air {
......
...@@ -77,7 +77,7 @@ class PatternGenerator { ...@@ -77,7 +77,7 @@ class PatternGenerator {
class SingleVecPatternGenerator : public PatternGenerator { class SingleVecPatternGenerator : public PatternGenerator {
public: public:
SingleVecPatternGenerator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list, SingleVecPatternGenerator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list,
const StmtInfo &for_info, const std::string &mode) const StmtInfo &for_info, const std::string &mode = "elewise")
: PatternGenerator(dst_info_list, for_info), : PatternGenerator(dst_info_list, for_info),
arg_info(ArgInfo(make_node<ArgInfoNode>())), arg_info(ArgInfo(make_node<ArgInfoNode>())),
body_args(VectorArgInfo()), body_args(VectorArgInfo()),
......
...@@ -33,9 +33,11 @@ ...@@ -33,9 +33,11 @@
#include "insn_info.h" #include "insn_info.h"
#include "insn_pattern.h" #include "insn_pattern.h"
#include "insn_emitter.h" #include "insn_emitter.h"
#include "ir_transform.h"
namespace akg { namespace akg {
namespace ir { namespace ir {
Expr GetVarCoefExpr(const Expr &index, const Var &loop_var) { Expr GetVarCoefExpr(const Expr &index, const Var &loop_var) {
Expr ret = Expr(); Expr ret = Expr();
Array<Expr> coefs = air::arith::DetectLinearEquation(index, {loop_var}); Array<Expr> coefs = air::arith::DetectLinearEquation(index, {loop_var});
...@@ -203,7 +205,7 @@ class HasScalarVarValue : public IRVisitor { ...@@ -203,7 +205,7 @@ class HasScalarVarValue : public IRVisitor {
class AdjustPragma : public IRMutator { class AdjustPragma : public IRMutator {
public: public:
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>()) { if (op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>()) {
is_candidate_ = true; is_candidate_ = true;
loop_vars_ = {}; loop_vars_ = {};
loop_extends_ = {}; loop_extends_ = {};
...@@ -295,7 +297,7 @@ class AdjustPragma : public IRMutator { ...@@ -295,7 +297,7 @@ class AdjustPragma : public IRMutator {
Array<Expr> srcs = call_ptr->args; Array<Expr> srcs = call_ptr->args;
CHECK_EQ(srcs.size(), 2); CHECK_EQ(srcs.size(), 2);
is_argmax_min_ = true; is_argmax_min_ = true;
reduce_type_ = (op->value.as<Call>()->name == "fargmin") ? "arg_min" : "arg_max"; reduce_type_ = (op->value.as<Call>()->name == "fargmin") ? "reduce_fargmin" : "reduce_fargmax";
return Store::make(op->buffer_var, Call::make(call_ptr->type, reduce_type_, {srcs[1]}, Call::CallType::Extern), return Store::make(op->buffer_var, Call::make(call_ptr->type, reduce_type_, {srcs[1]}, Call::CallType::Extern),
op->index, op->predicate); op->index, op->predicate);
} else if ((op->value.as<FloatImm>() || op->value.as<IntImm>() || op->value.as<UIntImm>()) && } else if ((op->value.as<FloatImm>() || op->value.as<IntImm>() || op->value.as<UIntImm>()) &&
...@@ -484,353 +486,6 @@ class AdjustPragma : public IRMutator { ...@@ -484,353 +486,6 @@ class AdjustPragma : public IRMutator {
Array<Var> transpose_vars_; Array<Var> transpose_vars_;
}; };
class TransposeTransform : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>() &&
op->value.as<StringImm>()->value == "dma_copy") {
pre_transpose_buffer = Var("srcTranspose_local_UB");
post_transpose_buffer = Var("dstTranspose_local_UB");
loop_vars_ = {};
loop_extends_ = {};
is_candidate_ = true;
is_block_transpose_ = false;
auto body = this->Mutate(op->body);
is_candidate_ = false;
if (is_block_transpose_) {
is_block_transpose_ = false;
auto allocate_pre_buffer = Allocate::make(pre_transpose_buffer, t_type, {TransTotalSize}, const_true(1), body);
auto attr_pre_buffer =
AttrStmt::make(pre_transpose_buffer, "storage_scope", Expr("local.UB"), allocate_pre_buffer);
auto allocate_post_buffer =
Allocate::make(post_transpose_buffer, t_type, {TransTotalSize}, const_true(1), attr_pre_buffer);
auto attr_post_buffer =
AttrStmt::make(post_transpose_buffer, "storage_scope", Expr("local.UB"), allocate_post_buffer);
return attr_post_buffer;
} else {
return AttrStmt::make(op->node, op->attr_key, op->value, body);
}
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (is_candidate_) {
loop_vars_.push_back(op->loop_var);
loop_extends_.push_back(op->extent);
Stmt body = this->Mutate(op->body);
if (is_block_transpose_ && IsInArray(trans_vars_, op->loop_var)) {
return body;
} else {
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
if (is_candidate_) {
auto value = op->value;
if (auto cast = op->value.as<Cast>()) {
value = cast->value;
}
CHECK(value.as<Load>());
auto src_ptr = value.as<Load>();
if (GetBufferType(op->buffer_var) == SCOPE_UBUF && GetBufferType(src_ptr->buffer_var) == SCOPE_UBUF) {
int dst_pos = GetVectorizedVarPosition(op->index, loop_vars_);
int src_pos = GetVectorizedVarPosition(src_ptr->index, loop_vars_);
if (dst_pos != -1 && src_pos != -1 && dst_pos != src_pos &&
floormod(loop_extends_[dst_pos], TransAxisLen).as<IntImm>() &&
floormod(loop_extends_[dst_pos], TransAxisLen).as<IntImm>()->value == 0 &&
Equal(GetVarCoefExpr(op->index, loop_vars_[src_pos]), loop_extends_[dst_pos])) {
if (loop_extends_[dst_pos].as<IntImm>() && loop_extends_[dst_pos].as<IntImm>()->value == TransAxisLen &&
loop_extends_[src_pos].as<IntImm>() && loop_extends_[src_pos].as<IntImm>()->value == TransAxisLen) {
return s;
} else {
is_block_transpose_ = true;
t_type = src_ptr->type;
trans_vars_ = {};
trans_vars_.push_back(loop_vars_[src_pos]);
trans_vars_.push_back(loop_vars_[dst_pos]);
Expr ori_w = GetVarCoefExpr(src_ptr->index, loop_vars_[dst_pos]);
Expr ori_h = loop_extends_[dst_pos];
Expr ori_block_w = floordiv(ori_w, TransAxisLen);
Expr ori_block_h = floordiv(ori_h, TransAxisLen);
Var loop_w = Var("block_w");
Var loop_h = Var("block_h");
Expr src_base_index = EliminateVarInExpr(src_ptr->index, trans_vars_);
Expr dst_base_index = EliminateVarInExpr(op->index, trans_vars_);
Var tt0 = Var("tt0");
Var tt1 = Var("tt1");
auto pre_copy = Store::make(
pre_transpose_buffer,
Load::make(t_type, src_ptr->buffer_var,
src_base_index + loop_h * TransAxisLen * ori_w + loop_w * TransAxisLen + tt1 * ori_w + tt0, 1),
tt1 * TransAxisLen + tt0, 1);
auto pre_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, pre_copy);
auto pre_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, pre_l0);
auto pre_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), pre_l1);
auto transpose =
Store::make(post_transpose_buffer, Load::make(t_type, pre_transpose_buffer, tt1 * TransAxisLen + tt0, 1),
tt0 * 16 + tt1, 1);
auto trans_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, transpose);
auto trans_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, trans_l0);
auto trans_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), trans_l1);
auto post_copy = Store::make(
op->buffer_var, Load::make(t_type, post_transpose_buffer, tt1 * TransAxisLen + tt0, 1),
dst_base_index + loop_w * TransAxisLen * ori_h + loop_h * TransAxisLen + tt1 * ori_h + tt0, 1);
auto post_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, post_copy);
auto post_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, post_l0);
auto post_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), post_l1);
auto full_inner = Block::make(Block::make(pre_attr, trans_attr), post_attr);
auto inner_w = For::make(loop_w, 0, ori_block_w, ForType::Serial, DeviceAPI::None, full_inner);
auto inner_h = For::make(loop_h, 0, ori_block_h, ForType::Serial, DeviceAPI::None, inner_w);
return inner_h;
}
}
}
}
return s;
}
bool is_candidate_{false};
bool is_block_transpose_{false};
Array<Var> trans_vars_;
Array<Var> loop_vars_;
Array<Expr> loop_extends_;
Type t_type;
Var pre_transpose_buffer;
Var post_transpose_buffer;
};
class IfReorder : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>() &&
op->value.as<StringImm>()->value != "mad") {
in_insn_ = true;
for_vars_.clear();
if_vars_.clear();
for_vec_.clear();
if_vec_.clear();
auto body = this->Mutate(op->body);
in_insn_ = false;
if (!if_vec_.empty()) {
Stmt new_s = AttrStmt::make(op->node, op->attr_key, op->value, body);
for (auto if_op : if_vec_) {
new_s = IfThenElse::make(if_op->condition, new_s);
}
for (auto for_op = for_vec_.rbegin(); for_op != for_vec_.rend(); ++for_op) {
bool find_flag = false;
for (auto for_iter = for_vars_.begin(); for_iter != for_vars_.end(); ++for_iter) {
if (Equal((*for_iter), (*for_op)->loop_var)) {
find_flag = true;
break;
}
}
if (find_flag) {
new_s = For::make((*for_op)->loop_var, (*for_op)->min, (*for_op)->extent, ForType::Serial, DeviceAPI::None,
new_s);
}
}
return new_s;
} else {
return s;
}
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (in_insn_) {
for_vec_.push_back(op);
for_vars_.push_back(op->loop_var);
Stmt body = this->Mutate(op->body);
std::vector<Var>::iterator for_iter;
for (for_iter = for_vars_.begin(); for_iter != for_vars_.end(); ++for_iter) {
if (Equal((*for_iter), op->loop_var)) {
break;
}
}
if (!if_vec_.empty()) {
std::vector<Var>::iterator if_iter;
bool find_flag = false;
for (if_iter = if_vars_.begin(); if_iter != if_vars_.end(); ++if_iter) {
if (Equal((*if_iter), op->loop_var)) {
find_flag = true;
break;
}
}
if (find_flag) {
return body;
} else {
for_vars_.erase(for_iter);
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
} else {
for_vars_.erase(for_iter);
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const IfThenElse *op, const Stmt &s) final {
if (in_insn_) {
if_vec_.push_back(op);
for (auto loop_var : for_vars_) {
if (HasVars(op->condition, loop_var)) {
if_vars_.push_back(loop_var);
}
}
Stmt body = this->Mutate(op->then_case);
return body;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
if (in_insn_) {
return s;
}
return IRMutator::Mutate_(op, s);
}
bool in_insn_{false};
std::vector<const IfThenElse *> if_vec_;
std::vector<Var> if_vars_;
std::vector<Var> for_vars_;
std::vector<const For *> for_vec_;
std::vector<const For *> before_if_;
};
class LoopReorder : public IRMutator {
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>()) {
in_insn_ = true;
pragma = op->value.as<StringImm>()->value;
for_map_.clear();
ori_vars_ = {};
var_order_.clear();
auto ret = this->Mutate(op->body);
in_insn_ = false;
if (!has_changed_) {
return s;
} else {
if (var_order_.empty()) {
ret = AttrStmt::make(op->node, op->attr_key, op->value, ret);
for (size_t i = 0; i < ori_vars_.size(); ++i) {
CHECK_GT(for_map_.count(ori_vars_[i].get()), 0);
auto ptr = for_map_[ori_vars_[i].get()];
ret = For::make(ptr->loop_var, ptr->min, ptr->extent, ptr->for_type, ptr->device_api, ret);
}
} else {
for (size_t i = 0; i < var_order_.size(); ++i) {
CHECK_GT(for_map_.count(var_order_[i].get()), 0);
auto ptr = for_map_[var_order_[i].get()];
ret = For::make(ptr->loop_var, ptr->min, ptr->extent, ptr->for_type, ptr->device_api, ret);
}
ret = AttrStmt::make(op->node, op->attr_key, op->value, ret);
}
return ret;
}
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (in_insn_) {
for_map_[(op->loop_var).get()] = op;
ori_vars_.push_back(op->loop_var);
auto body = this->Mutate(op->body);
return body;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
int dst_pos = GetVectorizedVarPosition(op->index, ori_vars_);
int len = static_cast<int>(ori_vars_.size());
std::vector<const Load *> srcs;
auto get_loads = [&srcs](const NodeRef &node) {
if (const auto v = node.as<Load>()) {
srcs.push_back(v);
}
};
PostOrderVisit(op->value, get_loads);
bool same_pos = true;
std::vector<int> srcs_pos;
for (int i = 0; i < static_cast<int>(srcs.size()); ++i) {
int temp_pos = GetVectorizedVarPosition(srcs[i]->index, ori_vars_);
srcs_pos.push_back(temp_pos);
if (temp_pos != dst_pos) {
same_pos = false;
}
}
has_changed_ = false;
if (dst_pos >= 0 && len >= 2 && dst_pos != (len - 1) && (same_pos || pragma == "broadcast")) {
// Src Load empty; all Load and Dst has the same key axis; broadcast
has_changed_ = true;
var_order_.push_back(ori_vars_[dst_pos]);
for (int i = len - 1; i >= 0; i--) {
if (i != dst_pos) {
var_order_.push_back(ori_vars_[i]);
}
}
} else if (pragma.find("reduce") != pragma.npos && len >= 2 && srcs_pos[0] != (len - 1)) {
// based on dst key axis: reduce
has_changed_ = true;
var_order_.push_back(ori_vars_[srcs_pos[0]]);
for (int i = len - 1; i >= 0; i--) {
if (i != srcs_pos[0]) {
var_order_.push_back(ori_vars_[i]);
}
}
}
return s;
}
std::unordered_map<const Variable *, const For *> for_map_;
std::vector<Var> var_order_;
Array<Var> ori_vars_;
bool has_changed_{false};
bool in_insn_{false};
std::string pragma;
};
class ForVarUnique : public IRMutator {
public:
Stmt Mutate_(const For *op, const Stmt &s) final {
auto body = this->Mutate(op->body);
if (var_maps_.count(op->loop_var.get())) {
Var new_var = Var("ii" + std::to_string(++index_));
std::unordered_map<const Variable *, Expr> value_map;
value_map[op->loop_var.get()] = new_var;
auto new_body = Substitute(body, value_map);
var_maps_[new_var.get()] = 1;
return For::make(new_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, new_body);
} else {
var_maps_[op->loop_var.get()] = 1;
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
}
std::unordered_map<const Variable *, int> var_maps_;
int index_{0};
};
class GenSIMD { class GenSIMD {
public: public:
GenSIMD(CCEInfo &t_info, Map<std::string, Buffer> &buffer_map, const std::string &pragma) GenSIMD(CCEInfo &t_info, Map<std::string, Buffer> &buffer_map, const std::string &pragma)
...@@ -1520,9 +1175,9 @@ class GenReduce { ...@@ -1520,9 +1175,9 @@ class GenReduce {
~GenReduce() = default; ~GenReduce() = default;
Stmt Run(int pre_index) { Stmt Run(int pre_index) {
is_arg_type_ = (pragma_ == "arg_max" || pragma_ == "arg_min"); is_arg_type_ = (pragma_ == "reduce_fargmax" || pragma_ == "reduce_fargmin");
RemoveVectorizedIndex(t_info_, 0); RemoveVectorizedIndex(t_info_, 0);
if (pragma_.find("sum") != std::string::npos) { if (pragma_.find("sum") != std::string::npos || pragma_.find("add") != std::string::npos) {
insn_intrinsic_ = "vcadd"; insn_intrinsic_ = "vcadd";
expansion_factor_ = 1; expansion_factor_ = 1;
} else if (pragma_.find("max") != std::string::npos) { } else if (pragma_.find("max") != std::string::npos) {
...@@ -1769,7 +1424,7 @@ class EmitVariableInsns : public IRMutator { ...@@ -1769,7 +1424,7 @@ class EmitVariableInsns : public IRMutator {
} }
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn") { if (op->attr_key == "pragma_emit_insn") {
CHECK(op->value.as<StringImm>()); CHECK(op->value.as<StringImm>());
pragma = op->value.as<StringImm>()->value; pragma = op->value.as<StringImm>()->value;
Stmt r; Stmt r;
...@@ -1791,8 +1446,7 @@ class EmitVariableInsns : public IRMutator { ...@@ -1791,8 +1446,7 @@ class EmitVariableInsns : public IRMutator {
if (!r.same_as(s)) { if (!r.same_as(s)) {
return r; return r;
} }
} else if (air::ir::attr::IsPragmaKey(op->attr_key) && } else if (op->attr_key == "pragma_im2col" || op->attr_key == "pragma_load3d") {
(op->attr_key == "pragma_im2col" || op->attr_key == "pragma_load3d")) {
if (paramters_.defined() && Downcast<Map<std::string, NodeRef>>(paramters_).count("feature")) { if (paramters_.defined() && Downcast<Map<std::string, NodeRef>>(paramters_).count("feature")) {
auto feature = Downcast<Map<std::string, NodeRef>>(paramters_)["feature"].as<StringImm>(); auto feature = Downcast<Map<std::string, NodeRef>>(paramters_)["feature"].as<StringImm>();
CHECK(feature); CHECK(feature);
...@@ -1842,13 +1496,13 @@ class EmitVariableInsns : public IRMutator { ...@@ -1842,13 +1496,13 @@ class EmitVariableInsns : public IRMutator {
if (pragma.find("vec_select") != std::string::npos) { if (pragma.find("vec_select") != std::string::npos) {
EmitSelect(op, t_info); EmitSelect(op, t_info);
} else if (pragma.find("dma_copy") == 0) { } else if (pragma.find("dma_copy") != std::string::npos) {
EmitDMA(t_info); EmitDMA(t_info);
} else if (pragma.find("vec_binary") == 0 || pragma.find("vec_single") == 0) { } else if (pragma.find("vec_binary") != std::string::npos || pragma.find("vec_single") != std::string::npos) {
EmitSIMD(t_info); EmitSIMD(t_info);
} else if (pragma.find("reduce") == 0 || pragma.find("arg_") == 0) { } else if (pragma.find("reduce") != std::string::npos || pragma.find("arg_") != std::string::npos) {
EmitReduce(t_info); EmitReduce(t_info);
} else if (pragma.find("broadcast") == 0) { } else if (pragma.find("broadcast") != std::string::npos) {
if (loops_vars_.empty()) { if (loops_vars_.empty()) {
gen_cce = t_info.ori_stmt; gen_cce = t_info.ori_stmt;
} else { } else {
......
...@@ -31,8 +31,7 @@ ...@@ -31,8 +31,7 @@
namespace akg { namespace akg {
namespace ir { namespace ir {
const int TransTotalSize = 256;
const int TransAxisLen = 16;
const int64_t FullReduceMaskValue = 6148914691236517205; const int64_t FullReduceMaskValue = 6148914691236517205;
class CCEInsn { class CCEInsn {
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef IR_TRANSFORM_H_
#define IR_TRANSFORM_H_
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <unordered_set>
#include <map>
#include <numeric>
#include <set>
#include <algorithm>
#include "ir_pass.h"
#include "common/array_api.h"
#include "insn_with_variable.h"
#include "insn_builder.h"
#include "insn_info.h"
#include "insn_pattern.h"
#include "../pass/analyze_align.h"
const int TransTotalSize = 256;
const int TransAxisLen = 16;
namespace akg {
namespace ir {
Expr GetVarCoefExpr(const Expr &index, const Var &loop_var);
std::string GetBufferType(Expr address);
class TransposeTransform : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>() &&
(op->value.as<StringImm>()->value == "dma_copy")) {
pre_transpose_buffer_ = Var("srcTranspose_local_UB");
post_transpose_buffer_ = Var("dstTranspose_local_UB");
pre_trans_cast_ = Var("pre_trans_cast__local_UB");
post_trans_cast_ = Var("post_trans_cast__local_UB");
loop_vars_ = {};
loop_extends_ = {};
is_candidate_ = true;
is_block_transpose_ = false;
is_native_transpose_ = false;
align_value = FREE_ALIGN;
remain_fors_.clear();
auto body = this->Mutate(op->body);
is_candidate_ = false;
if (is_block_transpose_) {
is_block_transpose_ = false;
if (t_type_ == Float(32)) { // need cast
body = Allocate::make(pre_trans_cast_, Float(16), {TransTotalSize}, const_true(1), body);
body = AttrStmt::make(pre_trans_cast_, "storage_scope", Expr("local.UB"), body);
body = Allocate::make(post_trans_cast_, Float(16), {TransTotalSize}, const_true(1), body);
body = AttrStmt::make(post_trans_cast_, "storage_scope", Expr("local.UB"), body);
}
auto allocate_pre_buffer =
Allocate::make(pre_transpose_buffer_, t_type_, {TransTotalSize}, const_true(1), body);
auto attr_pre_buffer =
AttrStmt::make(pre_transpose_buffer_, "storage_scope", Expr("local.UB"), allocate_pre_buffer);
auto allocate_post_buffer =
Allocate::make(post_transpose_buffer_, t_type_, {TransTotalSize}, const_true(1), attr_pre_buffer);
auto attr_post_buffer =
AttrStmt::make(post_transpose_buffer_, "storage_scope", Expr("local.UB"), allocate_post_buffer);
Stmt ret = attr_post_buffer;
if (align_value != FREE_ALIGN) {
ret = AttrStmt::make(align_buffer_, "align_info", Expr(align_value), ret);
}
return ret;
}
if (is_native_transpose_) {
Stmt ret = AttrStmt::make(op->node, op->attr_key, Expr("dma_copy_transpose"), body);
for (int i = 0; i <= static_cast<int>(remain_fors_.size()) - 1; ++i) {
ret = For::make(remain_fors_[i]->loop_var, remain_fors_[i]->min, remain_fors_[i]->extent, ForType::Serial,
DeviceAPI::None, ret);
}
return ret;
}
return AttrStmt::make(op->node, op->attr_key, op->value, body);
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (is_candidate_) {
loop_vars_.push_back(op->loop_var);
loop_extends_.push_back(op->extent);
Stmt body = this->Mutate(op->body);
if (is_block_transpose_ && IsInArray(trans_vars_, op->loop_var)) {
return body;
}
if (is_native_transpose_) {
if (IsInArray(trans_vars_, op->loop_var)) {
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
remain_fors_.push_back(op);
return body;
}
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
if (is_candidate_) {
auto value = op->value;
if (auto cast = op->value.as<Cast>()) {
value = cast->value;
}
CHECK(value.as<Load>());
auto src_ptr = value.as<Load>();
if (GetBufferType(op->buffer_var) == SCOPE_UBUF && GetBufferType(src_ptr->buffer_var) == SCOPE_UBUF &&
src_ptr->type == Float(16)) {
int dst_pos = GetVectorizedVarPosition(op->index, loop_vars_);
int src_pos = GetVectorizedVarPosition(src_ptr->index, loop_vars_);
if (dst_pos != -1 && src_pos != -1 && dst_pos != src_pos && HasVars(src_ptr->index, loop_vars_[dst_pos]) &&
HasVars(op->index, loop_vars_[src_pos]) && floormod(loop_extends_[dst_pos], TransAxisLen).as<IntImm>() &&
floormod(loop_extends_[dst_pos], TransAxisLen).as<IntImm>()->value == 0 &&
Equal(GetVarCoefExpr(op->index, loop_vars_[src_pos]), loop_extends_[dst_pos])) {
if (loop_extends_[dst_pos].as<IntImm>() && loop_extends_[dst_pos].as<IntImm>()->value == TransAxisLen &&
loop_extends_[src_pos].as<IntImm>() && loop_extends_[src_pos].as<IntImm>()->value == TransAxisLen) {
trans_vars_ = {};
trans_vars_.push_back(loop_vars_[src_pos]);
trans_vars_.push_back(loop_vars_[dst_pos]);
is_native_transpose_ = true;
return s;
}
is_block_transpose_ = true;
if (GetVarCoefExpr(src_ptr->index, loop_vars_[dst_pos]).as<IntImm>()) {
int coef_t = GetVarCoefExpr(src_ptr->index, loop_vars_[dst_pos]).as<IntImm>()->value;
if (coef_t % TransAxisLen != 0) {
align_value = coef_t;
align_buffer_ = src_ptr->buffer_var;
}
}
t_type_ = src_ptr->type;
trans_vars_ = {};
trans_vars_.push_back(loop_vars_[src_pos]);
trans_vars_.push_back(loop_vars_[dst_pos]);
Expr ori_w = GetVarCoefExpr(src_ptr->index, loop_vars_[dst_pos]);
Expr ori_h = loop_extends_[dst_pos];
Expr ori_block_w = floordiv(ori_w, TransAxisLen);
// padding the width
Expr unit_width = TransAxisLen;
if (!Equal(floormod(ori_w, TransAxisLen), 0)) {
ori_block_w = ori_block_w + 1;
}
if (ori_w.as<IntImm>() && ori_w.as<IntImm>()->value < TransAxisLen) {
unit_width = ori_w;
}
Expr ori_block_h = floordiv(ori_h, TransAxisLen);
Var loop_w = Var("block_w");
Var loop_h = Var("block_h");
Expr src_base_index = EliminateVarInExpr(src_ptr->index, trans_vars_);
Expr dst_base_index = EliminateVarInExpr(op->index, trans_vars_);
Var tt0 = Var("tt0");
Var tt1 = Var("tt1");
auto pre_copy = Store::make(
pre_transpose_buffer_,
Load::make(t_type_, src_ptr->buffer_var,
src_base_index + loop_h * TransAxisLen * ori_w + loop_w * TransAxisLen + tt1 * ori_w + tt0, 1),
tt1 * TransAxisLen + tt0, 1);
auto pre_l0 = For::make(tt0, 0, unit_width, ForType::Serial, DeviceAPI::None, pre_copy);
auto pre_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, pre_l0);
auto pre_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), pre_l1);
Stmt trans_attr = Stmt();
if (t_type_ == Float(16)) {
auto transpose =
Store::make(post_transpose_buffer_,
Load::make(t_type_, pre_transpose_buffer_, tt1 * TransAxisLen + tt0, 1), tt0 * 16 + tt1, 1);
auto trans_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, transpose);
auto trans_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, trans_l0);
trans_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy_transpose"), trans_l1);
} else {
auto pre_cast_store = Store::make(
pre_trans_cast_, Cast::make(Float(16), Load::make(t_type_, pre_transpose_buffer_, tt0, 1)), tt0, 1);
auto pre_cast_for = For::make(tt0, 0, TransTotalSize, ForType::Serial, DeviceAPI::None, pre_cast_store);
auto pre_cast_attr =
AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("vec_single_cast"), pre_cast_for);
auto transpose = Store::make(
post_trans_cast_, Load::make(Float(16), pre_trans_cast_, tt1 * TransAxisLen + tt0, 1), tt0 * 16 + tt1, 1);
auto trans_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, transpose);
auto trans_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, trans_l0);
auto trans_block =
AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy_transpose"), trans_l1);
auto post_cast_store = Store::make(
post_transpose_buffer_, Cast::make(t_type_, Load::make(Float(16), post_trans_cast_, tt0, 1)), tt0, 1);
auto post_cast_for = For::make(tt0, 0, TransTotalSize, ForType::Serial, DeviceAPI::None, post_cast_store);
auto post_cast_attr =
AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("vec_single_cast"), post_cast_for);
trans_attr = Block::make(Block::make(pre_cast_attr, trans_block), post_cast_attr);
}
auto post_copy =
Store::make(op->buffer_var, Load::make(t_type_, post_transpose_buffer_, tt1 * TransAxisLen + tt0, 1),
dst_base_index + loop_w * TransAxisLen * ori_h + loop_h * TransAxisLen + tt1 * ori_h + tt0, 1);
auto post_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, post_copy);
auto post_l1 = For::make(tt1, 0, unit_width, ForType::Serial, DeviceAPI::None, post_l0);
auto post_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), post_l1);
auto full_inner = Block::make(Block::make(pre_attr, trans_attr), post_attr);
auto inner_w = For::make(loop_w, 0, ori_block_w, ForType::Serial, DeviceAPI::None, full_inner);
if (ori_block_w.as<IntImm>() && ori_block_w.as<IntImm>()->value == 1) {
std::unordered_map<const Variable *, Expr> init;
init[loop_w.get()] = 0;
inner_w = Simplify(Substitute(full_inner, init));
}
auto inner_h = For::make(loop_h, 0, ori_block_h, ForType::Serial, DeviceAPI::None, inner_w);
if (ori_block_h.as<IntImm>() && ori_block_h.as<IntImm>()->value == 1) {
std::unordered_map<const Variable *, Expr> init;
init[loop_h.get()] = 0;
inner_h = Simplify(Substitute(inner_w, init));
}
return inner_h;
}
}
}
return s;
}
private:
bool is_candidate_{false};
bool is_native_transpose_{false};
bool is_block_transpose_{false};
int align_value{FREE_ALIGN};
Var align_buffer_;
Array<Var> trans_vars_;
Array<Var> loop_vars_;
Array<Expr> loop_extends_;
std::vector<const For *> remain_fors_;
Type t_type_;
Var pre_transpose_buffer_;
Var pre_trans_cast_;
Var post_trans_cast_;
Var post_transpose_buffer_;
};
class ForVarUnique : public IRMutator {
public:
Stmt Mutate_(const For *op, const Stmt &s) final {
auto body = this->Mutate(op->body);
if (var_maps_.count(op->loop_var.get())) {
Var new_var = Var("ii" + std::to_string(++index_));
std::unordered_map<const Variable *, Expr> value_map;
value_map[op->loop_var.get()] = new_var;
auto new_body = Substitute(body, value_map);
var_maps_[new_var.get()] = 1;
return For::make(new_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, new_body);
}
var_maps_[op->loop_var.get()] = 1;
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
private:
std::unordered_map<const Variable *, int> var_maps_;
int index_{0};
};
class LoopReorder : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>()) {
in_insn_ = true;
pragma_ = op->value.as<StringImm>()->value;
for_map_.clear();
ori_vars_ = {};
var_order_.clear();
auto ret = this->Mutate(op->body);
in_insn_ = false;
if (!has_changed_) {
return s;
}
if (var_order_.empty()) {
ret = AttrStmt::make(op->node, op->attr_key, op->value, ret);
for (size_t i = 0; i < ori_vars_.size(); ++i) {
CHECK_GT(for_map_.count(ori_vars_[i].get()), 0);
auto ptr = for_map_[ori_vars_[i].get()];
ret = For::make(ptr->loop_var, ptr->min, ptr->extent, ptr->for_type, ptr->device_api, ret);
}
return ret;
}
for (size_t i = 0; i < var_order_.size(); ++i) {
CHECK_GT(for_map_.count(var_order_[i].get()), 0);
auto ptr = for_map_[var_order_[i].get()];
ret = For::make(ptr->loop_var, ptr->min, ptr->extent, ptr->for_type, ptr->device_api, ret);
}
ret = AttrStmt::make(op->node, op->attr_key, op->value, ret);
return ret;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (in_insn_) {
for_map_[(op->loop_var).get()] = op;
ori_vars_.push_back(op->loop_var);
auto body = this->Mutate(op->body);
return body;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
int dst_pos = GetVectorizedVarPosition(op->index, ori_vars_);
int len = static_cast<int>(ori_vars_.size());
std::vector<const Load *> srcs;
auto get_loads = [&srcs](const NodeRef &node) {
if (const auto v = node.as<Load>()) {
srcs.push_back(v);
}
};
PostOrderVisit(op->value, get_loads);
bool same_pos = true;
std::vector<int> srcs_pos;
for (int i = 0; i < static_cast<int>(srcs.size()); ++i) {
int temp_pos = GetVectorizedVarPosition(srcs[i]->index, ori_vars_);
srcs_pos.push_back(temp_pos);
if (temp_pos != dst_pos) {
same_pos = false;
}
}
has_changed_ = false;
if (dst_pos >= 0 && len >= 2 && dst_pos != (len - 1) && (same_pos || pragma_ == "broadcast")) {
// Src Load empty; all Load and Dst has the same key axis; broadcast
has_changed_ = true;
var_order_.push_back(ori_vars_[dst_pos]);
for (int i = len - 1; i >= 0; i--) {
if (i != dst_pos) {
var_order_.push_back(ori_vars_[i]);
}
}
} else if (pragma_.find("reduce") != pragma_.npos && len >= 2 && srcs_pos[0] != (len - 1)) {
// based on dst key axis: reduce
has_changed_ = true;
var_order_.push_back(ori_vars_[srcs_pos[0]]);
for (int i = len - 1; i >= 0; i--) {
if (i != srcs_pos[0]) {
var_order_.push_back(ori_vars_[i]);
}
}
}
return s;
}
private:
std::unordered_map<const Variable *, const For *> for_map_;
std::vector<Var> var_order_;
Array<Var> ori_vars_;
bool has_changed_{false};
bool in_insn_{false};
std::string pragma_;
};
class IfReorder : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>() &&
!exclude_align_analyze_list.count(op->value.as<StringImm>()->value)) {
in_insn_ = true;
for_vars_.clear();
if_vars_.clear();
for_vec_.clear();
if_vec_.clear();
auto body = this->Mutate(op->body);
in_insn_ = false;
if (!if_vec_.empty()) {
Stmt new_s = AttrStmt::make(op->node, op->attr_key, op->value, body);
for (auto if_op : if_vec_) {
new_s = IfThenElse::make(if_op->condition, new_s);
}
for (auto for_op = for_vec_.rbegin(); for_op != for_vec_.rend(); ++for_op) {
bool find_flag = false;
for (auto for_iter = for_vars_.begin(); for_iter != for_vars_.end(); ++for_iter) {
if (Equal((*for_iter), (*for_op)->loop_var)) {
find_flag = true;
break;
}
}
if (find_flag) {
new_s = For::make((*for_op)->loop_var, (*for_op)->min, (*for_op)->extent, ForType::Serial, DeviceAPI::None,
new_s);
}
}
return new_s;
}
return s;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (in_insn_) {
for_vec_.push_back(op);
for_vars_.push_back(op->loop_var);
Stmt body = this->Mutate(op->body);
std::vector<Var>::iterator for_iter;
for (for_iter = for_vars_.begin(); for_iter != for_vars_.end(); ++for_iter) {
if (Equal((*for_iter), op->loop_var)) {
break;
}
}
if (!if_vec_.empty()) {
std::vector<Var>::iterator if_iter;
bool find_flag = false;
for (if_iter = if_vars_.begin(); if_iter != if_vars_.end(); ++if_iter) {
if (Equal((*if_iter), op->loop_var)) {
find_flag = true;
break;
}
}
if (find_flag) {
return body;
}
for_vars_.erase(for_iter);
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
for_vars_.erase(for_iter);
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const IfThenElse *op, const Stmt &s) final {
if (in_insn_) {
if_vec_.push_back(op);
for (auto loop_var : for_vars_) {
if (HasVars(op->condition, loop_var)) {
if_vars_.push_back(loop_var);
}
}
Stmt body = this->Mutate(op->then_case);
return body;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
if (in_insn_) {
return s;
}
return IRMutator::Mutate_(op, s);
}
private:
bool in_insn_{false};
std::vector<const IfThenElse *> if_vec_;
std::vector<Var> if_vars_;
std::vector<Var> for_vars_;
std::vector<const For *> for_vec_;
std::vector<const For *> before_if_;
};
} // namespace ir
} // namespace akg
#endif // IR_TRANSFORM_H_
\ No newline at end of file
...@@ -265,6 +265,10 @@ Stmt RewriteBroadcastVector(Stmt stmt); ...@@ -265,6 +265,10 @@ Stmt RewriteBroadcastVector(Stmt stmt);
Stmt OptimizePragma(Stmt stmt); Stmt OptimizePragma(Stmt stmt);
Stmt PackStore(Stmt stmt);
Stmt RecoverStore(Stmt stmt);
Stmt RewriteByAlignDynamic(Stmt stmt); Stmt RewriteByAlignDynamic(Stmt stmt);
Stmt EliminateAtomicDma(Stmt stmt); Stmt EliminateAtomicDma(Stmt stmt);
......
...@@ -21,14 +21,18 @@ ...@@ -21,14 +21,18 @@
#include <string> #include <string>
#include <set> #include <set>
#include <list>
#include <algorithm>
#include "pass/utils.h" #include "pass/utils.h"
#include "arith_expr_simplify.h" #include "arith_expr_simplify.h"
#include "expr_alg_simplify.h" #include "expr_alg_simplify.h"
#include "emit_insn/cce_params.h"
#include "common/array_api.h"
namespace akg { namespace akg {
namespace ir { namespace ir {
const std::set<std::string> exclude_list = { const std::set<std::string> exclude_align_analyze_list = {
"mad", "mad",
"scatter", "scatter",
"vec_binary_proposal_sort", "vec_binary_proposal_sort",
...@@ -39,8 +43,15 @@ const std::set<std::string> exclude_list = { ...@@ -39,8 +43,15 @@ const std::set<std::string> exclude_list = {
"vec_single_four2five_nchw", "vec_single_four2five_nchw",
"opt_broadcast", "opt_broadcast",
"reduce_reorder", "reduce_reorder",
"dma_atomic_add",
"dma_copy_transpose",
}; };
class IndexOptimizer : public air::ir::IRMutator {
const std::set<std::string> exclude_index_fix_list = {
"mad", "vec_binary_proposal_sort", "vec_binary_topk_sort", "vec_binary_nms", "vec_binary_iou", "vec_binary_dropout",
};
class IndexOptimizer : public IRMutator {
public: public:
explicit IndexOptimizer(bool rm = false) : var2expr(), rm_load_(rm) {} explicit IndexOptimizer(bool rm = false) : var2expr(), rm_load_(rm) {}
~IndexOptimizer() override = default; ~IndexOptimizer() override = default;
...@@ -71,6 +82,745 @@ class IndexOptimizer : public air::ir::IRMutator { ...@@ -71,6 +82,745 @@ class IndexOptimizer : public air::ir::IRMutator {
private: private:
bool rm_load_; bool rm_load_;
}; };
int GetCommonDivisor(std::vector<int> numbers);
class IndexInfo {
public:
Array<Var> vars;
Array<Expr> coefs;
Array<Expr> extents;
int divisor;
int vec_len{-1};
Var vec_var{};
Expr offset;
Expr index;
bool is_serial{true};
bool is_scalar{true};
};
class DstInfo : public IndexInfo {
public:
bool IsGlobal() { return (GetBufScope(p_store->buffer_var->name_hint) == DMA_COPY_GLOBAL); }
bool IsUB() { return (GetBufScope(p_store->buffer_var->name_hint) == SCOPE_UBUF); }
const Store *p_store;
};
class SrcInfo : public IndexInfo {
public:
bool IsGlobal() { return (GetBufScope(p_load->buffer_var->name_hint) == DMA_COPY_GLOBAL); }
bool IsUB() { return (GetBufScope(p_load->buffer_var->name_hint) == SCOPE_UBUF); }
const Load *p_load;
bool is_imm;
Expr imm;
};
class ArithInfo {
public:
Stmt GenIR() { return store; }
void GetIntrinsicType(Array<Var> &for_vars, Array<Var> &if_vars) {
if (for_vars.empty()) {
if (TryScalarType()) {
insn_type = "scalar";
} else {
insn_type = "discrete";
}
return;
}
if (TryScalarAssignType(if_vars)) {
insn_type = "scalar";
return;
}
if (TryReduceType()) {
insn_type = "reduce";
return;
}
auto simd_t = TrySIMDType();
if (simd_t == 1) {
insn_type = "simd";
return;
} else if (simd_t == 2) {
insn_type = "simd_split";
return;
}
if (TryVectorScalarType()) {
insn_type = "vector_scalar";
return;
}
if (TryVectorDumpType()) {
insn_type = "vector_dump";
return;
}
if (TryCrossingType()) {
insn_type = "crossing";
return;
}
if (TryDiscrete()) {
insn_type = "discrete";
return;
}
if (insn_type == "unknown") {
CHECK(0) << "\nUnknown Intrinsic Type";
}
}
// A[0] = B[1] + C[2]
bool TryScalarType() {
if (dst_info.IsUB() && dst_info.p_store->value.as<Load>() && src_info[0].IsGlobal()) {
return true;
}
if (!is_const(dst_info.index)) {
return false;
}
for (auto info : src_info) {
if (!is_const(info.index)) {
return false;
}
}
return true;
}
// for i { for j { A[i] = reduce(C[X*i + j]) } }
bool TryReduceType() {
if (dst_info.p_store->value.as<Call>()) {
auto t_call = dst_info.p_store->value.as<Call>();
if (t_call->name.find("reduce_") != std::string::npos) {
return true;
}
}
return false;
}
// for i { for j { A[X*i + j] = B[X*i + j] + C[j] } }
int TrySIMDType() {
Var cur_var = dst_info.vec_var;
int cur_len = dst_info.vec_len;
Expr cur_offset = dst_info.offset;
int block_size = GetUbBlkSize(dst_info.p_store->value.type());
bool is_simd = (cur_len >= 1) ? true : false;
for (auto info : src_info) {
if (info.vec_len != cur_len || !Equal(info.vec_var, cur_var)) {
is_simd = false;
break;
}
}
bool need_split = false;
if (is_simd) {
for (auto info : src_info) {
int info_block_size = GetUbBlkSize(info.p_load->type);
if (dst_info.IsUB() && info.IsUB()) {
if (is_const(cur_offset) && is_const(info.offset) &&
cur_offset.as<IntImm>()->value % block_size != info.offset.as<IntImm>()->value % info_block_size) {
need_split = true;
break;
}
}
}
}
if (is_simd && need_split) {
if (src_info.size() == 1) {
if (dst_info.divisor != 0 && src_info[0].divisor != 0 &&
dst_info.offset.as<IntImm>()->value % dst_info.divisor !=
src_info[0].offset.as<IntImm>()->value % src_info[0].divisor) {
dst_info.divisor = air::ir::gcd(dst_info.divisor, dst_info.offset.as<IntImm>()->value);
src_info[0].divisor = air::ir::gcd(src_info[0].divisor, src_info[0].offset.as<IntImm>()->value);
auto min_dst_src = std::min(dst_info.divisor, src_info[0].divisor);
dst_info.divisor = min_dst_src;
src_info[0].divisor = min_dst_src;
}
} else {
CHECK(0) << "\nNeed to split the vector var to make the offset equal or scalar computing\n";
}
}
bool unaligned_divisor = false;
if (is_simd) {
if (dst_info.IsUB()) {
if (dst_info.divisor != 0 && dst_info.divisor < cur_len) {
dst_info.divisor = air::ir::gcd(dst_info.divisor, cur_len);
unaligned_divisor = true;
}
}
for (auto info : src_info) {
if (info.IsUB()) {
if (info.divisor != 0 && info.divisor < cur_len) {
unaligned_divisor = true;
int temp_divisor = air::ir::gcd(info.divisor, cur_len);
dst_info.divisor = air::ir::gcd(dst_info.divisor, temp_divisor);
}
}
}
}
if (is_simd && !need_split && !unaligned_divisor) {
return 1;
}
if (is_simd && (need_split || unaligned_divisor)) {
return 2;
}
return 0;
}
// for i { for j { A[X*i + j] = B[X*i + j] + C[Z*i] } }
bool TryVectorScalarType() {
if (src_info.size() != 2) {
return false;
}
if (dst_info.is_serial && Equal(dst_info.vec_var, src_info[0].vec_var) &&
!HasVars(src_info[1].index, dst_info.vec_var) &&
(!src_info[1].is_serial || !Equal(dst_info.vec_var, src_info[1].vec_var))) {
scalar_load = src_info[1];
src_info.pop_back();
return true;
}
if (dst_info.is_serial && Equal(dst_info.vec_var, src_info[1].vec_var) &&
!HasVars(src_info[0].index, dst_info.vec_var) &&
(!src_info[0].is_serial || !Equal(dst_info.vec_var, src_info[0].vec_var))) {
scalar_load = src_info[0];
src_info.erase(src_info.begin());
return true;
}
return false;
}
// for i { for j { A[X*i + j] = C[Z*i] } }
bool TryVectorDumpType() {
if (src_info.size() != 1) {
return false;
}
if (GetBufScope(dst_info.p_store->buffer_var->name_hint) == SCOPE_UBUF &&
GetBufScope(src_info[0].p_load->buffer_var->name_hint) == SCOPE_UBUF && dst_info.is_serial &&
!HasVars(src_info[0].index, dst_info.vec_var) &&
(!src_info[0].is_serial || !Equal(dst_info.vec_var, src_info[0].vec_var))) {
scalar_load = src_info[0];
src_info.pop_back();
return true;
}
return false;
}
bool TryScalarAssignType(Array<Var> &if_vars) {
if (dst_info.IsUB() && dst_info.is_serial && src_info.size() == 1 && dst_info.p_store->value.as<Load>() &&
src_info[0].IsUB()) {
bool not_simd_or_dump = HasVars(src_info[0].index, dst_info.vec_var) &&
(!src_info[0].is_serial || !Equal(dst_info.vec_var, src_info[0].vec_var));
bool in_if_vars = !if_vars.empty() && IsInArray(if_vars, dst_info.vec_var);
if (not_simd_or_dump || in_if_vars) {
return true;
}
}
return false;
}
// for i { for j { A[X*i + j] = C[Y*j + i] } }
// for i { for j { A[X*i + j] = C[Y*j] } }
bool TryCrossingType() {
if (dst_info.is_serial && src_info.size() == 1 && HasVars(src_info[0].index, dst_info.vec_var) &&
(!src_info[0].is_serial || !Equal(dst_info.vec_var, src_info[0].vec_var))) {
return true;
}
return false;
}
// for i {for j { A[X*i + Y*j] = ....} }
bool TryDiscrete() { return !(dst_info.is_serial); }
void GetVectorizedInfo() {
if (insn_type == "scalar") {
is_scalar = true;
return;
}
if (insn_type == "simd" || insn_type == "vector_scalar" || insn_type == "vector_dump") {
vec_len = dst_info.vec_len;
vec_var = dst_info.vec_var;
offset = dst_info.offset;
return;
}
if (insn_type == "simd_split") {
vec_len = dst_info.divisor;
offset = 0;
return;
}
if (insn_type == "reduce") {
vec_len = src_info[0].vec_len;
vec_var = src_info[0].vec_var;
offset = src_info[0].offset;
return;
}
if (insn_type == "crossing" || insn_type == "discrete") {
vec_len = 1;
if (dst_info.is_serial) {
dst_info.vec_len = 1;
dst_info.divisor = 1;
}
for (size_t i = 0; i < src_info.size(); i++) {
if (src_info[i].is_serial) {
src_info[i].divisor = 1;
src_info[i].vec_len = 1;
}
}
return;
}
CHECK(0) << "\ninsn_type is unknown\n";
}
DstInfo dst_info;
std::vector<SrcInfo> src_info;
int vec_len;
Var vec_var;
Expr offset;
bool is_scalar{false};
Stmt store;
std::string op_type;
std::string insn_type{"unknown"};
SrcInfo scalar_load;
Expr scalar_imm{Expr()};
int scalar_imm_num{0};
};
class IRIfInfo {
public:
Array<Expr> conds;
Array<Var> vars;
Array<Stmt> ops;
};
class IRForInfo {
public:
Array<Var> vars;
std::vector<int> exts;
Array<Stmt> ops;
};
class IRInfo {
public:
Stmt GenStmt() {
auto ret = GenIfAndFor();
return ret;
}
Stmt GenIfAndFor() {
auto core = arith_info.store;
if (for_info.vars.empty()) {
return core;
}
Stmt ret = core;
for (int i = static_cast<int>(for_info.vars.size()) - 1; i >= 0; --i) {
ret = For::make(for_info.vars[i], 0, for_info.exts[i], ForType::Serial, DeviceAPI::None, ret);
}
return ret;
}
bool ChangeLastDimReduce() {
if (arith_info.src_info.size() != 2) {
return false;
}
size_t i = 0;
for (i = 0; i < arith_info.src_info.size(); ++i) {
if (Equal(arith_info.src_info[i].p_load->buffer_var, arith_info.dst_info.p_store->buffer_var) &&
Equal(arith_info.src_info[i].p_load->index, arith_info.dst_info.p_store->index)) {
break;
}
}
if (i >= 2) {
return false;
}
size_t index = 0;
if (!Equal(arith_info.src_info[1 - i].vec_var, arith_info.src_info[i].vec_var) &&
GetIndexOfElement(for_info.vars, arith_info.src_info[1 - i].vec_var, index) &&
!HasVars(arith_info.src_info[i].p_load->index, {arith_info.src_info[1 - i].vec_var})) {
SrcInfo t_src = arith_info.src_info[1 - i];
arith_info.src_info.clear();
arith_info.src_info.push_back(t_src);
arith_info.insn_type = "reduce_" + GetReduceType();
Expr pack_value =
Call::make(t_src.p_load->type, arith_info.insn_type, {GetRef<Expr>(t_src.p_load)}, Call::Extern);
arith_info.store = Store::make(arith_info.store.as<Store>()->buffer_var, pack_value,
arith_info.store.as<Store>()->index, arith_info.store.as<Store>()->predicate);
return true;
}
return false;
}
std::string GetReduceType() {
std::string ret = GetOpType(arith_info.dst_info.p_store->value);
std::transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
return ret;
}
IRIfInfo if_info;
IRForInfo for_info;
ArithInfo arith_info;
};
class ImmOffsetVisitor : public IRVisitor {
public:
int Run(const Expr &e) {
auto temp_index = Simplify(e);
IRVisitor::Visit(temp_index);
return imm_offset_;
}
void Visit_(const Add *op) {
if (op->a.as<IntImm>()) {
imm_offset_ = op->a.as<IntImm>()->value;
} else if (op->b.as<IntImm>()) {
imm_offset_ = op->b.as<IntImm>()->value;
} else {
IRVisitor::Visit(op->b);
}
}
bool in_add_flag_{false};
int imm_offset_{0};
};
class ParserVisitor : public IRVisitor {
public:
ParserVisitor(IRInfo &in, bool flag = false) : info(in), with_align(flag) {}
~ParserVisitor() override = default;
void Run(const Stmt &s) {
in_store = false;
IRVisitor::Visit(s);
if (with_align) {
GetInsnType();
info.arith_info.GetVectorizedInfo();
}
}
void Visit_(const For *op) {
info.for_info.vars.push_back(op->loop_var);
info.for_info.exts.push_back(op->extent.as<IntImm>()->value);
info.for_info.ops.push_back(op->body);
IRVisitor::Visit(op->body);
}
void Visit_(const IfThenElse *op) {
CHECK(!op->else_case.defined());
info.if_info.conds.push_back(op->condition);
auto var_list = GetVarsInExpr(op->condition);
for (auto t_var : var_list) {
if (!IsInArray(info.if_info.vars, t_var)) {
info.if_info.vars.push_back(t_var);
}
}
info.if_info.ops.push_back(op->then_case);
IRVisitor::Visit(op->then_case);
}
void Visit_(const Load *op) {
SrcInfo src_info;
src_info.index = op->index;
src_info.p_load = op;
GetIndexInfo(op->index, src_info);
info.arith_info.src_info.push_back(src_info);
}
void Visit_(const FloatImm *op) {
if (in_store) {
info.arith_info.scalar_imm = GetRef<Expr>(op);
++info.arith_info.scalar_imm_num;
}
}
void Visit_(const IntImm *op) {
if (in_store) {
info.arith_info.scalar_imm = GetRef<Expr>(op);
++info.arith_info.scalar_imm_num;
}
}
void Visit_(const Store *op) {
info.arith_info.store = GetRef<Stmt>(op);
info.arith_info.op_type = GetOpType(op->value);
in_store = true;
IRVisitor::Visit(op->value);
in_store = false;
DstInfo dst_info;
dst_info.p_store = op;
dst_info.index = op->index;
GetIndexInfo(op->index, dst_info);
info.arith_info.dst_info = dst_info;
}
void GetInsnType() { info.arith_info.GetIntrinsicType(info.for_info.vars, info.if_info.vars); }
template <typename T>
void GetIndexInfo(const Expr &e, T &t) {
bool is_serial = false;
int imm_offset = ImmOffsetVisitor().Run(e);
t.offset = imm_offset;
std::vector<int> nums;
bool is_linear_inner_for = true;
if (info.for_info.vars.empty()) {
t.is_scalar = true;
return;
}
for (size_t i = 0; i < info.for_info.vars.size(); i++) {
auto coef = air::arith::DetectLinearEquation(e, {info.for_info.vars[i]});
if (!coef.empty() && !Equal(coef[0], 0)) {
t.vars.push_back(info.for_info.vars[i]);
t.coefs.push_back(coef[0].as<IntImm>()->value);
t.extents.push_back(info.for_info.exts[i]);
if (!Equal(coef[0], 1)) {
nums.push_back(coef[0].as<IntImm>()->value);
} else {
is_serial = true;
t.vec_var = info.for_info.vars[i];
t.vec_len = info.for_info.exts[i];
}
} else if (coef.empty()) {
is_linear_inner_for = false;
}
}
if (is_linear_inner_for) {
if (nums.empty()) {
t.divisor = 0;
} else {
t.divisor = GetCommonDivisor(nums);
}
} else {
if (is_serial) {
Map<Var, Expr> value_map;
value_map.Set(t.vec_var, 0);
auto new_e = Simplify(Substitute(e, value_map));
if (Equal(Simplify(Mod::make(new_e, t.vec_len)), 0)) {
t.divisor = t.vec_len;
} else {
t.divisor = 1;
}
} else {
t.divisor = 1;
}
}
t.is_serial = is_serial;
}
private:
IRInfo &info;
bool with_align{false};
bool in_store{false};
};
class InsnTensor {
public:
InsnTensor(std::string name, Type type) : m_name(name), m_type(type) {}
virtual ~InsnTensor() {}
void SetAlignment(int align) { m_alignment = align; }
int GetAlignment() { return m_alignment; }
Type GetType() { return m_type; }
std::string m_name;
Type m_type;
int m_alignment{FREE_ALIGN};
};
class UnifyAlignInfo {
public:
bool NeedPadding(int align, int block_size) { return (align > 0 && align % block_size != 0); }
bool UnifyAlign() {
bool need_adjust = false;
int align = observers[0]->m_alignment;
int align_size = 32 / observers[0]->GetType().bytes();
for (size_t i = 1; i < observers.size(); ++i) {
auto temp_align = observers[i]->m_alignment;
auto temp_block = 32 / observers[i]->GetType().bytes();
if (align != temp_align && (NeedPadding(align, align_size) || NeedPadding(temp_align, temp_block))) {
need_adjust = true;
align = SpreadAlign(align, observers[i]->m_alignment, align_size, temp_block);
}
}
if (need_adjust) {
for (size_t i = 0; i < observers.size(); ++i) {
observers[i]->m_alignment = align;
}
}
return need_adjust;
}
int SpreadAlign(int left, int right, int left_block, int right_block) {
if (left < 0 || left % left_block == 0) {
return right;
}
if (right < 0 || right % right_block == 0) {
return left;
}
return GetCommonDivisor({left, right});
}
std::vector<InsnTensor *> observers;
std::vector<int> divisors;
std::vector<Expr> offsets;
int vector_len;
};
class AlignAttach : public IRMutator {
public:
AlignAttach(std::map<const Variable *, InsnTensor *> &in_map) : m_map_(in_map) {}
Stmt Mutate_(const Store *op, const Stmt &s) {
auto value = this->Mutate(op->value);
int align = 1;
if (m_map_.count(op->buffer_var.get())) {
align = m_map_[op->buffer_var.get()]->m_alignment;
}
return Store::make(op->buffer_var, value, op->index, align);
}
Expr Mutate_(const Load *op, const Expr &e) {
int align = 1;
if (m_map_.count(op->buffer_var.get())) {
align = m_map_[op->buffer_var.get()]->m_alignment;
}
return Load::make(op->type, op->buffer_var, op->index, align);
}
private:
std::map<const Variable *, InsnTensor *> &m_map_;
};
class AlignGen : public IRVisitor {
public:
Stmt Run(const Stmt stmt, std::unordered_map<const Variable *, Type> &var_info) {
for (auto &item : var_info) {
auto ptr = new InsnTensor(item.first->name_hint, item.second);
observer_dic_[item.first] = ptr;
}
IRVisitor::Visit(stmt);
BroadcastAlign();
auto ret = AlignAttach(observer_dic_).Mutate(stmt);
return ret;
}
void Visit_(const AttrStmt *op) final {
if (op->attr_key == "pragma_emit_insn" && exclude_align_analyze_list.count(op->value.as<StringImm>()->value) == 0) {
IRInfo info;
ParserVisitor(info, true).Run(op->body);
AddAlignInfo(info);
} else if (op->attr_key == "align_info" && op->node.as<Variable>() && observer_dic_[op->node.as<Variable>()] &&
op->value.as<IntImm>()) {
observer_dic_[op->node.as<Variable>()]->m_alignment = op->value.as<IntImm>()->value;
} else {
IRVisitor::Visit_(op);
}
}
void AddAlignInfo(IRInfo &info) {
if (info.arith_info.insn_type == "scalar") {
return;
}
bool is_ub_to_gm = (info.arith_info.src_info.size() == 1) &&
GetBufScope(info.arith_info.dst_info.p_store->buffer_var->name_hint) == DMA_COPY_GLOBAL;
bool is_gm_to_ub = (info.arith_info.src_info.size() == 1) &&
GetBufScope(info.arith_info.src_info[0].p_load->buffer_var->name_hint) == DMA_COPY_GLOBAL;
if (!is_ub_to_gm) {
auto dst_name = info.arith_info.dst_info.p_store->buffer_var.get();
auto divisor_dst = info.arith_info.dst_info.divisor;
if (!info.arith_info.is_scalar) {
HandleAlignment(observer_dic_[dst_name], divisor_dst, info.arith_info.vec_len);
}
}
if (!is_gm_to_ub) {
for (size_t i = 0; i < info.arith_info.src_info.size(); i++) {
auto src_name = info.arith_info.src_info[i].p_load->buffer_var.get();
if (observer_dic_.count(src_name) && !info.arith_info.is_scalar) {
auto src_observer = observer_dic_[src_name];
auto divisor_src = info.arith_info.src_info[i].divisor;
HandleAlignment(src_observer, divisor_src, info.arith_info.vec_len);
}
}
}
if (!is_ub_to_gm && !is_gm_to_ub && info.arith_info.insn_type != "reduce" &&
info.arith_info.insn_type != "crossing" && info.arith_info.insn_type != "discrete") {
UnifyAlignInfo temp_info;
auto dst_name = info.arith_info.dst_info.p_store->buffer_var.get();
temp_info.observers.push_back(observer_dic_[dst_name]);
temp_info.divisors.push_back(info.arith_info.dst_info.divisor);
temp_info.offsets.push_back(info.arith_info.dst_info.offset);
temp_info.vector_len = info.arith_info.vec_len;
for (size_t i = 0; i < info.arith_info.src_info.size(); i++) {
auto src_name = info.arith_info.src_info[i].p_load->buffer_var.get();
if (observer_dic_.count(src_name)) {
temp_info.observers.push_back(observer_dic_[src_name]);
temp_info.divisors.push_back(info.arith_info.src_info[i].divisor);
temp_info.offsets.push_back(info.arith_info.src_info[i].offset);
}
}
aligns_info_.push_back(temp_info);
}
}
void HandleAlignment(InsnTensor *observer, int divisor, int vector_len) {
auto block_size = GetUbBlkSize(observer->GetType());
CHECK(divisor % block_size == 0 || divisor >= vector_len);
auto cur_align = observer->GetAlignment();
int align_temp = 0;
if (cur_align == FREE_ALIGN && divisor % block_size == 0 && divisor >= vector_len) {
return;
}
if (cur_align == FREE_ALIGN && divisor % block_size == 0 && divisor < vector_len) {
return;
}
if (divisor != 0) {
if (cur_align == FREE_ALIGN) {
if (divisor == vector_len) {
align_temp = vector_len;
observer->SetAlignment(align_temp);
return;
}
if (divisor >= vector_len) {
return;
}
CHECK(0) << "Conditions not considered";
}
if (divisor % cur_align == 0 && vector_len < cur_align) {
return;
}
if (divisor % cur_align != 0) {
if (cur_align % block_size != 0) {
align_temp = air::ir::gcd(divisor, cur_align);
} else {
align_temp = divisor;
}
if (vector_len <= align_temp) {
observer->SetAlignment(align_temp);
} else {
align_temp = air::ir::gcd(vector_len, align_temp);
observer->SetAlignment(align_temp);
}
}
}
}
void BroadcastAlign() {
bool has_update = true;
while (has_update) {
has_update = false;
for (size_t i = 0; i < aligns_info_.size(); ++i) {
has_update = aligns_info_[i].UnifyAlign() || has_update;
}
}
}
private:
std::map<const Variable *, InsnTensor *> observer_dic_;
std::vector<UnifyAlignInfo> aligns_info_;
};
} // namespace ir } // namespace ir
} // namespace akg } // namespace akg
#endif // PASS_ANALYZE_ALIGN_H_ #endif // PASS_ANALYZE_ALIGN_H_
...@@ -466,7 +466,7 @@ class AlignVistor : public IRVisitor { ...@@ -466,7 +466,7 @@ class AlignVistor : public IRVisitor {
// only scan dma insns // only scan dma insns
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>() && if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>() &&
op->value.as<StringImm>()->value != "vec_binary_dropout" && op->value.as<StringImm>()->value != "vec_binary_dropout" &&
exclude_list.count(op->value.as<StringImm>()->value) == 0)) { exclude_align_analyze_list.count(op->value.as<StringImm>()->value) == 0)) {
bool in_dma_copy = false; bool in_dma_copy = false;
if (op->value.as<StringImm>() && op->value.as<StringImm>()->value == "dma_copy") { if (op->value.as<StringImm>() && op->value.as<StringImm>()->value == "dma_copy") {
in_dma_copy = true; in_dma_copy = true;
......
...@@ -26,673 +26,21 @@ ...@@ -26,673 +26,21 @@
namespace akg { namespace akg {
namespace ir { namespace ir {
namespace {
using Var2Scope = std::map<const Variable *, std::string>;
bool IsInStorageScope(const Var2Scope &table, const Variable *var) { return table.find(var) != table.end(); }
using AlignModifier = std::function<void(int64_t &)>;
using std::placeholders::_1;
class AlignInfo {
public:
explicit AlignInfo(const Type &t, int64_t off, const AlignModifier func = nullptr, bool spread = false)
: blk_sz(GetUbBlkSize(t)), base_offset(off), modifiers(), need_spread(spread) {
if (func != nullptr) {
modifiers.push_back(func);
}
}
explicit AlignInfo(const Type &t) : AlignInfo(t, 0, nullptr, false) {}
AlignInfo() : AlignInfo(Handle(1), 0, nullptr, false) { blk_sz = 0; }
~AlignInfo() = default;
int64_t blk_sz;
int64_t base_offset;
std::vector<AlignModifier> modifiers;
bool need_spread;
};
struct VarComp {
bool operator()(const Var &v0, const Var &v1) const { return v0.get() < v1.get(); }
};
using AlignDict = std::map<Var, AlignInfo, VarComp>;
void MergeAlignInfo(AlignInfo &a, const AlignInfo &b) {
CHECK(a.blk_sz != 0 || b.blk_sz != 0);
CHECK(a.blk_sz == 0 || b.blk_sz == 0 || a.blk_sz == b.blk_sz);
if (a.blk_sz == 0) {
a.blk_sz = b.blk_sz;
}
a.need_spread = a.need_spread || b.need_spread;
a.base_offset = air::ir::gcd(a.base_offset, b.base_offset);
a.modifiers.insert(a.modifiers.end(), b.modifiers.begin(), b.modifiers.end());
}
AlignDict MergeAlignDict(const AlignDict &a, const AlignDict &b) {
AlignDict rst = a;
for (const auto &e : b) {
auto it = rst.find(e.first);
if (it != rst.end()) {
MergeAlignInfo(it->second, e.second);
} else {
rst.emplace(e);
}
}
return rst;
}
AlignDict GenFreeAlignDict(const StmtInfoList &com_info_list) {
AlignDict dict;
for (const auto &com_info : com_info_list) {
dict.emplace(com_info->data_, AlignInfo(com_info->dtype_));
}
return dict;
}
AlignDict GenSpecAlignDict(const StmtInfoList &com_info_list, int64_t align, bool is_spread) {
AlignDict dict;
for (const auto &com_info : com_info_list) {
dict.emplace(com_info->data_, AlignInfo(com_info->dtype_, align, nullptr, is_spread));
}
return dict;
}
void FixAlignBySize(int64_t &align, int64_t size) {
if (align < size && align != 0 && (size % align) != 0) {
align = air::ir::gcd(align, size);
}
}
class RegExprSub : public IRMutator {
public:
RegExprSub() {}
~RegExprSub() override = default;
Expr run(const Expr &e) { return this->Mutate(e); }
Expr Mutate_(const Load *op, const Expr &e) final {
if (GetBufScope(op->buffer_var->name_hint) == SCOPE_REG && isImm(op->index)) {
return Variable::make(Int(32), "tmp");
}
return IRMutator::Mutate_(op, e);
}
};
AlignDict GenNormalAlignDict(const StmtInfoList &com_info_list, bool is_spread, bool all_remained_axis = false) {
AlignDict dict;
for (const auto &com_info : com_info_list) {
if (com_info->var_.empty() && !all_remained_axis) {
MergeAlignInfo(dict[com_info->data_], AlignInfo(com_info->dtype_, 0, nullptr, is_spread));
continue;
}
bool min_stride_eq1 = !com_info->var_.empty() && GetIntConst(GetItem(com_info->strides_, -1)) == 1;
auto index_expr = IndexOptimizer().Mutate(com_info->index_); int GetCommonDivisor(std::vector<int> numbers) {
if (min_stride_eq1) { CHECK(numbers.size() >= 1);
auto var = GetItem(com_info->var_, -1); int divisor = numbers[0];
index_expr = Simplify(EliminateVarInExpr(index_expr, {var})); for (size_t i = 1; i < numbers.size(); i++) {
} divisor = air::ir::gcd(divisor, numbers[i]);
int64_t offset_gcd = 1;
int64_t continuity_len = min_stride_eq1 ? GetIntConst(GetItem(com_info->shape_, -1)) : 1;
index_expr = RegExprSub().run(index_expr);
auto vars = GetVarsInExpr(index_expr);
if (vars.empty()) {
CHECK(is_const(index_expr));
offset_gcd = std::abs(GetIntConst(index_expr));
} else {
auto strides = air::arith::DetectLinearEquation(index_expr, vars);
if (strides.empty()) {
offset_gcd = -2; // "-2" means no need to consider
} else {
CHECK(!strides.empty());
offset_gcd = 0;
for (const auto &e : strides) {
offset_gcd = air::ir::gcd(offset_gcd, GetIntConst(e));
}
}
}
AlignModifier func = std::bind(FixAlignBySize, _1, continuity_len);
MergeAlignInfo(dict[com_info->data_], AlignInfo(com_info->dtype_, offset_gcd, func, is_spread));
}
return dict;
}
bool IsNonLinearScalar(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list) {
if (std::any_of(dst_info_list.begin(), dst_info_list.end(),
[](const StmtStoreInfo &com_info) { return (!com_info->var_.empty()); })) {
return true;
}
if (std::any_of(src_info_list.begin(), src_info_list.end(),
[](const StmtStoreInfo &com_info) { return (!com_info->var_.empty()); })) {
return true;
}
return false;
}
inline bool IsTranspose(const StmtStoreInfo &dst, const StmtStoreInfo &src) {
return dst->var_.size() > 1 && src->var_.size() > 1 && Equal(GetItem(dst->var_, -2), GetItem(src->var_, -1)) &&
Equal(GetItem(dst->var_, -1), GetItem(src->var_, -2)) &&
Equal(GetItem(dst->shape_, -1), GetItem(src->shape_, -2)) &&
Equal(GetItem(dst->shape_, -2), GetItem(src->shape_, -1)) && GetIntConst(GetItem(dst->strides_, -1)) == 1 &&
GetIntConst(GetItem(src->strides_, -1)) == 1 && Equal(GetItem(dst->strides_, -2), GetItem(src->shape_, -2)) &&
Equal(GetItem(src->strides_, -2), GetItem(dst->shape_, -2));
}
void FixAlignByShape(int64_t &align, int64_t shape0, int64_t shape1) {
if (align >= shape0 * shape1) {
return;
} else if (align >= shape0) {
CHECK_NE(shape0, 0);
if (align % shape0 == 0) {
auto times = align / shape0;
align = shape0 * air::ir::gcd(times, shape1);
return;
}
}
align = air::ir::gcd(align, shape0);
}
AlignDict GenTransposeAlign(const StmtStoreInfo &ori_dst, const StmtStoreInfo &ori_src, StmtInfo &if_info,
StmtInfo &for_info) {
auto dst = ori_dst.Copy();
auto src = ori_src.Copy();
auto var_old = GetItem(dst->var_, -1);
auto var_new = GetItem(dst->var_, -2);
dst.GetNode()->var_ = RemoveItemAtIndex(dst->var_, -1);
src.GetNode()->var_ = RemoveItemAtIndex(src->var_, -2);
int64_t sh0 = GetIntConst(GetItem(dst->shape_, -1));
int64_t sh1 = GetIntConst(GetItem(dst->shape_, -2));
auto shape = static_cast<int32_t>(sh0 * sh1);
dst.GetNode()->shape_ = RemoveItemAtIndex(dst->shape_, -1);
src.GetNode()->shape_ = RemoveItemAtIndex(src->shape_, -1);
SetItem(dst.GetNode()->shape_, -1, Expr(shape));
SetItem(src.GetNode()->shape_, -1, Expr(shape));
dst.GetNode()->strides_ = RemoveItemAtIndex(dst->strides_, -2);
src.GetNode()->strides_ = RemoveItemAtIndex(src->strides_, -2);
Map<Var, Expr> map({{var_old, Expr(0)}, {var_new, Expr(0)}});
dst.GetNode()->index_ = Simplify(Substitute(dst->index_, map) + var_new);
src.GetNode()->index_ = Simplify(Substitute(src->index_, map) + var_new);
StmtInfoList dst_list({dst});
StmtInfoList src_list({src});
CompactComputationInfoList(dst_list, src_list, if_info, for_info);
auto dict = GenNormalAlignDict(MergeTwo(dst_list, src_list), false);
dict[dst->data_].modifiers.clear();
dict[dst->data_].modifiers.push_back(std::bind(FixAlignByShape, _1, sh0, sh1));
dict[src->data_].modifiers.clear();
dict[src->data_].modifiers.push_back(std::bind(FixAlignByShape, _1, sh1, sh0));
return dict;
}
bool IsScalarDMA(const Stmt &op) {
StmtInfo f_info;
StmtInfo i_info;
std::string intrin;
std::string dma;
StmtInfoList src_info_list;
StmtInfoList dst_info_list;
GetDmaComputationInfo(op, dst_info_list, src_info_list, i_info, f_info, dma, intrin);
const auto &d_info = dst_info_list[0];
const auto &s_info = src_info_list[0];
bool last_dim_equal = !d_info->var_.empty() && !s_info->var_.empty() &&
GetItem(d_info->var_, -1).get() == GetItem(s_info->var_, -1).get() &&
!d_info->strides_.empty() && !s_info->strides_.empty() &&
GetIntConst(GetItem(d_info->strides_, -1)) != GetIntConst(GetItem(s_info->strides_, -1));
bool is_broadcast =
((!s_info->strides_.empty() && GetIntConst(GetItem(s_info->strides_, -1)) != 1) || s_info->var_.empty()) &&
((!d_info->strides_.empty() && GetIntConst(GetItem(d_info->strides_, -1)) != 1) || d_info->var_.empty());
bool ubuf_scalar = (is_broadcast || last_dim_equal) && intrin == INTRIN_NAME_COPY_UB_TO_UB;
bool broadcast_scalar = is_broadcast && intrin == "broadcast";
if (broadcast_scalar || ubuf_scalar) {
int shape = GetInt32Const(GetItem(d_info->shape_, -1));
int stride = GetInt32Const(GetItem(d_info->strides_, -1));
int block_size = GetUbBlkSize(d_info->dtype_);
if (!(ubuf_scalar && shape < block_size && stride == block_size &&
IsTwoItemEqual(d_info->strides_, s_info->strides_, -1, true))) {
return true;
}
}
return false;
}
AlignDict GetDataAlign(const Stmt &op, const bool is_dma_copy, std::vector<StmtInfoList> &info_vec) {
StmtInfo if_info;
StmtInfo for_info;
StmtInfoList dst_info_list;
StmtInfoList src_info_list;
GetCompactComputationInfo(op, dst_info_list, src_info_list, if_info, for_info, false, true);
auto merged_com_list = MergeTwo(dst_info_list, src_info_list);
info_vec.push_back(merged_com_list);
Array<NodeRef> stores;
Array<NodeRef> loads;
GetStoreAndLoads(op, stores, loads);
auto org_dst_info_list = GetComputationInfo(stores, for_info);
auto org_src_info_list = GetComputationInfo(loads, for_info);
StmtInfoList empty_com_list;
// check load list
if (src_info_list.empty()) {
// broadcast/scalar mode, such as A[i, j] = 0.0 / A[1] = 2.0
if (dst_info_list[0]->var_.empty()) {
return GenFreeAlignDict(dst_info_list);
} else {
return GenNormalAlignDict(merged_com_list, false);
}
} else if (src_info_list.size() == 1) {
auto dst_info = dst_info_list[0];
auto src_info = src_info_list[0];
if (dst_info->scope_ == SCOPE_UBUF && src_info->scope_ == SCOPE_UBUF) {
if (dst_info->var_.empty() && src_info->var_.empty()) {
if (is_dma_copy) {
if (IsNonLinearScalar(org_dst_info_list, org_src_info_list)) {
// check if it is non-linear index scalar mov, such as
// for (cc2, 0, 4) {
// for (cc3, 0, 6) {
// T_tile_local_UB[((cc2*6) + cc3)] = data_local__ub[(((cc2 % 2)*2) + (cc3 % 2))]
// }
// }
CleanNonLinearVar(org_dst_info_list, empty_com_list, if_info);
auto align_src = GenFreeAlignDict(src_info_list);
auto align_dst = GenNormalAlignDict(org_dst_info_list, false);
return MergeAlignDict(align_src, align_dst);
}
// intrin_name = 'copy_ubuf_to_ubuf'
// scalar op, will not influence the align
return GenFreeAlignDict(merged_com_list);
}
// intrin_name = vadds or vmuls
return GenNormalAlignDict(merged_com_list, false, true);
} else if (src_info->var_.empty()) {
if (GetIntConst(GetItem(dst_info->strides_, -1)) == 1) {
// scalar broadcast
CleanNonLinearVar(org_dst_info_list, empty_com_list, if_info);
auto align_src = GenFreeAlignDict(src_info_list);
auto align_dst = GenNormalAlignDict(org_dst_info_list, false);
return MergeAlignDict(align_src, align_dst);
}
// intrin_name = vector_dup
return GenFreeAlignDict(merged_com_list);
} else if (!(dst_info->var_.empty()) && Equal(GetItem(dst_info->var_, -1), GetItem(src_info->var_, -1))) {
if (GetIntConst(GetItem(dst_info->strides_, -1)) == GetIntConst(GetItem(src_info->strides_, -1)) &&
Equal(GetItem(org_dst_info_list[0]->var_, -1), GetItem(org_src_info_list[0]->var_, -1))) {
// elemwise mode, intrin_name = copy_ubuf_to_ubuf
return GenNormalAlignDict(merged_com_list, true);
}
// scalar dma mode
return GenFreeAlignDict(merged_com_list);
} else if (IsTranspose(dst_info, src_info)) {
if (is_dma_copy) {
// intrin_name = vtranspose
int block_size = GetUbBlkSize(dst_info->dtype_);
CHECK_NE(block_size, 0);
int dst_shape = GetInt32Const(GetItem(dst_info->shape_, -1));
int src_shape = GetInt32Const(GetItem(src_info->shape_, -1));
if (dst_shape % block_size != 0 ||
(src_shape % block_size != 0 && (src_shape > block_size || dst_shape > block_size))) {
return GenTransposeAlign(dst_info, src_info, if_info, for_info);
} else {
// special case optimization
return GenNormalAlignDict(merged_com_list, false);
}
}
// align = 1
return GenSpecAlignDict(merged_com_list, 1, true);
} else if (dst_info->var_.size() > 1 && src_info->var_.size() > 1 &&
!Equal(GetItem(dst_info->var_, -1), GetItem(src_info->var_, -1)) &&
Equal(GetItem(dst_info->var_, -2), GetItem(src_info->var_, -2))) {
// intrin_name = broadcast
// special case of last dim axis broadcast issue #675
CleanNonLinearVar(org_dst_info_list, empty_com_list, if_info);
auto align_src = GenFreeAlignDict(src_info_list);
auto align_dst = GenNormalAlignDict(org_dst_info_list, false);
return MergeAlignDict(align_src, align_dst);
} else if (IsScalarDMA(op)) {
return GenFreeAlignDict(merged_com_list);
}
return GenNormalAlignDict(merged_com_list, false);
} else if (dst_info->scope_ != DMA_COPY_GLOBAL && src_info->scope_ != DMA_COPY_GLOBAL &&
dst_info->var_.size() > 1 && src_info->var_.size() > 1 &&
Equal(GetItem(dst_info->var_, -1), GetItem(src_info->var_, -2)) &&
Equal(GetItem(dst_info->var_, -2), GetItem(src_info->var_, -1))) {
// check transopse cbuf, ca, cb, cc
if (is_dma_copy) {
// intrin_name = vtranspose
int64_t align = GetIntConst(GetItem(dst_info->shape_, -1) * GetItem(src_info->shape_, -1));
return GenSpecAlignDict(merged_com_list, align, true);
}
// discontinuoust dma mov
return GenSpecAlignDict(merged_com_list, 1, true);
} else if (dst_info->var_.empty() && src_info->var_.empty()) {
// not ub to ub mode, discontinuous dma mov
return GenNormalAlignDict(merged_com_list, true, true);
} else if (dst_info->var_.empty()) {
LOG(FATAL) << "Error: Copy Vector into a scalar.";
} else if (src_info->var_.empty()) {
// broadcast between ub and gm
return GenNormalAlignDict(merged_com_list, true, true);
} else if (!Equal(GetItem(dst_info->var_, -1), GetItem(src_info->var_, -1)) ||
GetIntConst(GetItem(dst_info->strides_, -1)) != 1 || GetIntConst(GetItem(src_info->strides_, -1)) != 1) {
// discontinuoust dma mov
return GenSpecAlignDict(merged_com_list, 1, true);
}
return GenNormalAlignDict(merged_com_list, true);
} else if (src_info_list.size() < 5) { // src_info_list allowed max value + 1
if (IsLastAxisReduction(dst_info_list, src_info_list)) {
// reduction mode
if (GetIntConst(GetItem(dst_info_list[0]->shape_, -1)) == 1) {
// reduce to a scalar
return GenFreeAlignDict(merged_com_list);
}
// last dim is compacted separately
return GenNormalAlignDict(merged_com_list, false);
} else if (IsElementwise(dst_info_list, src_info_list)) {
// elementwise mode
return GenNormalAlignDict(merged_com_list, true, true);
} else if (IsBroadcast(dst_info_list, src_info_list)) {
// broadcast mode
bool need_spread = !IsLastAxisBroadcast(dst_info_list, src_info_list);
return GenNormalAlignDict(merged_com_list, need_spread);
}
return GenNormalAlignDict(merged_com_list, true);
} else {
LOG(FATAL) << "Error: Can not support more than 4 loads.";
} }
// error, and return empty map return divisor;
return AlignDict();
} }
class AlignVistor : public IRVisitor { namespace {
public:
explicit AlignVistor(const Var2Scope &table)
: min_align(), gbl_storage(), storage_scope_(table), all_aligns_(), spread_vec_(), info_vec_() {}
~AlignVistor() override = default;
void Run(const Stmt stmt) {
this->Visit(stmt);
UpdateAlign();
}
void Visit_(const AttrStmt *op) final {
// nested scop, just return
if (op->attr_key == "isolate_range") return;
if (auto str_ptr = op->node.as<StringImm>()) {
if (str_ptr->value == "no_align") {
return IRVisitor::Visit_(op);
}
}
// only scan dma insns
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
op->value.as<StringImm>()->value != "vec_binary_dropout" &&
op->value.as<StringImm>()->value != "mask_broadcast" &&
exclude_list.count(op->value.as<StringImm>()->value) == 0)) {
bool in_dma_copy = false;
if (op->value.as<StringImm>() && op->value.as<StringImm>()->value == "dma_copy") {
in_dma_copy = true;
}
auto dict = GetDataAlign(op->body, in_dma_copy, info_vec_);
for (auto it = dict.begin(); it != dict.end();) {
if (!IsInStorageScope(storage_scope_, it->first.get())) {
gbl_storage.insert(it->first.get());
it = dict.erase(it);
} else {
++it;
}
}
std::vector<Var> spread_var;
for (const auto &e : dict) {
if (e.second.need_spread) {
spread_var.push_back(e.first);
}
MergeAlignInfo(all_aligns_[e.first], e.second);
}
if (spread_var.size() > 1) {
spread_vec_.push_back(std::move(spread_var));
}
}
return IRVisitor::Visit_(op);
}
std::map<const Variable *, int64_t> min_align;
std::set<const Variable *> gbl_storage;
private:
void UpdateAlign() {
for (auto e : gbl_storage) {
auto var_ptr = const_cast<Variable *>(e);
all_aligns_.emplace(Var(GetObjectPtr<Object>(var_ptr)), AlignInfo(var_ptr->type));
}
do {
for (auto &e : all_aligns_) {
auto &info = e.second;
auto blk_sz = info.blk_sz;
CHECK_NE(blk_sz, 0);
if (info.base_offset % blk_sz != 0) {
while (info.base_offset != 1) {
bool done = true;
for (auto func : info.modifiers) {
auto old = info.base_offset;
func(info.base_offset);
CHECK_LE(info.base_offset, old);
if (info.base_offset < old) {
done = false;
}
}
if (done && FixLoopAxis()) {
break;
}
}
}
}
} while (!DealWithSpread());
for (const auto &e : all_aligns_) {
if (IsInStorageScope(storage_scope_, e.first.get())) {
min_align.emplace(e.first.get(), e.second.base_offset);
}
}
}
bool FixLoopAxis() {
for (const auto &vec_ele : info_vec_) {
// for_v -> times
std::map<Var, std::vector<int64_t>, VarComp> coef_table;
// for_v -> [buffer -> times]
std::map<Var, std::map<Var, int64_t, VarComp>, VarComp> buf_table;
for (const auto &info : vec_ele) {
auto it = all_aligns_.find(info->data_);
CHECK(it != all_aligns_.end());
if (it->second.base_offset <= 1) {
continue;
}
for (size_t i = 0; i != info->var_.size(); ++i) {
auto stride = std::abs(GetIntConst(info->strides_[i]));
auto extent = std::abs(GetIntConst(info->shape_[i]));
auto align = it->second.base_offset;
if (stride < align && stride * extent > align) {
CHECK_NE(stride, 0);
if (align % stride != 0) {
it->second.base_offset = air::ir::gcd(align, stride);
return false;
}
CHECK_NE((align / stride), 0);
if (extent % (align / stride) != 0) {
auto times = align / stride;
auto new_times = air::ir::gcd(extent, times);
it->second.base_offset = it->second.base_offset * new_times / times;
return false;
}
auto var = info->var_[i];
auto times = align / stride;
coef_table[var].push_back(times);
auto &times_record = buf_table[var][it->first];
CHECK(times_record == 0 || times_record == times);
times_record = times;
}
}
}
for (const auto &i : coef_table) {
auto align = i.second.front();
bool changed = false;
for (auto ele : i.second) {
changed = changed || (ele != align);
align = air::ir::gcd(align, ele);
}
if (changed) {
for (auto v : buf_table[i.first]) {
all_aligns_[v.first].base_offset *= align;
CHECK_NE(v.second, 0);
all_aligns_[v.first].base_offset /= v.second;
}
return false;
}
}
}
return true;
}
bool DealWithSpread() {
for (const auto &vec : spread_vec_) {
auto it = all_aligns_.find(vec.front());
CHECK(it != all_aligns_.end());
auto align = it->second.base_offset;
bool changed = false;
for (const auto &e : vec) {
auto it_in = all_aligns_.find(e);
CHECK(it_in != all_aligns_.end());
changed = changed || (it_in->second.base_offset != align);
align = air::ir::gcd(align, it_in->second.base_offset);
}
if (changed) {
for (const auto &e : vec) {
auto it_in = all_aligns_.find(e);
CHECK(it_in != all_aligns_.end());
it_in->second.base_offset = align;
}
return false;
}
}
return true;
}
// storage scope
const Var2Scope &storage_scope_;
// all align_ info
AlignDict all_aligns_;
std::vector<std::vector<Var>> spread_vec_;
std::vector<StmtInfoList> info_vec_;
};
// predicate is for GPU, use it to hold min align
class AlignInsert : public IRMutator {
public:
AlignInsert() : min_align_(), gbl_storage_() {}
~AlignInsert() override = default;
Stmt Run(const Stmt stmt, const Var2Scope &storage_scope) {
AlignVistor visitor(storage_scope);
visitor.Run(stmt);
min_align_ = std::move(visitor.min_align);
gbl_storage_ = std::move(visitor.gbl_storage);
return this->Mutate(stmt);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
Expr value = this->Mutate(op->value);
auto index = this->Mutate(op->index);
int64_t val = gbl_storage_.find(op->buffer_var.get()) == gbl_storage_.end() ? free_align_flag_ : 1;
auto it = min_align_.find(op->buffer_var.get());
if (it != min_align_.end()) {
val = GetAlignValue(it->second, op->value.type());
}
return Store::make(op->buffer_var, value, index, make_const(Int(32), val));
}
Expr Mutate_(const Load *op, const Expr &e) final {
auto index = this->Mutate(op->index);
int64_t val = gbl_storage_.find(op->buffer_var.get()) == gbl_storage_.end() ? free_align_flag_ : 1;
auto it = min_align_.find(op->buffer_var.get());
if (it != min_align_.end()) {
val = GetAlignValue(it->second, op->type);
}
return Load::make(op->type, op->buffer_var, index, make_const(Int(32), val));
}
private:
static int64_t GetAlignValue(int64_t val, const air::DataType dtype) {
int value = GetUbBlkSize(dtype);
CHECK_NE(value, 0);
return val % value == 0 ? FREE_ALIGN : val;
}
std::map<const Variable *, int64_t> min_align_;
std::set<const Variable *> gbl_storage_; using Var2Scope = std::map<const Variable *, std::string>;
const int free_align_flag_ = -2; bool IsInStorageScope(const Var2Scope &table, const Variable *var) { return table.find(var) != table.end(); }
};
class FindSameNameBuf : public IRVisitor { class FindSameNameBuf : public IRVisitor {
public: public:
...@@ -782,16 +130,35 @@ class InsertIsolate : public IRMutator { ...@@ -782,16 +130,35 @@ class InsertIsolate : public IRMutator {
bool insert_isolate_; bool insert_isolate_;
}; };
class CacheVisiter : public IRVisitor {
public:
CacheVisiter() = default;
~CacheVisiter() override = default;
void Visit_(const Allocate *op) final {
var_type_map[op->buffer_var.get()] = op->type;
IRVisitor::Visit_(op);
}
std::unordered_map<const Variable *, Type> var_type_map;
};
// process each isolate_range once a time // process each isolate_range once a time
class ProcessParts : public IRMutator { class ProcessParts : public IRMutator {
public: public:
explicit ProcessParts(const Var2Scope &table) : level_(0), storage_scope_(table) {} explicit ProcessParts(const Var2Scope &table) : level_(0), storage_scope_(table) {}
~ProcessParts() override = default; ~ProcessParts() override = default;
std::unordered_map<const Variable *, Type> var_type_map;
Stmt Run(Stmt stmt) { Stmt Run(Stmt stmt) {
CacheVisiter buffer_visitor;
buffer_visitor.Visit(stmt);
var_type_map = buffer_visitor.var_type_map;
stmt = this->Mutate(stmt); stmt = this->Mutate(stmt);
if (level_ == 0) { if (level_ == 0) {
stmt = AlignInsert().Run(stmt, storage_scope_); stmt = AlignGen().Run(stmt, var_type_map);
} }
return stmt; return stmt;
} }
...@@ -799,7 +166,7 @@ class ProcessParts : public IRMutator { ...@@ -799,7 +166,7 @@ class ProcessParts : public IRMutator {
Stmt Mutate_(const Block *op, const Stmt &s) final { Stmt Mutate_(const Block *op, const Stmt &s) final {
if (!HasIsolate(s)) { if (!HasIsolate(s)) {
Stmt stmt = s; Stmt stmt = s;
stmt = AlignInsert().Run(stmt, storage_scope_); stmt = AlignGen().Run(stmt, var_type_map);
level_++; level_++;
return stmt; return stmt;
} }
...@@ -813,7 +180,7 @@ class ProcessParts : public IRMutator { ...@@ -813,7 +180,7 @@ class ProcessParts : public IRMutator {
Stmt stmt = IRMutator::Mutate_(op, s); Stmt stmt = IRMutator::Mutate_(op, s);
// no isolate_range in this attr // no isolate_range in this attr
if (cur_level == level_) { if (cur_level == level_) {
stmt = AlignInsert().Run(stmt, storage_scope_); stmt = AlignGen().Run(stmt, var_type_map);
} }
return stmt; return stmt;
} }
...@@ -841,14 +208,14 @@ class ProcessParts : public IRMutator { ...@@ -841,14 +208,14 @@ class ProcessParts : public IRMutator {
Stmt AnalyzeMinAlignStatic(Stmt stmt) { Stmt AnalyzeMinAlignStatic(Stmt stmt) {
stmt = air::ir::ConvertSSA(stmt); stmt = air::ir::ConvertSSA(stmt);
CacheVisiter buffer_visitor;
buffer_visitor.Visit(stmt);
FindSameNameBuf find_visitor; FindSameNameBuf find_visitor;
find_visitor.Visit(stmt); find_visitor.Visit(stmt);
stmt = MergeLoops(stmt);
stmt = InsertIsolate(find_visitor.storage_scope_).Mutate(stmt); stmt = InsertIsolate(find_visitor.storage_scope_).Mutate(stmt);
stmt = ProcessParts(find_visitor.storage_scope_).Run(stmt); stmt = ProcessParts(find_visitor.storage_scope_).Run(stmt);
stmt = RewriteByAlignStatic(stmt); stmt = RewriteByAlignStatic(stmt);
return stmt; return stmt;
} }
......
...@@ -43,7 +43,7 @@ class LoopsCompacter : public IRMutator { ...@@ -43,7 +43,7 @@ class LoopsCompacter : public IRMutator {
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() && if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
!exclude_list.count(op->value.as<StringImm>()->value))) { !exclude_align_analyze_list.count(op->value.as<StringImm>()->value))) {
stores_ = Array<NodeRef>(); stores_ = Array<NodeRef>();
loads_ = Array<NodeRef>(); loads_ = Array<NodeRef>();
GetStoreAndLoads(op->body, stores_, loads_); GetStoreAndLoads(op->body, stores_, loads_);
......
...@@ -192,6 +192,7 @@ class MultiLastAxisReduction : public IRMutator { ...@@ -192,6 +192,7 @@ class MultiLastAxisReduction : public IRMutator {
lastResult = loadTmp + storeLeft; lastResult = loadTmp + storeLeft;
} }
broadcastNum = Call::make(type_tmp, "vector_dup", {broadcastNum}, Call::PureIntrinsic);
Stmt stForOnce = Store::make(tmpBuffer, storeResult, newIdx, storeTmp->predicate); Stmt stForOnce = Store::make(tmpBuffer, storeResult, newIdx, storeTmp->predicate);
Stmt stForTwice = Store::make(storeTmp->buffer_var, lastResult, storeTmp->index, storeTmp->predicate); Stmt stForTwice = Store::make(storeTmp->buffer_var, lastResult, storeTmp->index, storeTmp->predicate);
Stmt stBroadcast = Store::make(tmpBuffer, broadcastNum, newIdx, storeTmp->predicate); Stmt stBroadcast = Store::make(tmpBuffer, broadcastNum, newIdx, storeTmp->predicate);
...@@ -212,7 +213,7 @@ class MultiLastAxisReduction : public IRMutator { ...@@ -212,7 +213,7 @@ class MultiLastAxisReduction : public IRMutator {
stForOnce = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr(str), stForOnce); stForOnce = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr(str), stForOnce);
stForTwice = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr(str), stForTwice); stForTwice = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr(str), stForTwice);
stBroadcast = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr("broadcast"), stBroadcast); stBroadcast = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr("vector_dup"), stBroadcast);
stmt = Block::make({stBroadcast, stForOnce, stForTwice}); stmt = Block::make({stBroadcast, stForOnce, stForTwice});
stmt = Allocate::make(tmpBuffer, type_tmp, extentsArray, const_true(), stmt); stmt = Allocate::make(tmpBuffer, type_tmp, extentsArray, const_true(), stmt);
......
...@@ -147,7 +147,7 @@ class EstimateAlign : public IRMutator { ...@@ -147,7 +147,7 @@ class EstimateAlign : public IRMutator {
Stmt Mutate_(const AttrStmt *op, const Stmt &stmt) final { Stmt Mutate_(const AttrStmt *op, const Stmt &stmt) final {
if (air::ir::attr::IsPragmaKey(op->attr_key) && op->value.as<StringImm>()) { if (air::ir::attr::IsPragmaKey(op->attr_key) && op->value.as<StringImm>()) {
if (exclude_list.count(op->value.as<StringImm>()->value)) { if (exclude_align_analyze_list.count(op->value.as<StringImm>()->value)) {
return stmt; return stmt;
} }
......
...@@ -46,7 +46,7 @@ class AxisPartitioner : public IRMutator { ...@@ -46,7 +46,7 @@ class AxisPartitioner : public IRMutator {
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() && if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
exclude_list.count(op->value.as<StringImm>()->value) == 0)) { exclude_index_fix_list.count(op->value.as<StringImm>()->value) == 0)) {
in_insn_ = true; in_insn_ = true;
counter_ = 0; counter_ = 0;
auto ret = IRMutator::Mutate_(op, s); auto ret = IRMutator::Mutate_(op, s);
...@@ -180,7 +180,7 @@ class RewriteAllocateAndIndex : public IRMutator { ...@@ -180,7 +180,7 @@ class RewriteAllocateAndIndex : public IRMutator {
} }
} }
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() && if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
(exclude_list.count(op->value.as<StringImm>()->value) == 0 || (exclude_index_fix_list.count(op->value.as<StringImm>()->value) == 0 ||
op->value.as<StringImm>()->value == "scatter"))) { op->value.as<StringImm>()->value == "scatter"))) {
in_insn_ = true; in_insn_ = true;
auto ret = IRMutator::Mutate_(op, s); auto ret = IRMutator::Mutate_(op, s);
......
...@@ -46,7 +46,7 @@ class AxisPartitioner : public IRMutator { ...@@ -46,7 +46,7 @@ class AxisPartitioner : public IRMutator {
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final { Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() && if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
exclude_list.count(op->value.as<StringImm>()->value) == 0)) { exclude_index_fix_list.count(op->value.as<StringImm>()->value) == 0)) {
in_insn_ = true; in_insn_ = true;
counter_ = 0; counter_ = 0;
auto ret = IRMutator::Mutate_(op, s); auto ret = IRMutator::Mutate_(op, s);
...@@ -182,7 +182,7 @@ class RewriteAllocateAndIndex : public IRMutator { ...@@ -182,7 +182,7 @@ class RewriteAllocateAndIndex : public IRMutator {
} }
} }
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() && if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
(exclude_list.count(op->value.as<StringImm>()->value) == 0 || (exclude_index_fix_list.count(op->value.as<StringImm>()->value) == 0 ||
op->value.as<StringImm>()->value == "scatter"))) { op->value.as<StringImm>()->value == "scatter"))) {
in_insn_ = true; in_insn_ = true;
auto ret = IRMutator::Mutate_(op, s); auto ret = IRMutator::Mutate_(op, s);
...@@ -307,12 +307,7 @@ class RewriteAllocateAndIndex : public IRMutator { ...@@ -307,12 +307,7 @@ class RewriteAllocateAndIndex : public IRMutator {
CHECK_NE(align, 0); CHECK_NE(align, 0);
int64_t coef = GetIntConst(strides[0]); int64_t coef = GetIntConst(strides[0]);
if (std::abs(coef) < align) { if (std::abs(coef) < align) {
auto it = var2ext_.find(v.get()); rst += v * strides[0];
if (it != var2ext_.end() && std::abs(coef * it->second) <= align) {
rst += v * strides[0];
} else {
return SimpleFix(tmp_idx_bk, opt.var2expr, align, times);
}
} else if (coef % align == 0) { } else if (coef % align == 0) {
auto new_coef = coef * times / align; auto new_coef = coef * times / align;
rst += v * Expr(static_cast<int32_t>(new_coef)); rst += v * Expr(static_cast<int32_t>(new_coef));
...@@ -359,7 +354,8 @@ class RewriteAllocateAndIndex : public IRMutator { ...@@ -359,7 +354,8 @@ class RewriteAllocateAndIndex : public IRMutator {
Stmt RewriteByAlignStatic(Stmt stmt) { Stmt RewriteByAlignStatic(Stmt stmt) {
stmt = AxisPartitioner().Run(stmt); stmt = AxisPartitioner().Run(stmt);
stmt = RewriteAllocateAndIndex().Mutate(stmt); stmt = RewriteAllocateAndIndex().Mutate(stmt);
return MergeLoops(stmt); stmt = MergeLoops(stmt);
return stmt;
} }
} // namespace ir } // namespace ir
} // namespace akg } // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include "emit_insn/insn_info.h"
#include "emit_insn/ir_transform.h"
#include "analyze_align.h"
namespace akg {
namespace ir {
class ReducePacker : public IRMutator {
public:
ReducePacker() = default;
~ReducePacker() override = default;
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
!exclude_align_analyze_list.count(op->value.as<StringImm>()->value))) {
IRInfo info;
ParserVisitor(info, false).Run(s);
if (info.ChangeLastDimReduce()) {
auto body = info.GenStmt();
return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr(info.arith_info.insn_type), body);
}
return s;
}
return IRMutator::Mutate_(op, s);
}
};
Stmt PackStore(Stmt stmt) {
stmt = TransposeTransform().Mutate(stmt);
stmt = ReducePacker().Mutate(stmt);
return stmt;
}
} // namespace ir
} // namespace akg
\ No newline at end of file
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include "emit_insn/insn_info.h"
#include "analyze_align.h"
#include "emit_insn/ir_transform.h"
namespace akg {
namespace ir {
class ReduceRecover : public IRMutator {
public:
ReduceRecover() = default;
~ReduceRecover() override = default;
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
op->value.as<StringImm>()->value.find("reduce_") != std::string::npos) {
old_pragma_ = op->value.as<StringImm>()->value;
if (old_pragma_ == "reduce_add") {
new_pragma_ = "vec_binary_add";
} else if (old_pragma_ == "reduce_max") {
new_pragma_ = "vec_binary_max";
} else if (old_pragma_ == "reduce_min") {
new_pragma_ = "vec_binary_min";
} else if (old_pragma_ == "reduce_fargmax") {
new_pragma_ = "vec_binary_fargmax";
} else if (old_pragma_ == "reduce_fargmin") {
new_pragma_ = "vec_binary_fargmin";
}
in_reduce_ = true;
auto body = this->Mutate(op->body);
in_reduce_ = false;
return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr(new_pragma_), body);
} else if (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
op->value.as<StringImm>()->value == "dma_copy_transpose") {
return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("vtranspose"), op->body);
} else if (op->attr_key == "align_info") {
return this->Mutate(op->body);
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
if (in_reduce_) {
if (old_pragma_ == "reduce_fargmax") {
auto load_load = op->value.as<Call>()->args[0];
auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate);
auto new_value = Call::make(load_load.type(), "fargmax", {src_load, load_load}, Call::CallType::PureIntrinsic);
auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate);
return new_store;
} else if (old_pragma_ == "reduce_fargmin") {
auto load_load = op->value.as<Call>()->args[0];
auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate);
auto new_value = Call::make(load_load.type(), "fargmin", {src_load, load_load}, Call::CallType::PureIntrinsic);
auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate);
return new_store;
} else if (old_pragma_ == "reduce_add") {
auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate);
auto new_value = Add::make(src_load, op->value.as<Call>()->args[0]);
auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate);
return new_store;
} else if (old_pragma_ == "reduce_max") {
auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate);
auto new_value = Max::make(src_load, op->value.as<Call>()->args[0]);
auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate);
return new_store;
} else if (old_pragma_ == "reduce_min") {
auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate);
auto new_value = Min::make(src_load, op->value.as<Call>()->args[0]);
auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate);
return new_store;
} else {
return s;
}
} else {
return IRMutator::Mutate_(op, s);
}
}
private:
std::string old_pragma_;
std::string new_pragma_;
bool in_reduce_;
};
std::string GetOpCode(const std::string &op_type) {
std::string op_code{};
if (op_type == "Add") {
op_code = "vadds";
} else if (op_type == "Mul") {
op_code = "vmuls";
} else if (op_type == "vaxpy") {
op_code = "vaxpy";
} else if (op_type == "DMACopy") {
op_code = "vector_dup";
}
return op_code;
}
class FinetunePragma : public IRMutator {
public:
FinetunePragma() = default;
~FinetunePragma() override = default;
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if ((op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
!exclude_align_analyze_list.count(op->value.as<StringImm>()->value))) {
IRInfo info;
ParserVisitor(info, true).Run(s);
std::string op_code = GetOpCode(info.arith_info.op_type);
if (!info.arith_info.dst_info.IsUB() || op_code.empty() ||
(!info.arith_info.src_info.empty() && !info.arith_info.src_info[0].IsUB())) {
return s;
}
if (info.arith_info.insn_type == "simd" && info.arith_info.scalar_imm_num == 1 &&
(op_code == "vmuls" || op_code == "vadds") && !info.arith_info.dst_info.p_store->value.type().is_float()) {
return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("scalar_calc"), op->body);
}
if (info.arith_info.insn_type == "vector_scalar" || info.arith_info.insn_type == "vector_dump") {
return GenStore(info, op_code, 0);
} else if (info.arith_info.insn_type == "simd" && info.arith_info.scalar_imm_num > 0) {
CHECK_EQ(info.arith_info.scalar_imm_num, 1);
return GenStore(info, op_code, 1);
} else if (info.arith_info.insn_type == "simd" && info.arith_info.scalar_imm_num == 0 &&
info.arith_info.op_type == "DMACopy" && info.arith_info.dst_info.IsUB() &&
info.arith_info.src_info.size() == 1 && info.arith_info.src_info[0].IsUB() &&
info.arith_info.dst_info.p_store->value.type().is_float()) {
/// change copy_ub_to_ub (fp16 or fp32) to adds (scalar = 0)
op_code = "vadds";
info.arith_info.scalar_imm_num = 1;
info.arith_info.scalar_imm = FloatImm::make(info.arith_info.dst_info.p_store->value.type(), 0);
return GenStore(info, op_code, 1);
} else if (info.arith_info.op_type == "DMACopy" &&
(info.arith_info.insn_type == "scalar" || info.arith_info.insn_type == "discrete") &&
info.arith_info.dst_info.IsUB() &&
(info.arith_info.src_info.size() == 1 && info.arith_info.src_info[0].IsUB())) {
return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("scalar_dma"), op->body);
} else if (info.arith_info.op_type == "DMACopy" &&
(info.arith_info.insn_type == "scalar" || info.arith_info.insn_type == "discrete") &&
info.arith_info.dst_info.IsUB() && info.arith_info.scalar_imm_num == 1) {
return GenStore(info, op_code, 1);
} else if (op->value.as<StringImm>()->value == "vec_single_muls" ||
op->value.as<StringImm>()->value == "vec_single_adds") {
if (op->value.as<StringImm>()->value == "vec_single_muls") {
op_code = "vmuls";
} else if (op->value.as<StringImm>()->value == "vec_single_adds") {
op_code = "vadds";
}
return GenStore(info, op_code, 1);
}
return s;
}
return IRMutator::Mutate_(op, s);
}
Stmt GenStore(IRInfo &info, const std::string &intrin_name, const int scalar_type = 0) {
CHECK(intrin_name == "vector_dup" || intrin_name == "vadds" || intrin_name == "vmuls" || intrin_name == "vaxpy");
/// scalar value
Expr scalar_value =
(scalar_type == 0) ? GetRef<Expr>(info.arith_info.scalar_load.p_load) : info.arith_info.scalar_imm;
Array<Expr> call_args{};
if (intrin_name == "vector_dup") {
call_args = {scalar_value};
} else {
Expr tensor_value = GetRef<Expr>(info.arith_info.src_info[0].p_load);
call_args = {tensor_value, scalar_value};
}
/// set store
auto old_ptr = info.arith_info.dst_info.p_store;
Expr new_value = Call::make(old_ptr->value.type(), intrin_name, call_args, Call::PureIntrinsic);
Stmt ret = Store::make(old_ptr->buffer_var, new_value, old_ptr->index, old_ptr->predicate);
if (scalar_type == 0) {
auto scalar_vars = info.arith_info.scalar_load.vars;
/// set inner for loop
for (int i = static_cast<int>(info.for_info.vars.size()) - 1; i >= 0; --i) {
if (!IsInArray(scalar_vars, info.for_info.vars[i])) {
ret = For::make(info.for_info.vars[i], 0, info.for_info.exts[i], ForType::Serial, DeviceAPI::None, ret);
}
}
/// set attribute
ret = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr(intrin_name), ret);
/// set outer for loop
for (int i = static_cast<int>(info.for_info.vars.size()) - 1; i >= 0; --i) {
if (IsInArray(scalar_vars, info.for_info.vars[i])) {
ret = For::make(info.for_info.vars[i], 0, info.for_info.exts[i], ForType::Serial, DeviceAPI::None, ret);
}
}
return ret;
} else {
for (int i = static_cast<int>(info.for_info.vars.size()) - 1; i >= 0; --i) {
ret = For::make(info.for_info.vars[i], 0, info.for_info.exts[i], ForType::Serial, DeviceAPI::None, ret);
}
ret = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr(intrin_name), ret);
return ret;
}
}
};
Stmt RecoverStore(Stmt stmt) {
stmt = IfReorder().Mutate(stmt);
stmt = FinetunePragma().Mutate(stmt);
stmt = ReduceRecover().Mutate(stmt);
return stmt;
}
} // namespace ir
} // namespace akg
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册