// Copyright (c) 2022 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/cinn/ir/ir_schedule_util.h" #include #include #include #include #include #include #include #include "paddle/cinn/common/cas.h" #include "paddle/cinn/common/ir_util.h" #include "paddle/cinn/ir/collect_ir_nodes.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_operators.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/ir_visitor.h" #include "paddle/cinn/lang/compute.h" #include "paddle/cinn/optim/ir_copy.h" #include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/replace_var_with_expr.h" namespace cinn { namespace ir { Tensor GetTensor(const Expr& block) { CHECK(block.As()); auto find_tensor = ir::CollectIRNodesWithoutTensor( block, [&](const Expr* x) { return x->As(); }, true); CHECK_EQ(find_tensor.size(), 1U) << "One block should only have one Store node!(except for root block)"; CHECK((*find_tensor.begin()).As()->tensor.as_tensor()); Tensor tensor = (*find_tensor.begin()).As()->tensor.as_tensor_ref(); return tensor; } Tensor GetReadTensor(const Expr& block, int index) { CHECK(block.As()); auto find_tensor = ir::CollectIRNodesWithoutTensor( block, [&](const Expr* x) { return x->As(); }, true); CHECK_EQ(find_tensor.size(), 1U) << "One block should only have one Store node!(except for root block)"; std::vector res; auto find_read_tensor = ir::CollectIRNodesWithoutTensor(block, [&](const Expr* x) { if (x->As()) res.push_back(x->As()->tensor.as_tensor_ref()); return x->As(); }); CHECK_EQ(find_read_tensor.size(), res.size()); CHECK(!find_read_tensor.empty()) << "Didn't find Load tensor in block!"; CHECK_LT(index, (int)find_read_tensor.size()) << "Index is not < read tensor's size!"; return res[index]; } int GetLoopExtent(const Expr& loop) { CHECK(loop.As()); CHECK(common::is_zero(loop.As()->min)); CHECK(loop.As()->extent.is_constant()); return (int)loop.As()->extent.get_constant(); } void SetCudaAxisInfo(Expr* lowered_func) { if (!lowered_func->as_lowered_func()) { LOG(ERROR) << "The input of SetCudaAxisInfo should be lowered_func!"; return; } auto func_body = lowered_func->as_lowered_func_ref()->body; CudaAxisInfo info; auto block_nodes = ir::CollectIRNodes(func_body, [&](const Expr* x) { if (x->As() && x->As()->bind_info().valid()) { auto bind_info = x->As()->bind_info(); info.set_valid(true); if (bind_info.for_type == ForType::GPUThread) { CHECK(common::is_zero(x->As()->min)); CHECK(x->As()->extent.is_constant()); int range = x->As()->extent.get_constant(); range = range > info.block_dim(bind_info.offset) ? range : info.block_dim(bind_info.offset); VLOG(3) << "Set block dim[" << bind_info.offset << "] with range " << range; info.set_block_dim(bind_info.offset, range); } else if (bind_info.for_type == ForType::GPUBlock) { CHECK(common::is_zero(x->As()->min)); CHECK(x->As()->extent.is_constant()); int range = x->As()->extent.get_constant(); range = range > info.grid_dim(bind_info.offset) ? range : info.grid_dim(bind_info.offset); info.set_grid_dim(bind_info.offset, range); VLOG(3) << "Set grid dim[" << bind_info.offset << "] with range " << range; } else { LOG(FATAL) << "The for loop's bind info should be gpu block or thread!"; } } return (x->As() && x->As()->bind_info().valid()); }); lowered_func->as_lowered_func_ref()->cuda_axis_info = info; } bool Contains(const Expr& container, const Expr& expr) { auto find_expr = ir::CollectIRNodesWithoutTensor( container, [&](const Expr* x) { return (x->node_type() == expr.node_type() && *x == expr); }, true); return (!find_expr.empty()); } Expr GetNextForLoop(const Expr& for_loop) { Expr result; CHECK(for_loop.As()) << "The input of GetNextForLoop should be ir::For!"; Expr for_body = for_loop.As()->body; ir::Block* for_body_block = for_body.As(); CHECK(for_body_block) << "The for_loop's body shoule be Block!"; // Only support for body block contains a sub for loop int next_idx = -1; for (int i = 0; i < for_body_block->stmts.size(); ++i) { Expr stmt = for_body_block->stmts[i]; if (stmt.As() || stmt.As()) { if (next_idx == -1) { next_idx = i; } else { // More then one sub for loop, Return undefined. return result; } } } if (next_idx == -1) { // More then one sub for loop, Return undefined. return result; } Expr block_body = for_body_block->stmts[next_idx]; if (block_body.As()) { // TODO(zhhsplendid): is it right to only handle true case? // It may be wrong, but the code is written by previous developer, for us, // we will check it later in the future. CHECK(block_body.As()->true_case.As()); Expr true_case = block_body.As()->true_case; if (true_case.As()->stmts.size() != 1U || !true_case.As()->stmts[0].As()) return result; result = true_case.As()->stmts[0]; return result; } else if (block_body.As()) { return block_body; } else { return result; } } std::vector GetIfThenElseInRange(const Expr& top, const Expr& bottom) { std::vector if_nodes; CHECK(top.As()); CHECK(bottom.As()); for (auto loop_iter = top; loop_iter != bottom;) { CHECK(loop_iter.As()); CHECK(loop_iter.As()->body.As()) << "For node's body should be Block!"; auto block = loop_iter.As()->body.As(); for (Expr tmp : block->stmts) { if (tmp.As()) { if_nodes.push_back(tmp); CHECK(tmp.As()->true_case.As()); Expr true_case = tmp.As()->true_case; CHECK(true_case.As()->stmts.size() == 1U && true_case.As()->stmts[0].As()); tmp = true_case.As()->stmts[0]; } if (tmp.As()) { loop_iter = tmp; } } } return if_nodes; } void ReplaceExpr(Expr* source, const std::vector& replaced, const std::vector& candidates) { CHECK_EQ(replaced.size(), candidates.size()) << "In ReplaceExpr, the size of Vars to be replaced must be equal to the size of cadidate Exprs! Please check."; if (replaced.empty()) return; std::map replacing_map; for (int i = 0; i < replaced.size(); ++i) { // If the Var to be replaced is equal to the candidate, we skip it. if (candidates[i].is_var() && candidates[i].as_var_ref() == replaced[i]) continue; replacing_map[replaced[i]] = candidates[i]; } MappingVarToExprMutator mapper(replacing_map); mapper(source); return; } std::vector ValidateFactors(const std::vector& factors, int total_extent) { CHECK(!factors.empty()) << "The factors param of Split should not be empty! Please check."; bool has_minus_one = false; int product = 1; for (auto& i : factors) { CHECK(i != 0) << "The params in factors of Split should not be 0! Please check."; CHECK(i >= -1) << "The params in factors of Split should not be less than -1! Please check."; if (i == -1) { CHECK(!has_minus_one) << "The params in factors of Split should not have more than one -1! Please check."; has_minus_one = true; } else { product *= i; } } std::vector validated_factors = factors; if (!has_minus_one) { CHECK_GE(product, total_extent) << "In Split, the factors' product should be equal to original loop's extent! Please check."; return validated_factors; } else { CHECK_LE(product, total_extent) << "In Split, when there is -1 in factors, the other factors' product should be <= " "original loop's extent! Please check."; int minus_one_candidate = (int)ceil((double)total_extent / (double)product); for (int i = 0; i < validated_factors.size(); ++i) { if (validated_factors[i] == -1) { validated_factors[i] = minus_one_candidate; } } return validated_factors; } } void CHECKRfactorValidation(const Expr& rf_loop, int rf_axis) { auto* rf_for = rf_loop.As(); CHECK(rf_for) << "Expr param of Rfactor must be For node! Please check."; // check the rf_loop only has one schedule block auto block_nodes = ir::CollectIRNodesWithoutTensor( rf_loop, [&](const Expr* x) { return x->As(); }, true); CHECK_EQ(block_nodes.size(), 1U) << "Rfactor Loop should only have one schedule block"; auto find_store = ir::CollectIRNodesWithoutTensor( rf_loop, [&](const Expr* x) { return x->As(); }, true); CHECK_EQ(find_store.size(), 1U); auto indice = find_store.begin()->As()->indices; // check rf_axis CHECK_LE(rf_axis, indice.size()) << "rf_axis should not be greater than store's domain size"; // check rfactor loop is reduce auto* sch_block_realize = block_nodes.begin()->As(); auto* sch_block = sch_block_realize->schedule_block.As(); CHECK(sch_block); auto& iter_values = sch_block_realize->iter_values; auto& iter_vars = sch_block->iter_vars; CHECK_EQ(iter_values.size(), iter_vars.size()); auto rf_loop_var = rf_for->loop_var; Var rf_block_var; for (int i = 0; i < iter_values.size(); ++i) { if (ContainVar({iter_values[i]}, rf_loop_var->name)) { CHECK(!rf_block_var.defined()) << "rfactor loop var can only be binded to one block var"; auto iter_value = iter_values[i].As<_Var_>(); CHECK(iter_value) << "not support complex reduce bindings"; rf_block_var = iter_vars[i]; auto it = std::find_if(indice.begin(), indice.end(), [&](const Expr& x) { return x.As<_Var_>() && x.As<_Var_>()->name == rf_block_var->name; }); CHECK(it == indice.end()) << "rfactor loop var is not reduce, please check!"; } } } std::vector GetLoopsOfExpr(const Expr& expr, const Expr& root) { auto loop_nodes = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { return x->As() && Contains(*x, expr); }); std::vector result(loop_nodes.begin(), loop_nodes.end()); if (result.empty()) LOG(FATAL) << "Didn't find expr's : \n" << expr << "\n loops in root : \n" << root; std::sort(result.begin(), result.end(), [&](Expr i, Expr j) { return (utils::GetStreamCnt(i).size() > utils::GetStreamCnt(j).size()); }); return result; } IterRange GetAccessedRange(const Expr& index, const std::vector& iter_vars, const std::vector& iter_ranges) { CHECK_EQ(iter_vars.size(), iter_ranges.size()); std::vector var_mins, var_maxs; for (const auto& range : iter_ranges) { var_mins.emplace_back(range.min); var_maxs.emplace_back(range.min + range.extent - 1); } Expr indice_min = optim::IRCopy(index); Expr indice_max = optim::IRCopy(index); // replace the var by the corresponding iter_value ReplaceExpr(&indice_min, iter_vars, var_mins); ReplaceExpr(&indice_max, iter_vars, var_maxs); // simplify expression indice_min = common::AutoSimplify(indice_min); indice_max = common::AutoSimplify(indice_max); Expr indice_extent; Expr mod_extent(0); if (indice_min.As() && indice_min.As()->b().is_constant()) mod_extent = indice_min.As()->b(); if (indice_min == indice_max) { if (common::is_zero(mod_extent)) { // If a index keeps constant, its extent should be 1. indice_extent = Expr(1); } else { indice_extent = mod_extent; } } else { indice_extent = common::AutoSimplify(common::AutoSimplify(indice_max) - common::AutoSimplify(indice_min) + 1); } if (indice_extent.is_constant() && indice_extent.get_constant() < 0) { VLOG(3) << "deduced indices are not constant"; indice_min = indice_max; indice_extent = Expr(-indice_extent.get_constant()); } VLOG(3) << "indice_min=" << indice_min << ", indice_max=" << indice_max << ", indice_extent=" << indice_extent; return IterRange(indice_min, indice_extent); } std::vector CalculateTensorRegions(const Expr& block, const std::vector& tensor_indices, const Tensor& tensor, const Expr& root) { CHECK(block.As()); auto iter_vars = block.As()->schedule_block.As()->iter_vars; auto iter_values = block.As()->iter_values; std::vector loop_vars; std::vector loop_ranges; auto outer_loops = GetLoopsOfExpr(block, root); for (auto& loop : outer_loops) { CHECK(loop.As()); loop_vars.emplace_back(loop.As()->loop_var); loop_ranges.emplace_back(IterRange(loop.As()->min, loop.As()->extent)); } std::vector result; for (int i = 0; i < tensor_indices.size(); ++i) { Expr binded_index = optim::IRCopy(tensor_indices[i]); ReplaceExpr(&binded_index, iter_vars, iter_values); auto range = GetAccessedRange(binded_index, loop_vars, loop_ranges); // in generally, the range should be constant, but in some cases our AutoSimplify // (algebraic simplification function) can't simplify completely where we use the whole // shape in this indice as the accessed range conservatively if (!range.min.is_constant() || !range.extent.is_constant()) { VLOG(3) << "deduced range is not constant, range.min=" << range.min << ", range.extent=" << range.extent; if (tensor->buffer.defined()) { CHECK_GT((int)tensor->buffer->shape.size(), i); result.emplace_back(IterRange(Expr(0), tensor->buffer->shape[i])); } else { CHECK_GT((int)tensor->shape.size(), i); result.emplace_back(IterRange(Expr(0), tensor->shape[i])); } } else { result.emplace_back(std::move(range)); } } return result; } Expr GetNthAccessExpr(const Expr& block, int index, bool is_write) { CHECK(block.As()); auto compute_body = block.As()->schedule_block.As()->body; if (is_write) { std::vector find_store_vec; auto find_store = ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { if (x->As()) find_store_vec.push_back(*x); return x->As(); }); CHECK_EQ(find_store.size(), find_store_vec.size()); CHECK_LT(index, (int)find_store.size()); Expr store_index = find_store_vec[index]; return store_index; } else { std::vector find_load_vec; auto find_load = ir::CollectIRNodesWithoutTensor(compute_body, [&](const Expr* x) { if (x->As()) find_load_vec.push_back(*x); return x->As(); }); CHECK_EQ(find_load.size(), find_load_vec.size()); CHECK_LT(index, (int)find_load.size()); Expr load_index = find_load_vec[index]; return load_index; } } Tensor MakeCacheTensor(const Tensor& tensor, const std::string& memory_type) { auto cache_tensor = lang::Compute( tensor->shape, [=](const std::vector& dims) { return tensor(dims); }, tensor->name + "_" + memory_type + "_temp_buffer"); cache_tensor->WithBuffer(memory_type); return cache_tensor; } Expr MakeCacheBlock(const std::vector& buffer_ranges, CacheBlockInfo* info, const std::string& memory_type, DeviceAPI device_api) { // loop variables std::vector loop_vars; // bindings in block realize std::vector iter_values; // Create loop vars and block vars' binding_value for (const auto& range : buffer_ranges) { Var loop_var(common::UniqName("cache_ax" + std::to_string(loop_vars.size()))); // Var loop_var("ax" + std::to_string(loop_vars.size())); loop_vars.push_back(loop_var); iter_values.push_back(common::AutoSimplify(range.min + loop_var)); } // block variables std::vector block_vars; Tensor new_tensor = info->alloc; // Create block vars, block's accessed region and accessing indices CHECK(new_tensor->buffer.defined()); for (auto& dim : new_tensor->buffer->shape) { Var var(Expr(0), dim, "v" + std::to_string(block_vars.size()), false); block_vars.push_back(var); } auto body = new_tensor->tensor_store_expanded_body(); std::vector axis_vars = common::GenDefaultAxis(new_tensor->domain.size()); axis_vars.insert(axis_vars.end(), new_tensor->reduce_axis.begin(), new_tensor->reduce_axis.end()); for (int i = 0; i < axis_vars.size(); ++i) { optim::ReplaceVarWithExpr(&body, axis_vars[i], block_vars[i]); } Expr block = ir::ScheduleBlockRealize::Make( iter_values, ir::ScheduleBlock::Make(block_vars, {}, {}, new_tensor->name, Block::Make({body}))); Expr new_body = block; for (int i = (int)loop_vars.size() - 1; i >= 0; i--) { new_body = For::Make(loop_vars[i], Expr(0), common::AutoSimplify(buffer_ranges[i].extent), ir::ForType::Serial, device_api, ir::Block::Make({new_body})); } info->cache_block = std::move(new_body); return block; } void FindInsertionPoint(Expr& root, CacheBlockInfo* info, bool is_write) { Expr find_tensor = is_write ? Expr(info->write_tensor) : Expr(info->read_tensor); auto find_produce_read = ir::CollectIRNodesWithoutTensor( root, [&](const Expr* x) { return x->As() && x->As()->tensor == find_tensor; }); if (find_produce_read.empty()) { CHECK(root.As()->schedule_block.As()); CHECK(root.As()->schedule_block.As()->body.As()); info->loc_block = root.As()->schedule_block.As()->body; info->loc_pos = 0; return; } CHECK_EQ(find_produce_read.size(), 1U); Expr producer = *(find_produce_read.begin()); CHECK(root.As()->schedule_block.As()); CHECK(root.As()->schedule_block.As()->body.As()); info->loc_block = root.As()->schedule_block.As()->body; for (int i = 0; i < (int)info->loc_block.As()->stmts.size(); ++i) { if (Contains(info->loc_block.As()->stmts[i], producer)) { info->loc_pos = i + 1; break; } } } const std::set CollectLoopsToSet(const std::vector& loops) { std::set for_loops; for (auto& i : loops) { CHECK(i.As()) << "loops should be For node! Please check."; auto inserted = for_loops.insert(i); if (!inserted.second) { LOG(FATAL) << "There should be no duplicate elements in loops! Please check."; } } return for_loops; } // This function is used in Reorder schedule primitive. Since input loop // Expr(s) of Reorder doesn't give original for loop order, we have to // find the top (most outter) loop and bottom (most inner) among loop Expr(s) std::pair GetBoundaryOfReorderRange(const std::set& loop_set) { Expr top = *loop_set.begin(); Expr bottom; std::set visited; bool first_traversal = true; for (Expr loop_i : loop_set) { if (visited.count(loop_i)) { continue; } Expr v_for = loop_i; CHECK(v_for.As()); while (v_for.defined()) { // If loop_i's sub loop is visited it must be pre-visited top. // Then loop_i should be the new top if (visited.count(v_for)) { if (v_for != top) { LOG(FATAL) << "Loops in GetBoundaryOfReorderRange is not a chain! Please check."; } top = loop_i; break; } // This while loop always GetNextForLoop(sub loop), so the last // visited v_for in the first traversal will be the bottom. if (first_traversal && loop_set.count(v_for)) { bottom = v_for; } visited.insert(v_for); v_for = GetNextForLoop(v_for); } first_traversal = false; } CHECK(top.As()); CHECK(bottom.defined()); CHECK(bottom.As()); return std::make_pair(top, bottom); } std::vector GetLoopsInRange(const Expr& top, const Expr& bottom) { std::vector chain; CHECK(top.As()); CHECK(bottom.As()); for (auto loop_iter = top; loop_iter != bottom;) { Expr tmp = GetNextForLoop(loop_iter); if (!tmp.defined()) LOG(FATAL) << "Loops in GetLoopsInReorderRange is not a chain! Please check."; chain.push_back(loop_iter); loop_iter = tmp; } chain.push_back(bottom); return chain; } // Construct a loop chain such that: // // loops[i_1] { // loops[i_2] { // ... // loops[i_n] { // stmts; // } // } // } // // where reordered_indices = {i_1, i_2, ... i_n } // // This is a helper function which constructs non-main chain for other body // statements in Reorder. See comment and call place in ConstructNewLoopChain Expr ConstructOtherStmtChain(const std::vector& stmts, const std::vector& loops, const std::vector reordered_indices) { Expr new_loop; for (int i = reordered_indices.size() - 1; i >= 0; --i) { Expr temp = optim::IRCopy(loops[reordered_indices[i]]); CHECK(temp.defined()); CHECK(temp.As()); if (new_loop.defined()) { temp.As()->body = Block::Make({new_loop}); } else { temp.As()->body = Block::Make({stmts}); } new_loop = temp; } return new_loop; } Expr ConstructNewLoopChain(const std::vector& chain, const std::vector& ordered_loops, const std::set& loop_set, std::vector& if_nodes) { std::vector> condition_vars; // In each IfThenElse node, find the vars its condition depends on. for (auto& if_expr : if_nodes) { CHECK(if_expr.As()); auto var_set = ir::CollectIRNodes(if_expr.As()->condition, [&](const Expr* x) { return x->as_var(); }); std::set var_name_set; for (auto& i : var_set) var_name_set.insert(i.as_var()->name); condition_vars.push_back(var_name_set); } Expr new_loop; int index = static_cast(ordered_loops.size()) - 1; std::vector reordered_loop_chain; // Construct the main loop chain from bottom to top. for (int i = static_cast(chain.size()) - 1; i >= 0; i--) { auto& loop_in_chain = chain[i]; CHECK(loop_in_chain.As()); Expr temp; if (loop_set.count(loop_in_chain)) { CHECK_GE(index, 0); temp = optim::IRCopy(ordered_loops[index]); --index; } else { temp = optim::IRCopy(loop_in_chain); } CHECK(temp.defined()); CHECK(temp.As()); // Main chain, each loop's body only contains sub_loop or bottom loop's body if (new_loop.defined()) { temp.As()->body = Block::Make({new_loop}); } else { temp.As()->body = loop_in_chain.As()->body; } Expr original_temp = temp; // Here we handle the IfThenElse nodes. for (int i = 0; i < static_cast(if_nodes.size()); ++i) { if (condition_vars[i].count(original_temp.As()->loop_var->name)) { Expr temp_body = temp.As()->body; if (temp_body.As() && temp_body.As()->stmts.size() == 1U) temp_body = temp_body.As()->stmts[0]; temp.As()->body = IfThenElse::Make( if_nodes[i].As()->condition, temp_body, if_nodes[i].As()->false_case); temp.As()->body = Block::Make({temp.As()->body}); if_nodes.erase(if_nodes.begin() + i); condition_vars.erase(condition_vars.begin() + i); i--; } } new_loop = temp; reordered_loop_chain.push_back(new_loop); } CHECK(new_loop.defined()); // new_loop_chain, which represents the main loop chain, now is from top to bottom. std::reverse(reordered_loop_chain.begin(), reordered_loop_chain.end()); // In the main loop chain, each loop's body only contains sub_loop or bottom // loop's body, but the origin loop chain may contain some other body stmts. // The main loop chain lost those other body stmts. // For example: // // for (i, 0, 32) { Reorder j, i for (j, 0, 64) { // other_body_stmts above main chine // for (j, 0, 64) { ------------------> for (i, 0, 32) { // bottom_loop_body bottom_loop_body // } } // } } // // We go throuph origin loop and check other body stmts, adding it as another // chain, such as: // // for (i, 0, 32) { // other_body_stmts // } // for (j, 0, 64) { // for (i, 0, 32) { // bottom_loop_body // } // } // // Construct the complete loop chain from origin loop top to bottom. CHECK_EQ(chain.size(), reordered_loop_chain.size()) << "origin loop chain size not equals reordered requirement when ConstructNewLoopChain in Reorder"; std::unordered_set origin_loop_var_names; Expr ret = new_loop; // Maintain an index to add stmt (other body stmt chain) // // stmt stmt MainChainLoop stmt stmt // index index+1 // // The index of this MainChainLoop points the place before next MainChainLoop // We can insert statements before MainChainLoop at the index, and insert // statements after MainChainLoop at the index + 1 int add_other_chain_index = 0; for (int i = 0; i < chain.size() - 1; ++i) { // we just check i < chain.size() - 1 // because bottom loop's body stmts have been all added const ir::For* loop_in_chain = chain[i].As(); ir::For* reordered_in_chain = reordered_loop_chain[i].As(); origin_loop_var_names.insert(loop_in_chain->loop_var->name); CHECK_EQ(origin_loop_var_names.size(), i + 1) << "Duplicate loop var name in origin Chain during Reorder"; const ir::Block* body_block = loop_in_chain->body.As(); if (body_block != nullptr && body_block->stmts.size() > 1) { // contains other body stmts // Get the other body statements before loop and after loop bool other_stmt_body_before_loop = true; std::vector stmts_before_loop; std::vector stmts_after_loop; for (int j = 0; j < body_block->stmts.size(); ++j) { if (body_block->stmts[j].As() && body_block->stmts[j].As()->loop_var->name == chain[i + 1].As()->loop_var->name) { other_stmt_body_before_loop = false; continue; } if (other_stmt_body_before_loop) { stmts_before_loop.push_back(body_block->stmts[j]); } else { stmts_after_loop.push_back(body_block->stmts[j]); } } // Find the chain that other body stmts shares with main loop chain std::vector reordered_indices; for (int j = 0; j < reordered_loop_chain.size(); ++j) { if (origin_loop_var_names.count(reordered_loop_chain[j].As()->loop_var->name)) { reordered_indices.push_back(j); } } CHECK_EQ(reordered_indices.size(), origin_loop_var_names.size()) << "Reordered chain loop var names doesn't match other stmt chain loop var names"; // Add other stmts chain to root Block if other stmts exist if (!stmts_before_loop.empty()) { Expr before_chain = ConstructOtherStmtChain(stmts_before_loop, reordered_loop_chain, reordered_indices); if (ret.As() == nullptr) { ret = ir::Block::Make({ret}); } std::vector& inplace_stmts = ret.As()->stmts; auto pos = inplace_stmts.begin() + add_other_chain_index; inplace_stmts.insert(pos, before_chain); ++add_other_chain_index; } if (!stmts_after_loop.empty()) { Expr after_chain = ConstructOtherStmtChain(stmts_after_loop, reordered_loop_chain, reordered_indices); if (ret.As() == nullptr) { ret = ir::Block::Make({ret}); } std::vector& inplace_stmts = ret.As()->stmts; auto pos = inplace_stmts.begin() + add_other_chain_index + 1; inplace_stmts.insert(pos, after_chain); } } } return ret; } std::vector GetProducers(const Expr& block, const Expr& root) { CHECK(block.As()); CHECK(root.As()); std::vector producers; // collect all producers' tensor names std::set producer_tensor_names; auto compute_body = block.As()->schedule_block.As()->body; ir::CollectIRNodesWithoutTensor(compute_body, [&producer_tensor_names](const Expr* x) { auto* load = x->As(); if (load) { producer_tensor_names.insert(load->tensor.as_tensor()->name); return true; } return false; }); // traverse each of other blocks and filter those ones which contain at least one producer tensor; auto find_blocks = ir::CollectIRNodesWithoutTensor( root, [&block, &root](const Expr* x) { return x->As() && *x != block && *x != root; }); for (auto&& cur : find_blocks) { auto* cur_block = cur.As()->schedule_block.As(); CHECK(cur_block) << "block result should be a ScheduleBlockRealize"; auto find_stores = ir::CollectIRNodesWithoutTensor(cur_block->body, [&producer_tensor_names](const Expr* x) { return x->As() && producer_tensor_names.count(x->As()->tensor.as_tensor()->name) > 0; }); if (!find_stores.empty()) producers.emplace_back(cur); } return producers; } std::vector GetConsumers(const Expr& block, const Expr& root) { CHECK(block.As()); CHECK(root.As()); std::vector consumers; std::string block_tensor = GetTensor(block)->name; auto find_block = ir::CollectIRNodesWithoutTensor( root, [&](const Expr* x) { return x->As() && *x != block && *x != root; }); for (auto& i : find_block) { CHECK(i.As()->schedule_block.As()); auto block_body = i.As()->schedule_block.As()->body; auto find_load = ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) { return x->As() && x->As()->tensor.as_tensor_ref()->name == block_tensor; }); if (!find_load.empty()) consumers.emplace_back(i); } return consumers; } void CheckComputeAtValidation(const Expr& block, const Expr& loop, const Expr& root) { auto find_block = ir::CollectIRNodesWithoutTensor( root, [&](const Expr* x) { return x->As() && *x == block; }, true); CHECK(!find_block.empty()) << "Didn't find block in root!"; auto find_loop = ir::CollectIRNodesWithoutTensor( root, [&](const Expr* x) { return x->As() && *x == loop; }, true); CHECK(!find_loop.empty()) << "Didn't find loop in root!"; auto find_block_in_loop = ir::CollectIRNodesWithoutTensor( loop, [&](const Expr* x) { return x->As() && *x == block; }, true); CHECK(find_block_in_loop.empty()) << "loop should not be block's ancestor!"; } void InsertBlock(Expr& for_loop, const Expr& insertion, int index) { CHECK(for_loop.As()); CHECK(for_loop.As()->body.As()); ir::Block* dst_block = for_loop.As()->body.As(); CHECK(index == -1 || index >= 0 && index < dst_block->stmts.size()) << "index = " << index << ", it should be -1 or between [0, block stmts size)"; if (index == -1) { dst_block->stmts.emplace_back(insertion); } else { auto dst_it = dst_block->stmts.begin() + index; if (dst_it->As()) { auto* inserted_block = dst_it->As()->true_case.As(); CHECK(inserted_block) << "the IfThenElse node to be inserted shuold contain a true_case block"; inserted_block->stmts.insert(inserted_block->stmts.begin(), insertion); } else { dst_block->stmts.insert(dst_it, insertion); } } } IterRange RangeUnion(const IterRange& range1, const IterRange& range2) { Expr new_min = common::AutoSimplify(Min::Make(range1.min, range2.min)); Expr new_extent = common::AutoSimplify( common::AutoSimplify(Max::Make(range1.min + range1.extent, range2.min + range2.extent)) - new_min); return IterRange(new_min, new_extent); } std::vector CalculateRequiredRegions(const Expr& block, const Expr& loop, const Expr& root, const std::vector& required_blocks, bool is_store_provided) { CHECK(block.As()) << "Param block should be a ir::ScheduleBlockRealize node"; CHECK(loop.As()) << "Param loop should be a ir::For node"; std::set provided_nodes; if (is_store_provided) { provided_nodes = ir::CollectIRNodesWithoutTensor(block, [&](const Expr* x) { return x->As(); }); } else { provided_nodes = ir::CollectIRNodesWithoutTensor(block, [&](const Expr* x) { return x->As(); }); } std::vector required_buffer_range; // deduce accessed regions of the provided tensor in block by itering each required block for (const Expr& pro_node : provided_nodes) { const std::string& provided_tensor_name = is_store_provided ? pro_node.As()->tensor.as_tensor()->name : pro_node.As()->tensor.as_tensor()->name; for (const Expr& req_block : required_blocks) { CHECK(req_block.As()); Expr block_body = optim::IRCopy(req_block.As()->schedule_block.As()->body); auto iter_vars = req_block.As()->schedule_block.As()->iter_vars; auto iter_values = req_block.As()->iter_values; ReplaceExpr(&block_body, iter_vars, iter_values); // Notice that we look for For nodes in loop's body instead of loop itself. auto find_loops = ir::CollectIRNodesWithoutTensor( loop.As()->body, [&](const Expr* x) { return x->As() && Contains(*x, req_block); }); // collect vars and their ranges of each loop under the input loop std::vector loop_vars; std::vector loop_ranges; for (const auto& for_loop : find_loops) { loop_vars.emplace_back(for_loop.As()->loop_var); loop_ranges.emplace_back(for_loop.As()->min, for_loop.As()->extent); } std::set required_nodes; if (is_store_provided) { required_nodes = ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) { return x->As() && x->As()->tensor.as_tensor_ref()->name == provided_tensor_name; }); } else { required_nodes = ir::CollectIRNodesWithoutTensor(block_body, [&](const Expr* x) { return x->As() && x->As()->tensor.as_tensor_ref()->name == provided_tensor_name; }); } // deducing range by indices of each required node for (const Expr& req_node : required_nodes) { const auto& indices = is_store_provided ? req_node.As()->indices : req_node.As()->indices; if (find_loops.empty()) { for (int i = 0; i < indices.size(); ++i) { if (i >= required_buffer_range.size()) required_buffer_range.emplace_back(indices[i], Expr(1)); else required_buffer_range[i] = RangeUnion(required_buffer_range[i], IterRange(indices[i], Expr(1))); } } else { for (int i = 0; i < indices.size(); ++i) { auto range = GetAccessedRange(indices[i], loop_vars, loop_ranges); if (i >= required_buffer_range.size()) { required_buffer_range.emplace_back(std::move(range)); } else { required_buffer_range[i] = RangeUnion(required_buffer_range[i], range); } } } } // end for load_nodes } } int iter_size = block.As()->iter_values.size(); // maybe some dimensions are not accessed by consumers so we should append them if (iter_size > required_buffer_range.size()) { for (int i = required_buffer_range.size(); i < iter_size; ++i) { CHECK(block.As()->iter_values[i].as_var() || block.As()->iter_values[i].is_constant()); if (block.As()->iter_values[i].as_var()) { auto find_for_loops = ir::CollectIRNodesWithoutTensor(root, [&](const Expr* x) { return x->As() && x->As()->loop_var->name == block.As()->iter_values[i].as_var_ref()->name; }); CHECK_EQ(find_for_loops.size(), 1U); required_buffer_range.emplace_back((*find_for_loops.begin()).As()->min, (*find_for_loops.begin()).As()->extent); } else { int cons = (int)block.As()->iter_values[i].is_constant(); required_buffer_range.emplace_back(Expr(cons), Expr(1)); } } } return required_buffer_range; } Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block, const Expr& root) { CHECK(schedule_block.As()); auto compute_body = schedule_block.As()->schedule_block.As()->body; // 1. Check the schedule block to be inlined is not a reduce tensor. auto find_store = ir::CollectIRNodesWithoutTensor( compute_body, [&](const Expr* x) { return x->As(); }, true); CHECK_EQ(find_store.size(), 1U); Expr tensor = (*find_store.begin()).As()->tensor; CHECK(!tensor.as_tensor_ref()->is_reduce_tensor()); // 2. Check this schedule block is the only writer of the tensor. find_store = ir::CollectIRNodesWithoutTensor( root, [&](const Expr* x) { return x->As() && (x->As()->tensor).as_tensor_ref()->name == tensor.as_tensor_ref()->name; }, true); CHECK_EQ(find_store.size(), 1U); // 3. Check there is no overlap between the buffers the schedule block reads and writes. auto find_load = ir::CollectIRNodesWithoutTensor( compute_body, [&](const Expr* x) { return x->As() && x->As()->tensor == tensor; }); CHECK(find_load.empty()); return (*find_store.begin()); } std::tuple CheckReverseComputeInlineValidationAndGetExprs(const Expr& schedule_block, const Expr& root) { CHECK(schedule_block.As()); auto compute_body = schedule_block.As()->schedule_block.As()->body; // 1. Check the schedule block to be reverse inlined is not a reduce tensor. auto find_inlined_load = ir::CollectIRNodesWithoutTensor( compute_body, [&](const Expr* x) { return x->As(); }, true); CHECK_EQ(find_inlined_load.size(), 1U); Expr tensor = (*find_inlined_load.begin()).As()->tensor; CHECK(!tensor.as_tensor_ref()->is_reduce_tensor()); auto inlined_load = *find_inlined_load.begin(); // 2. Check this schedule block is the only reader of the tensor. auto find_load = ir::CollectIRNodesWithoutTensor( root, [&](const Expr* x) { return x->As() && (x->As()->tensor).as_tensor_ref()->name == tensor.as_tensor_ref()->name; }, true); CHECK_EQ(find_load.size(), 1U); // 3. Check there is no overlap between the buffers the schedule block reads and writes. auto find_store = ir::CollectIRNodesWithoutTensor( compute_body, [&](const Expr* x) { return x->As() && x->As()->tensor == tensor; }); CHECK(find_store.empty()); // 4. Get store that will be inlined. auto find_inlined_store = ir::CollectIRNodesWithoutTensor( root, [&](const Expr* x) { return x->As() && x->As()->tensor == tensor; }); CHECK_EQ(find_inlined_store.size(), 1U); auto inlined_store = *find_inlined_store.begin(); // 5. Get target store. auto find_target_store = ir::CollectIRNodesWithoutTensor( compute_body, [&](const Expr* x) { return x->As(); }, true); CHECK_EQ(find_target_store.size(), 1U); auto target_store = *find_target_store.begin(); return {inlined_load, inlined_store, target_store}; } bool ContainVar(const std::vector& exprs, const std::string& var_name) { for (auto& expr : exprs) { auto find_expr = ir::CollectIRNodesWithoutTensor( expr, [&](const Expr* x) { return x->As<_Var_>() && x->As<_Var_>()->name == var_name; }, true); if (!find_expr.empty()) return true; } return false; } std::unordered_map PrimeFactorize(int n) { std::unordered_map factors; while (n % 2 == 0) { ++factors[2]; n /= 2; } for (int i = 3; i <= sqrt(n); i += 2) { while (n % i == 0) { ++factors[i]; n /= i; } } if (n > 2) { factors[n] = 1; } return factors; } std::vector SampleTile(utils::LinearRandomEngine::StateType* rand_seed, int n, int extent) { std::vector tile; while (n > 1) { std::unordered_map factors = PrimeFactorize(extent); int product = 1; for (auto& factor : factors) { if (factor.second >= 1) { int num = utils::SampleUniformInt(1, factor.second + 1, rand_seed); product *= std::pow(factor.first, num); } } tile.push_back(product); extent /= product; --n; } tile.push_back(extent); return tile; } } // namespace ir } // namespace cinn