// Copyright (c) 2021 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/cinn/optim/eliminate_broadcast_in_forloop.h" #include #include #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/ir_visitor.h" #include "paddle/cinn/optim/ir_replace.h" namespace cinn { namespace optim { namespace detail { struct EliminateBroadcastInForloop : public ir::IRMutator { void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } void Visit(const ir::Store* op, Expr* expr) { // TODO(Superjom) Support single one level of forloop. if (forloop_stack.size() < 2) return; auto* node = expr->As(); auto broadcasts = ir::CollectIRNodes(node->value, [&](const Expr* expr) { return expr->As(); }); std::vector let_exprs; Var tmp; Expr let_expr; Var cur_level_loop_var = forloop_stack.back()->As() ? forloop_stack.back()->As()->loop_var : forloop_stack.back()->As()->iterator; for (Expr broadcast : broadcasts) { if (ContainsLoopVar(broadcast, cur_level_loop_var)) continue; VLOG(4) << "eliminating " << broadcast; std::tie(let_expr, tmp) = CreateTmpLet(broadcast); let_exprs.push_back(let_expr); optim::IrReplace(expr, broadcast, tmp); } // insert the let expressions to the outer forloop. Expr* outer_forloop = forloop_stack[forloop_stack.size() - 2]; auto& outer_forloop_body = outer_forloop->As() ? outer_forloop->As()->body : outer_forloop->As()->body; auto* outer_forloop_body_block = outer_forloop_body.As(); if (outer_forloop_body_block) { outer_forloop_body_block->stmts.insert( std::begin(outer_forloop_body_block->stmts), let_exprs.begin(), let_exprs.end()); } else { let_exprs.push_back(outer_forloop_body); outer_forloop_body = ir::Block::Make(let_exprs); } } bool ContainsLoopVar(Expr expr, Var loop_var) { return !ir::CollectIRNodes(expr, [&](const Expr* e) -> bool { return e->As() && e->As()->name == loop_var->name; }).empty(); } std::tuple CreateTmpLet(Expr body) { Var tmp(Context::Global().NewName("tmp"), body.type()); Expr let_expr = ir::Let::Make(tmp, body); return std::make_tuple(let_expr, tmp); } void Visit(const ir::For* op, Expr* expr) { forloop_stack.push_back(expr); ir::IRMutator<>::Visit(op, expr); forloop_stack.pop_back(); } void Visit(const ir::PolyFor* op, Expr* expr) { forloop_stack.push_back(expr); ir::IRMutator<>::Visit(op, expr); forloop_stack.pop_back(); } std::vector forloop_stack; }; } // namespace detail void EliminateBroadcastInForloop(Expr* expr) { detail::EliminateBroadcastInForloop mutator; mutator(expr); } } // namespace optim } // namespace cinn