/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/node.h" namespace paddle { namespace inference { namespace analysis { // It is a better idea that the inputs and outputs of this graph is set manually // before, but there must be a Pass that helps to prune the unnecessary ops that // do not contribute to the given targets, so in this pass, analysis and get the // inputs and outputs is OK. void DataFlowGraph::Build() { inputs.clear(); outputs.clear(); std::unordered_set ins; std::unordered_set outs; for (auto &node : nodes.nodes()) { for (auto *in : node->inlinks) { ins.insert(in); } for (auto *out : node->outlinks) { outs.insert(out); } } // The nodes that in ins but not in outs is the graph's inputs // similarly, the nodes that in outs but not in ins is the graphs' outputs for (auto *in : ins) { if (!outs.count(in)) { inputs.push_back(in); } } for (auto *out : outs) { if (!outs.count(out)) { outputs.push_back(out); } } Clean(); } void DataFlowGraph::Clean() { for (auto &node : nodes.nodes()) { std::unordered_set inlinks_set(node->inlinks.begin(), node->inlinks.end()); std::unordered_set outlinks_set(node->outlinks.begin(), node->outlinks.end()); if (inlinks_set.size() < node->inlinks.size()) { LOG(INFO) << "Clean: node " << node->repr() << " prune duplicate inputs"; node->inlinks.assign(inlinks_set.begin(), inlinks_set.end()); } if (outlinks_set.size() < node->outlinks.size()) { LOG(INFO) << "Clean: node " << node->repr() << " prune duplicate inputs"; node->outlinks.assign(outlinks_set.begin(), outlinks_set.end()); } } } std::string DataFlowGraph::DotString() const { Dot dot; // Add nodes for (size_t i = 0; i < nodes.size(); i++) { const Node &node = nodes.Get(i); dot.AddNode(node.repr(), node.dot_attrs()); } // Add edges for (size_t i = 0; i < nodes.size(); i++) { const Node &node = nodes.Get(i); for (auto &in : node.inlinks) { dot.AddEdge(in->repr(), node.repr(), {}); } } return dot.Build(); } std::string DataFlowGraph::HumanReadableInfo(bool show_values, bool show_functions) const { std::stringstream values, functions; for (auto &n : nodes.nodes()) { if (show_values && n->IsValue()) { values << n->repr() << "\n"; } if (show_functions && n->IsFunction()) { functions << n->repr() << "\n"; } } return "Values:\n" + values.str() + "\n\n" + "Functions:\n" + functions.str(); } // // NodesBFSIterator // GraphTraits::NodesBFSIterator::NodesBFSIterator( const std::vector &source) : queue_(source.begin(), source.end()) {} // GraphTraits::NodesBFSIterator::NodesBFSIterator( // GraphTraits::NodesBFSIterator &&other) noexcept // : queue_(std::move(other.queue_)), // visited_(std::move(other.visited_)) {} GraphTraits::NodesBFSIterator::NodesBFSIterator( const GraphTraits::NodesBFSIterator &other) : queue_(other.queue_), visited_(other.visited_) {} Node &GraphTraits::NodesBFSIterator::operator*() { PADDLE_ENFORCE(!queue_.empty()); return *queue_.front(); } Node *GraphTraits::NodesBFSIterator::operator->() { PADDLE_ENFORCE(!queue_.empty()); return queue_.front(); } GraphTraits::NodesBFSIterator & GraphTraits::NodesBFSIterator::operator=( const GraphTraits::NodesBFSIterator &other) { queue_ = other.queue_; visited_ = other.visited_; return *this; } GraphTraits::NodesBFSIterator &GraphTraits::NodesBFSIterator::operator++() { PADDLE_ENFORCE(!queue_.empty()); auto *cur = queue_.front(); visited_.insert(cur); queue_.pop_front(); for (auto *output : cur->outlinks) { if (!visited_.count(output)) { queue_.push_back(output); visited_.insert(output); } } return *this; } bool GraphTraits::NodesBFSIterator::operator==( const GraphTraits::NodesBFSIterator &other) { if (queue_.empty()) return other.queue_.empty(); if ((!queue_.empty()) && (!other.queue_.empty())) { return queue_.front() == other.queue_.front() && visited_.size() == other.visited_.size(); // here need to check the // equality of queue and // visited. Just a light but week implementation. } return false; } // // NodesDFSIterator // GraphTraits::NodesDFSIterator::NodesDFSIterator( const std::vector &source) { for (auto *x : source) stack_.push(x); } // GraphTraits::NodesDFSIterator::NodesDFSIterator( // GraphTraits::NodesDFSIterator &&other) noexcept // : stack_(std::move(other.stack_)), // visited_(std::move(other.visited_)) {} GraphTraits::NodesDFSIterator::NodesDFSIterator( const GraphTraits::NodesDFSIterator &other) : stack_(other.stack_), visited_(other.visited_) {} Node &GraphTraits::NodesDFSIterator::operator*() { PADDLE_ENFORCE(!stack_.empty()); return *stack_.top(); } GraphTraits::NodesDFSIterator &GraphTraits::NodesDFSIterator::operator++() { if (stack_.empty()) return *this; visited_.insert(stack_.top()); auto *cur = stack_.top(); stack_.pop(); for (auto *x : cur->outlinks) { if (!visited_.count(x)) { stack_.push(x); visited_.insert(x); } } return *this; } bool GraphTraits::NodesDFSIterator::operator==( const GraphTraits::NodesDFSIterator &other) { if (stack_.empty()) return other.stack_.empty(); if ((!stack_.empty()) && (!other.stack_.empty())) { return stack_.top() == other.stack_.top(); } return false; } GraphTraits::NodesDFSIterator & GraphTraits::NodesDFSIterator::operator=( const GraphTraits::NodesDFSIterator &other) { stack_ = other.stack_; visited_ = other.visited_; return *this; } Node *GraphTraits::NodesDFSIterator::operator->() { return stack_.top(); } GraphTraits::NodesTSIterator::NodesTSIterator( const std::vector &source) { PADDLE_ENFORCE(!source.empty(), "Start points of topological sorting should not be empty!"); std::unordered_set visited; std::unordered_set to_visit{source.begin(), source.end()}; std::vector inlink_visited; while (!to_visit.empty()) { std::vector queue(to_visit.begin(), to_visit.end()); for (auto *p : queue) { inlink_visited.clear(); std::copy_if(p->inlinks.begin(), p->inlinks.end(), std::back_inserter(inlink_visited), [&](Node *x) { return visited.count(x); }); if (inlink_visited.size() == p->inlinks.size()) { sorted_.push_back(p); for (auto *_ : p->outlinks) { if (!visited.count(_)) { to_visit.insert(_); } } to_visit.erase(p); visited.insert(p); } } } } GraphTraits::NodesTSIterator::NodesTSIterator( const paddle::inference::analysis::GraphTraits< DataFlowGraph>::NodesTSIterator &other) : sorted_(other.sorted_), cursor_(other.cursor_) {} Node &GraphTraits::NodesTSIterator::operator*() { PADDLE_ENFORCE_LT(cursor_, sorted_.size()); return *sorted_[cursor_]; } paddle::inference::analysis::GraphTraits::NodesTSIterator &GraphTraits::NodesTSIterator::operator++() { if (++cursor_ >= sorted_.size()) { sorted_.clear(); cursor_ = 0; } return *this; } paddle::inference::analysis::GraphTraits::NodesTSIterator & GraphTraits::NodesTSIterator::operator=( const paddle::inference::analysis::GraphTraits< DataFlowGraph>::NodesTSIterator &other) { cursor_ = other.cursor_; sorted_ = other.sorted_; return *this; } bool GraphTraits::NodesTSIterator::operator==( const paddle::inference::analysis::GraphTraits< DataFlowGraph>::NodesTSIterator &other) { return sorted_ == other.sorted_ && cursor_ == other.cursor_; } Node *GraphTraits::NodesTSIterator::operator->() { PADDLE_ENFORCE_LT(cursor_, sorted_.size()); return sorted_[cursor_]; } } // namespace analysis } // namespace inference } // namespace paddle