提交 e537634d 编写于 作者: D dzhwinter

delete graph print pass. test=develop

上级 4f01de63
...@@ -52,8 +52,7 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s ...@@ -52,8 +52,7 @@ cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base s
cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper) cc_library(memory_optimize_helper SRCS memory_optimize_helper.cc DEPS graph graph_helper)
cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass) cc_library(memory_optimize_pass SRCS memory_optimize_pass.cc DEPS memory_optimize_helper pass)
cc_library(graph_print_pass SRCS graph_print_pass.cc DEPS graph_helper pass) cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info)
cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_info graph_print_pass)
cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper) cc_library(modify_op_lock_and_record_event_pass SRCS modify_op_lock_and_record_event_pass.cc DEPS computation_op_handle op_graph_view multi_devices_helper)
cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle cc_library(memory_early_delete_pass SRCS memory_early_delete_pass.cc DEPS memory_optimize_pass computation_op_handle scale_loss_grad_op_handle rpc_op_handle
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass) all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
...@@ -74,7 +73,6 @@ if (WITH_GPU) ...@@ -74,7 +73,6 @@ if (WITH_GPU)
endif() endif()
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph) cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph)
cc_test(memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass) cc_test(memory_optimize_pass_test SRCS memory_optimize_pass_test.cc memory_optimize_pass.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry pass)
cc_test(graph_print_pass_test SRCS graph_print_pass_test.cc DEPS graph_print_pass framework_proto graph graph_helper op_registry pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS}) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ${SSA_GRAPH_EXECUTOR_DEPS})
...@@ -99,4 +97,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS ...@@ -99,4 +97,4 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass graph_print_pass) memory_optimize_pass lock_free_optimize_pass)
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <memory> #include <memory>
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
...@@ -233,9 +232,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -233,9 +232,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
if (graph->Has(kAllOpDescs)) { if (graph->Has(kAllOpDescs)) {
graph->Erase(kAllOpDescs); graph->Erase(kAllOpDescs);
} }
if (!graph->Has(kGraphviz)) {
graph->Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
}
graph->Set<const std::vector<OpDesc *>>( graph->Set<const std::vector<OpDesc *>>(
kAllOpDescs, kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps())); new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
...@@ -245,10 +241,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -245,10 +241,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
"GPU, skipped."; "GPU, skipped.";
continue; continue;
} }
} else if (pass->Type() == "graph_print_path") {
if (!graph->Has(kGraphviz)) {
graph->Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
}
} }
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
} }
...@@ -274,5 +266,4 @@ USE_PASS(all_reduce_deps_pass); ...@@ -274,5 +266,4 @@ USE_PASS(all_reduce_deps_pass);
USE_PASS(modify_op_lock_and_record_event_pass); USE_PASS(modify_op_lock_and_record_event_pass);
USE_PASS(inplace_pass); USE_PASS(inplace_pass);
USE_PASS(lock_free_optimize_pass); USE_PASS(lock_free_optimize_pass);
USE_PASS(graph_print_pass);
USE_PASS(graph_to_program_pass); USE_PASS(graph_to_program_pass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace details {
class GraphvizVar : public GraphvizNode {
public:
GraphvizVar(ir::Node* n, const int& i) : GraphvizNode(n, i) {}
friend std::ostream& operator<<(std::ostream& sout, const GraphvizVar& var) {
sout << "var_" << var.id_ << " [label=\"" << var.node_->Name() << "\"]"
<< std::endl;
return sout;
}
};
class GraphvizOp : public GraphvizNode {
public:
GraphvizOp(ir::Node* n, const int& i) : GraphvizNode(n, i) {}
friend std::ostream& operator<<(std::ostream& sout, const GraphvizOp& op) {
sout << "op_" + std::to_string(op.id_) << " [label=\"" << op.node_->Name()
<< "\", shape=rect]" << std::endl;
sout << op.stream_.str();
return sout;
}
template <typename Callback>
void AddEdge(const Callback& cb) {
std::string op_name = "op_" + std::to_string(id_);
for (auto var : node_->inputs) {
std::string var_name = "var_" + std::to_string(cb(var));
stream_ << var_name << "->" << op_name << std::endl;
}
for (auto var : node_->outputs) {
std::string var_name = "var_" + std::to_string(cb(var));
stream_ << op_name << "->" << var_name << std::endl;
}
}
template <typename Callback>
void AddCustomEdge(const Callback& cb) {
stream_ << cb() << std::endl;
}
private:
std::ostringstream stream_;
};
template <typename T, typename Container>
std::vector<T*> FilterByNodeWrapper(const Container& con) {
std::vector<T*> ret;
for (auto& node : con) {
auto i = dynamic_cast<T*>(node.get());
if (i != nullptr) ret.emplace_back(i);
}
return ret;
}
std::unordered_map<ir::Node*, int> SSAGraphPrinterImpl::ToGraphvizNode(
const ir::Graph& graph) const {
// Convert to GraphvizNode format
auto& graphviz_nodes = graph.Get<GraphvizNodes>(kGraphviz);
graphviz_nodes.clear();
std::unordered_map<ir::Node*, int> vars;
std::unordered_map<ir::Node*, GraphvizOp*> ops;
int var_id = 0;
int op_id = 0;
for (auto& node : graph.Nodes()) {
if (node->IsVar()) {
graphviz_nodes.emplace(new GraphvizVar(node, var_id));
vars.emplace(std::make_pair(node, var_id++));
} else if (node->IsOp()) {
std::unique_ptr<GraphvizOp> op(new GraphvizOp(node, op_id++));
ops[node] = op.get();
graphviz_nodes.emplace(std::move(op));
} else {
PADDLE_THROW("Unknown op type");
}
}
// Detect circle. Draw circle in different lines
std::vector<std::vector<ir::Node*>> circles;
const std::string kCircleEdge = "[color=red,penwidth=3.0]";
if (ir::FindCircleSubGraph(graph, &circles)) {
VLOG(3) << "Graph has circle! circles count : " << circles.size();
for (auto& circle : circles) {
for (size_t i = 0; i < circle.size() - 1; ++i) {
GraphvizOp* prev = ops[circle[i]];
GraphvizOp* next = ops[circle[i + 1]];
std::string prev_op = "op_" + std::to_string(prev->Id());
std::string next_op = "op_" + std::to_string(next->Id());
prev->AddCustomEdge([&]() -> std::string {
return prev_op + "->" + next_op + kCircleEdge;
});
}
}
}
return vars;
}
void SSAGraphPrinterImpl::Print(const ir::Graph& graph,
std::ostream& sout) const {
auto vars = ToGraphvizNode(graph);
auto& nodes = graph.Get<GraphvizNodes>(kGraphviz);
sout << "digraph G {\n";
for (auto& var : FilterByNodeWrapper<GraphvizVar>(nodes)) {
sout << *var;
}
for (auto& op : FilterByNodeWrapper<GraphvizOp>(nodes)) {
op->AddEdge([&vars](ir::Node* var) { return vars.at(var); });
sout << *op;
}
sout << "}\n";
}
std::unique_ptr<ir::Graph> SSAGraphPrintPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
printer_.reset(new SSAGraphPrinterImpl());
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<std::string>(kGraphvizPath)));
PADDLE_ENFORCE(fout->good() == true, "Failed to open file.");
printer_->Print(*graph, *fout);
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(graph_print_pass, paddle::framework::details::SSAGraphPrintPass)
.RequirePassAttr(paddle::framework::details::kGraphvizPath);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fstream>
#include <memory>
#include <unordered_map>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kGraphvizPath[] = "debug_graphviz_path";
constexpr char kGraphviz[] = "graphviz";
// NOTE(dzhwinter): If the graph contains circles.
// the graph can not be topology sort.
// This printer will print the whole graph
// and highlight the circles. It's quite useful
// for debug the deadlock and circles.
class GraphvizNode {
public:
GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {}
virtual ~GraphvizNode() = default;
int Id() const { return id_; }
protected:
ir::Node* node_;
int id_;
};
typedef std::unordered_set<std::unique_ptr<GraphvizNode>> GraphvizNodes;
class SSAGraphPrinter {
public:
virtual ~SSAGraphPrinter() {}
virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0;
};
class SSAGraphPrinterImpl : public SSAGraphPrinter {
public:
void Print(const ir::Graph& graph, std::ostream& sout) const override;
private:
std::unordered_map<ir::Node*, int> ToGraphvizNode(
const ir::Graph& graph) const;
};
class SSAGraphPrintPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
private:
mutable std::unique_ptr<SSAGraphPrinter> printer_;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/graph_test_base.h"
REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
paddle::framework::SumOpMaker);
REGISTER_OPERATOR(split, paddle::framework::DummyOp,
paddle::framework::SplitOpMaker);
REGISTER_OPERATOR(assign, paddle::framework::DummyOp,
paddle::framework::AssignOpMaker,
paddle::framework::DummyVarTypeInference);
/*
a @ b
c
d @ e
*/
using paddle::framework::ProgramDesc;
using paddle::framework::proto::VarType;
inline static ProgramDesc FillProgramDesc() {
ProgramDesc prog;
prog.MutableBlock(0)->Var("a")->SetType(VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("b")->SetType(VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("c")->SetType(VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("d")->SetType(VarType::LOD_TENSOR);
prog.MutableBlock(0)->Var("e")->SetType(VarType::LOD_TENSOR);
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"a", "b"});
op->SetOutput("Out", {"c"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("split");
op->SetInput("X", {"c"});
op->SetOutput("Out", {"d", "e"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"d", "e"});
op->SetOutput("Out", {"d"});
}
{
auto* op = prog.MutableBlock(0)->AppendOp();
op->SetType("assign");
op->SetInput("X", {"d"});
op->SetOutput("Out", {"d"});
}
return prog;
}
namespace paddle {
namespace framework {
namespace details {
TEST(SSAGraphPrinter, Normal) {
auto program = FillProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
graph->Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
// redirect debug graph to a file.
constexpr char graph_path[] = "graph_print_pass.txt";
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
PADDLE_ENFORCE(fout->good());
printer->Print(*graph, *fout);
}
using ir::Graph;
using ir::Node;
void BuildCircleGraph(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
o1->outputs.push_back(v1);
o1->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o1);
}
void BuildCircleGraph2(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
o2->outputs.push_back(v2);
o1->inputs.push_back(v2);
v2->inputs.push_back(o2);
v2->outputs.push_back(o1);
}
void BuildNoCircleGraph(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation);
ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation);
ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable);
ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable);
// o1->v1->o2
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
// o2->v2->o3
// o2->v2->o4
o2->outputs.push_back(v2);
o3->inputs.push_back(v2);
o4->inputs.push_back(v2);
v2->inputs.push_back(o2);
v2->outputs.push_back(o3);
v2->outputs.push_back(o4);
// o2->v3->o5
o2->outputs.push_back(v3);
o5->inputs.push_back(v3);
v3->inputs.push_back(o2);
v3->outputs.push_back(o5);
// o3-v4->o5
o3->outputs.push_back(v4);
o5->inputs.push_back(v4);
v4->inputs.push_back(o3);
v4->outputs.push_back(o5);
// o2->v3->o1
v3->outputs.push_back(o1);
o1->inputs.push_back(v3);
}
TEST(SSAGraphPrinter, SimpleCircle) {
ProgramDesc prog;
Graph graph(prog);
BuildCircleGraph(&graph);
ASSERT_TRUE(HasCircle(graph));
graph.Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
// redirect debug graph to a file.
constexpr char graph_path[] = "graph_print_pass_simple_circle.txt";
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
PADDLE_ENFORCE(fout->good());
printer->Print(graph, *fout);
}
TEST(SSAGraphPrinter, ComplexCircle) {
ProgramDesc prog;
Graph graph(prog);
BuildCircleGraph2(&graph);
ASSERT_TRUE(HasCircle(graph));
graph.Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
// redirect debug graph to a file.
constexpr char graph_path[] = "graph_print_pass_complex_circle.txt";
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
PADDLE_ENFORCE(fout->good());
printer->Print(graph, *fout);
}
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/graph_print_pass.h"
#include "paddle/fluid/framework/details/memory_optimize_pass.h" #include "paddle/fluid/framework/details/memory_optimize_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
...@@ -114,24 +113,6 @@ static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) { ...@@ -114,24 +113,6 @@ static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) {
return input_it == prev_op->inputs.end() ? nullptr : *input_it; return input_it == prev_op->inputs.end() ? nullptr : *input_it;
} }
template <typename Container>
static inline bool ConnectByCtrlVar(const Container& group1,
const Container& group2) {
bool connected = false;
std::unordered_set<ir::Node*> outputs;
for (auto* op : group1) {
for (auto* var : op->outputs) {
if (var->IsCtrlVar()) outputs.emplace(var);
}
}
for (auto* op : group2) {
for (auto* var : op->inputs) {
if (outputs.count(var)) connected = true;
}
}
return connected;
}
InplacePass::InplacePass() : Pass() { InplacePass::InplacePass() : Pass() {
if (FLAGS_enable_inplace_whitelist) { if (FLAGS_enable_inplace_whitelist) {
for (auto& s : kInplacedOpWhiteList) { for (auto& s : kInplacedOpWhiteList) {
...@@ -316,18 +297,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op, ...@@ -316,18 +297,7 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
continue; continue;
} }
// 3. if output reuse input inplaced, the dependency group is not changed. // 3. if output has been memory optimize by python(fluid.memory_optmize()).
// For detail, check
// the function description in "OutConnectInputByCtrlVar"
if (view_.OutConnectInputByCtrlVar(in_node, out_node)) {
VLOG(4) << string::Sprintf(
"Skiped pair %s => %s. %s input and output connect by ctrl var."
"inplace such pair will generate a circle.",
out_var_name, in_var_name, op->Name());
continue;
}
// 4. if output has been memory optimize by python(fluid.memory_optmize()).
// this candidate can not be inplaced. Will be deprecated in the future. // this candidate can not be inplaced. Will be deprecated in the future.
if (view_.ReusedInPythonMemOpt(out_node->Name())) { if (view_.ReusedInPythonMemOpt(out_node->Name())) {
VLOG(4) << string::Sprintf( VLOG(4) << string::Sprintf(
...@@ -431,48 +401,6 @@ void GraphView::Build(ir::Graph* g) { ...@@ -431,48 +401,6 @@ void GraphView::Build(ir::Graph* g) {
const std::vector<ir::Node*> GraphView::AllOps() { return ops_; } const std::vector<ir::Node*> GraphView::AllOps() { return ops_; }
bool GraphView::OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var) {
// assume v_a0, v_a1 is variable. v_a0 -> v_a0 means already inplaced.
// v_a1 -> v_a1 means already inplaced.
// Currently we make decision to check if the v_a0 -> v_a1 can be inplace.
//
// v_a0
// +
// |
// v
// v_a0
// +
// |
// v
// v_a1
// +
// |
// v
// v_a1
// start from the first inplaced input v_a0(on the top one).
// Do a DFSSearch, get all its paths. If there is one path connect
// the in_var and out_var which contains control dep var.
// Means there a control path. out_var can not be inplaced use in_var.
std::unordered_set<ir::Node *> out_var_set, in_var_set;
ir::Node* out = out_var;
// get the ops with same output name
while (out != nullptr) {
out_var_set.emplace(out);
out = GetNextCascadeInplacedVar(out);
}
// get ops with same input name
ir::Node* in = in_var;
while (in != nullptr) {
in_var_set.emplace(in);
in = GetPrevCascadeInplacedVar(in);
}
// find if there is path with control dep var connect the in_var_set and
// out_var_set
return ConnectByCtrlVar(in_var_set, out_var_set);
}
bool GraphView::ReusedInPythonMemOpt(const std::string& var) const { bool GraphView::ReusedInPythonMemOpt(const std::string& var) const {
return dup_nodes_.count(var); return dup_nodes_.count(var);
} }
......
...@@ -40,8 +40,6 @@ class GraphView { ...@@ -40,8 +40,6 @@ class GraphView {
std::vector<ir::Node*> PendingOpsOnVar(ir::Node* var); std::vector<ir::Node*> PendingOpsOnVar(ir::Node* var);
bool OutConnectInputByCtrlVar(ir::Node* in_var, ir::Node* out_var);
// Will Deperated in the future. // Will Deperated in the future.
// NOTE(dzhwinter) : Python memory optimize will reuse // NOTE(dzhwinter) : Python memory optimize will reuse
// memory based var name, so different op output may // memory based var name, so different op output may
......
...@@ -19,12 +19,20 @@ ...@@ -19,12 +19,20 @@
#include <iosfwd> #include <iosfwd>
#include <ostream> #include <ostream>
#include <string> #include <string>
#include "paddle/fluid/framework/details/graph_print_pass.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
constexpr char kGraphvizPath[] = "debug_graphviz_path";
class SSAGraphPrinter {
public:
virtual ~SSAGraphPrinter() {}
virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0;
};
class GraphvizSSAGraphPrinter : public SSAGraphPrinter { class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
public: public:
void Print(const ir::Graph& graph, std::ostream& sout) const override; void Print(const ir::Graph& graph, std::ostream& sout) const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册