提交 cae5a931 编写于 作者: S superjomn

add topological sort

上级 aefca7f1
...@@ -18,7 +18,16 @@ ...@@ -18,7 +18,16 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace mir { namespace mir {
void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph> &graph) {}
void GenerateProgramPass::Apply(std::unique_ptr<mir::SSAGraph>& graph) {
for (auto& item : graph->TopoloticalOrder()) {
if (item->IsInstruct()) {
auto& instruct = item->AsInstruct();
kernels_.emplace_back(std::move(instruct.valid_kernels.front()));
}
}
}
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <list>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/mir/pass.h" #include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle { namespace paddle {
...@@ -26,7 +28,12 @@ namespace mir { ...@@ -26,7 +28,12 @@ namespace mir {
*/ */
class GenerateProgramPass : public ProgramPass { class GenerateProgramPass : public ProgramPass {
public: public:
void Apply(std::unique_ptr<mir::SSAGraph> &graph) override; void Apply(std::unique_ptr<mir::SSAGraph>& graph) override;
std::list<std::unique_ptr<KernelBase>>& kernels() { return kernels_; }
private:
std::list<std::unique_ptr<KernelBase>> kernels_;
}; };
} // namespace mir } // namespace mir
......
...@@ -13,3 +13,9 @@ ...@@ -13,3 +13,9 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/ssa_graph.h" #include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace paddle {
namespace lite {
namespace mir {} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <list> #include <list>
#include <map> #include <map>
#include <stack>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/mir/node.h" #include "paddle/fluid/lite/core/mir/node.h"
...@@ -81,7 +82,31 @@ class SSAGraph : GraphBase { ...@@ -81,7 +82,31 @@ class SSAGraph : GraphBase {
} }
} }
std::vector<mir::Node *> TopoloticalOrder() const; void sort_utils(mir::Node *n, std::map<mir::Node *, bool> &visited,
std::stack<mir::Node *> &stack) {
visited[n] = true;
for (auto &out : n->outlinks) {
if (!visited[out]) {
sort_utils(out, visited, stack);
}
}
}
std::vector<mir::Node *> TopoloticalOrder() {
std::map<mir::Node *, bool> visited;
std::stack<mir::Node *> stack;
std::vector<mir::Node *> res;
for (auto &n : mutable_nodes()) {
if (!visited[&n]) sort_utils(&n, visited, stack);
}
while (!stack.empty()) {
res.push_back(stack.top());
stack.pop();
}
return res;
}
const std::list<mir::Node> &nodes() const { return node_storage_; } const std::list<mir::Node> &nodes() const { return node_storage_; }
std::list<mir::Node> &mutable_nodes() { return node_storage_; } std::list<mir::Node> &mutable_nodes() { return node_storage_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册