diff --git a/src/emit_insn/insn_emitter_multimask.cc b/src/emit_insn/insn_emitter_multimask.cc index 4ce181b5f97e950e3376b20950e3e1341022babb..2c737eec0983b8fe9f0ca9f9590401ee1f839e36 100644 --- a/src/emit_insn/insn_emitter_multimask.cc +++ b/src/emit_insn/insn_emitter_multimask.cc @@ -86,19 +86,19 @@ Stmt MultiMaskEmitter(const Stmt &stmt) { int i_loop_elements = LeastCommonMultiple(broadcast_len, block_size); int mask_num = i_loop_elements / broadcast_len; - auto i_var = VarExpr("broadcast_idx"); Stmt body = Evaluate::make(0); for (int i = 0; i < std::min(mask_num, data_len); i++) { + auto i_var = VarExpr("broadcast_idx" + std::to_string(i)); auto dst_block_offset = (i * broadcast_len) % block_size; auto dst_block_cnt = (i * broadcast_len) - dst_block_offset; - Expr base_addr_offset = i_var * i_loop_elements + dst_block_cnt; - auto base_src = src_buffer_id.vload({i_var * mask_num + i}, i_type); CHECK_NE(i_loop_elements, 0); auto loop_num = (data_len + mask_num - 1 - i) * broadcast_len / i_loop_elements; // GenHead + Expr base_addr_offset = i_var * i_loop_elements + dst_block_cnt; + auto base_src = src_buffer_id.vload({i_var * mask_num + i}, i_type); int head_size = std::min(simd_size - dst_block_offset, broadcast_len); auto vec_mask_head = GetVecMaskWithOffset(head_size, dst_block_offset, i_type); auto head_mask = EmitSetVecMaskIntrin(Stmt(), i_type, vec_mask_head); @@ -111,30 +111,36 @@ Stmt MultiMaskEmitter(const Stmt &stmt) { // GenBody if (dst_block_offset + broadcast_len >= simd_size * 2) { + auto i_var_body = VarExpr(i_var->name_hint + "_body"); + Expr base_addr_offset_body = i_var_body * i_loop_elements + dst_block_cnt; + auto base_src_body = src_buffer_id.vload({i_var_body * mask_num + i}, i_type); int repeat_size = (dst_block_offset + broadcast_len) / simd_size - 1; auto vec_mask_body = GetVecMaskWithOffset(simd_size, 0, i_type); auto full_mask = EmitSetVecMaskIntrin(Stmt(), i_type, vec_mask_body); int body_addr_offset = simd_size; - Expr body_dst = GetAccessPtr(dst_buffer_id, "w", base_addr_offset + body_addr_offset); + Expr body_dst = GetAccessPtr(dst_buffer_id, "w", base_addr_offset_body + body_addr_offset); auto body_dump = - Evaluate::make(Call::make(i_type, "vector_dup", {body_dst, base_src, repeat_size, 1, 1, 8, 8}, Call::Extern)); + Evaluate::make(Call::make(i_type, "vector_dup", {body_dst, base_src_body, repeat_size, 1, 1, 8, 8}, Call::Extern)); auto body_gen = Block::make({full_mask, body_dump}); - auto body_stmt = For::make(i_var, Expr(0), Expr(loop_num), ForType::Serial, DeviceAPI::None, body_gen); + auto body_stmt = For::make(i_var_body, Expr(0), Expr(loop_num), ForType::Serial, DeviceAPI::None, body_gen); ret_stmt = Block::make(ret_stmt, body_stmt); } // GenTail if ((dst_block_offset + broadcast_len) % simd_size != 0 && dst_block_offset + broadcast_len > simd_size) { + auto i_var_tail = VarExpr(i_var->name_hint + "_tail"); + Expr base_addr_offset_body = i_var_tail * i_loop_elements + dst_block_cnt; + auto base_src_body = src_buffer_id.vload({i_var_tail * mask_num + i}, i_type); int tail_size = (dst_block_offset + broadcast_len) % simd_size; auto vec_mask_tail = GetVecMaskWithOffset(tail_size, 0, i_type); auto tail_mask = EmitSetVecMaskIntrin(Stmt(), i_type, vec_mask_tail); int tail_addr_offset = dst_block_offset + broadcast_len - tail_size; - Expr tail_dst = GetAccessPtr(dst_buffer_id, "w", base_addr_offset + tail_addr_offset); + Expr tail_dst = GetAccessPtr(dst_buffer_id, "w", base_addr_offset_body + tail_addr_offset); auto tail_dump = - Evaluate::make(Call::make(i_type, "vector_dup", {tail_dst, base_src, 1, 1, 1, 1, 1}, Call::Extern)); + Evaluate::make(Call::make(i_type, "vector_dup", {tail_dst, base_src_body, 1, 1, 1, 1, 1}, Call::Extern)); auto tail = Block::make({tail_mask, tail_dump}); - auto tail_stmt = For::make(i_var, Expr(0), Expr(loop_num), ForType::Serial, DeviceAPI::None, tail); + auto tail_stmt = For::make(i_var_tail, Expr(0), Expr(loop_num), ForType::Serial, DeviceAPI::None, tail); ret_stmt = Block::make(ret_stmt, tail_stmt); } diff --git a/src/pass/optimize_pragma.cc b/src/pass/optimize_pragma.cc index 14877fb2162df39da0d8640eccb3dede45a93ceb..a3bb32610b8e0bab7baaa4a52e63a404f1aa909b 100644 --- a/src/pass/optimize_pragma.cc +++ b/src/pass/optimize_pragma.cc @@ -138,7 +138,11 @@ class EstimateAlign : public IRMutator { public: bool IsSimpleAddress(const Stmt &stmt) { Mutate(stmt); - return all_simple_addressing_; + // Returns true only when the numbers of Store in IR that is not elementwise + // is only 1 or less, in this case, we can consider optimizing broadcast by + // using variable length mask in insn emitting pass safely because at most + // 1 Store does not need to cosider block alignment. + return (not_simple_addressing_cnt_ < 2); } Stmt Mutate_(const AttrStmt *op, const Stmt &stmt) final { @@ -146,22 +150,19 @@ class EstimateAlign : public IRMutator { if (exclude_list.count(op->value.as()->value)) { return stmt; } - is_candidate_ = true; StmtInfoList dst_info_list, src_info_list; StmtInfo if_info, for_info; GetCompactComputationInfo(op->body, dst_info_list, src_info_list, if_info, for_info, false); if (!src_info_list.empty() && !IsElementwise(dst_info_list, src_info_list)) { - all_simple_addressing_ = false; + not_simple_addressing_cnt_++; } - is_candidate_ = false; } return IRMutator::Mutate_(op, stmt); } - bool is_candidate_{false}; - bool all_simple_addressing_{true}; + int not_simple_addressing_cnt_{0}; // records the number of stores that are not elementwise }; Stmt OptimizePragma(Stmt stmt) {