提交 c70b60dd 编写于 作者: Y Yu Yang

Make executor steal graph inside

上级 4c3361cd
......@@ -37,8 +37,9 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
}
}
void MultiDevSSAGraphBuilder::Build(const ProgramDesc &program,
SSAGraph *graph) const {
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
auto graph = new SSAGraph();
SSAGraph &result = *graph;
result.vars_.resize(places_.size());
......@@ -134,6 +135,8 @@ void MultiDevSSAGraphBuilder::Build(const ProgramDesc &program,
harzaeds need to be handled.
*/
PolishGraphToSupportDataHazards(&result);
return std::unique_ptr<SSAGraph>(graph);
}
} // namespace details
} // namespace framework
......
......@@ -32,7 +32,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs);
void Build(const ProgramDesc &program, SSAGraph *graph) const override;
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
private:
std::string loss_var_name_;
......
......@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h"
#include <memory>
#include <string>
namespace paddle {
......@@ -28,7 +29,7 @@ class SSAGraphBuilder {
public:
SSAGraphBuilder() {}
virtual ~SSAGraphBuilder() {}
virtual void Build(const ProgramDesc &program, SSAGraph *graph) const = 0;
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
......
......@@ -34,16 +34,16 @@ class SSAGraphExecutor {
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
public:
explicit SSAGraphExecutor(SSAGraph *graph) : graph_(*graph) {}
// Steal graph inside
explicit SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph)
: graph_(std::move(graph)) {}
virtual ~SSAGraphExecutor() {}
virtual void Run(Scope *global_scope,
const std::vector<std::string> &fetch_tensors,
const std::string &fetch_list_name) = 0;
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
protected:
SSAGraph &graph_;
std::unique_ptr<SSAGraph> graph_;
};
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
......@@ -51,16 +51,17 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
SSAGraph *graph)
: SSAGraphExecutor(graph),
std::unique_ptr<SSAGraph> &&graph)
: SSAGraphExecutor(std::move(graph)),
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
local_scopes_(local_scopes),
places_(places),
fetch_ctxs_(places),
use_event_(use_event) {}
void Run(Scope *global_scope, const std::vector<std::string> &fetch_tensors,
const std::string &fetch_list_name) override {
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
std::unordered_set<OpHandleBase *> ready_ops;
......@@ -74,18 +75,18 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
};
// Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_.vars_) {
for (auto &var_map : graph_->vars_) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(version_pair.second);
}
}
}
for (auto &var : graph_.dep_vars_) {
for (auto &var : graph_->dep_vars_) {
InsertPendingVar(*var);
}
for (auto &op : graph_.ops_) {
for (auto &op : graph_->ops_) {
if (op->inputs_.empty()) { // Special case, Op has no input.
ready_ops.insert(op.get());
} else {
......@@ -101,7 +102,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_.vars_) {
for (auto &var_map : graph_->vars_) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
......@@ -182,8 +183,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
fetch_op.WaitAndMergeCPUTensors();
}
*global_scope->Var(fetch_list_name)->GetMutable<FeedFetchList>() =
fetch_data;
return fetch_data;
}
~ThreadedSSAGraphExecutor() {}
......@@ -240,8 +240,6 @@ class ParallelExecutorPrivate {
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
details::SSAGraph graph_;
std::unique_ptr<SSAGraphExecutor> executor_;
};
......@@ -274,10 +272,10 @@ ParallelExecutor::ParallelExecutor(
details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
params, member_->local_scopes_,
member_->nccl_ctxs_.get());
builder.Build(main_program, &member_->graph_);
auto graph = builder.Build(main_program);
member_->executor_.reset(new ThreadedSSAGraphExecutor(
num_threads, true, member_->local_scopes_, places, &member_->graph_));
num_threads, true, member_->local_scopes_, places, std::move(graph)));
// Step 3. Create vars in each scope;
for (auto *scope : member_->local_scopes_) {
......@@ -338,8 +336,9 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
member_->executor_->Run(member_->global_scope_, fetch_tensors,
fetched_var_name);
auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data;
}
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册