// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/details/graph_print_pass.h" #include #include #include "paddle/fluid/framework/ir/graph_helper.h" namespace paddle { namespace framework { namespace details { class GraphvizVar : public GraphvizNode { public: GraphvizVar(ir::Node* n, const int& i) : GraphvizNode(n, i) {} friend std::ostream& operator<<(std::ostream& sout, const GraphvizVar& var) { sout << "var_" << var.id_ << " [label=\"" << var.node_->Name() << "\"]" << std::endl; return sout; } }; class GraphvizOp : public GraphvizNode { public: GraphvizOp(ir::Node* n, const int& i) : GraphvizNode(n, i) {} friend std::ostream& operator<<(std::ostream& sout, const GraphvizOp& op) { sout << "op_" + std::to_string(op.id_) << " [label=\"" << op.node_->Name() << "\", shape=rect]" << std::endl; PADDLE_ENFORCE(op.stream_.rdbuf()->in_avail() != 0, "No inputs outputs. Please call AddEdge first!"); sout << op.stream_.str(); return sout; } template void AddEdge(const Callback& cb) { std::string op_name = "op_" + std::to_string(id_); for (auto var : node_->inputs) { std::string var_name = "var_" + std::to_string(cb(var)); stream_ << var_name << "->" << op_name << std::endl; } for (auto var : node_->outputs) { std::string var_name = "var_" + std::to_string(cb(var)); stream_ << op_name << "->" << var_name << std::endl; } } template void AddCustomEdge(const Callback& cb) { stream_ << cb() << std::endl; } private: std::ostringstream stream_; }; template std::vector FilterByNodeWrapper(const Container& con) { std::vector ret; for (auto& node : con) { auto i = dynamic_cast(node.get()); if (i != nullptr) ret.emplace_back(i); } return ret; } // bool DetectCircleRecursive(const std::map>, std::unordered_set* visited, // std::unordered_set *in_trace, std::vector>* // circles) { // if (visited->find(node) == visited->end()) { // visited->insert(node); // in_trace->insert(node); // for (ir::Node *in : adj_list.at(node)) { // if (visited->find(in) == visited->end() && // HasCircleHelper(in, adj_list, visited, in_trace)) { // return true; // } else if (in_trace->find(in) != in_trace->end()) { // circles->push_back(in_trace); // return true; // } // } // } // in_trace->erase(node); // return false; // } // bool DetectCircle(const std::map>& // adj_list, std::vector>* circles) { // std::unordered_set visited; // std::unordered_set in_trace; // bool has_circle = false; // for(auto& adj : adj_list) { // has_circle &= DetectCircleRecursive(adj, adj_list,&visited, &in_trace, // circles); // } // return has_circle; // } std::unordered_map SSAGraphPrinterImpl::ToGraphvizNode( const ir::Graph& graph) const { // Convert to GraphvizNode format auto& graphviz_nodes = graph.Get(kGraphviz); graphviz_nodes.clear(); std::unordered_map vars; std::unordered_map ops; int var_id = 0; int op_id = 0; for (auto& node : graph.Nodes()) { if (node->IsVar()) { graphviz_nodes.emplace(new GraphvizVar(node, var_id)); vars.emplace(std::make_pair(node, var_id++)); } else if (node->IsOp()) { std::unique_ptr op(new GraphvizOp(node, op_id++)); ops[node] = op.get(); graphviz_nodes.emplace(std::move(op)); // graphviz_nodes.emplace(new GraphvizOp(node, op_id++)); // ops.emplace(std::make_pair(node, graphviz_nodes.back().get())); } else { PADDLE_THROW("Unknown op type"); } } // Detect circle. Draw circle in different lines std::vector> circles; const std::string kCircleEdge = "[color=red,penwidth=3.0]"; if (ir::FindCircleSubGraph(graph, &circles)) { VLOG(3) << "Graph has circle! circles count : " << circles.size(); for (auto& circle : circles) { for (size_t i = 0; i < circle.size() - 1; ++i) { GraphvizOp* prev = ops[circle[i]]; GraphvizOp* next = ops[circle[i + 1]]; std::string prev_op = "op_" + std::to_string(prev->Id()); std::string next_op = "op_" + std::to_string(next->Id()); prev->AddCustomEdge([&]() -> std::string { return prev_op + "->" + next_op + kCircleEdge; }); } } } return vars; } void SSAGraphPrinterImpl::Print(const ir::Graph& graph, std::ostream& sout) const { auto vars = ToGraphvizNode(graph); auto& nodes = graph.Get(kGraphviz); sout << "digraph G {\n"; for (auto& var : FilterByNodeWrapper(nodes)) { sout << *var; } for (auto& op : FilterByNodeWrapper(nodes)) { op->AddEdge([&vars](ir::Node* var) { return vars.at(var); }); sout << *op; } sout << "}\n"; } std::unique_ptr SSAGraphPrintPass::ApplyImpl( std::unique_ptr graph) const { printer_.reset(new SSAGraphPrinterImpl()); std::unique_ptr fout( new std::ofstream(Get(kGraphvizPath))); PADDLE_ENFORCE(fout->good() == true, "Failed to open file."); printer_->Print(*graph, *fout); return graph; } } // namespace details } // namespace framework } // namespace paddle REGISTER_PASS(graph_print_pass, paddle::framework::details::SSAGraphPrintPass) .RequirePassAttr(paddle::framework::details::kGraphvizPath);