// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/details/inplace_op_pass.h" #include #include #include #include #include #include #include #include #include "paddle/fluid/framework/details/graph_print_pass.h" #include "paddle/fluid/framework/details/memory_optimize_pass.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/op_info.h" // NOTE(dzhwinter): inplace means one op output variable reuse the input space. // By our design, one operator only can read its input(const Variable), // write its output(non-const Variable). If one operator is inplaced, means // user have chance to write the space before reading happens. // Especially when some optimize code writing style is applied. // // // /* wrong case in operator */ // /*In this case, a larger allocation is allocated, input content is lost*/ // const Tensor* in = ctx.Input("In") // Tensor* out = ctx.Output("Out"); // auto* out_ptr = out->mutable_data(ctx.GetPlace()); // out_ptr[0] = 0; // input contect is overwrited. // NOTE(dzhwinter): // Only for backward compacity and stable. if enable_inplace_whitelist is turn // on. // only the ops in whitelist will be use inplace strategy. // if not, all the op will be inplaced if it registered with InplaceClass DEFINE_bool( enable_inplace_whitelist, false, "If this option turns on, only these op in whitelist can be inplaced." "If it turns off, all of the running op can be candidate of inplaced op." "Such as scale, elementwise_add" "By default, it's turned on"); DECLARE_string(memory_optimize_debug); // clang-format off const std::string kInplacedOpWhiteList[] = { // NOLINT "sigmoid", "exp", "relu", "tanh", "sqrt", "ceil", "floor", "reciprocal", "relu6", "soft_relu", "hard_sigmoid", "batch_norm", "batch_norm_grad", "sum", "sum_grad", "scale", "reshape", "elementwise_add", "elementwise_add_grad", }; // clang-format on namespace paddle { namespace framework { namespace details { static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) { // if next op is inplaced, then return the output var // otherwise return nullptr PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); ir::Node* inplaced_var = nullptr; for (auto* next_op : var->outputs) { for (auto* output : next_op->outputs) { if (output->IsVar() && !output->IsCtrlVar() && output->Name() == var->Name()) { inplaced_var = output; } } } return inplaced_var; } static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) { PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); if (var->inputs.empty()) return nullptr; auto* prev_op = var->inputs.at(0); auto input_it = std::find_if(prev_op->inputs.begin(), prev_op->inputs.end(), [&](ir::Node* node) { if (node->IsVar() && !node->IsCtrlVar() && node->Name() == var->Name()) { return true; } else { return false; } }); return input_it == prev_op->inputs.end() ? nullptr : *input_it; } template static inline bool ConnectByCtrlVar(const Container& group1, const Container& group2) { bool connected = false; std::unordered_set outputs; for (auto* op : group1) { for (auto* var : op->outputs) { if (var->IsCtrlVar()) outputs.emplace(var); } } for (auto* op : group2) { for (auto* var : op->inputs) { if (outputs.count(var)) connected = true; } } return connected; } InplacePass::InplacePass() : Pass() { if (FLAGS_enable_inplace_whitelist) { for (auto& s : kInplacedOpWhiteList) { whitelist_.emplace(s); } } } void InplacePass::InitSSAGraphNodes() const { std::unordered_map> all_vars; for (auto* op : view_.AllOps()) { for (auto* node : op->inputs) { if (!node->IsVar() || node->IsCtrlVar()) continue; if (all_vars[node->Name()].count(node) == 0) { all_vars[node->Name()].emplace(node); var_nodes_[node->Name()].emplace_back(node); } } for (auto* node : op->outputs) { if (!node->IsVar() || node->IsCtrlVar()) continue; if (all_vars[node->Name()].count(node) == 0) { all_vars[node->Name()].emplace(node); var_nodes_[node->Name()].emplace_back(node); } } } } std::unique_ptr InplacePass::ApplyImpl( std::unique_ptr graph) const { var_nodes_.clear(); view_.Build(graph.get()); InitSSAGraphNodes(); for (auto* op : view_.AllOps()) { if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) continue; TryInplaceOpInputOutput(op, graph.get()); } graph->ResolveHazard(var_nodes_); return graph; } void InplacePass::InplaceModifyDesc(const std::string& var, const std::string& cache_var, const size_t& idx) const { for (size_t i = idx; i < view_.AllOps().size(); ++i) { ir::Node* op = view_.AllOps()[i]; PADDLE_ENFORCE(op->IsOp() && op->Op()); auto* op_desc = op->Op(); op_desc->RenameInput(var, cache_var); op_desc->RenameOutput(var, cache_var); if (op_desc->Block()->HasVar(var)) op_desc->Block()->RemoveVar(var); op_desc->Flush(); } } const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var, const std::string& cache_var, const size_t& idx, ir::Graph* graph) const { PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && var_nodes_[var].at(0)->Var() != nullptr); std::unique_ptr var_desc(new VarDesc(*var_nodes_[var].at(0)->Var())); var_desc->SetName(cache_var); SSANodePair swap_nodes; for (size_t i = idx; i < view_.AllOps().size(); ++i) { auto* op = view_.AllOps()[i]; // redirect the input to the latest version of cache_var for (auto* node : op->inputs) { if (node->Name() == var) { ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); // swap node to cache_node cache_node->outputs.insert(cache_node->outputs.end(), node->outputs.begin(), node->outputs.end()); PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp()); auto* prev_op = node->inputs[0]; std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node, cache_node); cache_node->inputs.emplace_back(prev_op); for (auto* next_op : node->outputs) { std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, cache_node); } swap_nodes.emplace_back(std::make_pair(node, cache_node)); } } // if we need to rename the output, // always create a newer version of cache_var for (auto* node : op->outputs) { if (node->Name() == var) { ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); // swap node to cache node cache_node->outputs.insert(cache_node->outputs.end(), node->outputs.begin(), node->outputs.end()); cache_node->inputs.emplace_back(op); std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node); for (auto* next_op : node->outputs) { std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, cache_node); } swap_nodes.emplace_back(std::make_pair(node, cache_node)); } } } return swap_nodes; } void InplacePass::CommitModify(const SSANodePair& swap_nodes, ir::Graph* graph) const { for (auto& pair : swap_nodes) { auto *node = pair.first, *cache_node = pair.second; const std::string var = node->Name(), cache_var = cache_node->Name(); var_nodes_[cache_var].emplace_back(cache_node); graph->RemoveNode(node); auto& nodes = var_nodes_.at(var); // release unused var in graph. Because python side memory optimize // may reused the var in same name, so we only clear the var node // after current inplaced index. nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end()); } } void InplacePass::WithdrawModify(const SSANodePair& nodes, ir::Graph* graph) const { for (auto& pair : nodes) { auto *node = pair.first, *cache_node = pair.second; const std::string var = node->Name(), cache_var = cache_node->Name(); auto* prev_op = node->inputs[0]; std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node, node); for (auto* next_op : node->outputs) { std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node, node); } graph->RemoveNode(cache_node); } } void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ir::Graph* graph) const { VLOG(4) << "Try to inplace op " << op->Name(); PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr, "op_desc is nullptr"); // 4 pre-requirments need to meet if the op want to inplaced. // 1. infer_inplace_ is registered. auto* op_desc = op->Op(); auto& infer_inplace = OpInfoMap::Instance().Get(op_desc->Type()).infer_inplace_; if (!static_cast(infer_inplace)) return; PADDLE_ENFORCE(static_cast(infer_inplace), "%s's infer_inplace has not been registered", op_desc->Type()); auto* block = op_desc->Block(); auto in_to_outs = infer_inplace(*op_desc, block); auto& all_ops = view_.AllOps(); auto cursor = std::find(all_ops.begin(), all_ops.end(), op); size_t idx = std::distance(all_ops.begin(), cursor); for (auto& pair : in_to_outs) { auto& in_var_name = pair.first; auto& out_var_name = pair.second; auto* in_node = view_.GetNodeByName(in_var_name, op->inputs); auto* out_node = view_.GetNodeByName(out_var_name, op->outputs); // 2. there is no external pending op on the input node if (view_.PendingOpsOnVar(in_node).size() > 1) { VLOG(4) << string::Sprintf( "Skiped pair %s => %s. %s input has external dependency." "inplace such pair will overwrite the memory.", out_var_name, in_var_name, op->Name()); continue; } // 3. if output reuse input inplaced, the dependency group is not changed. // For detail, check // the function description in "OutConnectInputByCtrlVar" if (view_.OutConnectInputByCtrlVar(in_node, out_node)) { VLOG(4) << string::Sprintf( "Skiped pair %s => %s. %s input and output connect by ctrl var." "inplace such pair will generate a circle.", out_var_name, in_var_name, op->Name()); continue; } // 4. if output has been memory optimize by python(fluid.memory_optmize()). // this candidate can not be inplaced. Will be deprecated in the future. if (view_.ReusedInPythonMemOpt(out_node->Name())) { VLOG(4) << string::Sprintf( "Skiped %s => %s reused previous memory block in python memory " "optmize," "it inplace may generate a circle", out_var_name, in_var_name, op->Name()); continue; } // Debug Interface. Which would be skipped by the pass. if (out_node->Name() == FLAGS_memory_optimize_debug) { VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug=" << out_node->Name(); continue; } // NOTE(dzhwinter): // two stage commit of inplaced process. if after inplace happens generate a // circle, // then withdraw the changes. Otherwise, safely add the node. auto swap_nodes = TryInplaceModifyVar(out_var_name, in_var_name, idx, graph); if (!ir::HasCircle(*graph)) { VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(), out_var_name, in_var_name); InplaceModifyDesc(out_var_name, in_var_name, idx); CommitModify(swap_nodes, graph); } else { VLOG(3) << string::Sprintf( "Skiped pair %s => %s, inplace will generate a circle. withdraw %s", out_var_name, in_var_name, op->Name()); WithdrawModify(swap_nodes, graph); } } } ir::Node* GraphView::GetNodeByName(const std::string& name, const std::vector& nodes) const { // nodes should be op->inputs/outputs // node in same node do have different name. std::unordered_set nodes_in_op; bool has_dup_node = std::all_of(nodes.begin(), nodes.end(), [&nodes_in_op](ir::Node* node) { if (!node->IsVar() || node->IsCtrlVar() || node->Var() == nullptr) { if (nodes_in_op.count(node->Name())) return true; nodes_in_op.emplace(node->Name()); } return false; }); PADDLE_ENFORCE(has_dup_node == false, "nodes has same name!"); ir::Node* node = nullptr; for (auto* it : nodes) { if (!it->IsVar() || it->IsCtrlVar() || it->Var() == nullptr) continue; if (it->Name() == name) { node = it; break; } } PADDLE_ENFORCE(node != nullptr, string::Sprintf("Not found var %s in nodes!", name)); return node; } std::vector GraphView::PendingOpsOnVar(ir::Node* node) { // get the pending ops depends on same var node. // because node also maybe a inplaced variable, so need to backtrack all the // previous inplaced vars. std::vector pending_ops; ir::Node* p = node; while (p != nullptr) { pending_ops.insert(pending_ops.end(), p->outputs.begin(), p->outputs.end()); p = GetPrevCascadeInplacedVar(p); } return pending_ops; } void GraphView::Build(ir::Graph* g) { // track the var nodes in correct order. // Because we insert some new created node. Which may have data race between // nodes. // resolve data harzards depends on the var nodes in right order. ops_ = SortOpLikeDescOrder(*g); // track the nodes which reused previous node in Python memory optimize. // these node can not be inplaced, otherwise may generate a circle in graph. std::unordered_set all_vars; for (auto& node : g->Nodes()) { if (node->IsVar()) continue; for (auto& out : node->outputs) { if (out->IsCtrlVar() || out->Var() == nullptr) continue; if (all_vars.count(out->Name())) { dup_nodes_.emplace(out->Name()); } else { all_vars.emplace(out->Name()); } } } } const std::vector GraphView::AllOps() { return ops_; } bool GraphView::OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var) { // assume v_a0, v_a1 is variable. v_a0 -> v_a0 means already inplaced. // v_a1 -> v_a1 means already inplaced. // Currently we make decision to check if the v_a0 -> v_a1 can be inplace. // // v_a0 // + // | // v // v_a0 // + // | // v // v_a1 // + // | // v // v_a1 // start from the first inplaced input v_a0(on the top one). // Do a DFSSearch, get all its paths. If there is one path connect // the in_var and out_var which contains control dep var. // Means there a control path. out_var can not be inplaced use in_var. std::unordered_set out_var_set, in_var_set; ir::Node* out = out_var; // get the ops with same output name while (out != nullptr) { out_var_set.emplace(out); out = GetNextCascadeInplacedVar(out); } // get ops with same input name ir::Node* in = in_var; while (in != nullptr) { in_var_set.emplace(in); in = GetPrevCascadeInplacedVar(in); } // find if there is path with control dep var connect the in_var_set and // out_var_set return ConnectByCtrlVar(in_var_set, out_var_set); } bool GraphView::ReusedInPythonMemOpt(const std::string& var) const { return dup_nodes_.count(var); } } // namespace details } // namespace framework } // namespace paddle REGISTER_PASS(inplace_pass, paddle::framework::details::InplacePass);