diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.cc b/paddle/fluid/lite/core/mir/generate_program_pass.cc index 0be516489db6a0a66a66db6e87cdf9dc875f75ba..659a959fec63d1dbdf5d49b22c74b93338378658 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.cc +++ b/paddle/fluid/lite/core/mir/generate_program_pass.cc @@ -18,7 +18,16 @@ namespace paddle { namespace lite { namespace mir { -void GenerateProgramPass::Apply(std::unique_ptr &graph) {} + +void GenerateProgramPass::Apply(std::unique_ptr& graph) { + for (auto& item : graph->TopoloticalOrder()) { + if (item->IsInstruct()) { + auto& instruct = item->AsInstruct(); + kernels_.emplace_back(std::move(instruct.valid_kernels.front())); + } + } +} + } // namespace mir } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/core/mir/generate_program_pass.h b/paddle/fluid/lite/core/mir/generate_program_pass.h index bc78370b08d8785586ab5bb4babd883c9377988c..0421f5f2adafdabbab3868771b3c4502765e6ad9 100644 --- a/paddle/fluid/lite/core/mir/generate_program_pass.h +++ b/paddle/fluid/lite/core/mir/generate_program_pass.h @@ -14,6 +14,8 @@ #pragma once +#include +#include "paddle/fluid/lite/core/kernel.h" #include "paddle/fluid/lite/core/mir/pass.h" namespace paddle { @@ -26,7 +28,12 @@ namespace mir { */ class GenerateProgramPass : public ProgramPass { public: - void Apply(std::unique_ptr &graph) override; + void Apply(std::unique_ptr& graph) override; + + std::list>& kernels() { return kernels_; } + + private: + std::list> kernels_; }; } // namespace mir diff --git a/paddle/fluid/lite/core/mir/ssa_graph.cc b/paddle/fluid/lite/core/mir/ssa_graph.cc index 54d570cab4b69331a7888463536eab331f2864b4..cfab5b35f01dc0306a746f05f2103f18398a66fa 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph.cc @@ -13,3 +13,9 @@ // limitations under the License. #include "paddle/fluid/lite/core/mir/ssa_graph.h" + +namespace paddle { +namespace lite { +namespace mir {} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 6dee98c76df044aeb7d72178b109ee9408f5c990..95d2d2d98d26454d1448c574f51e24c68a8aa43b 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include "paddle/fluid/lite/core/mir/node.h" @@ -81,7 +82,31 @@ class SSAGraph : GraphBase { } } - std::vector TopoloticalOrder() const; + void sort_utils(mir::Node *n, std::map &visited, + std::stack &stack) { + visited[n] = true; + for (auto &out : n->outlinks) { + if (!visited[out]) { + sort_utils(out, visited, stack); + } + } + } + + std::vector TopoloticalOrder() { + std::map visited; + std::stack stack; + std::vector res; + + for (auto &n : mutable_nodes()) { + if (!visited[&n]) sort_utils(&n, visited, stack); + } + + while (!stack.empty()) { + res.push_back(stack.top()); + stack.pop(); + } + return res; + } const std::list &nodes() const { return node_storage_; } std::list &mutable_nodes() { return node_storage_; }