提交 9605fcd1 编写于 作者: X Xin Pan

all graphs

上级 af79b192
......@@ -21,7 +21,6 @@
namespace paddle {
namespace framework {
namespace details {
struct SSAGraph;
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
......
......@@ -21,7 +21,7 @@
namespace paddle {
namespace framework {
namespace details {
struct SSAGraph;
class SSAGraphPrinter {
public:
virtual ~SSAGraphPrinter() {}
......
......@@ -14,13 +14,14 @@
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
namespace framework {
namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph)
const std::vector<platform::Place> &places, std::unique_ptr<Graph> &&graph)
: graph_(std::move(graph)),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr),
......@@ -43,18 +44,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std::unordered_set<OpHandleBase *> delayed_ops;
// Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->vars_) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
}
}
}
for (auto &var : graph_->dep_vars_) {
for (auto &var : graph_->Get<details::GraphDepVars>("dep_vars")) {
InsertPendingVar(&pending_vars, &ready_vars, var.get());
}
for (auto &op : graph_->ops_) {
for (auto &op : graph_->Get<details::GraphOps>("ops")) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get());
} else {
......@@ -158,7 +159,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->vars_) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
......
......@@ -27,6 +27,7 @@
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
......@@ -39,7 +40,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph);
std::unique_ptr<Graph> &&graph);
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
......@@ -52,7 +53,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op);
private:
std::unique_ptr<SSAGraph> graph_;
std::unique_ptr<Graph> graph_;
std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
......
......@@ -135,14 +135,8 @@ ParallelExecutor::ParallelExecutor(
builder_ = builder_factory.Create();
std::unique_ptr<Graph> graph = builder_->Build(ProgramToGraph(main_program));
std::unique_ptr<details::SSAGraph> ssa_graph(new details::SSAGraph);
ssa_graph->vars_ = std::move(graph->Get<details::GraphVars>("vars"));
ssa_graph->ops_ = std::move(graph->Get<details::GraphOps>("ops"));
ssa_graph->dep_vars_ =
std::move(graph->Get<details::GraphDepVars>("dep_vars"));
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(ssa_graph)));
exec_strategy, member_->local_scopes_, places, std::move(graph)));
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册