From cae5a9314423fcd4e1b21e8befe4e022002c9b38 Mon Sep 17 00:00:00 2001 From: superjomn Date: Wed, 17 Apr 2019 20:32:26 +0800 Subject: [PATCH] add topological sort --- .../lite/core/mir/generate_program_pass.cc | 11 +++++++- .../lite/core/mir/generate_program_pass.h | 9 ++++++- paddle/fluid/lite/core/mir/ssa_graph.cc | 6 +++++ paddle/fluid/lite/core/mir/ssa_graph.h | 27 ++++++++++++++++++- 4 files changed, 50 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.cc b/paddle/fluid/lite/core/mir/generate_program_pass.cc index 0be516489db..659a959fec6 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.cc +++ b/paddle/fluid/lite/core/mir/generate_program_pass.cc @@ -18,7 +18,16 @@ namespace paddle { namespace lite { namespace mir { -void GenerateProgramPass::Apply(std::unique_ptr &graph) {} + +void GenerateProgramPass::Apply(std::unique_ptr& graph) { + for (auto& item : graph->TopoloticalOrder()) { + if (item->IsInstruct()) { + auto& instruct = item->AsInstruct(); + kernels_.emplace_back(std::move(instruct.valid_kernels.front())); + } + } +} + } // namespace mir } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.h b/paddle/fluid/lite/core/mir/generate_program_pass.h index bc78370b08d..0421f5f2ada 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.h +++ b/paddle/fluid/lite/core/mir/generate_program_pass.h @@ -14,6 +14,8 @@ #pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/mir/pass.h" namespace paddle { @@ -26,7 +28,12 @@ namespace mir { */ class GenerateProgramPass : public ProgramPass { public: - void Apply(std::unique_ptr &graph) override; + void Apply(std::unique_ptr& graph) override; + + std::list>& kernels() { return kernels_; } + + private: + std::list> kernels_; }; } // namespace mir diff --git a/paddle/fluid/lite/core/mir/ssa_graph.cc b/paddle/fluid/lite/core/mir/ssa_graph.cc index 54d570cab4b..cfab5b35f01 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph.cc @@ -13,3 +13,9 @@ // limitations under the License. #include "paddle/fluid/lite/core/mir/ssa_graph.h" + +namespace paddle { +namespace lite { +namespace mir {} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 6dee98c76df..95d2d2d98d2 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include "paddle/fluid/lite/core/mir/node.h" @@ -81,7 +82,31 @@ class SSAGraph : GraphBase { } } - std::vector TopoloticalOrder() const; + void sort_utils(mir::Node *n, std::map &visited, + std::stack &stack) { + visited[n] = true; + for (auto &out : n->outlinks) { + if (!visited[out]) { + sort_utils(out, visited, stack); + } + } + } + + std::vector TopoloticalOrder() { + std::map visited; + std::stack stack; + std::vector res; + + for (auto &n : mutable_nodes()) { + if (!visited[&n]) sort_utils(&n, visited, stack); + } + + while (!stack.empty()) { + res.push_back(stack.top()); + stack.pop(); + } + return res; + } const std::list &nodes() const { return node_storage_; } std::list &mutable_nodes() { return node_storage_; } -- GitLab