提交 59f460e7 编写于 作者: C cy

fix alignment and pragma

上级 0d9d3012
......@@ -115,6 +115,9 @@ REGISTER_PASS(AnalyzeMinAlignStatic);
REGISTER_PASS(AnalyzeMinAlignDynamic);
REGISTER_PASS(RewriteBroadcastVector);
REGISTER_PASS(OptimizePragma);
REGISTER_PASS(PackStore);
REGISTER_PASS(RecoverStore);
REGISTER_PASS(MergeLoops);
REGISTER_PASS(ExpandC0);
REGISTER_PASS(ForEliminate);
REGISTER_PASS(FixLoopExtent);
......
......@@ -738,16 +738,19 @@ NodeRef Lower(Schedule sch, const Array<NodeRef> &in_args, const Array<NodeRef>
if (global_attrs.GetBoolAttr(kDeadCodeElim, false)) {
stmt = NEXT_PASS(DeadCodeElim, stmt);
}
if (!is_dynamic) {
stmt = NEXT_PASS(RewriteBroadcastVector, stmt);
stmt = NEXT_PASS(OptimizePragma, stmt);
}
if (is_dynamic) {
stmt = NEXT_PASS(AnalyzeMinAlignDynamic, stmt, global_attrs.GetIntAttr(kEnableConvAnalyzeAlign, true),
global_attrs.GetIntAttr(kEnableScalarAlign, false));
global_attrs.GetIntAttr(kEnableScalarAlign, false));
} else {
stmt = NEXT_PASS(RewriteBroadcastVector, stmt);
stmt = NEXT_PASS(OptimizePragma, stmt);
stmt = NEXT_PASS(MergeLoops, stmt, false);
stmt = NEXT_PASS(PackStore, stmt);
stmt = NEXT_PASS(AnalyzeMinAlignStatic, stmt);
stmt = NEXT_PASS(RecoverStore, stmt);
}
stmt = NEXT_PASS(MultiLastAxisReductions, stmt, is_dynamic);
stmt = NEXT_PASS(AutoReorder, stmt);
if (enable_multicore != 0) {
......
......@@ -25,6 +25,9 @@
#include "insn_info.h"
#include "cce_params.h"
namespace akg {
enum SingleType {SIMD, Tensor_Scalar, Vector_Dump};
struct MutableMaskParams {
Var mask_var_;
Expr loop_var_;
......@@ -239,8 +242,11 @@ class VectorInsnBuilder : public InsnBuilder {
class SingleVecInsnBuilder : public VectorInsnBuilder {
public:
SingleVecInsnBuilder(const StmtStoreInfo &dst, const StmtStoreInfo &src, const ArgInfo &args,
const std::string &intrin_name, const Buffer &tmp_buf = Buffer())
: VectorInsnBuilder(dst, {src}, args, intrin_name), src_info_(src_info_list_[0]), tmp_buffer_(tmp_buf) {
const std::string &intrin_name, const Expr &scalar_src = Expr(),
const SingleType insn_type = SingleType::SIMD)
: VectorInsnBuilder(dst, {src}, args, intrin_name),
src_info_(src_info_list_[0]),
scalar_src_(scalar_src), insn_type_(insn_type) {
CHECK(src_info_.defined());
}
~SingleVecInsnBuilder() override = default;
......@@ -254,8 +260,10 @@ class SingleVecInsnBuilder : public VectorInsnBuilder {
Stmt CreateBroadcast(const VectorArgInfo &arg_info, const Var &local_var, Stmt stmt);
StmtStoreInfo src_info_;
Buffer tmp_buffer_;
Buffer broadcast_buffer_;
Expr scalar_src_;
SingleType insn_type_; // 0 simd : 1 vector_scalar : 2 vector_dup
};
class MultiVecInsnBuilder : public VectorInsnBuilder {
......
......@@ -92,9 +92,6 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) {
Expr dst_offset = dst_info_->insn_offset_;
Expr src_offset = src_info_->insn_offset_;
Var local_var = Var("broadcast_for_vec_local_UB", Handle());
stmt = CreateBroadcast(arg_info, local_var, stmt);
// Handle stride_m1 loop of single vector intrin, if stride_m1 > 255, it will be separated
if (dst_stride_m1 >= MAX_STRIDE_M1 || src_stride_m1 >= MAX_STRIDE_M1) {
auto var = Var("repeatStrideM1Idx");
......@@ -112,14 +109,6 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) {
}
}
if (!dst_info_->var_.empty() && src_info_->var_.empty() && intrin_name_ != INTRIN_NAME_VECTOR_DUP) {
// need to broadcast src first
stmt = Allocate::make(local_var, src_info_->dtype_, {Expr(src_block_size * FULL_BLOCK_NUM)}, const_true(), stmt);
if (!src_info_->scope_.empty()) {
stmt = AttrStmt::make(local_var, STORAGE_SCOPE, StringImm::make(src_info_->scope_), stmt);
}
}
CHECK(stmt.defined()) << "Error: Stmt is undefined!";
return stmt;
......@@ -131,70 +120,36 @@ Stmt SingleVecInsnBuilder::EmitExpandedIntrin(const VectorArgInfo &arg_info) {
/// \return
Stmt SingleVecInsnBuilder::EmitIntrinBody(const VectorArgInfo &arg_info, const Map<std::string, Expr> &args) {
Stmt body;
CHECK(!arg_info->src_stride_m0_list_.empty());
CHECK(!arg_info->src_stride_m1_list_.empty());
auto dst_buffer_id = GenBufferId(dst_info_);
auto src_buffer_id = GenBufferId(src_info_);
Expr repeat = args["repeat"];
auto dst_buffer_id = GenBufferId(dst_info_);
Expr dst_offset = Sub::make(args["dstOffset"], arg_info->block_offset_);
Expr src_offset = args["srcOffset"];
Expr src_stride_m1 = arg_info->src_stride_m1_list_[0];
auto dst = GetAccessPtr(dst_buffer_id, "w", dst_offset);
auto src = GetAccessPtr(src_buffer_id, "r", src_offset);
if (broadcast_buffer_.defined()) {
src_stride_m1 = 0;
src = GetAccessPtr(broadcast_buffer_, "r", Expr(0));
Array<Expr> insn_args {};
if (insn_type_ == SingleType::Vector_Dump) {
insn_args = {dst, scalar_src_, repeat};
} else {
auto src_buffer_id = GenBufferId(src_info_);
Expr src_offset = args["srcOffset"];
auto src = GetAccessPtr(src_buffer_id, "r", src_offset);
if (insn_type_ == SingleType::SIMD) {
insn_args = {dst, src, repeat};
} else if (insn_type_ == SingleType::Tensor_Scalar) {
insn_args = {dst, src, scalar_src_, repeat};
} else {
CHECK(0) << "\nUnknown insn_type_\n";
}
}
Array<Expr> stride_args = {arg_info->dst_stride_m0_, arg_info->src_stride_m0_list_[0], arg_info->dst_stride_m1_,
src_stride_m1};
Array<Expr> insn_args = {dst, src, repeat};
if (arg_info->scalar_.defined()) {
auto scalar = arg_info->scalar_;
if (tmp_buffer_.defined()) {
dst = GetAccessPtr(tmp_buffer_, "w", dst_offset);
}
insn_args = {dst, scalar, repeat};
if (intrin_name_ != INTRIN_NAME_VECTOR_DUP) {
Insert(insn_args, 1, src);
}
}
arg_info->src_stride_m1_list_[0]};
insn_args = MergeTwo(insn_args, stride_args);
body = EmitCceIntrinTemplate(Stmt(), dst.type(), insn_args, intrin_name_);
return body;
}
/// Create broadcast intrin if src is scalar
/// \param arg_info
/// \param local_var
/// \param stmt
/// \return
Stmt SingleVecInsnBuilder::CreateBroadcast(const VectorArgInfo &arg_info, const Var &local_var, Stmt stmt) {
if (!dst_info_->var_.empty() && src_info_->var_.empty() && intrin_name_ != INTRIN_NAME_VECTOR_DUP) {
// need to broadcast src first
auto src_block_size = GetUbBlkSize(src_info_->dtype_);
broadcast_buffer_ = BufferNode::make(local_var, src_info_->dtype_, {Expr(src_block_size * FULL_BLOCK_NUM)}, {},
src_info_->elem_offset_, "broadcast_for_vec_local_UB", src_info_->scope_,
src_info_->data_alignment_, 1, BufferType::kDefault);
auto broad_dst = GetAccessPtr(broadcast_buffer_, "w", 0);
Array<Expr> args = {
broad_dst, GenBufferId(src_info_).vload({Expr(0)}, src_info_->dtype_), Expr(1), Expr(1), Expr(1), Expr(0),
Expr(0)};
stmt = EmitSetVecMaskIntrin(stmt, src_info_->dtype_, GetAllMask(src_info_->dtype_));
stmt = InsertBody(stmt, EmitCceIntrinTemplate(Stmt(), src_info_->dtype_, args, INTRIN_NAME_VECTOR_DUP));
stmt = EmitSetVecMaskIntrin(stmt, dst_info_->dtype_, arg_info->vec_mask_);
}
return stmt;
}
/// if repeat-size > cce_max_repeat, then split it into loop as "Davinci ISA User Guide t6.3 (8.2.2)" mentioned
/// max_cce_repeat = 255, considering params are about 2 cycles, set it to be 255 // 2 = 127
......@@ -1250,8 +1205,10 @@ Stmt EmitCceBinaryVectorToReduceLastAxis(const StmtStoreInfo &dst_info, const St
auto vec_dup_arg_info = GenReduceHelperArgInfo(vec_dup_dst_info, for_extent, scalar, "VecDup");
vec_dup_dst_info.GetNode()->data_ = final_var;
vec_dup_dst_info.GetNode()->name_ = final_var->name_hint;
SingleVecInsnBuilder single_vec_builder = SingleVecInsnBuilder(vec_dup_dst_info, vec_dup_dst_info, vec_dup_arg_info,
INTRIN_NAME_VECTOR_DUP, final_dst_buffer);
INTRIN_NAME_VECTOR_DUP, scalar, SingleType::Vector_Dump);
auto insn_list = single_vec_builder.EmitIntrin();
auto stmt = std::accumulate(insn_list.begin(), insn_list.end(), Stmt(),
[](const Stmt &s0, const Stmt &s1) { return InsertBody(s0, s1); });
......
......@@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "emit_insn/insn_emitter.h"
#include "insn_emitter.h"
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
......@@ -53,145 +53,68 @@ std::vector<size_t> SortIndexes(const std::vector<int> &v) {
/// \param intrin_name - The CCE intrin name
/// \param broadcast_last_axis - Tag of broadcast_last_axis mode
/// \return Stmt of emitted CCE intrin
Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name, bool broadcast_last_axis = false) {
Stmt SingleVecEmitter(const Stmt &op, std::string intrin_name) {
CHECK(op);
Stmt result;
// optimization of copy_ubuf_to_ubuf
bool is_dma_opt = false;
if (intrin_name == INTRIN_NAME_COPY_UB_TO_UB) {
CommentManager::GetInstance().AddComment("Insn_type", "dma_copy");
CommentManager::GetInstance().AddComment("Insn_name", INTRIN_NAME_COPY_UB_TO_UB);
CommentManager::GetInstance().AddComment("Vadds_replace_copy", "enable");
intrin_name = "vadds";
is_dma_opt = true;
} else {
CommentManager::GetInstance().AddComment("Insn_type", "single_vector");
CommentManager::GetInstance().AddComment("Insn_name", intrin_name);
}
CommentManager::GetInstance().AddComment("Insn_type", "single_vector");
CommentManager::GetInstance().AddComment("Insn_name", intrin_name);
StmtInfoList dst_info_list;
StmtInfoList src_info_list;
StmtStoreInfo scalar_info;
StmtInfo for_info;
StmtInfo if_info;
std::string mode = GetSingleVecComputationInfo(op, intrin_name, dst_info_list, src_info_list, if_info, for_info);
bool same_dtype = intrin_name.find("vconv_") == std::string::npos;
GetCompactComputationInfo(op, dst_info_list, src_info_list, if_info, for_info, same_dtype, true);
CHECK(!dst_info_list.empty());
if (broadcast_last_axis) {
mode = "broadcast_last_axis";
// In this case, must come from binary vec, so must have two src
CHECK(src_info_list.size() >= 2) << "Broadcast last axis mode must have at least two srcs.";
if (!IsTwoItemEqual(src_info_list[0]->var_, dst_info_list[0]->var_, -1)) {
scalar_info = src_info_list[0];
src_info_list.Set(0, src_info_list[1]);
} else if (!IsTwoItemEqual(src_info_list[1]->var_, dst_info_list[0]->var_, -1)) {
scalar_info = src_info_list[1];
}
} else {
if (mode == "broadcast" && !src_info_list.empty() && dst_info_list.size() == 1) {
if (!IsTwoItemEqual(src_info_list[0]->var_, dst_info_list[0]->var_, -1)) {
mode = "broadcast_last_axis";
Array<Expr> call_args;
int call_cnt = 0;
if (intrin_name == "vector_dup" || intrin_name == "vadds" ||
intrin_name == "vmuls" || intrin_name == "vaxpy") {
auto GetCallInfo = [&intrin_name, &call_args, &call_cnt](const NodeRef &op) {
if (op.as<Call>() && op.as<Call>()->name == intrin_name) {
call_args = op.as<Call>()->args;
call_cnt = call_cnt + 1;
}
if (src_info_list.size() > 1) {
if (!IsTwoItemEqual(src_info_list[1]->var_, dst_info_list[0]->var_, -1)) {
mode = "broadcast_last_axis";
} else {
scalar_info = src_info_list[0];
src_info_list.Set(0, src_info_list[1]);
}
}
}
}
if (broadcast_last_axis) {
mode = "broadcast_last_axis";
};
PostOrderVisit(op, GetCallInfo);
CHECK_EQ(call_cnt, 1);
}
if (intrin_name == INTRIN_NAME_VECTOR_DUP) {
auto dst_info = dst_info_list[0];
if (dst_info->var_.size() > 1 &&
GetIntConst(GetItem(dst_info->strides_, -1)) == GetIntConst(GetItem(dst_info->shape_, -1)) + 1) {
// diagnoal broadcast case
return op;
}
dst_info.CleanFlexVar();
SingleType insn_type {SingleType::SIMD};
Expr scalar_src {};
if (intrin_name == "vector_dup") {
insn_type = SingleType::Vector_Dump;
src_info_list = {};
scalar_src = call_args[0];
} else if (intrin_name == "vadds" || intrin_name == "vmuls" || intrin_name == "vaxpy") {
insn_type = SingleType::Tensor_Scalar;
src_info_list = {src_info_list[0]};
scalar_src = call_args[1];
}
// check is single vector broadcast reduce mode exist
SingleVecPatternGenerator generator = SingleVecPatternGenerator(dst_info_list, src_info_list, for_info, mode);
SingleVecPatternGenerator generator = SingleVecPatternGenerator(dst_info_list, src_info_list, for_info);
auto params = generator.GetInsnArgs();
dst_info_list = params.dst_info_list;
src_info_list = params.src_info_list;
for_info = params.for_info;
ArgInfo arg_info = params.arg_info;
CommentManager::GetInstance().AddComment("Compute_type", mode);
CommentManager::GetInstance().AddComment("Compute_type", intrin_name);
CommentManager::GetInstance().AddComment("Pattern", arg_info.GetPattern());
if (intrin_name == "vadds" || intrin_name == "vmuls" || intrin_name == INTRIN_NAME_VECTOR_DUP) {
auto stores = GetStores(op);
auto store = stores[0].as<Store>();
auto scalar = Expr(0);
if (intrin_name == "vadds" || intrin_name == "vmuls") {
if (!dst_info_list.empty()) {
scalar = FloatImm::make(dst_info_list[0]->dtype_, 0.000000);
}
if (!dst_info_list[0]->dtype_.is_float()) {
return op;
}
if (!is_dma_opt) {
if (!scalar_info.defined()) {
auto children = GetBinaryOpExprChildren(store->value);
if (children.empty()) {
LOG(FATAL) << store->value << " is not binary op.";
}
scalar = children[1];
} else {
scalar = Load::make(scalar_info->dtype_, scalar_info->data_, scalar_info->index_, Expr(1));
}
}
} else if (intrin_name == INTRIN_NAME_VECTOR_DUP) {
if (store->value->IsInstance<Load>()) {
// scale is load
scalar =
Load::make(src_info_list[0]->dtype_, store->value.as<Load>()->buffer_var, src_info_list[0]->index_, Expr(1));
} else {
// scale is imm
scalar = store->value;
}
}
if (arg_info->body_arg_info_.defined()) {
arg_info->body_arg_info_.GetNode()->scalar_ = scalar;
}
if (arg_info->tail_arg_info_.defined()) {
arg_info->tail_arg_info_.GetNode()->scalar_ = scalar;
}
}
if (intrin_name == "vconv_deq") {
result = InsertBody(
result, Evaluate::make(Call::make(Float(16), "set_deqscale", {FloatImm::make(Float(16), 1.0)}, Call::Extern)));
}
SingleVecInsnBuilder single_vec_builder =
SingleVecInsnBuilder(dst_info_list[0], src_info_list[0], arg_info, intrin_name);
SingleVecInsnBuilder(dst_info_list[0], src_info_list[0], arg_info, intrin_name, scalar_src, insn_type);
auto insn_list = single_vec_builder.EmitIntrin();
if (intrin_name == INTRIN_NAME_VECTOR_DUP && dst_info_list[0]->var_.empty()) {
Stmt store;
auto ScanStore = [&store](const NodeRef &op) {
const auto e = op.as<Store>();
if (e != nullptr) {
store = Store::make(e->buffer_var, e->value, e->index, e->predicate);
}
};
air::ir::PostOrderVisit(op, ScanStore);
store = EmitSetVecMaskIntrin(store, dst_info_list[0]->dtype_);
insn_list = {store};
}
return FoldInsnWithForInfo(insn_list, if_info, for_info, result);
auto ret = FoldInsnWithForInfo(insn_list, if_info, for_info, result);
return ret;
}
/// Function to emit binary vector intrin
......@@ -211,11 +134,6 @@ Stmt BinaryVecEmitter(const Stmt &op, std::string intrin_name, bool enable_bisec
CommentManager::GetInstance().AddComment("Insn_name", intrin_name);
switch (arg_info->arg_type_) {
case ARG_VECTOR_BROADCAST_LAST_AXIS: {
CommentManager::GetInstance().CleanComments();
intrin_name += "s";
return SingleVecEmitter(op, intrin_name, true);
}
case ARG_VECTOR_REDUCTION_LAST_AXIS: {
CommentManager::GetInstance().AddComment("Compute_type", "reduce_last_axis");
auto dst_info = dst_info_list[0];
......@@ -928,83 +846,8 @@ Stmt DmaMovEmitter(const Stmt &op, bool enable_cover_protect) {
StmtInfo for_info;
GetDmaComputationInfo(op, dst_info_list, src_info_list, if_info, for_info, dma_mode, intrin_name);
auto check_alignment = [](const Expr &align, const Array<Expr> &shape) {
if (GetIntConst(align) == 1 || shape.size() == 1u) {
return true;
}
if (shape.empty()) {
return false;
}
Expr sz = 1;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
sz = sz * shape[i];
if (GetIntConst(align) == GetIntConst(sz)) {
return true;
}
}
return false;
};
const auto &dst_info = dst_info_list[0];
const auto &src_info = src_info_list[0];
int block_size = GetUbBlkSize(dst_info->dtype_);
// check scalar to scalar
// check if dst is considered as scalar
// check if src is considered as scalar
bool is_broadcast =
(dst_info->var_.empty() || (!dst_info->strides_.empty() && GetIntConst(GetItem(dst_info->strides_, -1)) != 1)) &&
(src_info->var_.empty() || (!src_info->strides_.empty() && GetIntConst(GetItem(src_info->strides_, -1)) != 1));
// check vector to vector, but in scalar dma mode
bool last_dim_equal = !dst_info->var_.empty() && !src_info->var_.empty() && !dst_info->strides_.empty() &&
!src_info->strides_.empty() &&
GetItem(dst_info->var_, -1).get() == GetItem(src_info->var_, -1).get() &&
GetIntConst(GetItem(dst_info->strides_, -1)) != GetIntConst(GetItem(src_info->strides_, -1));
bool broadcast_scalar = intrin_name == "broadcast" && is_broadcast;
bool ubuf_scalar = intrin_name == INTRIN_NAME_COPY_UB_TO_UB && (is_broadcast || last_dim_equal);
if (broadcast_scalar || ubuf_scalar) {
int shape1 = GetInt32Const(GetItem(dst_info->shape_, -1));
int stride1 = GetInt32Const(GetItem(dst_info->strides_, -1));
if (ubuf_scalar && shape1 < block_size && stride1 == block_size &&
IsTwoItemEqual(dst_info->strides_, src_info->strides_, -1, true) && src_info->dtype_.bits() != 64) {
// if last dim small than blocksize, then use vadds
return SingleVecEmitter(op, intrin_name);
}
CommentManager::GetInstance().AddComment("Insn_type", "dma_copy");
CommentManager::GetInstance().AddComment("Insn_name", "scalar");
if (src_info->var_.empty() && dst_info->var_.empty()) {
return op;
} else {
// check align
if (!check_alignment(dst_info->data_alignment_, dst_info->shape_)) {
return op;
}
Stmt base_stmt = EmitScalarDmaIntrinTemplate(op, src_info, dst_info);
return GenIfAndFor(base_stmt, if_info, for_info, false);
}
}
if (intrin_name == "broadcast") {
return SingleVecEmitter(op, INTRIN_NAME_VECTOR_DUP);
} else if (intrin_name == INTRIN_NAME_COPY_UB_TO_UB) {
// Use vadds to optimize dma copy
if (if_info.vars_.empty() && dst_info->dtype_.is_float() && src_info->dtype_.is_float()) {
if ((dst_info->dtype_.bits() == 32 && src_info->dtype_.bits() == 32) ||
(dst_info->dtype_.bits() == 16 && src_info->dtype_.bits() == 16)) {
int repeat_len = block_size * FULL_BLOCK_NUM;
CHECK_NE(block_size, 0);
int shape1 = GetInt32Const(GetItem(dst_info->shape_, -1));
if ((shape1 >= repeat_len / 2 && shape1 <= repeat_len) ||
(dst_info->shape_.size() >= 3 && shape1 <= block_size) ||
(dst_info->shape_.size() >= 2 && shape1 % block_size == 0)) {
// if last dim shape is too small, there is no need to opt
return SingleVecEmitter(op, intrin_name);
}
}
}
}
CommentManager::GetInstance().AddComment("Insn_type", "dma_copy");
......@@ -1014,31 +857,10 @@ Stmt DmaMovEmitter(const Stmt &op, bool enable_cover_protect) {
Map<std::string, Expr> ub_copy_post;
auto arg_info_map =
GetDmaCopyInsnArgs(intrin_name, dst_info_list, src_info_list, for_info, ub_copy_pre, ub_copy_post);
if (intrin_name == "vtranspose_scalar") {
base_stmt = EmitScalarDmaIntrinTemplate(op, src_info, dst_info);
CommentManager::GetInstance().AddComment("Insn_name", "scalar");
} else if (intrin_name == "vtranspose") {
Array<Expr> args = {arg_info_map["loop_width"], arg_info_map["loop_height"], arg_info_map["shape_width"]};
Array<Expr> pre_ub_copy_args;
if (!ub_copy_pre.empty()) {
pre_ub_copy_args = Array<Expr>(
{ub_copy_pre["nBurst"], ub_copy_pre["lenBurst"], ub_copy_pre["srcStride"], ub_copy_pre["dstStride"]});
}
Array<Expr> post_ub_copy_args;
if (!ub_copy_post.empty()) {
post_ub_copy_args = Array<Expr>(
{ub_copy_post["nBurst"], ub_copy_post["lenBurst"], ub_copy_post["srcStride"], ub_copy_post["dstStride"]});
}
TransposeInsnBuilder builder =
TransposeInsnBuilder(dst_info, src_info, args, pre_ub_copy_args, post_ub_copy_args);
base_stmt = builder.EmitSingleIntrin();
CommentManager::GetInstance().AddComment("Insn_name", intrin_name);
} else {
DmaInsnBuilder dma_builder =
DmaInsnBuilder(dst_info, src_info, intrin_name, arg_info_map, false, false, enable_cover_protect);
base_stmt = dma_builder.EmitSingleIntrin();
CommentManager::GetInstance().AddComment("Insn_name", intrin_name);
}
DmaInsnBuilder dma_builder =
DmaInsnBuilder(dst_info, src_info, intrin_name, arg_info_map, false, false, enable_cover_protect);
base_stmt = dma_builder.EmitSingleIntrin();
CommentManager::GetInstance().AddComment("Insn_name", intrin_name);
} else if (dma_mode == "cce_load") {
auto arg_info_map = GetDmaLoad2DInsnArgs(intrin_name, dst_info_list, src_info_list, for_info);
DmaInsnBuilder builder = DmaInsnBuilder(dst_info, src_info, intrin_name, arg_info_map, true);
......@@ -1104,6 +926,19 @@ Stmt DmaAtomicAddEmitter(const Stmt &op) {
return stmt;
}
Stmt VTransposeEmitter(const Stmt &op) {
StmtInfoList dst_info_list;
StmtInfoList src_info_list;
StmtInfo for_info;
StmtInfo if_info;
GetCompactComputationInfo(op, dst_info_list, src_info_list, if_info, for_info, true, true);
auto dst_buffer_id = GenBufferId(dst_info_list[0]);
auto src_buffer_id = GenBufferId(src_info_list[0]);
auto dst = GetAccessPtr(dst_buffer_id, "w", 0);
auto src = GetAccessPtr(src_buffer_id, "r", 0);
return Evaluate::make(Call::make(Float(16), "vtranspose", {dst, src}, Call::Extern));
}
/// Function to emit dropout intrin
/// \param op - The input stmt to be emitted as intrin
/// \return Stmt of emitted CCE intrin
......@@ -1913,97 +1748,6 @@ Stmt ReduceCombineEmitter(const Stmt &op, bool enable_bisect) {
Stmt InsnEmit(std::string insn_name, const Stmt &op, bool enable_bisect, bool enable_cover_protect, int comment_level) {
CHECK(op.defined());
static const std::map<std::string, std::string> ReplaceAttrPragmaMap = {
// vector binary
{"binary_vcadd", "vec_binary_add"},
{"vaxpy", "vec_binary_axpy"},
// vector single
{"vec_single_fabs", "vec_single_abs"},
{"broadcast", "vec_broadcast"},
// cube
{"mad", "cube_mad"},
{"ub2gm", "cube_ub2gm"},
{"im2col", "cube_img2col"},
// special attrs
{"vec_binary_proposal_sort", "vec_proposal_sort"},
{"vec_binary_topk_sort", "vec_topk_sort"},
{"vec_binary_dropout", "vec_dropout"},
{"vec_binary_fargmax", "vec_argmax"},
{"vec_binary_fargmin", "vec_argmin"},
{"vec_binary_iou", "vec_iou"},
{"vec_binary_nms", "vec_nms"},
{"mask_broadcast", "vec_broadcast"},
};
static const std::map<std::string, std::string> BinaryVecInsnMap = {
// vadd.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadd.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadd.f32 support target:mini_v100 cloud_v100
// vadd contains two situations:
// 1. normal elewise vector add
// - all src[i].shape = dst.shape
// 2. reductive vector add
// - exist src[i].shape != dst.shape
{"vec_binary_add", "vadd"},
// vsub.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vsub.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vsub.f32 support target:mini_v100 cloud_v100
{"vec_binary_sub", "vsub"},
// vmul.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmul.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmul.f32 support target:mini_v100 cloud_v100
{"vec_binary_mul", "vmul"},
// vmin.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmin.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmin.f32 support target:mini_v100 cloud_v100
{"vec_binary_min", "vmin"},
// vmax.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmax.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmax.f32 support target:mini_v100 cloud_v100
{"vec_binary_max", "vmax"},
{"vec_binary_div", "vdiv"},
{"vec_binary_and", "vand"},
{"vec_binary_bitwise_and", "vand"},
{"vec_binary_or", "vor"},
{"vec_binary_bitwise_or", "vor"},
{"vec_binary_vmadd", "vmadd"},
{"vec_binary_vmaddrelu", "vmaddrelu"},
{"vec_binary_vmla", "vmla"}};
static const std::map<std::string, std::string> SingleVecInsnMap = {
// vmuls.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmuls.f32 supporttarget:mini_v100 cloud_v100
{"vec_single_muls", "vmuls"},
// vadds.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadds.f32 support target:mini_v100 cloud_v100
{"vec_single_adds", "vadds"},
// vrelu.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
{"vec_single_relu", "vrelu"},
// vabs.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vabs.f32 support target:mini_v100 cloud_v100
{"vec_single_abs", "vabs"},
// vln.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vln.f32 support target:cloud_v100
{"vec_single_log", "vln"},
// vexp.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vexp.f32 support target:cloud_v100
{"vec_single_exp", "vexp"},
// vrec.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vrec.f32 support target:mini_v100 cloud_v100
{"vec_single_rec", "vrec"},
// vnot support target:mini_v100 tiny_v100 lite_v100 cloud_v100
{"vec_single_not", "vnot"},
{"vec_single_bitwise_not", "vnot"},
// vsqrt support target:cloud_v100
{"vec_single_sqrt", "vsqrt"},
{"vec_single_rsqrt", "vrsqrt"},
{"vec_broadcast", "vector_dup"}};
static const std::map<std::string, std::string> SingleCastInsnMap = {
{"vec_single_floor", "f"}, {"vec_single_round", "r"}, {"vec_single_ceil", "c"}, {"vec_single_trunc", "z"}};
static const std::set<std::string> ReturnOpInsnSet = {"scalar_dma", "scatter", "vec_binary_select_loop_var"};
static const std::map<std::string, std::function<Stmt(const Stmt &)>> InsnFunctorMap = {
{"dma_atomic_add", DmaAtomicAddEmitter},
{"vec_single_cast", SingleCastEmitter},
......@@ -2017,9 +1761,9 @@ Stmt InsnEmit(std::string insn_name, const Stmt &op, bool enable_bisect, bool en
{"vec_dropout", BinaryDropoutEmitter},
{"cube_mad", MadEmitter},
{"vec_select_scalar", SelectWithScalarEmitter},
{"vec_binary_axpy", VaxpyEmitter},
{"opt_broadcast", MultiMaskEmitter},
{"vec_single_four2five_nchw", VnchwconvEmitter}};
{"vec_single_four2five_nchw", VnchwconvEmitter},
{"vtranspose", VTransposeEmitter}};
if (ReplaceAttrPragmaMap.count(insn_name) != 0) {
insn_name = ReplaceAttrPragmaMap.find(insn_name)->second;
......
......@@ -30,6 +30,100 @@
namespace akg {
namespace ir {
static const std::map<std::string, std::string> ReplaceAttrPragmaMap = {
// vector binary
{"binary_vcadd", "vec_binary_add"},
// vector single
{"vec_single_fabs", "vec_single_abs"},
{"broadcast", "vec_broadcast"},
// cube
{"mad", "cube_mad"},
{"ub2gm", "cube_ub2gm"},
{"im2col", "cube_img2col"},
// special attrs
{"vec_binary_proposal_sort", "vec_proposal_sort"},
{"vec_binary_topk_sort", "vec_topk_sort"},
{"vec_binary_dropout", "vec_dropout"},
{"vec_binary_fargmax", "vec_argmax"},
{"vec_binary_fargmin", "vec_argmin"},
{"vec_binary_iou", "vec_iou"},
{"vec_binary_nms", "vec_nms"},
{"mask_broadcast", "vec_broadcast"},
};
static const std::map<std::string, std::string> BinaryVecInsnMap = {
// vadd.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadd.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadd.f32 support target:mini_v100 cloud_v100
// vadd contains two situations:
// 1. normal elewise vector add
// - all src[i].shape = dst.shape
// 2. reductive vector add
// - exist src[i].shape != dst.shape
{"vec_binary_add", "vadd"},
// vsub.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vsub.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vsub.f32 support target:mini_v100 cloud_v100
{"vec_binary_sub", "vsub"},
// vmul.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmul.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmul.f32 support target:mini_v100 cloud_v100
{"vec_binary_mul", "vmul"},
// vmin.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmin.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmin.f32 support target:mini_v100 cloud_v100
{"vec_binary_min", "vmin"},
// vmax.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmax.s32 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmax.f32 support target:mini_v100 cloud_v100
{"vec_binary_max", "vmax"},
{"vec_binary_div", "vdiv"},
{"vec_binary_and", "vand"},
{"vec_binary_bitwise_and", "vand"},
{"vec_binary_or", "vor"},
{"vec_binary_bitwise_or", "vor"},
{"vec_binary_vmadd", "vmadd"},
{"vec_binary_vmaddrelu", "vmaddrelu"},
{"vec_binary_vmla", "vmla"}};
static const std::map<std::string, std::string> SingleVecInsnMap = {
// vmuls.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vmuls.f32 supporttarget:mini_v100 cloud_v100
{"vec_single_muls", "vmuls"},
// vadds.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vadds.f32 support target:mini_v100 cloud_v100
{"vec_single_adds", "vadds"},
// vrelu.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
{"vec_single_relu", "vrelu"},
// vabs.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vabs.f32 support target:mini_v100 cloud_v100
{"vec_single_abs", "vabs"},
// vln.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vln.f32 support target:cloud_v100
{"vec_single_log", "vln"},
// vexp.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vexp.f32 support target:cloud_v100
{"vec_single_exp", "vexp"},
// vrec.f16 support target:mini_v100 tiny_v100 lite_v100 cloud_v100
// vrec.f32 support target:mini_v100 cloud_v100
{"vec_single_rec", "vrec"},
// vnot support target:mini_v100 tiny_v100 lite_v100 cloud_v100
{"vec_single_not", "vnot"},
{"vec_single_bitwise_not", "vnot"},
// vsqrt support target:cloud_v100
{"vec_single_sqrt", "vsqrt"},
{"vec_single_rsqrt", "vrsqrt"},
{"vaxpy", "vaxpy"},
{"vec_broadcast", "vector_dup"},
{"vadds", "vadds"},
{"vmuls", "vmuls"},
{"vector_dup", "vector_dup"},
};
static const std::map<std::string, std::string> SingleCastInsnMap = {
{"vec_single_floor", "f"}, {"vec_single_round", "r"}, {"vec_single_ceil", "c"}, {"vec_single_trunc", "z"}};
static const std::set<std::string> ReturnOpInsnSet = {"scalar_calc", "scalar_dma", "scatter", "vec_binary_select_loop_var"};
Stmt EmitInsnWithDynamicShapes(const Stmt &s, const Map<Tensor, Buffer> &extern_buffer);
......
......@@ -935,7 +935,7 @@ void GetCompactComputationInfo(const Stmt &stmt, StmtInfoList &dst_info_list, St
/// \param if_info - The if-condition as input
/// \param for_info - The for-loop info to be modified
void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_info_list, const StmtInfo &if_info,
StmtInfo &for_info) {
StmtInfo &for_info) {
auto MergeTwoVar = [](const Var &keep_var, const Var &delete_var, StmtInfoList &dst_info_list,
StmtInfoList &src_info_list, StmtInfo &for_info) {
for (auto info : dst_info_list) {
......@@ -1059,8 +1059,7 @@ void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_i
bool find_merge = false;
for (size_t i = 0; (i < var_cnt - 1) && (!find_merge); i++) {
for (size_t j = i + 1; j < var_cnt; j++) {
if (CanMergeTwoVar(for_info.vars_[i], for_info.vars_[j], dst_info_list, src_info_list,
for_info)) {
if (CanMergeTwoVar(for_info.vars_[i], for_info.vars_[j], dst_info_list, src_info_list, for_info)) {
find_merge = true;
break;
}
......@@ -1075,7 +1074,6 @@ void CompactComputationInfoList(StmtInfoList &dst_info_list, StmtInfoList &src_i
}
}
/// A helper function for single dst_info's compact
/// \param dst_info
/// \param src_info_list
......@@ -1357,6 +1355,43 @@ int GetVectorizedVarPosition(const Expr &index, Array<Var> &loop_vars) {
return pos;
}
std::string GetOpType(const Expr &value) {
if (value.as<Add>()) {
return value.as<Add>()->_type_key;
}
if (value.as<Sub>()) {
return value.as<Sub>()->_type_key;
}
if (value.as<Mul>()) {
return value.as<Mul>()->_type_key;
}
if (value.as<Div>()) {
return value.as<Div>()->_type_key;
}
if (value.as<Mod>()) {
return value.as<Mod>()->_type_key;
}
if (value.as<FloorDiv>()) {
return value.as<FloorDiv>()->_type_key;
}
if (value.as<FloorMod>()) {
return value.as<FloorMod>()->_type_key;
}
if (value.as<Min>()) {
return value.as<Min>()->_type_key;
}
if (value.as<Max>()) {
return value.as<Max>()->_type_key;
}
if (value.as<Call>()) {
return value.as<Call>()->name;
}
if (value.as<Load>() || value.as<IntImm>() || value.as<FloatImm>()) {
return "DMACopy";
}
return "undefined";
}
/// TVM Function Register, enable python code to call these cpp function.
TVM_REGISTER_API("cce_util.GetCceAxis").set_body([](TVMArgs args, TVMRetValue *ret) { *ret = GetCceAxis(); });
......
......@@ -49,13 +49,7 @@ enum ArgType {
ARG_NOT_DEFINE
};
enum PatternType {
PATTERN_3D = 1,
PATTERN_PARTIAL_3D,
PATTERN_2D,
PATTERN_2D_BLOCK,
PATTERN_1D
};
enum PatternType { PATTERN_3D = 1, PATTERN_PARTIAL_3D, PATTERN_2D, PATTERN_2D_BLOCK, PATTERN_1D };
class StmtStoreInfoNode : public Node {
public:
......@@ -98,13 +92,9 @@ class StmtStoreInfo : public NodeRef {
explicit StmtStoreInfo(const ObjectPtr<Object> &n) : NodeRef(n), node_(n) {}
~StmtStoreInfo() = default;
inline StmtStoreInfoNode *GetNode() const {
return static_cast<StmtStoreInfoNode *>(node_.get());
}
inline StmtStoreInfoNode *GetNode() const { return static_cast<StmtStoreInfoNode *>(node_.get()); }
inline const StmtStoreInfoNode *operator->() const {
return static_cast<const StmtStoreInfoNode *>(node_.get());
}
inline const StmtStoreInfoNode *operator->() const { return static_cast<const StmtStoreInfoNode *>(node_.get()); }
void CleanFlexVar();
......@@ -188,13 +178,9 @@ class VectorArgInfo : public NodeRef {
explicit VectorArgInfo(const ObjectPtr<Object> &n) : NodeRef(n), node_(n) {}
~VectorArgInfo() = default;
inline VectorArgInfoNode *GetNode() const {
return static_cast<VectorArgInfoNode *>(node_.get());
}
inline VectorArgInfoNode *GetNode() const { return static_cast<VectorArgInfoNode *>(node_.get()); }
inline const VectorArgInfoNode *operator->() const {
return static_cast<const VectorArgInfoNode *>(node_.get());
}
inline const VectorArgInfoNode *operator->() const { return static_cast<const VectorArgInfoNode *>(node_.get()); }
void Print() const {
LOG(DEBUG) << "[ body_num: " << GetNode()->body_num_ << ", body_offset: " << GetNode()->body_offset_
......@@ -235,13 +221,9 @@ class ArgInfo : public NodeRef {
explicit ArgInfo(const ObjectPtr<Object> &n) : NodeRef(n), node_(n) {}
~ArgInfo() = default;
inline ArgInfoNode *GetNode() const {
return static_cast<ArgInfoNode *>(node_.get());
}
inline ArgInfoNode *GetNode() const { return static_cast<ArgInfoNode *>(node_.get()); }
inline const ArgInfoNode *operator->() const {
return static_cast<const ArgInfoNode *>(node_.get());
}
inline const ArgInfoNode *operator->() const { return static_cast<const ArgInfoNode *>(node_.get()); }
inline std::string GetPattern() const {
switch (GetNode()->pattern_) {
......@@ -373,6 +355,8 @@ bool IsBisectionReduction(const StmtInfoList &dst_info_list, const StmtInfoList
bool HasVars(const Expr &index, const Var &vec_var);
int GetVectorizedVarPosition(const Expr &index, Array<Var> &loop_vars);
std::string GetOpType(const Expr &value);
} // namespace akg
namespace air {
......
......@@ -77,7 +77,7 @@ class PatternGenerator {
class SingleVecPatternGenerator : public PatternGenerator {
public:
SingleVecPatternGenerator(const StmtInfoList &dst_info_list, const StmtInfoList &src_info_list,
const StmtInfo &for_info, const std::string &mode)
const StmtInfo &for_info, const std::string &mode = "elewise")
: PatternGenerator(dst_info_list, for_info),
arg_info(ArgInfo(make_node<ArgInfoNode>())),
body_args(VectorArgInfo()),
......
......@@ -33,9 +33,11 @@
#include "insn_info.h"
#include "insn_pattern.h"
#include "insn_emitter.h"
#include "ir_transform.h"
namespace akg {
namespace ir {
Expr GetVarCoefExpr(const Expr &index, const Var &loop_var) {
Expr ret = Expr();
Array<Expr> coefs = air::arith::DetectLinearEquation(index, {loop_var});
......@@ -203,7 +205,7 @@ class HasScalarVarValue : public IRVisitor {
class AdjustPragma : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>()) {
if (op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>()) {
is_candidate_ = true;
loop_vars_ = {};
loop_extends_ = {};
......@@ -295,7 +297,7 @@ class AdjustPragma : public IRMutator {
Array<Expr> srcs = call_ptr->args;
CHECK_EQ(srcs.size(), 2);
is_argmax_min_ = true;
reduce_type_ = (op->value.as<Call>()->name == "fargmin") ? "arg_min" : "arg_max";
reduce_type_ = (op->value.as<Call>()->name == "fargmin") ? "reduce_fargmin" : "reduce_fargmax";
return Store::make(op->buffer_var, Call::make(call_ptr->type, reduce_type_, {srcs[1]}, Call::CallType::Extern),
op->index, op->predicate);
} else if ((op->value.as<FloatImm>() || op->value.as<IntImm>() || op->value.as<UIntImm>()) &&
......@@ -484,353 +486,6 @@ class AdjustPragma : public IRMutator {
Array<Var> transpose_vars_;
};
class TransposeTransform : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>() &&
op->value.as<StringImm>()->value == "dma_copy") {
pre_transpose_buffer = Var("srcTranspose_local_UB");
post_transpose_buffer = Var("dstTranspose_local_UB");
loop_vars_ = {};
loop_extends_ = {};
is_candidate_ = true;
is_block_transpose_ = false;
auto body = this->Mutate(op->body);
is_candidate_ = false;
if (is_block_transpose_) {
is_block_transpose_ = false;
auto allocate_pre_buffer = Allocate::make(pre_transpose_buffer, t_type, {TransTotalSize}, const_true(1), body);
auto attr_pre_buffer =
AttrStmt::make(pre_transpose_buffer, "storage_scope", Expr("local.UB"), allocate_pre_buffer);
auto allocate_post_buffer =
Allocate::make(post_transpose_buffer, t_type, {TransTotalSize}, const_true(1), attr_pre_buffer);
auto attr_post_buffer =
AttrStmt::make(post_transpose_buffer, "storage_scope", Expr("local.UB"), allocate_post_buffer);
return attr_post_buffer;
} else {
return AttrStmt::make(op->node, op->attr_key, op->value, body);
}
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (is_candidate_) {
loop_vars_.push_back(op->loop_var);
loop_extends_.push_back(op->extent);
Stmt body = this->Mutate(op->body);
if (is_block_transpose_ && IsInArray(trans_vars_, op->loop_var)) {
return body;
} else {
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
if (is_candidate_) {
auto value = op->value;
if (auto cast = op->value.as<Cast>()) {
value = cast->value;
}
CHECK(value.as<Load>());
auto src_ptr = value.as<Load>();
if (GetBufferType(op->buffer_var) == SCOPE_UBUF && GetBufferType(src_ptr->buffer_var) == SCOPE_UBUF) {
int dst_pos = GetVectorizedVarPosition(op->index, loop_vars_);
int src_pos = GetVectorizedVarPosition(src_ptr->index, loop_vars_);
if (dst_pos != -1 && src_pos != -1 && dst_pos != src_pos &&
floormod(loop_extends_[dst_pos], TransAxisLen).as<IntImm>() &&
floormod(loop_extends_[dst_pos], TransAxisLen).as<IntImm>()->value == 0 &&
Equal(GetVarCoefExpr(op->index, loop_vars_[src_pos]), loop_extends_[dst_pos])) {
if (loop_extends_[dst_pos].as<IntImm>() && loop_extends_[dst_pos].as<IntImm>()->value == TransAxisLen &&
loop_extends_[src_pos].as<IntImm>() && loop_extends_[src_pos].as<IntImm>()->value == TransAxisLen) {
return s;
} else {
is_block_transpose_ = true;
t_type = src_ptr->type;
trans_vars_ = {};
trans_vars_.push_back(loop_vars_[src_pos]);
trans_vars_.push_back(loop_vars_[dst_pos]);
Expr ori_w = GetVarCoefExpr(src_ptr->index, loop_vars_[dst_pos]);
Expr ori_h = loop_extends_[dst_pos];
Expr ori_block_w = floordiv(ori_w, TransAxisLen);
Expr ori_block_h = floordiv(ori_h, TransAxisLen);
Var loop_w = Var("block_w");
Var loop_h = Var("block_h");
Expr src_base_index = EliminateVarInExpr(src_ptr->index, trans_vars_);
Expr dst_base_index = EliminateVarInExpr(op->index, trans_vars_);
Var tt0 = Var("tt0");
Var tt1 = Var("tt1");
auto pre_copy = Store::make(
pre_transpose_buffer,
Load::make(t_type, src_ptr->buffer_var,
src_base_index + loop_h * TransAxisLen * ori_w + loop_w * TransAxisLen + tt1 * ori_w + tt0, 1),
tt1 * TransAxisLen + tt0, 1);
auto pre_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, pre_copy);
auto pre_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, pre_l0);
auto pre_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), pre_l1);
auto transpose =
Store::make(post_transpose_buffer, Load::make(t_type, pre_transpose_buffer, tt1 * TransAxisLen + tt0, 1),
tt0 * 16 + tt1, 1);
auto trans_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, transpose);
auto trans_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, trans_l0);
auto trans_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), trans_l1);
auto post_copy = Store::make(
op->buffer_var, Load::make(t_type, post_transpose_buffer, tt1 * TransAxisLen + tt0, 1),
dst_base_index + loop_w * TransAxisLen * ori_h + loop_h * TransAxisLen + tt1 * ori_h + tt0, 1);
auto post_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, post_copy);
auto post_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, post_l0);
auto post_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), post_l1);
auto full_inner = Block::make(Block::make(pre_attr, trans_attr), post_attr);
auto inner_w = For::make(loop_w, 0, ori_block_w, ForType::Serial, DeviceAPI::None, full_inner);
auto inner_h = For::make(loop_h, 0, ori_block_h, ForType::Serial, DeviceAPI::None, inner_w);
return inner_h;
}
}
}
}
return s;
}
bool is_candidate_{false};
bool is_block_transpose_{false};
Array<Var> trans_vars_;
Array<Var> loop_vars_;
Array<Expr> loop_extends_;
Type t_type;
Var pre_transpose_buffer;
Var post_transpose_buffer;
};
class IfReorder : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>() &&
op->value.as<StringImm>()->value != "mad") {
in_insn_ = true;
for_vars_.clear();
if_vars_.clear();
for_vec_.clear();
if_vec_.clear();
auto body = this->Mutate(op->body);
in_insn_ = false;
if (!if_vec_.empty()) {
Stmt new_s = AttrStmt::make(op->node, op->attr_key, op->value, body);
for (auto if_op : if_vec_) {
new_s = IfThenElse::make(if_op->condition, new_s);
}
for (auto for_op = for_vec_.rbegin(); for_op != for_vec_.rend(); ++for_op) {
bool find_flag = false;
for (auto for_iter = for_vars_.begin(); for_iter != for_vars_.end(); ++for_iter) {
if (Equal((*for_iter), (*for_op)->loop_var)) {
find_flag = true;
break;
}
}
if (find_flag) {
new_s = For::make((*for_op)->loop_var, (*for_op)->min, (*for_op)->extent, ForType::Serial, DeviceAPI::None,
new_s);
}
}
return new_s;
} else {
return s;
}
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (in_insn_) {
for_vec_.push_back(op);
for_vars_.push_back(op->loop_var);
Stmt body = this->Mutate(op->body);
std::vector<Var>::iterator for_iter;
for (for_iter = for_vars_.begin(); for_iter != for_vars_.end(); ++for_iter) {
if (Equal((*for_iter), op->loop_var)) {
break;
}
}
if (!if_vec_.empty()) {
std::vector<Var>::iterator if_iter;
bool find_flag = false;
for (if_iter = if_vars_.begin(); if_iter != if_vars_.end(); ++if_iter) {
if (Equal((*if_iter), op->loop_var)) {
find_flag = true;
break;
}
}
if (find_flag) {
return body;
} else {
for_vars_.erase(for_iter);
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
} else {
for_vars_.erase(for_iter);
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const IfThenElse *op, const Stmt &s) final {
if (in_insn_) {
if_vec_.push_back(op);
for (auto loop_var : for_vars_) {
if (HasVars(op->condition, loop_var)) {
if_vars_.push_back(loop_var);
}
}
Stmt body = this->Mutate(op->then_case);
return body;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
if (in_insn_) {
return s;
}
return IRMutator::Mutate_(op, s);
}
bool in_insn_{false};
std::vector<const IfThenElse *> if_vec_;
std::vector<Var> if_vars_;
std::vector<Var> for_vars_;
std::vector<const For *> for_vec_;
std::vector<const For *> before_if_;
};
class LoopReorder : public IRMutator {
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>()) {
in_insn_ = true;
pragma = op->value.as<StringImm>()->value;
for_map_.clear();
ori_vars_ = {};
var_order_.clear();
auto ret = this->Mutate(op->body);
in_insn_ = false;
if (!has_changed_) {
return s;
} else {
if (var_order_.empty()) {
ret = AttrStmt::make(op->node, op->attr_key, op->value, ret);
for (size_t i = 0; i < ori_vars_.size(); ++i) {
CHECK_GT(for_map_.count(ori_vars_[i].get()), 0);
auto ptr = for_map_[ori_vars_[i].get()];
ret = For::make(ptr->loop_var, ptr->min, ptr->extent, ptr->for_type, ptr->device_api, ret);
}
} else {
for (size_t i = 0; i < var_order_.size(); ++i) {
CHECK_GT(for_map_.count(var_order_[i].get()), 0);
auto ptr = for_map_[var_order_[i].get()];
ret = For::make(ptr->loop_var, ptr->min, ptr->extent, ptr->for_type, ptr->device_api, ret);
}
ret = AttrStmt::make(op->node, op->attr_key, op->value, ret);
}
return ret;
}
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (in_insn_) {
for_map_[(op->loop_var).get()] = op;
ori_vars_.push_back(op->loop_var);
auto body = this->Mutate(op->body);
return body;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
int dst_pos = GetVectorizedVarPosition(op->index, ori_vars_);
int len = static_cast<int>(ori_vars_.size());
std::vector<const Load *> srcs;
auto get_loads = [&srcs](const NodeRef &node) {
if (const auto v = node.as<Load>()) {
srcs.push_back(v);
}
};
PostOrderVisit(op->value, get_loads);
bool same_pos = true;
std::vector<int> srcs_pos;
for (int i = 0; i < static_cast<int>(srcs.size()); ++i) {
int temp_pos = GetVectorizedVarPosition(srcs[i]->index, ori_vars_);
srcs_pos.push_back(temp_pos);
if (temp_pos != dst_pos) {
same_pos = false;
}
}
has_changed_ = false;
if (dst_pos >= 0 && len >= 2 && dst_pos != (len - 1) && (same_pos || pragma == "broadcast")) {
// Src Load empty; all Load and Dst has the same key axis; broadcast
has_changed_ = true;
var_order_.push_back(ori_vars_[dst_pos]);
for (int i = len - 1; i >= 0; i--) {
if (i != dst_pos) {
var_order_.push_back(ori_vars_[i]);
}
}
} else if (pragma.find("reduce") != pragma.npos && len >= 2 && srcs_pos[0] != (len - 1)) {
// based on dst key axis: reduce
has_changed_ = true;
var_order_.push_back(ori_vars_[srcs_pos[0]]);
for (int i = len - 1; i >= 0; i--) {
if (i != srcs_pos[0]) {
var_order_.push_back(ori_vars_[i]);
}
}
}
return s;
}
std::unordered_map<const Variable *, const For *> for_map_;
std::vector<Var> var_order_;
Array<Var> ori_vars_;
bool has_changed_{false};
bool in_insn_{false};
std::string pragma;
};
class ForVarUnique : public IRMutator {
public:
Stmt Mutate_(const For *op, const Stmt &s) final {
auto body = this->Mutate(op->body);
if (var_maps_.count(op->loop_var.get())) {
Var new_var = Var("ii" + std::to_string(++index_));
std::unordered_map<const Variable *, Expr> value_map;
value_map[op->loop_var.get()] = new_var;
auto new_body = Substitute(body, value_map);
var_maps_[new_var.get()] = 1;
return For::make(new_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, new_body);
} else {
var_maps_[op->loop_var.get()] = 1;
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
}
std::unordered_map<const Variable *, int> var_maps_;
int index_{0};
};
class GenSIMD {
public:
GenSIMD(CCEInfo &t_info, Map<std::string, Buffer> &buffer_map, const std::string &pragma)
......@@ -1520,9 +1175,9 @@ class GenReduce {
~GenReduce() = default;
Stmt Run(int pre_index) {
is_arg_type_ = (pragma_ == "arg_max" || pragma_ == "arg_min");
is_arg_type_ = (pragma_ == "reduce_fargmax" || pragma_ == "reduce_fargmin");
RemoveVectorizedIndex(t_info_, 0);
if (pragma_.find("sum") != std::string::npos) {
if (pragma_.find("sum") != std::string::npos || pragma_.find("add") != std::string::npos) {
insn_intrinsic_ = "vcadd";
expansion_factor_ = 1;
} else if (pragma_.find("max") != std::string::npos) {
......@@ -1769,7 +1424,7 @@ class EmitVariableInsns : public IRMutator {
}
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (air::ir::attr::IsPragmaKey(op->attr_key) && op->attr_key == "pragma_emit_insn") {
if (op->attr_key == "pragma_emit_insn") {
CHECK(op->value.as<StringImm>());
pragma = op->value.as<StringImm>()->value;
Stmt r;
......@@ -1791,8 +1446,7 @@ class EmitVariableInsns : public IRMutator {
if (!r.same_as(s)) {
return r;
}
} else if (air::ir::attr::IsPragmaKey(op->attr_key) &&
(op->attr_key == "pragma_im2col" || op->attr_key == "pragma_load3d")) {
} else if (op->attr_key == "pragma_im2col" || op->attr_key == "pragma_load3d") {
if (paramters_.defined() && Downcast<Map<std::string, NodeRef>>(paramters_).count("feature")) {
auto feature = Downcast<Map<std::string, NodeRef>>(paramters_)["feature"].as<StringImm>();
CHECK(feature);
......@@ -1842,13 +1496,13 @@ class EmitVariableInsns : public IRMutator {
if (pragma.find("vec_select") != std::string::npos) {
EmitSelect(op, t_info);
} else if (pragma.find("dma_copy") == 0) {
} else if (pragma.find("dma_copy") != std::string::npos) {
EmitDMA(t_info);
} else if (pragma.find("vec_binary") == 0 || pragma.find("vec_single") == 0) {
} else if (pragma.find("vec_binary") != std::string::npos || pragma.find("vec_single") != std::string::npos) {
EmitSIMD(t_info);
} else if (pragma.find("reduce") == 0 || pragma.find("arg_") == 0) {
} else if (pragma.find("reduce") != std::string::npos || pragma.find("arg_") != std::string::npos) {
EmitReduce(t_info);
} else if (pragma.find("broadcast") == 0) {
} else if (pragma.find("broadcast") != std::string::npos) {
if (loops_vars_.empty()) {
gen_cce = t_info.ori_stmt;
} else {
......
......@@ -31,8 +31,7 @@
namespace akg {
namespace ir {
const int TransTotalSize = 256;
const int TransAxisLen = 16;
const int64_t FullReduceMaskValue = 6148914691236517205;
class CCEInsn {
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef IR_TRANSFORM_H_
#define IR_TRANSFORM_H_
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <unordered_set>
#include <map>
#include <numeric>
#include <set>
#include <algorithm>
#include "ir_pass.h"
#include "common/array_api.h"
#include "insn_with_variable.h"
#include "insn_builder.h"
#include "insn_info.h"
#include "insn_pattern.h"
#include "../pass/analyze_align.h"
const int TransTotalSize = 256;
const int TransAxisLen = 16;
namespace akg {
namespace ir {
Expr GetVarCoefExpr(const Expr &index, const Var &loop_var);
std::string GetBufferType(Expr address);
class TransposeTransform : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>() &&
(op->value.as<StringImm>()->value == "dma_copy")) {
pre_transpose_buffer_ = Var("srcTranspose_local_UB");
post_transpose_buffer_ = Var("dstTranspose_local_UB");
pre_trans_cast_ = Var("pre_trans_cast__local_UB");
post_trans_cast_ = Var("post_trans_cast__local_UB");
loop_vars_ = {};
loop_extends_ = {};
is_candidate_ = true;
is_block_transpose_ = false;
is_native_transpose_ = false;
align_value = FREE_ALIGN;
remain_fors_.clear();
auto body = this->Mutate(op->body);
is_candidate_ = false;
if (is_block_transpose_) {
is_block_transpose_ = false;
if (t_type_ == Float(32)) { // need cast
body = Allocate::make(pre_trans_cast_, Float(16), {TransTotalSize}, const_true(1), body);
body = AttrStmt::make(pre_trans_cast_, "storage_scope", Expr("local.UB"), body);
body = Allocate::make(post_trans_cast_, Float(16), {TransTotalSize}, const_true(1), body);
body = AttrStmt::make(post_trans_cast_, "storage_scope", Expr("local.UB"), body);
}
auto allocate_pre_buffer =
Allocate::make(pre_transpose_buffer_, t_type_, {TransTotalSize}, const_true(1), body);
auto attr_pre_buffer =
AttrStmt::make(pre_transpose_buffer_, "storage_scope", Expr("local.UB"), allocate_pre_buffer);
auto allocate_post_buffer =
Allocate::make(post_transpose_buffer_, t_type_, {TransTotalSize}, const_true(1), attr_pre_buffer);
auto attr_post_buffer =
AttrStmt::make(post_transpose_buffer_, "storage_scope", Expr("local.UB"), allocate_post_buffer);
Stmt ret = attr_post_buffer;
if (align_value != FREE_ALIGN) {
ret = AttrStmt::make(align_buffer_, "align_info", Expr(align_value), ret);
}
return ret;
}
if (is_native_transpose_) {
Stmt ret = AttrStmt::make(op->node, op->attr_key, Expr("dma_copy_transpose"), body);
for (int i = 0; i <= static_cast<int>(remain_fors_.size()) - 1; ++i) {
ret = For::make(remain_fors_[i]->loop_var, remain_fors_[i]->min, remain_fors_[i]->extent, ForType::Serial,
DeviceAPI::None, ret);
}
return ret;
}
return AttrStmt::make(op->node, op->attr_key, op->value, body);
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (is_candidate_) {
loop_vars_.push_back(op->loop_var);
loop_extends_.push_back(op->extent);
Stmt body = this->Mutate(op->body);
if (is_block_transpose_ && IsInArray(trans_vars_, op->loop_var)) {
return body;
}
if (is_native_transpose_) {
if (IsInArray(trans_vars_, op->loop_var)) {
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
remain_fors_.push_back(op);
return body;
}
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
if (is_candidate_) {
auto value = op->value;
if (auto cast = op->value.as<Cast>()) {
value = cast->value;
}
CHECK(value.as<Load>());
auto src_ptr = value.as<Load>();
if (GetBufferType(op->buffer_var) == SCOPE_UBUF && GetBufferType(src_ptr->buffer_var) == SCOPE_UBUF &&
src_ptr->type == Float(16)) {
int dst_pos = GetVectorizedVarPosition(op->index, loop_vars_);
int src_pos = GetVectorizedVarPosition(src_ptr->index, loop_vars_);
if (dst_pos != -1 && src_pos != -1 && dst_pos != src_pos && HasVars(src_ptr->index, loop_vars_[dst_pos]) &&
HasVars(op->index, loop_vars_[src_pos]) && floormod(loop_extends_[dst_pos], TransAxisLen).as<IntImm>() &&
floormod(loop_extends_[dst_pos], TransAxisLen).as<IntImm>()->value == 0 &&
Equal(GetVarCoefExpr(op->index, loop_vars_[src_pos]), loop_extends_[dst_pos])) {
if (loop_extends_[dst_pos].as<IntImm>() && loop_extends_[dst_pos].as<IntImm>()->value == TransAxisLen &&
loop_extends_[src_pos].as<IntImm>() && loop_extends_[src_pos].as<IntImm>()->value == TransAxisLen) {
trans_vars_ = {};
trans_vars_.push_back(loop_vars_[src_pos]);
trans_vars_.push_back(loop_vars_[dst_pos]);
is_native_transpose_ = true;
return s;
}
is_block_transpose_ = true;
if (GetVarCoefExpr(src_ptr->index, loop_vars_[dst_pos]).as<IntImm>()) {
int coef_t = GetVarCoefExpr(src_ptr->index, loop_vars_[dst_pos]).as<IntImm>()->value;
if (coef_t % TransAxisLen != 0) {
align_value = coef_t;
align_buffer_ = src_ptr->buffer_var;
}
}
t_type_ = src_ptr->type;
trans_vars_ = {};
trans_vars_.push_back(loop_vars_[src_pos]);
trans_vars_.push_back(loop_vars_[dst_pos]);
Expr ori_w = GetVarCoefExpr(src_ptr->index, loop_vars_[dst_pos]);
Expr ori_h = loop_extends_[dst_pos];
Expr ori_block_w = floordiv(ori_w, TransAxisLen);
// padding the width
Expr unit_width = TransAxisLen;
if (!Equal(floormod(ori_w, TransAxisLen), 0)) {
ori_block_w = ori_block_w + 1;
}
if (ori_w.as<IntImm>() && ori_w.as<IntImm>()->value < TransAxisLen) {
unit_width = ori_w;
}
Expr ori_block_h = floordiv(ori_h, TransAxisLen);
Var loop_w = Var("block_w");
Var loop_h = Var("block_h");
Expr src_base_index = EliminateVarInExpr(src_ptr->index, trans_vars_);
Expr dst_base_index = EliminateVarInExpr(op->index, trans_vars_);
Var tt0 = Var("tt0");
Var tt1 = Var("tt1");
auto pre_copy = Store::make(
pre_transpose_buffer_,
Load::make(t_type_, src_ptr->buffer_var,
src_base_index + loop_h * TransAxisLen * ori_w + loop_w * TransAxisLen + tt1 * ori_w + tt0, 1),
tt1 * TransAxisLen + tt0, 1);
auto pre_l0 = For::make(tt0, 0, unit_width, ForType::Serial, DeviceAPI::None, pre_copy);
auto pre_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, pre_l0);
auto pre_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), pre_l1);
Stmt trans_attr = Stmt();
if (t_type_ == Float(16)) {
auto transpose =
Store::make(post_transpose_buffer_,
Load::make(t_type_, pre_transpose_buffer_, tt1 * TransAxisLen + tt0, 1), tt0 * 16 + tt1, 1);
auto trans_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, transpose);
auto trans_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, trans_l0);
trans_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy_transpose"), trans_l1);
} else {
auto pre_cast_store = Store::make(
pre_trans_cast_, Cast::make(Float(16), Load::make(t_type_, pre_transpose_buffer_, tt0, 1)), tt0, 1);
auto pre_cast_for = For::make(tt0, 0, TransTotalSize, ForType::Serial, DeviceAPI::None, pre_cast_store);
auto pre_cast_attr =
AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("vec_single_cast"), pre_cast_for);
auto transpose = Store::make(
post_trans_cast_, Load::make(Float(16), pre_trans_cast_, tt1 * TransAxisLen + tt0, 1), tt0 * 16 + tt1, 1);
auto trans_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, transpose);
auto trans_l1 = For::make(tt1, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, trans_l0);
auto trans_block =
AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy_transpose"), trans_l1);
auto post_cast_store = Store::make(
post_transpose_buffer_, Cast::make(t_type_, Load::make(Float(16), post_trans_cast_, tt0, 1)), tt0, 1);
auto post_cast_for = For::make(tt0, 0, TransTotalSize, ForType::Serial, DeviceAPI::None, post_cast_store);
auto post_cast_attr =
AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("vec_single_cast"), post_cast_for);
trans_attr = Block::make(Block::make(pre_cast_attr, trans_block), post_cast_attr);
}
auto post_copy =
Store::make(op->buffer_var, Load::make(t_type_, post_transpose_buffer_, tt1 * TransAxisLen + tt0, 1),
dst_base_index + loop_w * TransAxisLen * ori_h + loop_h * TransAxisLen + tt1 * ori_h + tt0, 1);
auto post_l0 = For::make(tt0, 0, TransAxisLen, ForType::Serial, DeviceAPI::None, post_copy);
auto post_l1 = For::make(tt1, 0, unit_width, ForType::Serial, DeviceAPI::None, post_l0);
auto post_attr = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("dma_copy"), post_l1);
auto full_inner = Block::make(Block::make(pre_attr, trans_attr), post_attr);
auto inner_w = For::make(loop_w, 0, ori_block_w, ForType::Serial, DeviceAPI::None, full_inner);
if (ori_block_w.as<IntImm>() && ori_block_w.as<IntImm>()->value == 1) {
std::unordered_map<const Variable *, Expr> init;
init[loop_w.get()] = 0;
inner_w = Simplify(Substitute(full_inner, init));
}
auto inner_h = For::make(loop_h, 0, ori_block_h, ForType::Serial, DeviceAPI::None, inner_w);
if (ori_block_h.as<IntImm>() && ori_block_h.as<IntImm>()->value == 1) {
std::unordered_map<const Variable *, Expr> init;
init[loop_h.get()] = 0;
inner_h = Simplify(Substitute(inner_w, init));
}
return inner_h;
}
}
}
return s;
}
private:
bool is_candidate_{false};
bool is_native_transpose_{false};
bool is_block_transpose_{false};
int align_value{FREE_ALIGN};
Var align_buffer_;
Array<Var> trans_vars_;
Array<Var> loop_vars_;
Array<Expr> loop_extends_;
std::vector<const For *> remain_fors_;
Type t_type_;
Var pre_transpose_buffer_;
Var pre_trans_cast_;
Var post_trans_cast_;
Var post_transpose_buffer_;
};
class ForVarUnique : public IRMutator {
public:
Stmt Mutate_(const For *op, const Stmt &s) final {
auto body = this->Mutate(op->body);
if (var_maps_.count(op->loop_var.get())) {
Var new_var = Var("ii" + std::to_string(++index_));
std::unordered_map<const Variable *, Expr> value_map;
value_map[op->loop_var.get()] = new_var;
auto new_body = Substitute(body, value_map);
var_maps_[new_var.get()] = 1;
return For::make(new_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, new_body);
}
var_maps_[op->loop_var.get()] = 1;
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
private:
std::unordered_map<const Variable *, int> var_maps_;
int index_{0};
};
class LoopReorder : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>()) {
in_insn_ = true;
pragma_ = op->value.as<StringImm>()->value;
for_map_.clear();
ori_vars_ = {};
var_order_.clear();
auto ret = this->Mutate(op->body);
in_insn_ = false;
if (!has_changed_) {
return s;
}
if (var_order_.empty()) {
ret = AttrStmt::make(op->node, op->attr_key, op->value, ret);
for (size_t i = 0; i < ori_vars_.size(); ++i) {
CHECK_GT(for_map_.count(ori_vars_[i].get()), 0);
auto ptr = for_map_[ori_vars_[i].get()];
ret = For::make(ptr->loop_var, ptr->min, ptr->extent, ptr->for_type, ptr->device_api, ret);
}
return ret;
}
for (size_t i = 0; i < var_order_.size(); ++i) {
CHECK_GT(for_map_.count(var_order_[i].get()), 0);
auto ptr = for_map_[var_order_[i].get()];
ret = For::make(ptr->loop_var, ptr->min, ptr->extent, ptr->for_type, ptr->device_api, ret);
}
ret = AttrStmt::make(op->node, op->attr_key, op->value, ret);
return ret;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (in_insn_) {
for_map_[(op->loop_var).get()] = op;
ori_vars_.push_back(op->loop_var);
auto body = this->Mutate(op->body);
return body;
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
int dst_pos = GetVectorizedVarPosition(op->index, ori_vars_);
int len = static_cast<int>(ori_vars_.size());
std::vector<const Load *> srcs;
auto get_loads = [&srcs](const NodeRef &node) {
if (const auto v = node.as<Load>()) {
srcs.push_back(v);
}
};
PostOrderVisit(op->value, get_loads);
bool same_pos = true;
std::vector<int> srcs_pos;
for (int i = 0; i < static_cast<int>(srcs.size()); ++i) {
int temp_pos = GetVectorizedVarPosition(srcs[i]->index, ori_vars_);
srcs_pos.push_back(temp_pos);
if (temp_pos != dst_pos) {
same_pos = false;
}
}
has_changed_ = false;
if (dst_pos >= 0 && len >= 2 && dst_pos != (len - 1) && (same_pos || pragma_ == "broadcast")) {
// Src Load empty; all Load and Dst has the same key axis; broadcast
has_changed_ = true;
var_order_.push_back(ori_vars_[dst_pos]);
for (int i = len - 1; i >= 0; i--) {
if (i != dst_pos) {
var_order_.push_back(ori_vars_[i]);
}
}
} else if (pragma_.find("reduce") != pragma_.npos && len >= 2 && srcs_pos[0] != (len - 1)) {
// based on dst key axis: reduce
has_changed_ = true;
var_order_.push_back(ori_vars_[srcs_pos[0]]);
for (int i = len - 1; i >= 0; i--) {
if (i != srcs_pos[0]) {
var_order_.push_back(ori_vars_[i]);
}
}
}
return s;
}
private:
std::unordered_map<const Variable *, const For *> for_map_;
std::vector<Var> var_order_;
Array<Var> ori_vars_;
bool has_changed_{false};
bool in_insn_{false};
std::string pragma_;
};
class IfReorder : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>() &&
!exclude_align_analyze_list.count(op->value.as<StringImm>()->value)) {
in_insn_ = true;
for_vars_.clear();
if_vars_.clear();
for_vec_.clear();
if_vec_.clear();
auto body = this->Mutate(op->body);
in_insn_ = false;
if (!if_vec_.empty()) {
Stmt new_s = AttrStmt::make(op->node, op->attr_key, op->value, body);
for (auto if_op : if_vec_) {
new_s = IfThenElse::make(if_op->condition, new_s);
}
for (auto for_op = for_vec_.rbegin(); for_op != for_vec_.rend(); ++for_op) {
bool find_flag = false;
for (auto for_iter = for_vars_.begin(); for_iter != for_vars_.end(); ++for_iter) {
if (Equal((*for_iter), (*for_op)->loop_var)) {
find_flag = true;
break;
}
}
if (find_flag) {
new_s = For::make((*for_op)->loop_var, (*for_op)->min, (*for_op)->extent, ForType::Serial, DeviceAPI::None,
new_s);
}
}
return new_s;
}
return s;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const For *op, const Stmt &s) final {
if (in_insn_) {
for_vec_.push_back(op);
for_vars_.push_back(op->loop_var);
Stmt body = this->Mutate(op->body);
std::vector<Var>::iterator for_iter;
for (for_iter = for_vars_.begin(); for_iter != for_vars_.end(); ++for_iter) {
if (Equal((*for_iter), op->loop_var)) {
break;
}
}
if (!if_vec_.empty()) {
std::vector<Var>::iterator if_iter;
bool find_flag = false;
for (if_iter = if_vars_.begin(); if_iter != if_vars_.end(); ++if_iter) {
if (Equal((*if_iter), op->loop_var)) {
find_flag = true;
break;
}
}
if (find_flag) {
return body;
}
for_vars_.erase(for_iter);
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
for_vars_.erase(for_iter);
return For::make(op->loop_var, op->min, op->extent, ForType::Serial, DeviceAPI::None, body);
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const IfThenElse *op, const Stmt &s) final {
if (in_insn_) {
if_vec_.push_back(op);
for (auto loop_var : for_vars_) {
if (HasVars(op->condition, loop_var)) {
if_vars_.push_back(loop_var);
}
}
Stmt body = this->Mutate(op->then_case);
return body;
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
if (in_insn_) {
return s;
}
return IRMutator::Mutate_(op, s);
}
private:
bool in_insn_{false};
std::vector<const IfThenElse *> if_vec_;
std::vector<Var> if_vars_;
std::vector<Var> for_vars_;
std::vector<const For *> for_vec_;
std::vector<const For *> before_if_;
};
} // namespace ir
} // namespace akg
#endif // IR_TRANSFORM_H_
\ No newline at end of file
......@@ -265,6 +265,10 @@ Stmt RewriteBroadcastVector(Stmt stmt);
Stmt OptimizePragma(Stmt stmt);
Stmt PackStore(Stmt stmt);
Stmt RecoverStore(Stmt stmt);
Stmt RewriteByAlignDynamic(Stmt stmt);
Stmt EliminateAtomicDma(Stmt stmt);
......
此差异已折叠。
......@@ -466,7 +466,7 @@ class AlignVistor : public IRVisitor {
// only scan dma insns
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value.as<StringImm>() &&
op->value.as<StringImm>()->value != "vec_binary_dropout" &&
exclude_list.count(op->value.as<StringImm>()->value) == 0)) {
exclude_align_analyze_list.count(op->value.as<StringImm>()->value) == 0)) {
bool in_dma_copy = false;
if (op->value.as<StringImm>() && op->value.as<StringImm>()->value == "dma_copy") {
in_dma_copy = true;
......
此差异已折叠。
......@@ -43,7 +43,7 @@ class LoopsCompacter : public IRMutator {
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
!exclude_list.count(op->value.as<StringImm>()->value))) {
!exclude_align_analyze_list.count(op->value.as<StringImm>()->value))) {
stores_ = Array<NodeRef>();
loads_ = Array<NodeRef>();
GetStoreAndLoads(op->body, stores_, loads_);
......
......@@ -192,6 +192,7 @@ class MultiLastAxisReduction : public IRMutator {
lastResult = loadTmp + storeLeft;
}
broadcastNum = Call::make(type_tmp, "vector_dup", {broadcastNum}, Call::PureIntrinsic);
Stmt stForOnce = Store::make(tmpBuffer, storeResult, newIdx, storeTmp->predicate);
Stmt stForTwice = Store::make(storeTmp->buffer_var, lastResult, storeTmp->index, storeTmp->predicate);
Stmt stBroadcast = Store::make(tmpBuffer, broadcastNum, newIdx, storeTmp->predicate);
......@@ -212,7 +213,7 @@ class MultiLastAxisReduction : public IRMutator {
stForOnce = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr(str), stForOnce);
stForTwice = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr(str), stForTwice);
stBroadcast = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr("broadcast"), stBroadcast);
stBroadcast = AttrStmt::make(VarExpr("0", Int(32)), "pragma_emit_insn", Expr("vector_dup"), stBroadcast);
stmt = Block::make({stBroadcast, stForOnce, stForTwice});
stmt = Allocate::make(tmpBuffer, type_tmp, extentsArray, const_true(), stmt);
......
......@@ -147,7 +147,7 @@ class EstimateAlign : public IRMutator {
Stmt Mutate_(const AttrStmt *op, const Stmt &stmt) final {
if (air::ir::attr::IsPragmaKey(op->attr_key) && op->value.as<StringImm>()) {
if (exclude_list.count(op->value.as<StringImm>()->value)) {
if (exclude_align_analyze_list.count(op->value.as<StringImm>()->value)) {
return stmt;
}
......
......@@ -46,7 +46,7 @@ class AxisPartitioner : public IRMutator {
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
exclude_list.count(op->value.as<StringImm>()->value) == 0)) {
exclude_index_fix_list.count(op->value.as<StringImm>()->value) == 0)) {
in_insn_ = true;
counter_ = 0;
auto ret = IRMutator::Mutate_(op, s);
......@@ -180,7 +180,7 @@ class RewriteAllocateAndIndex : public IRMutator {
}
}
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
(exclude_list.count(op->value.as<StringImm>()->value) == 0 ||
(exclude_index_fix_list.count(op->value.as<StringImm>()->value) == 0 ||
op->value.as<StringImm>()->value == "scatter"))) {
in_insn_ = true;
auto ret = IRMutator::Mutate_(op, s);
......
......@@ -46,7 +46,7 @@ class AxisPartitioner : public IRMutator {
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
exclude_list.count(op->value.as<StringImm>()->value) == 0)) {
exclude_index_fix_list.count(op->value.as<StringImm>()->value) == 0)) {
in_insn_ = true;
counter_ = 0;
auto ret = IRMutator::Mutate_(op, s);
......@@ -182,7 +182,7 @@ class RewriteAllocateAndIndex : public IRMutator {
}
}
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
(exclude_list.count(op->value.as<StringImm>()->value) == 0 ||
(exclude_index_fix_list.count(op->value.as<StringImm>()->value) == 0 ||
op->value.as<StringImm>()->value == "scatter"))) {
in_insn_ = true;
auto ret = IRMutator::Mutate_(op, s);
......@@ -307,12 +307,7 @@ class RewriteAllocateAndIndex : public IRMutator {
CHECK_NE(align, 0);
int64_t coef = GetIntConst(strides[0]);
if (std::abs(coef) < align) {
auto it = var2ext_.find(v.get());
if (it != var2ext_.end() && std::abs(coef * it->second) <= align) {
rst += v * strides[0];
} else {
return SimpleFix(tmp_idx_bk, opt.var2expr, align, times);
}
rst += v * strides[0];
} else if (coef % align == 0) {
auto new_coef = coef * times / align;
rst += v * Expr(static_cast<int32_t>(new_coef));
......@@ -359,7 +354,8 @@ class RewriteAllocateAndIndex : public IRMutator {
Stmt RewriteByAlignStatic(Stmt stmt) {
stmt = AxisPartitioner().Run(stmt);
stmt = RewriteAllocateAndIndex().Mutate(stmt);
return MergeLoops(stmt);
stmt = MergeLoops(stmt);
return stmt;
}
} // namespace ir
} // namespace akg
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include "emit_insn/insn_info.h"
#include "emit_insn/ir_transform.h"
#include "analyze_align.h"
namespace akg {
namespace ir {
class ReducePacker : public IRMutator {
public:
ReducePacker() = default;
~ReducePacker() override = default;
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_ub_gm" || (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
!exclude_align_analyze_list.count(op->value.as<StringImm>()->value))) {
IRInfo info;
ParserVisitor(info, false).Run(s);
if (info.ChangeLastDimReduce()) {
auto body = info.GenStmt();
return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr(info.arith_info.insn_type), body);
}
return s;
}
return IRMutator::Mutate_(op, s);
}
};
Stmt PackStore(Stmt stmt) {
stmt = TransposeTransform().Mutate(stmt);
stmt = ReducePacker().Mutate(stmt);
return stmt;
}
} // namespace ir
} // namespace akg
\ No newline at end of file
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include "emit_insn/insn_info.h"
#include "analyze_align.h"
#include "emit_insn/ir_transform.h"
namespace akg {
namespace ir {
class ReduceRecover : public IRMutator {
public:
ReduceRecover() = default;
~ReduceRecover() override = default;
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
op->value.as<StringImm>()->value.find("reduce_") != std::string::npos) {
old_pragma_ = op->value.as<StringImm>()->value;
if (old_pragma_ == "reduce_add") {
new_pragma_ = "vec_binary_add";
} else if (old_pragma_ == "reduce_max") {
new_pragma_ = "vec_binary_max";
} else if (old_pragma_ == "reduce_min") {
new_pragma_ = "vec_binary_min";
} else if (old_pragma_ == "reduce_fargmax") {
new_pragma_ = "vec_binary_fargmax";
} else if (old_pragma_ == "reduce_fargmin") {
new_pragma_ = "vec_binary_fargmin";
}
in_reduce_ = true;
auto body = this->Mutate(op->body);
in_reduce_ = false;
return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr(new_pragma_), body);
} else if (op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
op->value.as<StringImm>()->value == "dma_copy_transpose") {
return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("vtranspose"), op->body);
} else if (op->attr_key == "align_info") {
return this->Mutate(op->body);
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Store *op, const Stmt &s) final {
if (in_reduce_) {
if (old_pragma_ == "reduce_fargmax") {
auto load_load = op->value.as<Call>()->args[0];
auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate);
auto new_value = Call::make(load_load.type(), "fargmax", {src_load, load_load}, Call::CallType::PureIntrinsic);
auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate);
return new_store;
} else if (old_pragma_ == "reduce_fargmin") {
auto load_load = op->value.as<Call>()->args[0];
auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate);
auto new_value = Call::make(load_load.type(), "fargmin", {src_load, load_load}, Call::CallType::PureIntrinsic);
auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate);
return new_store;
} else if (old_pragma_ == "reduce_add") {
auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate);
auto new_value = Add::make(src_load, op->value.as<Call>()->args[0]);
auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate);
return new_store;
} else if (old_pragma_ == "reduce_max") {
auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate);
auto new_value = Max::make(src_load, op->value.as<Call>()->args[0]);
auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate);
return new_store;
} else if (old_pragma_ == "reduce_min") {
auto src_load = Load::make(op->value.type(), op->buffer_var, op->index, op->predicate);
auto new_value = Min::make(src_load, op->value.as<Call>()->args[0]);
auto new_store = Store::make(op->buffer_var, new_value, op->index, op->predicate);
return new_store;
} else {
return s;
}
} else {
return IRMutator::Mutate_(op, s);
}
}
private:
std::string old_pragma_;
std::string new_pragma_;
bool in_reduce_;
};
std::string GetOpCode(const std::string &op_type) {
std::string op_code{};
if (op_type == "Add") {
op_code = "vadds";
} else if (op_type == "Mul") {
op_code = "vmuls";
} else if (op_type == "vaxpy") {
op_code = "vaxpy";
} else if (op_type == "DMACopy") {
op_code = "vector_dup";
}
return op_code;
}
class FinetunePragma : public IRMutator {
public:
FinetunePragma() = default;
~FinetunePragma() override = default;
Stmt Mutate_(const AttrStmt *op, const Stmt &s) final {
if ((op->attr_key == "pragma_emit_insn" && op->value->IsInstance<StringImm>() &&
!exclude_align_analyze_list.count(op->value.as<StringImm>()->value))) {
IRInfo info;
ParserVisitor(info, true).Run(s);
std::string op_code = GetOpCode(info.arith_info.op_type);
if (!info.arith_info.dst_info.IsUB() || op_code.empty() ||
(!info.arith_info.src_info.empty() && !info.arith_info.src_info[0].IsUB())) {
return s;
}
if (info.arith_info.insn_type == "simd" && info.arith_info.scalar_imm_num == 1 &&
(op_code == "vmuls" || op_code == "vadds") && !info.arith_info.dst_info.p_store->value.type().is_float()) {
return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("scalar_calc"), op->body);
}
if (info.arith_info.insn_type == "vector_scalar" || info.arith_info.insn_type == "vector_dump") {
return GenStore(info, op_code, 0);
} else if (info.arith_info.insn_type == "simd" && info.arith_info.scalar_imm_num > 0) {
CHECK_EQ(info.arith_info.scalar_imm_num, 1);
return GenStore(info, op_code, 1);
} else if (info.arith_info.insn_type == "simd" && info.arith_info.scalar_imm_num == 0 &&
info.arith_info.op_type == "DMACopy" && info.arith_info.dst_info.IsUB() &&
info.arith_info.src_info.size() == 1 && info.arith_info.src_info[0].IsUB() &&
info.arith_info.dst_info.p_store->value.type().is_float()) {
/// change copy_ub_to_ub (fp16 or fp32) to adds (scalar = 0)
op_code = "vadds";
info.arith_info.scalar_imm_num = 1;
info.arith_info.scalar_imm = FloatImm::make(info.arith_info.dst_info.p_store->value.type(), 0);
return GenStore(info, op_code, 1);
} else if (info.arith_info.op_type == "DMACopy" &&
(info.arith_info.insn_type == "scalar" || info.arith_info.insn_type == "discrete") &&
info.arith_info.dst_info.IsUB() &&
(info.arith_info.src_info.size() == 1 && info.arith_info.src_info[0].IsUB())) {
return AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr("scalar_dma"), op->body);
} else if (info.arith_info.op_type == "DMACopy" &&
(info.arith_info.insn_type == "scalar" || info.arith_info.insn_type == "discrete") &&
info.arith_info.dst_info.IsUB() && info.arith_info.scalar_imm_num == 1) {
return GenStore(info, op_code, 1);
} else if (op->value.as<StringImm>()->value == "vec_single_muls" ||
op->value.as<StringImm>()->value == "vec_single_adds") {
if (op->value.as<StringImm>()->value == "vec_single_muls") {
op_code = "vmuls";
} else if (op->value.as<StringImm>()->value == "vec_single_adds") {
op_code = "vadds";
}
return GenStore(info, op_code, 1);
}
return s;
}
return IRMutator::Mutate_(op, s);
}
Stmt GenStore(IRInfo &info, const std::string &intrin_name, const int scalar_type = 0) {
CHECK(intrin_name == "vector_dup" || intrin_name == "vadds" || intrin_name == "vmuls" || intrin_name == "vaxpy");
/// scalar value
Expr scalar_value =
(scalar_type == 0) ? GetRef<Expr>(info.arith_info.scalar_load.p_load) : info.arith_info.scalar_imm;
Array<Expr> call_args{};
if (intrin_name == "vector_dup") {
call_args = {scalar_value};
} else {
Expr tensor_value = GetRef<Expr>(info.arith_info.src_info[0].p_load);
call_args = {tensor_value, scalar_value};
}
/// set store
auto old_ptr = info.arith_info.dst_info.p_store;
Expr new_value = Call::make(old_ptr->value.type(), intrin_name, call_args, Call::PureIntrinsic);
Stmt ret = Store::make(old_ptr->buffer_var, new_value, old_ptr->index, old_ptr->predicate);
if (scalar_type == 0) {
auto scalar_vars = info.arith_info.scalar_load.vars;
/// set inner for loop
for (int i = static_cast<int>(info.for_info.vars.size()) - 1; i >= 0; --i) {
if (!IsInArray(scalar_vars, info.for_info.vars[i])) {
ret = For::make(info.for_info.vars[i], 0, info.for_info.exts[i], ForType::Serial, DeviceAPI::None, ret);
}
}
/// set attribute
ret = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr(intrin_name), ret);
/// set outer for loop
for (int i = static_cast<int>(info.for_info.vars.size()) - 1; i >= 0; --i) {
if (IsInArray(scalar_vars, info.for_info.vars[i])) {
ret = For::make(info.for_info.vars[i], 0, info.for_info.exts[i], ForType::Serial, DeviceAPI::None, ret);
}
}
return ret;
} else {
for (int i = static_cast<int>(info.for_info.vars.size()) - 1; i >= 0; --i) {
ret = For::make(info.for_info.vars[i], 0, info.for_info.exts[i], ForType::Serial, DeviceAPI::None, ret);
}
ret = AttrStmt::make(make_zero(Int(32)), "pragma_emit_insn", Expr(intrin_name), ret);
return ret;
}
}
};
Stmt RecoverStore(Stmt stmt) {
stmt = IfReorder().Mutate(stmt);
stmt = FinetunePragma().Mutate(stmt);
stmt = ReduceRecover().Mutate(stmt);
return stmt;
}
} // namespace ir
} // namespace akg
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册