提交 a8d1a104 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!45 fix variable broadcast_idx redefinition error when pragma dma_copy is replaced by opt_broadcast

Merge pull request !45 from looop5/opt_broadcast
......@@ -86,19 +86,19 @@ Stmt MultiMaskEmitter(const Stmt &stmt) {
int i_loop_elements = LeastCommonMultiple(broadcast_len, block_size);
int mask_num = i_loop_elements / broadcast_len;
auto i_var = VarExpr("broadcast_idx");
Stmt body = Evaluate::make(0);
for (int i = 0; i < std::min(mask_num, data_len); i++) {
auto i_var = VarExpr("broadcast_idx" + std::to_string(i));
auto dst_block_offset = (i * broadcast_len) % block_size;
auto dst_block_cnt = (i * broadcast_len) - dst_block_offset;
Expr base_addr_offset = i_var * i_loop_elements + dst_block_cnt;
auto base_src = src_buffer_id.vload({i_var * mask_num + i}, i_type);
CHECK_NE(i_loop_elements, 0);
auto loop_num = (data_len + mask_num - 1 - i) * broadcast_len / i_loop_elements;
// GenHead
Expr base_addr_offset = i_var * i_loop_elements + dst_block_cnt;
auto base_src = src_buffer_id.vload({i_var * mask_num + i}, i_type);
int head_size = std::min(simd_size - dst_block_offset, broadcast_len);
auto vec_mask_head = GetVecMaskWithOffset(head_size, dst_block_offset, i_type);
auto head_mask = EmitSetVecMaskIntrin(Stmt(), i_type, vec_mask_head);
......@@ -111,30 +111,36 @@ Stmt MultiMaskEmitter(const Stmt &stmt) {
// GenBody
if (dst_block_offset + broadcast_len >= simd_size * 2) {
auto i_var_body = VarExpr(i_var->name_hint + "_body");
Expr base_addr_offset_body = i_var_body * i_loop_elements + dst_block_cnt;
auto base_src_body = src_buffer_id.vload({i_var_body * mask_num + i}, i_type);
int repeat_size = (dst_block_offset + broadcast_len) / simd_size - 1;
auto vec_mask_body = GetVecMaskWithOffset(simd_size, 0, i_type);
auto full_mask = EmitSetVecMaskIntrin(Stmt(), i_type, vec_mask_body);
int body_addr_offset = simd_size;
Expr body_dst = GetAccessPtr(dst_buffer_id, "w", base_addr_offset + body_addr_offset);
Expr body_dst = GetAccessPtr(dst_buffer_id, "w", base_addr_offset_body + body_addr_offset);
auto body_dump =
Evaluate::make(Call::make(i_type, "vector_dup", {body_dst, base_src, repeat_size, 1, 1, 8, 8}, Call::Extern));
Evaluate::make(Call::make(i_type, "vector_dup", {body_dst, base_src_body, repeat_size, 1, 1, 8, 8}, Call::Extern));
auto body_gen = Block::make({full_mask, body_dump});
auto body_stmt = For::make(i_var, Expr(0), Expr(loop_num), ForType::Serial, DeviceAPI::None, body_gen);
auto body_stmt = For::make(i_var_body, Expr(0), Expr(loop_num), ForType::Serial, DeviceAPI::None, body_gen);
ret_stmt = Block::make(ret_stmt, body_stmt);
}
// GenTail
if ((dst_block_offset + broadcast_len) % simd_size != 0 && dst_block_offset + broadcast_len > simd_size) {
auto i_var_tail = VarExpr(i_var->name_hint + "_tail");
Expr base_addr_offset_body = i_var_tail * i_loop_elements + dst_block_cnt;
auto base_src_body = src_buffer_id.vload({i_var_tail * mask_num + i}, i_type);
int tail_size = (dst_block_offset + broadcast_len) % simd_size;
auto vec_mask_tail = GetVecMaskWithOffset(tail_size, 0, i_type);
auto tail_mask = EmitSetVecMaskIntrin(Stmt(), i_type, vec_mask_tail);
int tail_addr_offset = dst_block_offset + broadcast_len - tail_size;
Expr tail_dst = GetAccessPtr(dst_buffer_id, "w", base_addr_offset + tail_addr_offset);
Expr tail_dst = GetAccessPtr(dst_buffer_id, "w", base_addr_offset_body + tail_addr_offset);
auto tail_dump =
Evaluate::make(Call::make(i_type, "vector_dup", {tail_dst, base_src, 1, 1, 1, 1, 1}, Call::Extern));
Evaluate::make(Call::make(i_type, "vector_dup", {tail_dst, base_src_body, 1, 1, 1, 1, 1}, Call::Extern));
auto tail = Block::make({tail_mask, tail_dump});
auto tail_stmt = For::make(i_var, Expr(0), Expr(loop_num), ForType::Serial, DeviceAPI::None, tail);
auto tail_stmt = For::make(i_var_tail, Expr(0), Expr(loop_num), ForType::Serial, DeviceAPI::None, tail);
ret_stmt = Block::make(ret_stmt, tail_stmt);
}
......
......@@ -138,7 +138,11 @@ class EstimateAlign : public IRMutator {
public:
bool IsSimpleAddress(const Stmt &stmt) {
Mutate(stmt);
return all_simple_addressing_;
// Returns true only when the numbers of Store in IR that is not elementwise
// is only 1 or less, in this case, we can consider optimizing broadcast by
// using variable length mask in insn emitting pass safely because at most
// 1 Store does not need to cosider block alignment.
return (not_simple_addressing_cnt_ < 2);
}
Stmt Mutate_(const AttrStmt *op, const Stmt &stmt) final {
......@@ -146,22 +150,19 @@ class EstimateAlign : public IRMutator {
if (exclude_list.count(op->value.as<StringImm>()->value)) {
return stmt;
}
is_candidate_ = true;
StmtInfoList dst_info_list, src_info_list;
StmtInfo if_info, for_info;
GetCompactComputationInfo(op->body, dst_info_list, src_info_list, if_info, for_info, false);
if (!src_info_list.empty() && !IsElementwise(dst_info_list, src_info_list)) {
all_simple_addressing_ = false;
not_simple_addressing_cnt_++;
}
is_candidate_ = false;
}
return IRMutator::Mutate_(op, stmt);
}
bool is_candidate_{false};
bool all_simple_addressing_{true};
int not_simple_addressing_cnt_{0}; // records the number of stores that are not elementwise
};
Stmt OptimizePragma(Stmt stmt) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册