From d6d3e6afe2d07a17bff9a8f9d94e37793c5cb724 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 28 Jan 2019 15:05:10 +0800 Subject: [PATCH] add more skip strategy --- .../framework/details/graph_print_pass.cc | 65 ++++- .../framework/details/graph_print_pass.h | 2 + .../details/graph_print_pass_test.cc | 111 ++++++++ .../framework/details/inplace_op_pass.cc | 248 ++++++++++++------ .../fluid/framework/details/inplace_op_pass.h | 22 +- paddle/fluid/framework/ir/graph_helper.cc | 31 ++- paddle/fluid/framework/ir/graph_helper.h | 5 + .../fluid/framework/ir/graph_helper_test.cc | 11 + .../unittests/parallel_executor_test_base.py | 9 +- .../tests/unittests/test_ir_inplace_pass.py | 14 +- 10 files changed, 425 insertions(+), 93 deletions(-) diff --git a/paddle/fluid/framework/details/graph_print_pass.cc b/paddle/fluid/framework/details/graph_print_pass.cc index b0a87810dbf..69ebb4bcbd6 100644 --- a/paddle/fluid/framework/details/graph_print_pass.cc +++ b/paddle/fluid/framework/details/graph_print_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/graph_print_pass.h" #include #include +#include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { @@ -54,6 +55,11 @@ class GraphvizOp : public GraphvizNode { } } + template + void AddCustomEdge(const Callback& cb) { + stream_ << cb() << std::endl; + } + private: std::ostringstream stream_; }; @@ -68,12 +74,47 @@ std::vector FilterByNodeWrapper(const Container& con) { return ret; } +// bool DetectCircleRecursive(const std::map>, std::unordered_set* visited, +// std::unordered_set *in_trace, std::vector>* +// circles) { +// if (visited->find(node) == visited->end()) { +// visited->insert(node); +// in_trace->insert(node); + +// for (ir::Node *in : adj_list.at(node)) { +// if (visited->find(in) == visited->end() && +// HasCircleHelper(in, adj_list, visited, in_trace)) { +// return true; +// } else if (in_trace->find(in) != in_trace->end()) { +// circles->push_back(in_trace); +// return true; +// } +// } +// } +// in_trace->erase(node); +// return false; +// } + +// bool DetectCircle(const std::map>& +// adj_list, std::vector>* circles) { +// std::unordered_set visited; +// std::unordered_set in_trace; +// bool has_circle = false; +// for(auto& adj : adj_list) { +// has_circle &= DetectCircleRecursive(adj, adj_list,&visited, &in_trace, +// circles); +// } +// return has_circle; +// } + std::unordered_map SSAGraphPrinterImpl::ToGraphvizNode( const ir::Graph& graph) const { // Convert to GraphvizNode format auto& graphviz_nodes = graph.Get(kGraphviz); graphviz_nodes.clear(); std::unordered_map vars; + std::unordered_map ops; int var_id = 0; int op_id = 0; for (auto& node : graph.Nodes()) { @@ -81,11 +122,33 @@ std::unordered_map SSAGraphPrinterImpl::ToGraphvizNode( graphviz_nodes.emplace(new GraphvizVar(node, var_id)); vars.emplace(std::make_pair(node, var_id++)); } else if (node->IsOp()) { - graphviz_nodes.emplace(new GraphvizOp(node, op_id++)); + std::unique_ptr op(new GraphvizOp(node, op_id++)); + ops[node] = op.get(); + graphviz_nodes.emplace(std::move(op)); + // graphviz_nodes.emplace(new GraphvizOp(node, op_id++)); + // ops.emplace(std::make_pair(node, graphviz_nodes.back().get())); } else { PADDLE_THROW("Unknown op type"); } } + + // Detect circle. Draw circle in different lines + std::vector> circles; + const std::string kCircleEdge = "[color=red,penwidth=3.0]"; + if (ir::FindCircleSubGraph(graph, &circles)) { + VLOG(3) << "Graph has circle! circles count : " << circles.size(); + for (auto& circle : circles) { + for (size_t i = 0; i < circle.size() - 1; ++i) { + GraphvizOp* prev = ops[circle[i]]; + GraphvizOp* next = ops[circle[i + 1]]; + std::string prev_op = "op_" + std::to_string(prev->Id()); + std::string next_op = "op_" + std::to_string(next->Id()); + prev->AddCustomEdge([&]() -> std::string { + return prev_op + "->" + next_op + kCircleEdge; + }); + } + } + } return vars; } diff --git a/paddle/fluid/framework/details/graph_print_pass.h b/paddle/fluid/framework/details/graph_print_pass.h index 10ff8c321bb..5ff98609ce2 100644 --- a/paddle/fluid/framework/details/graph_print_pass.h +++ b/paddle/fluid/framework/details/graph_print_pass.h @@ -31,6 +31,8 @@ class GraphvizNode { GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {} virtual ~GraphvizNode() = default; + int Id() const { return id_; } + protected: ir::Node* node_; int id_; diff --git a/paddle/fluid/framework/details/graph_print_pass_test.cc b/paddle/fluid/framework/details/graph_print_pass_test.cc index 1149d1684eb..d8fd1beba38 100644 --- a/paddle/fluid/framework/details/graph_print_pass_test.cc +++ b/paddle/fluid/framework/details/graph_print_pass_test.cc @@ -19,6 +19,9 @@ REGISTER_OPERATOR(sum, paddle::framework::DummyOp, paddle::framework::SumOpMaker); REGISTER_OPERATOR(split, paddle::framework::DummyOp, paddle::framework::SplitOpMaker); +REGISTER_OPERATOR(assign, paddle::framework::DummyOp, + paddle::framework::AssignOpMaker, + paddle::framework::DummyVarTypeInference); /* a @ b @@ -54,6 +57,12 @@ inline static ProgramDesc FillProgramDesc() { op->SetInput("X", {"d", "e"}); op->SetOutput("Out", {"d"}); } + { + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("assign"); + op->SetInput("X", {"d"}); + op->SetOutput("Out", {"d"}); + } return prog; } @@ -74,6 +83,108 @@ TEST(SSAGraphPrinter, Normal) { printer->Print(*graph, *fout); } +using ir::Graph; +using ir::Node; +void BuildCircleGraph(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + + o1->outputs.push_back(v1); + o1->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o1); +} + +void BuildCircleGraph2(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + + o2->outputs.push_back(v2); + o1->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o1); +} + +void BuildNoCircleGraph(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation); + ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation); + ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable); + ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable); + + // o1->v1->o2 + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + // o2->v2->o3 + // o2->v2->o4 + o2->outputs.push_back(v2); + o3->inputs.push_back(v2); + o4->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o3); + v2->outputs.push_back(o4); + // o2->v3->o5 + o2->outputs.push_back(v3); + o5->inputs.push_back(v3); + v3->inputs.push_back(o2); + v3->outputs.push_back(o5); + // o3-v4->o5 + o3->outputs.push_back(v4); + o5->inputs.push_back(v4); + v4->inputs.push_back(o3); + v4->outputs.push_back(o5); + + // o2->v3->o1 + v3->outputs.push_back(o1); + o1->inputs.push_back(v3); +} + +TEST(SSAGraphPrinter, SimpleCircle) { + ProgramDesc prog; + + Graph graph(prog); + BuildCircleGraph(&graph); + ASSERT_TRUE(HasCircle(graph)); + + graph.Set(kGraphviz, new GraphvizNodes); + std::unique_ptr printer(new SSAGraphPrinterImpl); + + // redirect debug graph to a file. + constexpr char graph_path[] = "graph_print_pass_simple_circle.txt"; + std::unique_ptr fout(new std::ofstream(graph_path)); + PADDLE_ENFORCE(fout->good()); + printer->Print(graph, *fout); +} + +TEST(SSAGraphPrinter, ComplexCircle) { + ProgramDesc prog; + Graph graph(prog); + BuildCircleGraph2(&graph); + ASSERT_TRUE(HasCircle(graph)); + + graph.Set(kGraphviz, new GraphvizNodes); + std::unique_ptr printer(new SSAGraphPrinterImpl); + + // redirect debug graph to a file. + constexpr char graph_path[] = "graph_print_pass_complex_circle.txt"; + std::unique_ptr fout(new std::ofstream(graph_path)); + PADDLE_ENFORCE(fout->good()); + printer->Print(graph, *fout); +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/inplace_op_pass.cc b/paddle/fluid/framework/details/inplace_op_pass.cc index 11ecc383b40..d8a6be8573a 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.cc +++ b/paddle/fluid/framework/details/inplace_op_pass.cc @@ -23,6 +23,7 @@ #include #include "paddle/fluid/framework/details/graph_print_pass.h" #include "paddle/fluid/framework/details/memory_optimize_pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/op_info.h" // NOTE(dzhwinter): inplace means one op output variable reuse the input space. @@ -39,16 +40,20 @@ // auto* out_ptr = out->mutable_data(ctx.GetPlace()); // out_ptr[0] = 0; // input contect is overwrited. -// For backward compacity. if enable_inplace_whitelist is turn on. +// NOTE(dzhwinter): +// Only for backward compacity and stable. if enable_inplace_whitelist is turn +// on. // only the ops in whitelist will be use inplace strategy. // if not, all the op will be inplaced if it registered with InplaceClass DEFINE_bool( - enable_inplace_whitelist, true, + enable_inplace_whitelist, false, "If this option turns on, only these op in whitelist can be inplaced." "If it turns off, all of the running op can be candidate of inplaced op." "Such as scale, elementwise_add" "By default, it's turned on"); +DECLARE_string(memory_optimize_debug); + // clang-format off const std::string kInplacedOpWhiteList[] = { // NOLINT "sigmoid", @@ -77,63 +82,6 @@ namespace paddle { namespace framework { namespace details { -static inline std::string NodeDebugString(ir::Node* var) { - std::ostringstream os; - if (var->IsCtrlVar()) { - os << "kControlDepVarName" - << " "; - } else if (var->IsOp()) { - os << "kOperation" - << " " << var->Name(); - PADDLE_ENFORCE(var->Op() != nullptr && var->Op()->Type() == var->Name()); - } else if (var->IsVar()) { - os << "kVariable" - << " " << var->Name(); - PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name()); - } else { - PADDLE_THROW("Unknown node type."); - } - return os.str(); -} - -static inline std::string OpDebugString(ir::Node* var) { - ir::Node* op = var; - if (var->IsVar()) op = var->inputs.at(0); - std::stringstream os; - os << op->Name() << " : "; - - os << "Input "; - VLOG(3) << op->Name(); - for (auto* var : op->inputs) { - if (var->IsVar() && !var->IsCtrlVar()) { - PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name(), - "unmatched desc and var"); - // os << var << ":" << var->Name() << " "; - os << var->Name() << " "; - } - } - os << "Output "; - VLOG(3) << op->Name(); - for (auto* var : op->outputs) { - VLOG(3) << var; - VLOG(3) << var->Name(); - if (!var->IsVar()) { - VLOG(3) << "error"; - } - // VLOG(3) << var->Var()->Name(); - if (var->IsVar() && !var->IsCtrlVar()) { - PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name(), - "unmatched desc and var"); - // os << var << ":" << var->Name() << " "; - os << var->Name() << " "; - } - if (var->Name() == "fc_10.tmp_0") { - VLOG(3) << NodeDebugString(var); - } - } - return os.str(); -} - static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) { // if next op is inplaced, then return the output var // otherwise return nullptr @@ -218,6 +166,10 @@ std::unique_ptr InplacePass::ApplyImpl( InitSSAGraphNodes(); std::unique_ptr printer(new SSAGraphPrinterImpl); + constexpr char graph_path1[] = "ir_graph_before_inplaced.txt"; + std::unique_ptr fout1(new std::ofstream(graph_path1)); + PADDLE_ENFORCE(fout1->good()); + printer->Print(*graph, *fout1); for (auto* op : view_.AllOps()) { if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) @@ -230,9 +182,6 @@ std::unique_ptr InplacePass::ApplyImpl( std::unique_ptr fout(new std::ofstream(graph_path)); PADDLE_ENFORCE(fout->good()); printer->Print(*graph, *fout); - // for(auto* op : view_.AllOps()) { - // VLOG(3) << OpDebugString(op); - // } return graph; } @@ -250,6 +199,92 @@ void InplacePass::InplaceModifyDesc(const std::string& var, } } +const SSANodeVector InplacePass::TryInplaceModifyVar( + const std::string& var, const std::string& cache_var, const size_t& idx, + ir::Graph* graph) const { + PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && + var_nodes_[var].at(0)->Var() != nullptr); + std::unique_ptr var_desc(new VarDesc(*var_nodes_[var].at(0)->Var())); + var_desc->SetName(cache_var); + + SSANodeVector swap_nodes; + for (size_t i = idx; i < view_.AllOps().size(); ++i) { + auto* op = view_.AllOps()[i]; + + // redirect the input to the latest version of cache_var + for (auto* node : op->inputs) { + if (node->Name() == var) { + ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); + // swap node to cache_node + cache_node->outputs.insert(cache_node->outputs.end(), + node->outputs.begin(), node->outputs.end()); + PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp()); + auto* prev_op = node->inputs[0]; + std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node, + cache_node); + cache_node->inputs.emplace_back(prev_op); + for (auto* next_op : node->outputs) { + std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, + cache_node); + } + + swap_nodes[node].emplace_back(cache_node); + } + } + for (auto* node : op->outputs) { + if (node->Name() == var) { + ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); + var_nodes_[cache_var].emplace_back(cache_node); + // swap node to cache node + cache_node->outputs.insert(cache_node->outputs.end(), + node->outputs.begin(), node->outputs.end()); + cache_node->inputs.emplace_back(op); + std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node); + for (auto* next_op : node->outputs) { + std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, + cache_node); + } + swap_nodes[node].emplace_back(cache_node); + } + } + } + return swap_nodes; +} + +void InplacePass::CommitModify(const SSANodeVector& swap_nodes, + ir::Graph* graph) const { + for (auto& pair : swap_nodes) { + auto* node = pair.first; + const std::string var = node->Name(); + for (auto* cache_node : pair.second) { + const std::string cache_var = cache_node->Name(); + var_nodes_[cache_var].emplace_back(cache_node); + } + auto& nodes = var_nodes_.at(var); + nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end()); + graph->RemoveNode(node); + } +} + +void InplacePass::WithDrawModify(const SSANodeVector& nodes, + ir::Graph* graph) const { + for (auto& pair : nodes) { + auto* node = pair.first; + const std::string var = node->Name(); + for (auto* cache_node : pair.second) { + const std::string cache_var = cache_node->Name(); + auto* prev_op = node->inputs[0]; + std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node, + node); + for (auto* next_op : node->outputs) { + std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node, + node); + } + graph->RemoveNode(cache_node); + } + } +} + void InplacePass::InplaceModifyVar(const std::string& var, const std::string& cache_var, const size_t& idx, ir::Graph* graph) const { @@ -318,7 +353,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ir::Graph* graph) const { PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr, "op_desc is nullptr"); - // 3 pre-requirments need to meet if the op want to inplaced. + // 4 pre-requirments need to meet if the op want to inplaced. // 1. infer_inplace_ is registered. auto* op_desc = op->Op(); auto& infer_inplace = @@ -333,36 +368,68 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, auto& all_ops = view_.AllOps(); auto cursor = std::find(all_ops.begin(), all_ops.end(), op); size_t idx = std::distance(all_ops.begin(), cursor); - VLOG(3) << op->Name() << idx; for (auto& pair : in_to_outs) { auto& in_var_name = pair.first; auto& out_var_name = pair.second; auto* in_node = view_.GetNodeByName(in_var_name, op->inputs); auto* out_node = view_.GetNodeByName(out_var_name, op->outputs); + // 2. there is no external pending op on the input node if (view_.PendingOpsOnVar(in_node).size() > 1) { - VLOG(3) << string::Sprintf( - "!!! %s input has external dependency, can not inplaced, %s => %s " - "skiped", - op->Name(), out_var_name, in_var_name); + VLOG(4) << string::Sprintf( + "Skiped pair %s => %s. %s input has external dependency." + "inplace such pair will overwrite the memory.", + out_var_name, in_var_name, op->Name()); continue; } + // 3. if output reuse input inplaced, the dependency group is not changed. // For detail, check // the function description in "OutConnectInputByCtrlVar" if (view_.OutConnectInputByCtrlVar(in_node, out_node)) { - VLOG(3) << string::Sprintf( - "!!! %s input output connect by ctrl var, cannot inplaced, %s => %s " - "skiped", - op->Name(), out_var_name, in_var_name); + VLOG(4) << string::Sprintf( + "Skiped pair %s => %s. %s input and output connect by ctrl var." + "inplace such pair will generate a circle.", + out_var_name, in_var_name, op->Name()); continue; } - VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(), - out_var_name, in_var_name); - // VLOG(3) << "Out " << OpDebugString(op); - InplaceModifyDesc(out_var_name, in_var_name, idx); - InplaceModifyVar(out_var_name, in_var_name, idx, graph); + + // 4. if output has been memory optimize by python(fluid.memory_optmize()). + // this candidate can not be inplaced. Will be deprecated in the future. + if (view_.ReusedInPythonMemOpt(out_node->Name())) { + VLOG(4) << string::Sprintf( + "Skiped %s => %s reused previous memory block in python memory " + "optmize," + "it inplace may generate a circle", + out_var_name, in_var_name, op->Name()); + continue; + } + + // Debug Interface. Which would be skipped by the pass. + if (out_node->Name() == FLAGS_memory_optimize_debug) { + VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug=" + << out_node->Name(); + continue; + } + + auto swap_nodes = + TryInplaceModifyVar(out_var_name, in_var_name, idx, graph); + + // NOTE(dzhwinter): + // two stage commit of inplaced op. If add such node generate a circle, + // then withdraw the changes. Otherwise, safely add the node. + if (!ir::HasCircle(*graph)) { + VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(), + out_var_name, in_var_name); + CommitModify(swap_nodes, graph); + InplaceModifyDesc(out_var_name, in_var_name, idx); + } else { + VLOG(3) << string::Sprintf( + "Skiped pair %s => %s, inplace will generate a circle. withdraw %s", + out_var_name, in_var_name, op->Name()); + WithDrawModify(swap_nodes, graph); + } } } @@ -406,7 +473,28 @@ std::vector GraphView::PendingOpsOnVar(ir::Node* node) { return pending_ops; } -void GraphView::Build(ir::Graph* g) { ops_ = SortOpLikeDescOrder(*g); } +void GraphView::Build(ir::Graph* g) { + // track the var nodes in correct order. + // Because we insert some new created node. Which may have data race between + // nodes. + // resolve data harzards depends on the var nodes in right order. + ops_ = SortOpLikeDescOrder(*g); + + // track the nodes which reused previous node in Python memory optimize. + // these node can not be inplaced, otherwise may generate a circle in graph. + std::unordered_set all_vars; + for (auto& node : g->Nodes()) { + if (node->IsVar()) continue; + for (auto& out : node->outputs) { + if (out->IsCtrlVar() || out->Var() == nullptr) continue; + if (all_vars.count(out->Name())) { + dup_nodes_.emplace(out->Name()); + } else { + all_vars.emplace(out->Name()); + } + } + } +} const std::vector GraphView::AllOps() { return ops_; } @@ -452,6 +540,10 @@ bool GraphView::OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var) { return ConnectByCtrlVar(in_var_set, out_var_set); } +bool GraphView::ReusedInPythonMemOpt(const std::string& var) const { + return dup_nodes_.count(var); +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/inplace_op_pass.h b/paddle/fluid/framework/details/inplace_op_pass.h index c2b565a7435..cf1099323a9 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.h +++ b/paddle/fluid/framework/details/inplace_op_pass.h @@ -2,7 +2,7 @@ // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. -// You may obtain a copy of the License at +// You may abtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // @@ -15,6 +15,7 @@ #pragma once #include #include +#include #include #include #include "paddle/fluid/framework/details/memory_optimize_helper.h" @@ -40,10 +41,20 @@ class GraphView { bool OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var); + // Will Deperated in the future. + // NOTE(dzhwinter) : Python memory optimize will reuse + // memory based var name, so different op output may + // have the same variable name. enable inplace on such node + // will generate a circle in ssa graph. + bool ReusedInPythonMemOpt(const std::string& var) const; + private: std::vector ops_; + std::unordered_set dup_nodes_; // mem opt affect nodes + std::map> adj_list_; }; +typedef std::unordered_map> SSANodeVector; class InplacePass : public ir::Pass { public: InplacePass(); @@ -58,6 +69,15 @@ class InplacePass : public ir::Pass { void InplaceModifyVar(const std::string& in_var, const std::string& out_var, const size_t& idx, ir::Graph* graph) const; + const SSANodeVector TryInplaceModifyVar(const std::string& var, + const std::string& cache_var, + const size_t& idx, + ir::Graph* graph) const; + + void CommitModify(const SSANodeVector&, ir::Graph* graph) const; + + void WithDrawModify(const SSANodeVector& nodes, ir::Graph* graph) const; + void InplaceModifyDesc(const std::string& in_var, const std::string& out_var, const size_t& idx) const; diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 8de93cf285e..22d4c0a91cc 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -52,16 +52,29 @@ bool HasCircleHelper( ir::Node *node, const std::map> &adj_list, std::unordered_set *visited, - std::unordered_set *in_trace) { + std::unordered_set *in_trace, + std::vector> *circles) { if (visited->find(node) == visited->end()) { visited->insert(node); in_trace->insert(node); for (ir::Node *in : adj_list.at(node)) { if (visited->find(in) == visited->end() && - HasCircleHelper(in, adj_list, visited, in_trace)) { + HasCircleHelper(in, adj_list, visited, in_trace, circles)) { return true; } else if (in_trace->find(in) != in_trace->end()) { + if (circles != nullptr) { + std::vector circle; + circle.emplace_back(in); + ir::Node *p = in; + for (auto &adj : adj_list.at(p)) { + if (in_trace->count(adj)) { + circle.emplace_back(adj); + p = adj; + } + } + circles->emplace_back(circle); + } return true; } } @@ -71,11 +84,12 @@ bool HasCircleHelper( } bool HasCircleInternal( - const std::map> &adj_list) { + const std::map> &adj_list, + std::vector> *circles) { std::unordered_set visited; std::unordered_set in_trace; for (auto &adj : adj_list) { - if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) { + if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace, circles)) { return true; } } @@ -84,13 +98,18 @@ bool HasCircleInternal( } // namespace bool HasCircle(const Graph &graph) { - return HasCircleInternal(BuildOperationAdjList(graph)); + return HasCircleInternal(BuildOperationAdjList(graph), nullptr); +} + +bool FindCircleSubGraph(const Graph &graph, + std::vector> *circles) { + return HasCircleInternal(BuildOperationAdjList(graph), circles); } std::vector TopologySortOperations(const Graph &graph) { std::map> adj_list = BuildOperationAdjList(graph); - PADDLE_ENFORCE(!HasCircleInternal(adj_list)); + PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr)); std::unordered_set visited; std::vector ret; for (auto adj : adj_list) { diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index fba4936f2c5..214de9ec7d8 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -28,6 +28,11 @@ namespace ir { // Test if the graph contains circle. bool HasCircle(const Graph &graph); +// Find All Circles for debugging, +// store all subgraph in circles. +bool FindCircleSubGraph(const Graph &graph, + std::vector> *circles); + size_t GraphNum(const Graph &graph); // Topology Sort the operations in the graph from inputs to outputs. diff --git a/paddle/fluid/framework/ir/graph_helper_test.cc b/paddle/fluid/framework/ir/graph_helper_test.cc index 260a73ae763..8ea3dbbf241 100644 --- a/paddle/fluid/framework/ir/graph_helper_test.cc +++ b/paddle/fluid/framework/ir/graph_helper_test.cc @@ -195,6 +195,17 @@ void BuildTwoGraphs(Graph* g) { // v4->outputs.push_back(o5); } +TEST(GraphHelperTest, Circles) { + ProgramDesc prog; + + Graph g(prog); + BuildCircleGraph(&g); + + std::vector> circles; + ASSERT_TRUE(FindCircleSubGraph(g, &circles)); + ASSERT_EQ(circles.size() == 1UL); +} + TEST(GraphHelperTest, GraphNum) { ProgramDesc prog; diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index 5e5e6033d8e..eaf2ebb62fd 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -32,7 +32,7 @@ class TestParallelExecutorBase(unittest.TestCase): def check_network_convergence(self, method, use_cuda=True, - memory_opt=True, + memory_opt=False, iter=50, batch_size=None, allow_op_delay=False, @@ -67,8 +67,6 @@ class TestParallelExecutorBase(unittest.TestCase): if memory_opt: fluid.memory_optimize(main) - with open("program_model.txt", "w") as f: - f.write(str(main)) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(startup) @@ -82,9 +80,10 @@ class TestParallelExecutorBase(unittest.TestCase): build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv build_strategy.memory_optimize = use_ir_memory_optimize - build_strategy.enable_inplace = enable_inplace + # python memory optimization is conflict with inplace pass. + # Use ir graph memory optimization after inplace pass is the correct way. + build_strategy.enable_inplace = False if memory_opt else enable_inplace build_strategy.enable_sequential_execution = enable_sequential_execution - build_strategy.debug_graphviz_path = "debug_ir_graph_" if use_cuda and core.is_compiled_with_cuda(): build_strategy.remove_unnecessary_lock = True diff --git a/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py b/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py index 0c9cd99322d..b87407e31e4 100644 --- a/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py +++ b/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py @@ -46,7 +46,10 @@ class TestIrInplace(TestParallelExecutorBase): def setUpClass(cls): os.environ['CPU_NUM'] = str(4) - def _fc_with_batchnorm(self, ir_memory_optimize, enable_inplace): + def _fc_with_batchnorm(self, + ir_memory_optimize, + enable_inplace, + memory_opt=False): np.random.seed(5) img = np.random.random(size=[32, 784]).astype(np.float32) label = np.ones(shape=[32, 1], dtype='int64') @@ -55,7 +58,7 @@ class TestIrInplace(TestParallelExecutorBase): feed_dict={"image": img, "label": label}, use_cuda=True, - memory_opt=False, # inplace is conflict with memory opt + memory_opt=memory_opt, use_ir_memory_optimize=ir_memory_optimize, enable_inplace=enable_inplace) @@ -67,3 +70,10 @@ class TestIrInplace(TestParallelExecutorBase): self.assertAlmostEqual(loss00, loss10, delta=delta) self.assertAlmostEqual(loss00, loss01, delta=delta) self.assertAlmostEqual(loss00, loss11, delta=delta) + + def test_fc_with_batchnorm_memory_opt(self, delta=1e-3): + loss00 = self._fc_with_batchnorm(False, True, False) + loss10 = self._fc_with_batchnorm(False, True, True) + loss10 = self._fc_with_batchnorm(True, True, True) + self.assertAlmostEqual(loss00, loss10, delta=delta) + self.assertAlmostEqual(loss00, loss01, delta=delta) -- GitLab