提交 e3144393 编写于 作者: Y Yu Yang

Extract Executors to indie modules

上级 c70b60dd
...@@ -89,8 +89,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope ...@@ -89,8 +89,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method) framework_proto backward glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor)
backward glog lod_rank_table simple_threadpool multi_devices_graph_builder fetch_op_handle)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -10,3 +10,6 @@ cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) ...@@ -10,3 +10,6 @@ cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
nccl_all_reduce_op_handle scale_loss_grad_op_handle) nccl_all_reduce_op_handle scale_loss_grad_op_handle)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
namespace paddle {
namespace framework {
namespace details {
SSAGraphExecutor::SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph)
: graph_(std::move(graph)) {}
SSAGraphExecutor::~SSAGraphExecutor() {}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
namespace paddle {
namespace framework {
namespace details {
class SSAGraphExecutor {
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
public:
// Steal graph inside
explicit SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph);
virtual ~SSAGraphExecutor();
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
protected:
std::unique_ptr<SSAGraph> graph_;
};
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
size_t num_threads, bool use_event,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph)
: SSAGraphExecutor(std::move(graph)),
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
local_scopes_(local_scopes),
places_(places),
fetch_ctxs_(places),
use_event_(use_event) {}
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
std::unordered_set<OpHandleBase *> ready_ops;
auto InsertPendingVar = [&pending_vars](VarHandleBase &var) {
pending_vars[&var] = var.generated_op_ == nullptr;
};
auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
pending_ops.insert({&op_instance, op_instance.inputs_.size()});
};
// Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->vars_) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(version_pair.second);
}
}
}
for (auto &var : graph_->dep_vars_) {
InsertPendingVar(*var);
}
for (auto &op : graph_->ops_) {
if (op->inputs_.empty()) { // Special case, Op has no input.
ready_ops.insert(op.get());
} else {
InsertPendingOp(*op);
}
}
// Step 2. Insert FetchOps
std::vector<FetchOpHandle> fetch_ops;
std::vector<DummyVarHandle> dummy_vars;
FeedFetchList fetch_data(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->vars_) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
}
}
}
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars[var_name];
fetch_ops.emplace_back(&fetch_data, i, &local_scopes_);
details::FetchOpHandle *op = &fetch_ops.back();
// FIXME: Use new device context
for (auto &p : places_) {
op->dev_ctx_[p] = fetch_ctxs_.Get(p);
}
for (auto *var : vars) {
op->AddInput(var);
}
dummy_vars.emplace_back();
auto *var = &dummy_vars.back();
var->generated_op_ = nullptr;
op->AddOutput(var);
InsertPendingVar(*var);
InsertPendingOp(*op);
}
auto run_all_ready_ops = [&] {
for (auto *op : ready_ops) {
RunOp(pending_vars, op);
}
ready_ops.clear();
};
// Step 3. Execution
while (!pending_vars.empty()) {
// 1. Run All Ready ops
run_all_ready_ops();
// 2. Find ready variable
VarHandleBase *ready_var = nullptr;
for (auto &pair : pending_vars) {
if (pair.second.load(std::memory_order_acquire)) {
ready_var = pair.first;
break;
}
}
// if there is no variable ready
if (ready_var == nullptr) {
// FIXME use conditional var instead of busy wait.
// if there is an exception, throw it
if (exception_) {
throw * exception_;
}
// keep waiting the ready variables
continue;
}
// 3. Remove the dependency of ready_var.
// Find the ready_ops after the ready_var.
pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) {
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
ready_ops.insert(op);
}
}
// Keep loop until all vars are ready.
}
// Wait FetchOps.
for (auto &fetch_op : fetch_ops) {
fetch_op.WaitAndMergeCPUTensors();
}
return fetch_data;
}
void ThreadedSSAGraphExecutor::RunOp(
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
details::OpHandleBase *op) {
std::vector<std::atomic<bool> *> *ready_buffer =
new std::vector<std::atomic<bool> *>();
for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]);
}
auto op_run = [ready_buffer, op, this] {
try {
VLOG(10) << op->DebugString();
op->Run(use_event_);
for (auto *ready : *ready_buffer) {
ready->store(true, std::memory_order_release);
}
delete ready_buffer;
} catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
LOG(FATAL) << "Unknown exception catched";
}
};
if (pool_) {
pool_->enqueue(op_run);
} else {
op_run();
}
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
namespace paddle {
namespace framework {
class Scope;
namespace details {
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph);
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
~ThreadedSSAGraphExecutor() {}
private:
void RunOp(
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
details::OpHandleBase *op);
private:
std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_;
const bool use_event_;
std::unique_ptr<platform::EnforceNotMet> exception_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -13,221 +13,17 @@ See the License for the specific language governing permissions and ...@@ -13,221 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include "ThreadPool.h" #include "ThreadPool.h"
#include "lod_tensor.h"
#include "op_registry.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using details::DummyVarHandle;
using details::FetchOpHandle;
using details::OpHandleBase;
using details::SSAGraph;
using details::VarHandleBase;
class SSAGraphExecutor {
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
public:
// Steal graph inside
explicit SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph)
: graph_(std::move(graph)) {}
virtual ~SSAGraphExecutor() {}
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
protected:
std::unique_ptr<SSAGraph> graph_;
};
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph)
: SSAGraphExecutor(std::move(graph)),
pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
local_scopes_(local_scopes),
places_(places),
fetch_ctxs_(places),
use_event_(use_event) {}
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
std::unordered_set<OpHandleBase *> ready_ops;
auto InsertPendingVar = [&pending_vars](VarHandleBase &var) {
pending_vars[&var] = var.generated_op_ == nullptr;
};
auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
pending_ops.insert({&op_instance, op_instance.inputs_.size()});
};
// Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->vars_) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(version_pair.second);
}
}
}
for (auto &var : graph_->dep_vars_) {
InsertPendingVar(*var);
}
for (auto &op : graph_->ops_) {
if (op->inputs_.empty()) { // Special case, Op has no input.
ready_ops.insert(op.get());
} else {
InsertPendingOp(*op);
}
}
// Step 2. Insert FetchOps
std::vector<FetchOpHandle> fetch_ops;
std::vector<DummyVarHandle> dummy_vars;
FeedFetchList fetch_data(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->vars_) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
}
}
}
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars[var_name];
fetch_ops.emplace_back(&fetch_data, i, &local_scopes_);
details::FetchOpHandle *op = &fetch_ops.back();
// FIXME: Use new device context
for (auto &p : places_) {
op->dev_ctx_[p] = fetch_ctxs_.Get(p);
}
for (auto *var : vars) {
op->AddInput(var);
}
dummy_vars.emplace_back();
auto *var = &dummy_vars.back();
var->generated_op_ = nullptr;
op->AddOutput(var);
InsertPendingVar(*var);
InsertPendingOp(*op);
}
auto run_all_ready_ops = [&] {
for (auto *op : ready_ops) {
RunOp(pending_vars, op);
}
ready_ops.clear();
};
// Step 3. Execution
while (!pending_vars.empty()) {
// 1. Run All Ready ops
run_all_ready_ops();
// 2. Find ready variable
VarHandleBase *ready_var = nullptr;
for (auto &pair : pending_vars) {
if (pair.second.load(std::memory_order_acquire)) {
ready_var = pair.first;
break;
}
}
// if there is no variable ready
if (ready_var == nullptr) {
// FIXME use conditional var instead of busy wait.
// if there is an exception, throw it
if (exception_) {
throw * exception_;
}
// keep waiting the ready variables
continue;
}
// 3. Remove the dependency of ready_var.
// Find the ready_ops after the ready_var.
pending_vars.erase(ready_var);
for (auto *op : ready_var->pending_ops_) {
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
ready_ops.insert(op);
}
}
// Keep loop until all vars are ready.
}
// Wait FetchOps.
for (auto &fetch_op : fetch_ops) {
fetch_op.WaitAndMergeCPUTensors();
}
return fetch_data;
}
~ThreadedSSAGraphExecutor() {}
private:
void RunOp(
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
details::OpHandleBase *op) {
std::vector<std::atomic<bool> *> *ready_buffer =
new std::vector<std::atomic<bool> *>();
for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]);
}
auto op_run = [ready_buffer, op, this] {
try {
VLOG(10) << op->DebugString();
op->Run(use_event_);
for (auto *ready : *ready_buffer) {
ready->store(true, std::memory_order_release);
}
delete ready_buffer;
} catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
LOG(FATAL) << "Unknown exception catched";
}
};
if (pool_) {
pool_->enqueue(op_run);
} else {
op_run();
}
}
private:
std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_;
const bool use_event_;
std::unique_ptr<platform::EnforceNotMet> exception_;
};
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places) explicit ParallelExecutorPrivate(const std::vector<platform::Place> &places)
...@@ -239,8 +35,7 @@ class ParallelExecutorPrivate { ...@@ -239,8 +35,7 @@ class ParallelExecutorPrivate {
Scope *global_scope_; Scope *global_scope_;
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
std::unique_ptr<details::SSAGraphExecutor> executor_;
std::unique_ptr<SSAGraphExecutor> executor_;
}; };
ParallelExecutor::ParallelExecutor( ParallelExecutor::ParallelExecutor(
...@@ -274,7 +69,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -274,7 +69,7 @@ ParallelExecutor::ParallelExecutor(
member_->nccl_ctxs_.get()); member_->nccl_ctxs_.get());
auto graph = builder.Build(main_program); auto graph = builder.Build(main_program);
member_->executor_.reset(new ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
num_threads, true, member_->local_scopes_, places, std::move(graph))); num_threads, true, member_->local_scopes_, places, std::move(graph)));
// Step 3. Create vars in each scope; // Step 3. Create vars in each scope;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册