diff --git a/paddle/fluid/framework/details/graph_print_pass.cc b/paddle/fluid/framework/details/graph_print_pass.cc index b0a87810dbff2f7d204a5565eb9e5de0671f78ed..69ebb4bcbd651851db4fe7ff41e660f8c6d190ee 100644 --- a/paddle/fluid/framework/details/graph_print_pass.cc +++ b/paddle/fluid/framework/details/graph_print_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/graph_print_pass.h" #include #include +#include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { @@ -54,6 +55,11 @@ class GraphvizOp : public GraphvizNode { } } + template + void AddCustomEdge(const Callback& cb) { + stream_ << cb() << std::endl; + } + private: std::ostringstream stream_; }; @@ -68,12 +74,47 @@ std::vector FilterByNodeWrapper(const Container& con) { return ret; } +// bool DetectCircleRecursive(const std::map>, std::unordered_set* visited, +// std::unordered_set *in_trace, std::vector>* +// circles) { +// if (visited->find(node) == visited->end()) { +// visited->insert(node); +// in_trace->insert(node); + +// for (ir::Node *in : adj_list.at(node)) { +// if (visited->find(in) == visited->end() && +// HasCircleHelper(in, adj_list, visited, in_trace)) { +// return true; +// } else if (in_trace->find(in) != in_trace->end()) { +// circles->push_back(in_trace); +// return true; +// } +// } +// } +// in_trace->erase(node); +// return false; +// } + +// bool DetectCircle(const std::map>& +// adj_list, std::vector>* circles) { +// std::unordered_set visited; +// std::unordered_set in_trace; +// bool has_circle = false; +// for(auto& adj : adj_list) { +// has_circle &= DetectCircleRecursive(adj, adj_list,&visited, &in_trace, +// circles); +// } +// return has_circle; +// } + std::unordered_map SSAGraphPrinterImpl::ToGraphvizNode( const ir::Graph& graph) const { // Convert to GraphvizNode format auto& graphviz_nodes = graph.Get(kGraphviz); graphviz_nodes.clear(); std::unordered_map vars; + std::unordered_map ops; int var_id = 0; int op_id = 0; for (auto& node : graph.Nodes()) { @@ -81,11 +122,33 @@ std::unordered_map SSAGraphPrinterImpl::ToGraphvizNode( graphviz_nodes.emplace(new GraphvizVar(node, var_id)); vars.emplace(std::make_pair(node, var_id++)); } else if (node->IsOp()) { - graphviz_nodes.emplace(new GraphvizOp(node, op_id++)); + std::unique_ptr op(new GraphvizOp(node, op_id++)); + ops[node] = op.get(); + graphviz_nodes.emplace(std::move(op)); + // graphviz_nodes.emplace(new GraphvizOp(node, op_id++)); + // ops.emplace(std::make_pair(node, graphviz_nodes.back().get())); } else { PADDLE_THROW("Unknown op type"); } } + + // Detect circle. Draw circle in different lines + std::vector> circles; + const std::string kCircleEdge = "[color=red,penwidth=3.0]"; + if (ir::FindCircleSubGraph(graph, &circles)) { + VLOG(3) << "Graph has circle! circles count : " << circles.size(); + for (auto& circle : circles) { + for (size_t i = 0; i < circle.size() - 1; ++i) { + GraphvizOp* prev = ops[circle[i]]; + GraphvizOp* next = ops[circle[i + 1]]; + std::string prev_op = "op_" + std::to_string(prev->Id()); + std::string next_op = "op_" + std::to_string(next->Id()); + prev->AddCustomEdge([&]() -> std::string { + return prev_op + "->" + next_op + kCircleEdge; + }); + } + } + } return vars; } diff --git a/paddle/fluid/framework/details/graph_print_pass.h b/paddle/fluid/framework/details/graph_print_pass.h index 10ff8c321bb69aa53275a08900b8389e7344aacb..5ff98609ce2507a4fe0758caa07bfaebe866e4bd 100644 --- a/paddle/fluid/framework/details/graph_print_pass.h +++ b/paddle/fluid/framework/details/graph_print_pass.h @@ -31,6 +31,8 @@ class GraphvizNode { GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {} virtual ~GraphvizNode() = default; + int Id() const { return id_; } + protected: ir::Node* node_; int id_; diff --git a/paddle/fluid/framework/details/graph_print_pass_test.cc b/paddle/fluid/framework/details/graph_print_pass_test.cc index 1149d1684ebba8cf8675cddd1c8dfad8175b3bc9..d8fd1beba38d64c3c1f14d13122b503c8b4f657f 100644 --- a/paddle/fluid/framework/details/graph_print_pass_test.cc +++ b/paddle/fluid/framework/details/graph_print_pass_test.cc @@ -19,6 +19,9 @@ REGISTER_OPERATOR(sum, paddle::framework::DummyOp, paddle::framework::SumOpMaker); REGISTER_OPERATOR(split, paddle::framework::DummyOp, paddle::framework::SplitOpMaker); +REGISTER_OPERATOR(assign, paddle::framework::DummyOp, + paddle::framework::AssignOpMaker, + paddle::framework::DummyVarTypeInference); /* a @ b @@ -54,6 +57,12 @@ inline static ProgramDesc FillProgramDesc() { op->SetInput("X", {"d", "e"}); op->SetOutput("Out", {"d"}); } + { + auto* op = prog.MutableBlock(0)->AppendOp(); + op->SetType("assign"); + op->SetInput("X", {"d"}); + op->SetOutput("Out", {"d"}); + } return prog; } @@ -74,6 +83,108 @@ TEST(SSAGraphPrinter, Normal) { printer->Print(*graph, *fout); } +using ir::Graph; +using ir::Node; +void BuildCircleGraph(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + + o1->outputs.push_back(v1); + o1->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o1); +} + +void BuildCircleGraph2(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + + o2->outputs.push_back(v2); + o1->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o1); +} + +void BuildNoCircleGraph(Graph* g) { + ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation); + ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation); + ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation); + ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation); + ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation); + ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable); + ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable); + ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable); + ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable); + + // o1->v1->o2 + o1->outputs.push_back(v1); + o2->inputs.push_back(v1); + v1->inputs.push_back(o1); + v1->outputs.push_back(o2); + // o2->v2->o3 + // o2->v2->o4 + o2->outputs.push_back(v2); + o3->inputs.push_back(v2); + o4->inputs.push_back(v2); + v2->inputs.push_back(o2); + v2->outputs.push_back(o3); + v2->outputs.push_back(o4); + // o2->v3->o5 + o2->outputs.push_back(v3); + o5->inputs.push_back(v3); + v3->inputs.push_back(o2); + v3->outputs.push_back(o5); + // o3-v4->o5 + o3->outputs.push_back(v4); + o5->inputs.push_back(v4); + v4->inputs.push_back(o3); + v4->outputs.push_back(o5); + + // o2->v3->o1 + v3->outputs.push_back(o1); + o1->inputs.push_back(v3); +} + +TEST(SSAGraphPrinter, SimpleCircle) { + ProgramDesc prog; + + Graph graph(prog); + BuildCircleGraph(&graph); + ASSERT_TRUE(HasCircle(graph)); + + graph.Set(kGraphviz, new GraphvizNodes); + std::unique_ptr printer(new SSAGraphPrinterImpl); + + // redirect debug graph to a file. + constexpr char graph_path[] = "graph_print_pass_simple_circle.txt"; + std::unique_ptr fout(new std::ofstream(graph_path)); + PADDLE_ENFORCE(fout->good()); + printer->Print(graph, *fout); +} + +TEST(SSAGraphPrinter, ComplexCircle) { + ProgramDesc prog; + Graph graph(prog); + BuildCircleGraph2(&graph); + ASSERT_TRUE(HasCircle(graph)); + + graph.Set(kGraphviz, new GraphvizNodes); + std::unique_ptr printer(new SSAGraphPrinterImpl); + + // redirect debug graph to a file. + constexpr char graph_path[] = "graph_print_pass_complex_circle.txt"; + std::unique_ptr fout(new std::ofstream(graph_path)); + PADDLE_ENFORCE(fout->good()); + printer->Print(graph, *fout); +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/inplace_op_pass.cc b/paddle/fluid/framework/details/inplace_op_pass.cc index 11ecc383b402aa360c981a20560546447855d8c5..d8a6be8573a2c50b0c2bfad5c65ec323f1beb0aa 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.cc +++ b/paddle/fluid/framework/details/inplace_op_pass.cc @@ -23,6 +23,7 @@ #include #include "paddle/fluid/framework/details/graph_print_pass.h" #include "paddle/fluid/framework/details/memory_optimize_pass.h" +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/op_info.h" // NOTE(dzhwinter): inplace means one op output variable reuse the input space. @@ -39,16 +40,20 @@ // auto* out_ptr = out->mutable_data(ctx.GetPlace()); // out_ptr[0] = 0; // input contect is overwrited. -// For backward compacity. if enable_inplace_whitelist is turn on. +// NOTE(dzhwinter): +// Only for backward compacity and stable. if enable_inplace_whitelist is turn +// on. // only the ops in whitelist will be use inplace strategy. // if not, all the op will be inplaced if it registered with InplaceClass DEFINE_bool( - enable_inplace_whitelist, true, + enable_inplace_whitelist, false, "If this option turns on, only these op in whitelist can be inplaced." "If it turns off, all of the running op can be candidate of inplaced op." "Such as scale, elementwise_add" "By default, it's turned on"); +DECLARE_string(memory_optimize_debug); + // clang-format off const std::string kInplacedOpWhiteList[] = { // NOLINT "sigmoid", @@ -77,63 +82,6 @@ namespace paddle { namespace framework { namespace details { -static inline std::string NodeDebugString(ir::Node* var) { - std::ostringstream os; - if (var->IsCtrlVar()) { - os << "kControlDepVarName" - << " "; - } else if (var->IsOp()) { - os << "kOperation" - << " " << var->Name(); - PADDLE_ENFORCE(var->Op() != nullptr && var->Op()->Type() == var->Name()); - } else if (var->IsVar()) { - os << "kVariable" - << " " << var->Name(); - PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name()); - } else { - PADDLE_THROW("Unknown node type."); - } - return os.str(); -} - -static inline std::string OpDebugString(ir::Node* var) { - ir::Node* op = var; - if (var->IsVar()) op = var->inputs.at(0); - std::stringstream os; - os << op->Name() << " : "; - - os << "Input "; - VLOG(3) << op->Name(); - for (auto* var : op->inputs) { - if (var->IsVar() && !var->IsCtrlVar()) { - PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name(), - "unmatched desc and var"); - // os << var << ":" << var->Name() << " "; - os << var->Name() << " "; - } - } - os << "Output "; - VLOG(3) << op->Name(); - for (auto* var : op->outputs) { - VLOG(3) << var; - VLOG(3) << var->Name(); - if (!var->IsVar()) { - VLOG(3) << "error"; - } - // VLOG(3) << var->Var()->Name(); - if (var->IsVar() && !var->IsCtrlVar()) { - PADDLE_ENFORCE(var->Var() != nullptr && var->Var()->Name() == var->Name(), - "unmatched desc and var"); - // os << var << ":" << var->Name() << " "; - os << var->Name() << " "; - } - if (var->Name() == "fc_10.tmp_0") { - VLOG(3) << NodeDebugString(var); - } - } - return os.str(); -} - static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) { // if next op is inplaced, then return the output var // otherwise return nullptr @@ -218,6 +166,10 @@ std::unique_ptr InplacePass::ApplyImpl( InitSSAGraphNodes(); std::unique_ptr printer(new SSAGraphPrinterImpl); + constexpr char graph_path1[] = "ir_graph_before_inplaced.txt"; + std::unique_ptr fout1(new std::ofstream(graph_path1)); + PADDLE_ENFORCE(fout1->good()); + printer->Print(*graph, *fout1); for (auto* op : view_.AllOps()) { if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) @@ -230,9 +182,6 @@ std::unique_ptr InplacePass::ApplyImpl( std::unique_ptr fout(new std::ofstream(graph_path)); PADDLE_ENFORCE(fout->good()); printer->Print(*graph, *fout); - // for(auto* op : view_.AllOps()) { - // VLOG(3) << OpDebugString(op); - // } return graph; } @@ -250,6 +199,92 @@ void InplacePass::InplaceModifyDesc(const std::string& var, } } +const SSANodeVector InplacePass::TryInplaceModifyVar( + const std::string& var, const std::string& cache_var, const size_t& idx, + ir::Graph* graph) const { + PADDLE_ENFORCE(var_nodes_[var].size() >= 1 && + var_nodes_[var].at(0)->Var() != nullptr); + std::unique_ptr var_desc(new VarDesc(*var_nodes_[var].at(0)->Var())); + var_desc->SetName(cache_var); + + SSANodeVector swap_nodes; + for (size_t i = idx; i < view_.AllOps().size(); ++i) { + auto* op = view_.AllOps()[i]; + + // redirect the input to the latest version of cache_var + for (auto* node : op->inputs) { + if (node->Name() == var) { + ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); + // swap node to cache_node + cache_node->outputs.insert(cache_node->outputs.end(), + node->outputs.begin(), node->outputs.end()); + PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp()); + auto* prev_op = node->inputs[0]; + std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node, + cache_node); + cache_node->inputs.emplace_back(prev_op); + for (auto* next_op : node->outputs) { + std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, + cache_node); + } + + swap_nodes[node].emplace_back(cache_node); + } + } + for (auto* node : op->outputs) { + if (node->Name() == var) { + ir::Node* cache_node = graph->CreateVarNode(var_desc.get()); + var_nodes_[cache_var].emplace_back(cache_node); + // swap node to cache node + cache_node->outputs.insert(cache_node->outputs.end(), + node->outputs.begin(), node->outputs.end()); + cache_node->inputs.emplace_back(op); + std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node); + for (auto* next_op : node->outputs) { + std::replace(next_op->inputs.begin(), next_op->inputs.end(), node, + cache_node); + } + swap_nodes[node].emplace_back(cache_node); + } + } + } + return swap_nodes; +} + +void InplacePass::CommitModify(const SSANodeVector& swap_nodes, + ir::Graph* graph) const { + for (auto& pair : swap_nodes) { + auto* node = pair.first; + const std::string var = node->Name(); + for (auto* cache_node : pair.second) { + const std::string cache_var = cache_node->Name(); + var_nodes_[cache_var].emplace_back(cache_node); + } + auto& nodes = var_nodes_.at(var); + nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end()); + graph->RemoveNode(node); + } +} + +void InplacePass::WithDrawModify(const SSANodeVector& nodes, + ir::Graph* graph) const { + for (auto& pair : nodes) { + auto* node = pair.first; + const std::string var = node->Name(); + for (auto* cache_node : pair.second) { + const std::string cache_var = cache_node->Name(); + auto* prev_op = node->inputs[0]; + std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node, + node); + for (auto* next_op : node->outputs) { + std::replace(next_op->inputs.begin(), next_op->inputs.end(), cache_node, + node); + } + graph->RemoveNode(cache_node); + } + } +} + void InplacePass::InplaceModifyVar(const std::string& var, const std::string& cache_var, const size_t& idx, ir::Graph* graph) const { @@ -318,7 +353,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ir::Graph* graph) const { PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr, "op_desc is nullptr"); - // 3 pre-requirments need to meet if the op want to inplaced. + // 4 pre-requirments need to meet if the op want to inplaced. // 1. infer_inplace_ is registered. auto* op_desc = op->Op(); auto& infer_inplace = @@ -333,36 +368,68 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, auto& all_ops = view_.AllOps(); auto cursor = std::find(all_ops.begin(), all_ops.end(), op); size_t idx = std::distance(all_ops.begin(), cursor); - VLOG(3) << op->Name() << idx; for (auto& pair : in_to_outs) { auto& in_var_name = pair.first; auto& out_var_name = pair.second; auto* in_node = view_.GetNodeByName(in_var_name, op->inputs); auto* out_node = view_.GetNodeByName(out_var_name, op->outputs); + // 2. there is no external pending op on the input node if (view_.PendingOpsOnVar(in_node).size() > 1) { - VLOG(3) << string::Sprintf( - "!!! %s input has external dependency, can not inplaced, %s => %s " - "skiped", - op->Name(), out_var_name, in_var_name); + VLOG(4) << string::Sprintf( + "Skiped pair %s => %s. %s input has external dependency." + "inplace such pair will overwrite the memory.", + out_var_name, in_var_name, op->Name()); continue; } + // 3. if output reuse input inplaced, the dependency group is not changed. // For detail, check // the function description in "OutConnectInputByCtrlVar" if (view_.OutConnectInputByCtrlVar(in_node, out_node)) { - VLOG(3) << string::Sprintf( - "!!! %s input output connect by ctrl var, cannot inplaced, %s => %s " - "skiped", - op->Name(), out_var_name, in_var_name); + VLOG(4) << string::Sprintf( + "Skiped pair %s => %s. %s input and output connect by ctrl var." + "inplace such pair will generate a circle.", + out_var_name, in_var_name, op->Name()); continue; } - VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(), - out_var_name, in_var_name); - // VLOG(3) << "Out " << OpDebugString(op); - InplaceModifyDesc(out_var_name, in_var_name, idx); - InplaceModifyVar(out_var_name, in_var_name, idx, graph); + + // 4. if output has been memory optimize by python(fluid.memory_optmize()). + // this candidate can not be inplaced. Will be deprecated in the future. + if (view_.ReusedInPythonMemOpt(out_node->Name())) { + VLOG(4) << string::Sprintf( + "Skiped %s => %s reused previous memory block in python memory " + "optmize," + "it inplace may generate a circle", + out_var_name, in_var_name, op->Name()); + continue; + } + + // Debug Interface. Which would be skipped by the pass. + if (out_node->Name() == FLAGS_memory_optimize_debug) { + VLOG(3) << "Skiped var by force. FLAGS_memory_optimize_debug=" + << out_node->Name(); + continue; + } + + auto swap_nodes = + TryInplaceModifyVar(out_var_name, in_var_name, idx, graph); + + // NOTE(dzhwinter): + // two stage commit of inplaced op. If add such node generate a circle, + // then withdraw the changes. Otherwise, safely add the node. + if (!ir::HasCircle(*graph)) { + VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(), + out_var_name, in_var_name); + CommitModify(swap_nodes, graph); + InplaceModifyDesc(out_var_name, in_var_name, idx); + } else { + VLOG(3) << string::Sprintf( + "Skiped pair %s => %s, inplace will generate a circle. withdraw %s", + out_var_name, in_var_name, op->Name()); + WithDrawModify(swap_nodes, graph); + } } } @@ -406,7 +473,28 @@ std::vector GraphView::PendingOpsOnVar(ir::Node* node) { return pending_ops; } -void GraphView::Build(ir::Graph* g) { ops_ = SortOpLikeDescOrder(*g); } +void GraphView::Build(ir::Graph* g) { + // track the var nodes in correct order. + // Because we insert some new created node. Which may have data race between + // nodes. + // resolve data harzards depends on the var nodes in right order. + ops_ = SortOpLikeDescOrder(*g); + + // track the nodes which reused previous node in Python memory optimize. + // these node can not be inplaced, otherwise may generate a circle in graph. + std::unordered_set all_vars; + for (auto& node : g->Nodes()) { + if (node->IsVar()) continue; + for (auto& out : node->outputs) { + if (out->IsCtrlVar() || out->Var() == nullptr) continue; + if (all_vars.count(out->Name())) { + dup_nodes_.emplace(out->Name()); + } else { + all_vars.emplace(out->Name()); + } + } + } +} const std::vector GraphView::AllOps() { return ops_; } @@ -452,6 +540,10 @@ bool GraphView::OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var) { return ConnectByCtrlVar(in_var_set, out_var_set); } +bool GraphView::ReusedInPythonMemOpt(const std::string& var) const { + return dup_nodes_.count(var); +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/inplace_op_pass.h b/paddle/fluid/framework/details/inplace_op_pass.h index c2b565a7435b92291124bd1ef483ce6d7b1d4097..cf1099323a9279bd85e373e789669b379c8a2916 100644 --- a/paddle/fluid/framework/details/inplace_op_pass.h +++ b/paddle/fluid/framework/details/inplace_op_pass.h @@ -2,7 +2,7 @@ // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. -// You may obtain a copy of the License at +// You may abtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // @@ -15,6 +15,7 @@ #pragma once #include #include +#include #include #include #include "paddle/fluid/framework/details/memory_optimize_helper.h" @@ -40,10 +41,20 @@ class GraphView { bool OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var); + // Will Deperated in the future. + // NOTE(dzhwinter) : Python memory optimize will reuse + // memory based var name, so different op output may + // have the same variable name. enable inplace on such node + // will generate a circle in ssa graph. + bool ReusedInPythonMemOpt(const std::string& var) const; + private: std::vector ops_; + std::unordered_set dup_nodes_; // mem opt affect nodes + std::map> adj_list_; }; +typedef std::unordered_map> SSANodeVector; class InplacePass : public ir::Pass { public: InplacePass(); @@ -58,6 +69,15 @@ class InplacePass : public ir::Pass { void InplaceModifyVar(const std::string& in_var, const std::string& out_var, const size_t& idx, ir::Graph* graph) const; + const SSANodeVector TryInplaceModifyVar(const std::string& var, + const std::string& cache_var, + const size_t& idx, + ir::Graph* graph) const; + + void CommitModify(const SSANodeVector&, ir::Graph* graph) const; + + void WithDrawModify(const SSANodeVector& nodes, ir::Graph* graph) const; + void InplaceModifyDesc(const std::string& in_var, const std::string& out_var, const size_t& idx) const; diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 8de93cf285e4bf34c2d2bf425fa5f3459704b3d6..22d4c0a91cc1638264a8c57aa2841ff4e65a1400 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -52,16 +52,29 @@ bool HasCircleHelper( ir::Node *node, const std::map> &adj_list, std::unordered_set *visited, - std::unordered_set *in_trace) { + std::unordered_set *in_trace, + std::vector> *circles) { if (visited->find(node) == visited->end()) { visited->insert(node); in_trace->insert(node); for (ir::Node *in : adj_list.at(node)) { if (visited->find(in) == visited->end() && - HasCircleHelper(in, adj_list, visited, in_trace)) { + HasCircleHelper(in, adj_list, visited, in_trace, circles)) { return true; } else if (in_trace->find(in) != in_trace->end()) { + if (circles != nullptr) { + std::vector circle; + circle.emplace_back(in); + ir::Node *p = in; + for (auto &adj : adj_list.at(p)) { + if (in_trace->count(adj)) { + circle.emplace_back(adj); + p = adj; + } + } + circles->emplace_back(circle); + } return true; } } @@ -71,11 +84,12 @@ bool HasCircleHelper( } bool HasCircleInternal( - const std::map> &adj_list) { + const std::map> &adj_list, + std::vector> *circles) { std::unordered_set visited; std::unordered_set in_trace; for (auto &adj : adj_list) { - if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) { + if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace, circles)) { return true; } } @@ -84,13 +98,18 @@ bool HasCircleInternal( } // namespace bool HasCircle(const Graph &graph) { - return HasCircleInternal(BuildOperationAdjList(graph)); + return HasCircleInternal(BuildOperationAdjList(graph), nullptr); +} + +bool FindCircleSubGraph(const Graph &graph, + std::vector> *circles) { + return HasCircleInternal(BuildOperationAdjList(graph), circles); } std::vector TopologySortOperations(const Graph &graph) { std::map> adj_list = BuildOperationAdjList(graph); - PADDLE_ENFORCE(!HasCircleInternal(adj_list)); + PADDLE_ENFORCE(!HasCircleInternal(adj_list, nullptr)); std::unordered_set visited; std::vector ret; for (auto adj : adj_list) { diff --git a/paddle/fluid/framework/ir/graph_helper.h b/paddle/fluid/framework/ir/graph_helper.h index fba4936f2c5c971f6c63a452ec4480ff091db25c..214de9ec7d85aee6021b18866295777e317aa79d 100644 --- a/paddle/fluid/framework/ir/graph_helper.h +++ b/paddle/fluid/framework/ir/graph_helper.h @@ -28,6 +28,11 @@ namespace ir { // Test if the graph contains circle. bool HasCircle(const Graph &graph); +// Find All Circles for debugging, +// store all subgraph in circles. +bool FindCircleSubGraph(const Graph &graph, + std::vector> *circles); + size_t GraphNum(const Graph &graph); // Topology Sort the operations in the graph from inputs to outputs. diff --git a/paddle/fluid/framework/ir/graph_helper_test.cc b/paddle/fluid/framework/ir/graph_helper_test.cc index 260a73ae763bd2cdea9948e4d928377a7c718dda..8ea3dbbf241876ae8eea6ae2c5b144fca02f6615 100644 --- a/paddle/fluid/framework/ir/graph_helper_test.cc +++ b/paddle/fluid/framework/ir/graph_helper_test.cc @@ -195,6 +195,17 @@ void BuildTwoGraphs(Graph* g) { // v4->outputs.push_back(o5); } +TEST(GraphHelperTest, Circles) { + ProgramDesc prog; + + Graph g(prog); + BuildCircleGraph(&g); + + std::vector> circles; + ASSERT_TRUE(FindCircleSubGraph(g, &circles)); + ASSERT_EQ(circles.size() == 1UL); +} + TEST(GraphHelperTest, GraphNum) { ProgramDesc prog; diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index 5e5e6033d8e7301fee6685043f83492258f7e454..eaf2ebb62fdcb691017a6d521f80f880ddad751b 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -32,7 +32,7 @@ class TestParallelExecutorBase(unittest.TestCase): def check_network_convergence(self, method, use_cuda=True, - memory_opt=True, + memory_opt=False, iter=50, batch_size=None, allow_op_delay=False, @@ -67,8 +67,6 @@ class TestParallelExecutorBase(unittest.TestCase): if memory_opt: fluid.memory_optimize(main) - with open("program_model.txt", "w") as f: - f.write(str(main)) place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(startup) @@ -82,9 +80,10 @@ class TestParallelExecutorBase(unittest.TestCase): build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv build_strategy.memory_optimize = use_ir_memory_optimize - build_strategy.enable_inplace = enable_inplace + # python memory optimization is conflict with inplace pass. + # Use ir graph memory optimization after inplace pass is the correct way. + build_strategy.enable_inplace = False if memory_opt else enable_inplace build_strategy.enable_sequential_execution = enable_sequential_execution - build_strategy.debug_graphviz_path = "debug_ir_graph_" if use_cuda and core.is_compiled_with_cuda(): build_strategy.remove_unnecessary_lock = True diff --git a/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py b/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py index 0c9cd99322dad3b26b8ba7614b83906e3fef79f8..b87407e31e41c7284f5eae21d349d4b5f066347b 100644 --- a/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py +++ b/python/paddle/fluid/tests/unittests/test_ir_inplace_pass.py @@ -46,7 +46,10 @@ class TestIrInplace(TestParallelExecutorBase): def setUpClass(cls): os.environ['CPU_NUM'] = str(4) - def _fc_with_batchnorm(self, ir_memory_optimize, enable_inplace): + def _fc_with_batchnorm(self, + ir_memory_optimize, + enable_inplace, + memory_opt=False): np.random.seed(5) img = np.random.random(size=[32, 784]).astype(np.float32) label = np.ones(shape=[32, 1], dtype='int64') @@ -55,7 +58,7 @@ class TestIrInplace(TestParallelExecutorBase): feed_dict={"image": img, "label": label}, use_cuda=True, - memory_opt=False, # inplace is conflict with memory opt + memory_opt=memory_opt, use_ir_memory_optimize=ir_memory_optimize, enable_inplace=enable_inplace) @@ -67,3 +70,10 @@ class TestIrInplace(TestParallelExecutorBase): self.assertAlmostEqual(loss00, loss10, delta=delta) self.assertAlmostEqual(loss00, loss01, delta=delta) self.assertAlmostEqual(loss00, loss11, delta=delta) + + def test_fc_with_batchnorm_memory_opt(self, delta=1e-3): + loss00 = self._fc_with_batchnorm(False, True, False) + loss10 = self._fc_with_batchnorm(False, True, True) + loss10 = self._fc_with_batchnorm(True, True, True) + self.assertAlmostEqual(loss00, loss10, delta=delta) + self.assertAlmostEqual(loss00, loss01, delta=delta)