diff --git a/paddle/fluid/framework/details/inplace_op_pass.cc b/paddle/fluid/framework/details/inplace_op_pass.cc index 03b21ae0ae8863711113a9ee6a76874cf174c77f..d4812a01d8a141cc3dee1d4fbdec7d252e9b3fd1 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.cc +++ b/paddle/fluid/framework/details/inplace_op_pass.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,20 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/details/inplace_op_pass.h" -#include -#include -#include -#include +#include #include -#include -#include #include -#include #include -#include #include "paddle/fluid/framework/details/memory_optimize_pass.h" +#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/op_info.h" // NOTE(dzhwinter): inplace means one op output variable reuse the input space. @@ -56,6 +50,10 @@ DEFINE_bool( DECLARE_string(memory_optimize_debug); +namespace paddle { +namespace framework { +namespace details { + // clang-format off const std::string kInplacedOpWhiteList[] = { // NOLINT "sigmoid", @@ -83,490 +81,378 @@ const std::string kInplacedOpWhiteList[] = { // NOLINT // but the static size during compiling time would be wrong. // Use a flag to indicate such ops. Please fix me when found a better way. static const std::unordered_set kSameShapeOpWhiteSet{ // NOLINT - "reshape2" + "reshape2", "reshape2_grad" }; // clang-format on -namespace paddle { -namespace framework { -namespace details { +class InplacePass : public ir::Pass { + public: + InplacePass(); -static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) { - // if next op is inplaced, then return the output var - // otherwise return nullptr - PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); - ir::Node* inplaced_var = nullptr; - for (auto* next_op : var->outputs) { - for (auto* output : next_op->outputs) { - if (output->IsVar() && !output->IsCtrlVar() && - output->Name() == var->Name()) { - inplaced_var = output; - } + protected: + void ApplyImpl(ir::Graph *graph) const override; + + private: + // Collect vars that cannot be reused + // e.g.: subblock ops in/out, distributed ops in/out, op_role_var + void CollectSkipVars(ir::Graph *graph, + const std::vector &ops) const; + + // Check whether var_name should be skipped + bool IsSkipVar(const std::string &var_name) const; + + // Rename out with name of in, and guarantee that the graph is + // still a SSA graph + void RenameInOut(ir::Node *op, ir::Node *in, ir::Node *out) const; + + // Check whether var is the last version one in SSA graph + bool IsLastVersionVar(ir::Node *var) const; + + // Check whether all `ops` is the preceding ops of `op` + bool CheckOpDeps(ir::Node *op, const std::vector &ops) const; + + // Find node whose name is equal to the given name + static ir::Node *FindNodeByName(const std::string &name, + const std::vector &nodes); + + // Get all versions vars named var_name + std::vector *AllVersionVars(const std::string &var_name) const; + + private: + // SSA graph. var_name -> each version of vars + mutable std::map> ssa_map_; + + // Skip vars, including subblock ops in/out, distributed ops in/out, + // op_role_var + mutable std::unordered_set skip_vars_; + + // Op whitelist which should not peform inplace + // Only enabled when FLAGS_enable_inplace_whitelist is true. + mutable std::unordered_set whitelist_ops_; +}; + +InplacePass::InplacePass() { + if (FLAGS_enable_inplace_whitelist) { + for (auto &s : kInplacedOpWhiteList) { + whitelist_ops_.emplace(s); } } - return inplaced_var; } -static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) { - PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar()); - if (var->inputs.empty()) return nullptr; - auto* prev_op = var->inputs.at(0); - auto input_it = std::find_if(prev_op->inputs.begin(), prev_op->inputs.end(), - [&](ir::Node* node) { - if (node->IsVar() && !node->IsCtrlVar() && - node->Name() == var->Name()) { - return true; - } else { - return false; - } - }); - return input_it == prev_op->inputs.end() ? nullptr : *input_it; +std::vector *InplacePass::AllVersionVars( + const std::string &var_name) const { + auto iter = ssa_map_.find(var_name); + PADDLE_ENFORCE(iter != ssa_map_.end(), "cannot find var %s in ssa graph", + var_name); + PADDLE_ENFORCE(!iter->second.empty(), "var %s is empty in ssa graph", + var_name); + return &(iter->second); } -InplacePass::InplacePass() : Pass() { - if (FLAGS_enable_inplace_whitelist) { - for (auto& s : kInplacedOpWhiteList) { - whitelist_.emplace(s); +bool InplacePass::IsSkipVar(const std::string &var_name) const { + return skip_vars_.count(var_name) > 0; +} + +bool InplacePass::IsLastVersionVar(ir::Node *var) const { + return AllVersionVars(var->Name())->back() == var; +} + +bool InplacePass::CheckOpDeps(ir::Node *op, + const std::vector &ops) const { + std::unordered_set other_ops(ops.begin(), ops.end()); + other_ops.erase(op); + if (other_ops.empty()) return true; + + // Traverse all preceding ops of op + std::queue queue; + std::unordered_set visited_ops; + queue.push(op); + visited_ops.insert(op); + + // Visit all preceding ops of `op`, and erase it from other_ops if it is + // inside other_ops. Return true only if other_ops is empty(), which means + // that all `ops` are preceding ops of `op`. + while (!queue.empty()) { + auto *cur_op = queue.front(); + queue.pop(); + + for (auto *in_var : cur_op->inputs) { + for (auto *in_op : in_var->inputs) { + if (visited_ops.count(in_op) != 0) { + continue; + } + + visited_ops.insert(in_op); + queue.push(in_op); + other_ops.erase(in_op); + if (other_ops.empty()) return true; + } } } + return false; } -void InplacePass::InitSSAGraphNodes() const { - std::unordered_map> all_vars; - for (auto* op : view_.AllOps()) { - for (auto* node : op->inputs) { - if (!node->IsVar() || node->IsCtrlVar()) continue; - if (all_vars[node->Name()].count(node) == 0) { - all_vars[node->Name()].emplace(node); - var_nodes_[node->Name()].emplace_back(node); +void InplacePass::CollectSkipVars(ir::Graph *graph, + const std::vector &ops) const { + // 1. Collect op role vars + PADDLE_ENFORCE(graph->Has(details::kMemOptSkipVars), + "Graph should have attr %s", details::kMemOptSkipVars); + auto &mem_opt_whitelist = graph->Get(kMemOptSkipVars); + for (const auto &var : mem_opt_whitelist) { + skip_vars_.emplace(var); + } + + // 2. track the nodes which used by parameter server. + // these node can not be inplaced, otherwise trainer + // pserver can not find each other's name. + // Also check the ops which has sub-block + auto update_skip_set = [&](ir::Node *node) { + for (auto &in : node->inputs) { + if (in->IsVar() && in->Var() != nullptr) { + skip_vars_.emplace(in->Name()); } } - for (auto* node : op->outputs) { - if (!node->IsVar() || node->IsCtrlVar()) continue; - if (all_vars[node->Name()].count(node) == 0) { - all_vars[node->Name()].emplace(node); - var_nodes_[node->Name()].emplace_back(node); + for (auto &out : node->outputs) { + if (out->IsVar() && out->Var() != nullptr) { + skip_vars_.emplace(out->Name()); } } - } -} - -void InplacePass::ApplyImpl(ir::Graph* graph) const { - var_nodes_.clear(); - view_.Build(graph); - InitSSAGraphNodes(); + }; - auto cnt = 0; - for (auto* op : view_.AllOps()) { - VLOG(4) << "Handle op " << cnt++ << ": " << op->Name(); - if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) + for (auto *node : ops) { + if (!node->IsOp()) continue; + // avoid optimizing the variable used in sub-blocks + if (OpHasSubBlock(node->Op())) { + update_skip_set(node); continue; - TryInplaceOpInputOutput(op, graph); - } -} + } -void InplacePass::InplaceModifyDesc(const std::string& var, - const std::string& cache_var, - const size_t& idx) const { - for (size_t i = idx; i < view_.AllOps().size(); ++i) { - ir::Node* op = view_.AllOps()[i]; - PADDLE_ENFORCE(op->IsOp() && op->Op()); - auto* op_desc = op->Op(); - op_desc->RenameInput(var, cache_var); - op_desc->RenameOutput(var, cache_var); - - op_desc->Flush(); + auto node_name = node->Name(); + if (node_name == "send" || node_name == "recv" || node_name == "prefetch") { + update_skip_set(node); + } } } -const NodeSwapQueue InplacePass::TryInplaceModifyVar( - const std::string& var, const std::string& cache_var, const size_t& idx, - ir::Graph* graph) const { - PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && - var_nodes_[var].at(0)->Var() != nullptr); - std::unique_ptr var_desc(new VarDesc(*var_nodes_[var].at(0)->Var())); - var_desc->SetName(cache_var); - - NodeSwapQueue swap_nodes; - - for (size_t i = idx; i < view_.AllOps().size(); ++i) { - auto* op = view_.AllOps()[i]; - - // redirect the input to the latest version of cache_var - for (auto* node : op->inputs) { - if (node->Name() == var) { - ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); - - // swap node to cache_node - cache_node->outputs.insert(cache_node->outputs.end(), - node->outputs.begin(), node->outputs.end()); - PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp()); - auto* prev_op = node->inputs[0]; - std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node, - cache_node); - cache_node->inputs.emplace_back(prev_op); - for (auto* next_op : node->outputs) { - std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, - cache_node); +void InplacePass::RenameInOut(ir::Node *op, ir::Node *in_var, + ir::Node *out_var) const { + auto out_var_name = out_var->Name(); + auto in_var_name = in_var->Name(); + + auto &all_out_nodes = *AllVersionVars(out_var_name); + auto &all_in_nodes = *AllVersionVars(in_var_name); + + auto iter = std::find(all_out_nodes.begin(), all_out_nodes.end(), out_var); + PADDLE_ENFORCE(iter != all_out_nodes.end(), "Cannot find out var %s", + out_var_name); + + // The following codes are designed to guarantee that ssa_map_ is still + // an ssa graph after inplace is performed. + // Step 1: Rename the following versions of out_var as the name of in_var + // Step 2: Remove the following versions of out_var and append them to in_var + // Be careful that the inputs of input op of out_var should not be renamed, + // but outputs should be renamed. + auto original_iter = iter; + while (iter != all_out_nodes.end()) { + auto *node = *iter; + /* Step 1 */ + node->RenameVar(in_var_name); + if (iter != original_iter) { + for (auto *in : node->inputs) { + if (in->IsOp() && in->Op()) { + in->Op()->RenameOutput(out_var_name, in_var_name); + in->Op()->RenameInput(out_var_name, in_var_name); + in->Op()->Flush(); } - - swap_nodes.emplace_back(std::make_pair(node, cache_node)); } } - // if we need to rename the output, - // always create a newer version of cache_var - for (auto* node : op->outputs) { - if (node->Name() == var) { - ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); - // swap node to cache node - cache_node->outputs.insert(cache_node->outputs.end(), - node->outputs.begin(), node->outputs.end()); - cache_node->inputs.emplace_back(op); - std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node); - for (auto* next_op : node->outputs) { - std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, - cache_node); - } - - swap_nodes.emplace_back(std::make_pair(node, cache_node)); + for (auto *out : node->outputs) { + if (out->IsOp() && out->Op()) { + out->Op()->RenameOutput(out_var_name, in_var_name); + out->Op()->RenameInput(out_var_name, in_var_name); + out->Op()->Flush(); } } + + /* Step 2 */ + all_in_nodes.emplace_back(node); + ++iter; } - return swap_nodes; -} + /* Step 2 */ + all_out_nodes.erase(original_iter, all_out_nodes.end()); -void InplacePass::CommitModify(const NodeSwapQueue& swap_nodes, - ir::Graph* graph) const { - for (auto& pair : swap_nodes) { - auto *node = pair.first, *cache_node = pair.second; - const std::string var = node->Name(), cache_var = cache_node->Name(); - var_nodes_[cache_var].emplace_back(cache_node); - graph->RemoveNode(node); - auto& nodes = var_nodes_.at(var); - // release unused var in graph. Because python side memory optimize - // may reused the var in same name, so we only clear the var node - // after current inplaced index. - nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end()); + if (all_out_nodes.empty()) { + ssa_map_.erase(out_var_name); } + op->Op()->RenameOutput(out_var_name, in_var_name); + op->Op()->Flush(); } -void InplacePass::WithdrawModify(const NodeSwapQueue& nodes, - ir::Graph* graph) const { - for (auto& pair : nodes) { - auto *node = pair.first, *cache_node = pair.second; - const std::string var = node->Name(), cache_var = cache_node->Name(); - auto* prev_op = node->inputs[0]; - std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node, - node); - for (auto* next_op : node->outputs) { - std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node, - node); +ir::Node *InplacePass::FindNodeByName(const std::string &name, + const std::vector &nodes) { + ir::Node *found_node = nullptr; + for (auto *node : nodes) { + if (node->Name() == name) { + PADDLE_ENFORCE(found_node == nullptr, "Find duplicate input nodes %s", + name); + found_node = node; } - graph->RemoveNode(cache_node); } + return found_node; } -void InplacePass::TryInplaceOpInputOutput(ir::Node* op, - ir::Graph* graph) const { - VLOG(4) << "Try to inplace op " << op->Name(); - // some pre-requirments need to meet if the op want to inplaced. - PADDLE_ENFORCE(op->Op() != nullptr, "op_desc is nullptr"); - - auto* op_desc = op->Op(); - auto& infer_inplace = - OpInfoMap::Instance().Get(op_desc->Type()).infer_inplace_; - - // 1. infer_inplace_ is registered. - if (!static_cast(infer_inplace)) return; - PADDLE_ENFORCE(static_cast(infer_inplace), - "%s's infer_inplace has not been registered", op_desc->Type()); - - auto in_to_outs = infer_inplace(*op_desc); - - auto& all_ops = view_.AllOps(); - auto cursor = std::find(all_ops.begin(), all_ops.end(), op); - size_t idx = std::distance(all_ops.begin(), cursor); - - for (auto& pair : in_to_outs) { - auto& in_para_name = pair.first; - auto& out_para_name = pair.second; - - auto input_vars = op->Op()->Input(in_para_name); - if (!input_vars.size()) { - VLOG(4) << "Parameter " << in_para_name << " is empty skip " - << in_para_name << " => " << out_para_name << " pair"; - continue; - } - auto output_vars = op->Op()->Output(out_para_name); - if (!output_vars.size()) { - VLOG(4) << "Parameter " << out_para_name << " is empty skip " - << in_para_name << " => " << out_para_name << " pair"; - continue; - } - auto in_var_name = input_vars.at(0); - auto out_var_name = output_vars.at(0); - auto* in_node = view_.GetNodeByName(in_var_name, op->inputs); - auto* out_node = view_.GetNodeByName(out_var_name, op->outputs); - - VLOG(4) << "Try to replace: " << in_var_name << " => " << out_var_name; - if (view_.InSkipSet(in_var_name)) { - VLOG(4) << string::Sprintf("SKIP: %s is in skip set", in_var_name); - continue; - } - - if (view_.InSkipSet(out_var_name)) { - VLOG(4) << string::Sprintf("SKIP: %s is in skip set", out_var_name); - continue; +void InplacePass::ApplyImpl(ir::Graph *graph) const { + // Step 1: topo sort ops, collect skip vars + auto ops = ir::TopologySortOperations(*graph); + CollectSkipVars(graph, ops); + + // Step 2: build ssa var map + for (auto *op_node : ops) { + for (auto *in : op_node->inputs) { + PADDLE_ENFORCE(in->IsVar()); + // Only create a new var node when var first occurs in input of op. + if (ssa_map_.count(in->Name()) == 0) { + ssa_map_[in->Name()].emplace_back(in); + } } - if (var_nodes_[in_var_name].back() != in_node) { - VLOG(4) << "SKIP since " << in_var_name - << " is also used as output by other ops"; - continue; + // Always create a new var node for each output of op. + for (auto *out : op_node->outputs) { + PADDLE_ENFORCE(out->IsVar()); + ssa_map_[out->Name()].emplace_back(out); } + } - bool can_replace = true; - if (in_var_name == out_var_name) { - can_replace = false; - VLOG(4) << "SKIP: Input variable " << in_var_name << " & Output variable " - << out_var_name << " are the same"; - } else if (!NodeCanReused(in_node)) { - can_replace = false; - VLOG(4) << "SKIP: Input variable " << in_var_name << "cannot be reused"; - } else if (!NodeCanReused(out_node)) { - can_replace = false; - VLOG(4) << "SKIP: Output variable " << out_var_name - << " cannot be reused"; - } else if (in_node->Var()->GetType() != out_node->Var()->GetType()) { - can_replace = false; - VLOG(4) << "SKIP: Input type : " << in_node->Var()->GetType() - << " does not match Output type : " << out_node->Var()->GetType(); - } else if (details::NodeSize(*in_node->Var()) != - details::NodeSize(*out_node->Var()) && - kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) { - can_replace = false; - VLOG(4) << "SKIP: Input and Output varialbe size not match"; - } + // Step 3: traverse ops and try inplace if possible + for (auto *op_node : ops) { + PADDLE_ENFORCE_NOT_NULL(op_node->Op(), "op_desc is nullptr"); - if (!can_replace) continue; + auto *op_desc = op_node->Op(); + auto op_type = op_desc->Type(); - // 2. If the variable is the input of muliple ops, we need to make sure - // current op has dependecny on other ops use the same variable - if (in_node->outputs.size() > 1 && !view_.CheckDeps(in_node, op)) { - VLOG(4) << string::Sprintf( - "Skiped pair %s => %s. %s input has external dependency." - "inplace such pair will overwrite the memory.", - out_var_name, in_var_name, op->Name()); + // Skip op inside whitelist + if (whitelist_ops_.count(op_type) > 0) { continue; } - // Debug Interface. Which would be skipped by the pass. - if (out_node->Name() == FLAGS_memory_optimize_debug) { - VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug=" - << out_node->Name(); + auto &infer_inplace = OpInfoMap::Instance().Get(op_type).infer_inplace_; + + if (!infer_inplace) { continue; } - // NOTE(dzhwinter): - // two stage commit of inplaced process. if after inplace happens generate a - // circle, - // then withdraw the changes. Otherwise, safely add the node. - auto swap_nodes = - TryInplaceModifyVar(out_var_name, in_var_name, idx, graph); - - if (!ir::HasCircle(*graph)) { - VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(), - out_var_name, in_var_name); - InplaceModifyDesc(out_var_name, in_var_name, idx); - CommitModify(swap_nodes, graph); - } else { - VLOG(3) << string::Sprintf( - "Skiped pair %s => %s, inplace will generate a circle. withdraw %s", - out_var_name, in_var_name, op->Name()); - WithdrawModify(swap_nodes, graph); - } - } -} + auto in_to_outs = infer_inplace(*op_desc); + for (auto &pair : in_to_outs) { + auto &in_param = pair.first; + auto &out_param = pair.second; -void GraphView::TopoSort(ir::Graph* graph) { - // - ops_.clear(); - auto deps_num = [](ir::Node* op) { - auto cnt = 0; - for (auto& var : op->inputs) - if (var->inputs.size() > 0) ++cnt; - return cnt; - }; + auto &in_args = op_desc->Input(in_param); + auto &out_args = op_desc->Output(out_param); - std::queue> ready_ops; + if (in_args.empty()) { + VLOG(4) << "Cannot inplace because Input(" << in_param + << ") is empty in " << op_type; + continue; + } - int level = 0; - auto nodes = graph->Nodes(); - std::unordered_map deps_map; - for (auto& node : nodes) { - if (node->IsOp() && node->Op() != nullptr) { - deps_map[node] = deps_num(node); - if (0 == deps_map[node]) { - ready_ops.push({node, level}); + if (out_args.empty()) { + VLOG(4) << "Cannot inplace because Output(" << out_param + << ") is empty in " << op_type; + continue; } - } - } - while (!ready_ops.empty()) { - auto item = ready_ops.front(); - ready_ops.pop(); + auto &in_arg = in_args[0]; + auto &out_arg = out_args[0]; - ops_.emplace_back(item.first); - // record level when pop from queue - op_level_[item.first] = item.second; + if (IsSkipVar(in_arg)) { + VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg + << " is skipped in " << op_type; + continue; + } - for (auto node : item.first->outputs) { - for (auto op : node->outputs) { - --deps_map[op]; - if (deps_map[op] == 0) ready_ops.push({op, item.second + 1}); + if (IsSkipVar(out_arg)) { + VLOG(4) << "Cannot inplace because Output(" << out_param + << ")=" << out_arg << " is skipped in " << op_type; + continue; } - } - } - bool all_ops_checked = true; - for (auto& node : nodes) { - if (node->IsOp() && node->Op() != nullptr && deps_map[node] > 0) { - all_ops_checked = false; - LOG(WARNING) - << "Node " << node->Name() << " has not been checked. " - << "Maybe some passes have not handle node dependency rightly."; - break; - } - } + if (in_arg == out_arg) { + VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg + << " is the same with Output(" << out_param << ")=" << out_arg + << " in " << op_type; + continue; + } - PADDLE_ENFORCE(all_ops_checked, "All ops deps should be 0 after analysis"); -} + auto *in_node = FindNodeByName(in_arg, op_node->inputs); + PADDLE_ENFORCE_NOT_NULL(in_node, "Input(%s)=%s cannot be found in op %s", + in_param, in_arg, op_type); -// return true if current op node depeneds on all other op that use the same -// variable node -bool GraphView::CheckDeps(ir::Node* var, ir::Node* current_op) const { - // get op list that rely on the same variable - auto op_list = var->outputs; - for (auto& op : op_list) { - if (op == current_op) continue; - - VLOG(4) << " GraphView::CheckDeps : " << op->Name() << " & " - << current_op->Name(); - if (!CheckOpDeps(op, current_op)) return false; - VLOG(4) << ""; - } - return true; -} - -// check if op2 depends on op1's output -bool GraphView::CheckOpDeps(ir::Node* op1, ir::Node* op2) const { - if (VLOG_IS_ON(4)) { - auto print_op = [&](ir::Node* op, const char* name) { - std::ostringstream os; - os << " " << name << " : " << op->Name() << " "; - os << "Input args : "; - for (auto& arg : op->inputs) os << arg->Name() << " "; - os << "Output args : "; - for (auto& arg : op->outputs) os << arg->Name() << " "; - os << "Level : " << op_level_.at(op); - VLOG(4) << os.str(); - }; - print_op(op1, "OP1"); - print_op(op2, "OP2"); - } - if (op1 == op2) return true; - if (op_level_.at(op1) >= op_level_.at(op2)) return false; + if (!NodeCanReused(in_node)) { + VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg + << " is not reusable in " << op_type; + continue; + } - for (auto& var : op2->inputs) - if (var->inputs.size() > 0 && CheckOpDeps(op1, var->inputs[0])) return true; + if (!IsLastVersionVar(in_node)) { + VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg + << " is not the last version in " << op_type; + continue; + } - return false; -} + // If in_node is used as inputs of many ops, check whether all of that ops + // depends on op_node. If not, in_node cannot be inplaced. + if (in_node->outputs.size() > 1 && + !CheckOpDeps(op_node, in_node->outputs)) { + VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg + << " is not lastly used in " << op_type; + continue; + } -ir::Node* GraphView::GetNodeByName(const std::string& name, - const std::vector& nodes) const { - // nodes should be op->inputs/outputs - // node in same node do have different name. - std::unordered_set nodes_in_op; - bool has_dup_node = - std::all_of(nodes.begin(), nodes.end(), [&nodes_in_op](ir::Node* node) { - if (!node->IsVar() || node->IsCtrlVar() || node->Var() == nullptr) { - if (nodes_in_op.count(node->Name())) return true; - nodes_in_op.emplace(node->Name()); - } - return false; - }); - PADDLE_ENFORCE(has_dup_node == false, "nodes has same name!"); - ir::Node* node = nullptr; - for (auto* it : nodes) { - if (!it->IsVar() || it->IsCtrlVar() || it->Var() == nullptr) continue; - if (it->Name() == name) { - node = it; - break; - } - } - PADDLE_ENFORCE(node != nullptr, - string::Sprintf("Not found var %s in nodes!", name)); - return node; -} + auto *out_node = FindNodeByName(out_arg, op_node->outputs); + PADDLE_ENFORCE_NOT_NULL(out_node, + "Output(%s)=%s cannot be found in op %s", + out_param, out_arg, op_type); -std::vector GraphView::PendingOpsOnVar(ir::Node* node) { - // get the pending ops depends on same var node. - // because node also maybe a inplaced variable, so need to backtrack all the - // previous inplaced vars. - std::vector pending_ops; - ir::Node* p = node; - while (p != nullptr) { - pending_ops.insert(pending_ops.end(), p->outputs.begin(), p->outputs.end()); - p = GetPrevCascadeInplacedVar(p); - } - return pending_ops; -} + if (!NodeCanReused(out_node)) { + VLOG(4) << "Cannot inplace because Output(" << out_param + << ")=" << out_arg << " is not reusable in " << op_type; + continue; + } -void GraphView::Build(ir::Graph* g) { - // track the var nodes in correct order. - // Because we insert some new created node. Which may have data race between - // nodes. - // resolve data harzards depends on the var nodes in right order. - TopoSort(g); + if (in_node->Var()->GetType() != out_node->Var()->GetType()) { + VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg + << " is not the same type with " + << "Output(" << out_param << ")=" << out_arg << " in " + << op_type; + continue; + } - // fill the skip_set_ - PADDLE_ENFORCE(g->Has(details::kMemOptSkipVars)); - auto& mem_opt_whitelist = g->Get(kMemOptSkipVars); - for (const auto& var : mem_opt_whitelist) skip_set_.emplace(var); + if (details::NodeSize(*in_node->Var()) != + details::NodeSize(*out_node->Var()) && + kSameShapeOpWhiteSet.count(op_desc->Type()) == 0) { + VLOG(4) << "Cannot inplace because Input(" << in_param << ")=" << in_arg + << " is not the same size with " + << "Output(" << out_param << ")=" << out_arg << " in " + << op_type; + continue; + } - // 2. track the nodes which used by parameter server. - // these node can not be inplaced, otherwise trainer - // pserver can not find each other name. - auto update_skip_set = [&](ir::Node* node) { - for (auto& in : node->inputs) { - if (in->IsVar() && in->Var() != nullptr) { - skip_set_.emplace(in->Name()); + // Debug Interface. Which would be skipped by the pass. + if (out_arg == FLAGS_memory_optimize_debug) { + VLOG(4) << "Skiped var by force. FLAGS_memory_optimize_debug=" + << out_node->Name(); + continue; } - } - for (auto& out : node->outputs) { - if (out->IsVar() && out->Var() != nullptr) skip_set_.emplace(out->Name()); - } - }; - for (auto& node : g->Nodes()) { - if (!node->IsOp()) continue; - // avoid optimize the variable used in sub-blocks - if (OpHasSubBlock(node->Op())) update_skip_set(node); - if (node->Name() == "send") update_skip_set(node); - if (node->Name() == "recv") update_skip_set(node); - if (node->Name() == "prefetch") update_skip_set(node); + VLOG(4) << "Rename " << out_node->Name() << " with " << in_node->Name() + << " in " << op_type; + RenameInOut(op_node, in_node, out_node); + } } } -const std::vector& GraphView::AllOps() { return ops_; } - -bool GraphView::InSkipSet(const std::string& var) const { - return skip_set_.count(var); -} - } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/inplace_op_pass.h b/paddle/fluid/framework/details/inplace_op_pass.h deleted file mode 100644 index 2cd6cbd1b0317c3ea301428f2537023b026e581e..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/details/inplace_op_pass.h +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may abtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/framework/details/memory_optimize_helper.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/pass.h" - -namespace paddle { -namespace framework { -namespace details { - -class GraphView { - public: - GraphView() = default; - - void Build(ir::Graph* g); - - const std::vector& AllOps(); - - ir::Node* GetNodeByName(const std::string& name, - const std::vector& nodes) const; - - std::vector PendingOpsOnVar(ir::Node* var); - - // Will Deperated in the future. - // NOTE(dzhwinter) : - // 1. Python memory optimize will reuse - // memory based var name, so different op output may - // have the same variable name. enable inplace on such node - // will generate a circle in ssa graph. - // 2. DistributeTranspiler will use unique name to - // map the parameter and gradient, must be skipped. - bool InSkipSet(const std::string& var) const; - - bool CheckDeps(ir::Node* var, ir::Node* current_op) const; - bool CheckOpDeps(ir::Node* op1, ir::Node* op2) const; - void TopoSort(ir::Graph* g); - - private: - std::vector ops_; - std::unordered_set skip_set_; // mem opt affect nodes - std::map> adj_list_; - std::unordered_map op_level_; -}; - -// swap pairs in sequence -typedef std::vector> NodeSwapQueue; -class InplacePass : public ir::Pass { - public: - InplacePass(); - - protected: - void ApplyImpl(ir::Graph* graph) const override; - - void InitSSAGraphNodes() const; - - private: - const NodeSwapQueue TryInplaceModifyVar(const std::string& var, - const std::string& cache_var, - const size_t& idx, - ir::Graph* graph) const; - - void CommitModify(const NodeSwapQueue&, ir::Graph* graph) const; - - void WithdrawModify(const NodeSwapQueue& nodes, ir::Graph* graph) const; - - void InplaceModifyDesc(const std::string& in_var, const std::string& out_var, - const size_t& idx) const; - - void TryInplaceOpInputOutput(ir::Node* op, ir::Graph* graph) const; - - mutable std::map> var_nodes_; - - mutable std::unordered_set whitelist_; - mutable GraphView view_; -}; - -} // namespace details -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/details/op_graph_view.h b/paddle/fluid/framework/details/op_graph_view.h index 77aa02eba56acb3bb20a5c5a55c75af78a3c1c81..1585c6f728531acde1d97aaac5c51b09e27c7d50 100644 --- a/paddle/fluid/framework/details/op_graph_view.h +++ b/paddle/fluid/framework/details/op_graph_view.h @@ -56,7 +56,7 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op, std::unordered_set visited; std::queue q; q.push(op); - do { + while (!q.empty()) { op = q.front(); q.pop(); for (auto &pending_op : pending_ops_.at(op)) { @@ -65,9 +65,10 @@ bool OpGraphView::VisitAllPendingOps(OpHandleBase *op, if (!callback(pending_op)) { return false; } + q.push(pending_op); } } - } while (!q.empty()); + } return true; } diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index 2b2a1ac51613e2520d96dd7ab6829f219678feea..31c32cc2e7b0354b2f624f457326f33409d276e2 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -118,82 +118,6 @@ class ShrinkDepsOpFunctor { const OpGraphView graph_; }; -/** - * Find the nearest downstream computation op handle. If the op is a - * computation op, just return itself. - */ -static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself( - OpHandleBase *op, size_t scope_idx) { - std::queue q; - std::unordered_set visited; - q.push(op); - do { - auto *op = q.front(); - q.pop(); - auto *compute_op = dynamic_cast(op); - if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) { - return compute_op; - } - for (auto *out_var : op->Outputs()) { - for (auto *pending_op : out_var->PendingOps()) { - if (visited.count(pending_op)) continue; - visited.insert(pending_op); - q.push(pending_op); - } - } - } while (!q.empty()); - return nullptr; -} - -static std::unordered_set -ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx, - const ShrinkDepsOpFunctor &shrink_func, - bool *ok) { - // stage one. Get last op for variable. - std::unordered_set candidates; - { - if (var->PendingOps().empty() && var->GeneratedOp()) { - // No operator depends on this variable. So the last operator is the op - // who generates this variable. - candidates.emplace(var->GeneratedOp()); - } else { - candidates = var->PendingOps(); - } - - // No pending ops or generated op is nullptr - if (candidates.empty()) { - *ok = false; - return {}; - } - } - - // stage two. Try to cast them to computation op. - // return (*ok=false) when failed. - // - // The reason why we cannot make any types of op handle to be the last lived - // op is: - // some op handle may operate on many DeviceContext, however, our garbage - // collector can only wait one DeviceContext for now. So currently, we wait - // the nearest compute op. - std::unordered_set computation_op; - { - for (auto *op : candidates) { - auto *compute_op = - FindNextComputationOpHandleOrReturnItself(op, scope_idx); - if (compute_op == nullptr) { - *ok = false; - return {}; - } - computation_op.emplace(compute_op); - } - } - - // stage three. Try to shrink computation op if they depend on each other. - // Get the smallest set of the most ops. - *ok = true; - return shrink_func(computation_op); -} - /** * Shrink op dependencies according to no need buffer vars. * @@ -267,6 +191,99 @@ static bool ShrinkNoNeedBufferVarOpDependency( } } +/** + * Find the nearest downstream computation op handle. If the op is a + * computation op, just return itself. + */ +static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself( + OpHandleBase *op, size_t scope_idx) { + std::queue q; + std::unordered_set visited; + q.push(op); + while (!q.empty()) { + auto *op = q.front(); + q.pop(); + auto *compute_op = dynamic_cast(op); + if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) { + return compute_op; + } + for (auto *out_var : op->Outputs()) { + for (auto *pending_op : out_var->PendingOps()) { + if (visited.count(pending_op)) continue; + visited.insert(pending_op); + q.push(pending_op); + } + } + } + return nullptr; +} + +enum LastLiveOpSearchStatus { kSuccess, kFailure, kShouldPrecede }; + +static std::unordered_set +ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx, + const std::string &var_name, + const ShrinkDepsOpFunctor &shrink_func, + LastLiveOpSearchStatus *status) { + // stage one. Get last op for variable. + std::unordered_set candidates; + { + if (var->PendingOps().empty() && var->GeneratedOp()) { + // No operator depends on this variable. So the last operator is the op + // who generates this variable. + candidates.emplace(var->GeneratedOp()); + } else { + candidates = var->PendingOps(); + } + + // No pending ops or generated op is nullptr + if (candidates.empty()) { + *status = LastLiveOpSearchStatus::kFailure; + return {}; + } + } + + // stage two. Try to cast them to computation op. + // return (*status=kFailure) when failed. + // + // The reason why we cannot make any types of op handle to be the last lived + // op is: + // some op handle may operate on many DeviceContext, however, our garbage + // collector can only wait one DeviceContext for now. So currently, we wait + // the nearest compute op. + std::unordered_set computation_op; + { + for (auto *op : candidates) { + auto *compute_op = + FindNextComputationOpHandleOrReturnItself(op, scope_idx); + if (compute_op == nullptr) { + *status = LastLiveOpSearchStatus::kFailure; + return {}; + } + computation_op.emplace(compute_op); + } + } + + // stage three. Try to shrink computation op if any of them does + // not need the buffer of var_name. + // If all computation ops do not need the buffer of var_name, + // return empty computation op set, and mark the status as kShouldPrecede, + // which means that the last living ops of var_name should be + // found in the previous version of var_name. + if (ShrinkNoNeedBufferVarOpDependency(var_name, &computation_op)) { + *status = LastLiveOpSearchStatus::kShouldPrecede; + return {}; + } + + PADDLE_ENFORCE(!computation_op.empty(), + "Computation ops should not be empty"); + + // stage four. Try to shrink computation op if they depend on each other. + // Get the smallest set of the most ops. + *status = LastLiveOpSearchStatus::kSuccess; + return shrink_func(computation_op); +} + void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { auto &ref_cnts = Get>(kGlobalReferenceCount); auto &last_live_ops_of_vars = @@ -284,12 +301,12 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ShrinkDepsOpFunctor shrink_func( ir::FilterByNodeWrapper(*graph)); + VLOG(1) << "Place number: " << vars.size(); for (size_t i = 0; i < vars.size(); ++i) { for (auto &name_var_pair : vars[i]) { // Whether this variable can be reused or deleted? If not, we do not // compute reference counts and dependencies. VarDesc *var_desc = TryGetLatestVarDesc(name_var_pair.second); - if (var_desc == nullptr || var_desc->Persistable()) { continue; } @@ -305,34 +322,33 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { auto &var_name = name_var_pair.first; auto &var_handles = name_var_pair.second; + PADDLE_ENFORCE_EQ(var_desc->Name(), var_name); + for (auto iter = var_handles.rbegin(); iter != var_handles.rend(); ++iter) { - bool ok; - auto result = - ExtractComputationOpFromLastLivedVar(*iter, i, shrink_func, &ok); + VLOG(10) << "Try to find last living ops of " << var_name << " " + << (iter - var_handles.rbegin()) << " time"; + LastLiveOpSearchStatus status = LastLiveOpSearchStatus::kFailure; + auto result = ExtractComputationOpFromLastLivedVar( + *iter, i, var_name, shrink_func, &status); // Seldomly, some vars may have no pending or preceding computation ops // Just break; - if (!ok) break; - VLOG(10) << "Extract " << result.size() << " ops of var " << var_name; + if (status == LastLiveOpSearchStatus::kFailure) { + break; + } - size_t original_op_deps = result.size(); - // If all ops do not need buffer of var_name, calculate reference count - // of the previous version of var_name. - if (ShrinkNoNeedBufferVarOpDependency(var_name, &result)) { + if (status == LastLiveOpSearchStatus::kShouldPrecede) { VLOG(10) << "Try to precede reference count computing at var " << var_name; continue; } - size_t final_op_deps = result.size(); - if (final_op_deps < original_op_deps) { - VLOG(5) << "Shrink op deps from " << original_op_deps << " to " - << final_op_deps; - } - + PADDLE_ENFORCE_EQ(status, LastLiveOpSearchStatus::kSuccess); PADDLE_ENFORCE(!result.empty(), "Last living ops of %s cannot be empty", var_name); + + VLOG(10) << "Extract " << result.size() << " ops of var " << var_name; ref_cnts[i].emplace(var_name, result.size()); last_live_ops_of_vars[i].emplace(var_name, std::move(result)); break; diff --git a/paddle/fluid/framework/inplace_op_inference_test.cc b/paddle/fluid/framework/inplace_op_inference_test.cc index a2c213945d7d3c0c6f540d994873f633694eeee9..b2141628d2bbca548cd157a2d323348071125421 100644 --- a/paddle/fluid/framework/inplace_op_inference_test.cc +++ b/paddle/fluid/framework/inplace_op_inference_test.cc @@ -18,7 +18,6 @@ #include #include #include "gtest/gtest.h" -#include "paddle/fluid/framework/details/inplace_op_pass.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/op_info.h" @@ -27,9 +26,15 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/var_type_inference.h" +USE_PASS(inplace_pass); + namespace paddle { namespace framework { +std::unique_ptr CreateInplacePass() { + return ir::PassRegistry::Instance().Get("inplace_pass"); +} + class NOP : public OperatorBase { public: NOP(const std::string& type, const VariableNameMap& inputs, @@ -202,7 +207,7 @@ ir::Node* GetNodeFromGraph(ir::Graph* g, std::string name) { std::unique_ptr test_SingleOpInplaceInToOut( std::unique_ptr g) { - std::unique_ptr pass(new details::InplacePass()); + auto pass = CreateInplacePass(); ir::Node* op_node = GetNodeFromGraph(g.get(), "single_op"); EXPECT_NE(op_node, nullptr); pass->Apply(g.get()); @@ -268,7 +273,7 @@ TEST(InferInplace, MultiOutInplaceInToOut) { std::unique_ptr g(new ir::Graph(prog)); g->Set(details::kMemOptSkipVars, new std::unordered_set()); - std::unique_ptr pass(new details::InplacePass()); + auto pass = CreateInplacePass(); pass->Apply(g.get()); auto op_node = GetNodeFromGraph(g.get(), "multi_out_op"); ASSERT_TRUE(op_node != nullptr); @@ -304,7 +309,7 @@ TEST(InferInplace, MultiGradInplaceInToOut) { std::unique_ptr g(new ir::Graph(prog)); g->Set(details::kMemOptSkipVars, new std::unordered_set()); - std::unique_ptr pass(new details::InplacePass()); + auto pass = CreateInplacePass(); pass->Apply(g.get()); auto op_node = GetNodeFromGraph(g.get(), "multi_out_grad"); ASSERT_TRUE(op_node != nullptr); diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 72fb876d98dc84164398583baf22c49014af483a..09a4613ba5484470f87b17b8e1977a7107570881 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -108,11 +108,18 @@ class Node { Name().find(ir::Node::kControlDepVarName) != std::string::npos; } + void RenameVar(const std::string& new_name) { + PADDLE_ENFORCE(type_ == Type::kVariable && var_desc_, + "Must be type of variable"); + name_ = new_name; + var_desc_->SetName(new_name); + } + std::vector inputs; std::vector outputs; protected: - const std::string name_; + std::string name_; std::unique_ptr var_desc_; std::unique_ptr op_desc_; Type type_; diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 1c2f5eae8d8dd88481aad0a7d7f86a588f5c480d..70eec7af99b157627918df0771c45e2a5bcf1421 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -220,16 +220,6 @@ class SoftmaxOpGradMaker : public framework::SingleGradOpDescMaker { } }; -class SoftmaxInplaceInToOut : public framework::InplaceOpInference { - public: - std::unordered_map operator()( - const framework::OpDesc& op_desc) const override { - return std::unordered_map{ - {"X", "Out"}, - }; - } -}; - } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py b/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py index 4e196758efc990506957089fb5b88ebb099cca29..988b67733664e5caf91f8864b40d5d6a12a2da87 100644 --- a/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py +++ b/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py @@ -74,3 +74,7 @@ class TestIrInplace(TestParallelExecutorBase): self.assertAlmostEqual(loss00, loss10, delta=delta) self.assertAlmostEqual(loss00, loss01, delta=delta) self.assertAlmostEqual(loss00, loss11, delta=delta) + + +if __name__ == '__main__': + unittest.main()