提交 c70b60dd 编写于 作者: Y Yu Yang

Make executor steal graph inside

上级 4c3361cd
...@@ -37,8 +37,9 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -37,8 +37,9 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
} }
} }
void MultiDevSSAGraphBuilder::Build(const ProgramDesc &program, std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
SSAGraph *graph) const { const ProgramDesc &program) const {
auto graph = new SSAGraph();
SSAGraph &result = *graph; SSAGraph &result = *graph;
result.vars_.resize(places_.size()); result.vars_.resize(places_.size());
...@@ -134,6 +135,8 @@ void MultiDevSSAGraphBuilder::Build(const ProgramDesc &program, ...@@ -134,6 +135,8 @@ void MultiDevSSAGraphBuilder::Build(const ProgramDesc &program,
harzaeds need to be handled. harzaeds need to be handled.
*/ */
PolishGraphToSupportDataHazards(&result); PolishGraphToSupportDataHazards(&result);
return std::unique_ptr<SSAGraph>(graph);
} }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -32,7 +32,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -32,7 +32,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs); platform::NCCLContextMap *nccl_ctxs);
void Build(const ProgramDesc &program, SSAGraph *graph) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include <memory>
#include <string> #include <string>
namespace paddle { namespace paddle {
...@@ -28,7 +29,7 @@ class SSAGraphBuilder { ...@@ -28,7 +29,7 @@ class SSAGraphBuilder {
public: public:
SSAGraphBuilder() {} SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {} virtual ~SSAGraphBuilder() {}
virtual void Build(const ProgramDesc &program, SSAGraph *graph) const = 0; virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
...@@ -34,16 +34,16 @@ class SSAGraphExecutor { ...@@ -34,16 +34,16 @@ class SSAGraphExecutor {
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor); DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
public: public:
explicit SSAGraphExecutor(SSAGraph *graph) : graph_(*graph) {} // Steal graph inside
explicit SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph)
: graph_(std::move(graph)) {}
virtual ~SSAGraphExecutor() {} virtual ~SSAGraphExecutor() {}
virtual void Run(Scope *global_scope, virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
const std::vector<std::string> &fetch_tensors,
const std::string &fetch_list_name) = 0;
protected: protected:
SSAGraph &graph_; std::unique_ptr<SSAGraph> graph_;
}; };
class ThreadedSSAGraphExecutor : public SSAGraphExecutor { class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...@@ -51,16 +51,17 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -51,16 +51,17 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event, ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
SSAGraph *graph) std::unique_ptr<SSAGraph> &&graph)
: SSAGraphExecutor(graph), : SSAGraphExecutor(std::move(graph)),
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr), pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
local_scopes_(local_scopes), local_scopes_(local_scopes),
places_(places), places_(places),
fetch_ctxs_(places), fetch_ctxs_(places),
use_event_(use_event) {} use_event_(use_event) {}
void Run(Scope *global_scope, const std::vector<std::string> &fetch_tensors, // Run a SSAGraph by a thread pool
const std::string &fetch_list_name) override { // Use topological sort algorithm
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override {
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars; std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
std::unordered_set<OpHandleBase *> ready_ops; std::unordered_set<OpHandleBase *> ready_ops;
...@@ -74,18 +75,18 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -74,18 +75,18 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
}; };
// Transform SSAGraph to pending_ops & pending_vars // Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_.vars_) { for (auto &var_map : graph_->vars_) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(version_pair.second); InsertPendingVar(version_pair.second);
} }
} }
} }
for (auto &var : graph_.dep_vars_) { for (auto &var : graph_->dep_vars_) {
InsertPendingVar(*var); InsertPendingVar(*var);
} }
for (auto &op : graph_.ops_) { for (auto &op : graph_->ops_) {
if (op->inputs_.empty()) { // Special case, Op has no input. if (op->inputs_.empty()) { // Special case, Op has no input.
ready_ops.insert(op.get()); ready_ops.insert(op.get());
} else { } else {
...@@ -101,7 +102,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -101,7 +102,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_.vars_) { for (auto &var_map : graph_->vars_) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second); fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
...@@ -182,8 +183,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -182,8 +183,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
fetch_op.WaitAndMergeCPUTensors(); fetch_op.WaitAndMergeCPUTensors();
} }
*global_scope->Var(fetch_list_name)->GetMutable<FeedFetchList>() = return fetch_data;
fetch_data;
} }
~ThreadedSSAGraphExecutor() {} ~ThreadedSSAGraphExecutor() {}
...@@ -240,8 +240,6 @@ class ParallelExecutorPrivate { ...@@ -240,8 +240,6 @@ class ParallelExecutorPrivate {
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
details::SSAGraph graph_;
std::unique_ptr<SSAGraphExecutor> executor_; std::unique_ptr<SSAGraphExecutor> executor_;
}; };
...@@ -274,10 +272,10 @@ ParallelExecutor::ParallelExecutor( ...@@ -274,10 +272,10 @@ ParallelExecutor::ParallelExecutor(
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_, params, member_->local_scopes_,
member_->nccl_ctxs_.get()); member_->nccl_ctxs_.get());
builder.Build(main_program, &member_->graph_); auto graph = builder.Build(main_program);
member_->executor_.reset(new ThreadedSSAGraphExecutor( member_->executor_.reset(new ThreadedSSAGraphExecutor(
num_threads, true, member_->local_scopes_, places, &member_->graph_)); num_threads, true, member_->local_scopes_, places, std::move(graph)));
// Step 3. Create vars in each scope; // Step 3. Create vars in each scope;
for (auto *scope : member_->local_scopes_) { for (auto *scope : member_->local_scopes_) {
...@@ -338,8 +336,9 @@ void ParallelExecutor::BuildNCCLCommunicator() const { ...@@ -338,8 +336,9 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) { const std::string &fetched_var_name) {
member_->executor_->Run(member_->global_scope_, fetch_tensors, auto fetch_data = member_->executor_->Run(fetch_tensors);
fetched_var_name); *member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data;
} }
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册