未验证 提交 e336dc86 编写于 作者: C chengduo 提交者: GitHub

[Speed] Refine the Executor when the num_thread=1 (#17405)

Refine the Executor when the num_thread=1
上级 30e178fa
...@@ -43,7 +43,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -43,7 +43,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
bootstrap_ops_.emplace_back(op); bootstrap_ops_.emplace_back(op);
} }
} }
PADDLE_ENFORCE_GT(op_deps_.size(), 0, "The graph doesn't have operators.");
PrepareAtomicOpDeps(); PrepareAtomicOpDeps();
} }
...@@ -52,53 +52,31 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -52,53 +52,31 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>> std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
op_deps = atomic_op_deps_.get(); op_deps = atomic_op_deps_.get();
PrepareAtomicOpDeps(); PrepareAtomicOpDeps();
size_t num_ops = op_deps->size();
paddle::framework::FeedFetchList fetches; paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size()); fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<FetchOpHandle *> fetch_ops; std::vector<OpHandleBase *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops; std::vector<OpHandleBase *> ready_fetch_ops;
exception_.Clear();
for (auto &fetch_var_name : fetch_tensors) { InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) { &fetch_ops, &ready_fetch_ops);
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
}
}
}
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
"Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)",
var_name);
auto &vars = fetched_var_it->second;
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_);
fetch_ops.emplace_back(op);
for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
}
for (auto *var : vars) {
op->AddInput(var);
}
int dep = static_cast<int>(op->NotReadyInputSize()); if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
(*op_deps)[op] = dep; // If the num_threads is 1, we can record the order of operator's
if (dep == 0) { // execution in the first iteration, and in subsequent iterations,
ready_fetch_ops.emplace_back(op); // run the recorded operators directly. This strategy could make the
} // execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops);
} }
} else {
size_t num_complete = 0; traced_ops_.clear();
remaining_ = 0; remaining_ = 0;
auto complete_q = std::make_shared<BlockingQueue<size_t>>(); auto complete_q = std::make_shared<BlockingQueue<size_t>>();
for (auto op : bootstrap_ops_) { for (auto op : bootstrap_ops_) {
...@@ -107,6 +85,8 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -107,6 +85,8 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
for (auto op : ready_fetch_ops) { for (auto op : ready_fetch_ops) {
RunOpAsync(op_deps.get(), op, complete_q); RunOpAsync(op_deps.get(), op, complete_q);
} }
size_t num_complete = 0;
while (num_complete != op_deps->size()) { while (num_complete != op_deps->size()) {
size_t num_comp = complete_q->Pop(); size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) { if (num_comp == -1UL) {
...@@ -121,28 +101,74 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -121,28 +101,74 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
} }
} }
if (exception_.IsCaught()) { if (exception_.IsCaught()) {
ClearFetchOp(graph_, &fetch_ops); ExecutionFinal(&fetch_ops);
exception_.ReThrow();
} }
} }
num_complete += num_comp; num_complete += num_comp;
} }
}
// Wait FetchOps. // Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops); ClearFetchOp(graph_, &fetch_ops);
return fetches; return fetches;
} }
void FastThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>> *fetched_vars,
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops) {
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<GraphVars>(kGraphVars)) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
(*fetched_vars)[fetch_var_name].push_back(*it->second.rbegin());
}
}
}
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors.at(i);
auto fetched_var_it = fetched_vars->find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars->end(),
"Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)",
var_name);
auto &vars = fetched_var_it->second;
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
}
for (auto *var : vars) {
op->AddInput(var);
}
int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep;
if (dep == 0) {
ready_fetch_ops->emplace_back(op);
}
}
}
bool FastThreadedSSAGraphExecutor::RunOp( bool FastThreadedSSAGraphExecutor::RunOp(
OpHandleBase *op, const std::shared_ptr<BlockingQueue<size_t>> &complete_q, OpHandleBase *op, const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
size_t *complete) { size_t *complete) {
try { RunOpSync(op);
if (LIKELY(!exception_.IsCaught())) {
if (LIKELY(!strategy_.dry_run_)) { if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_); RecordOps(op);
} }
++(*complete); ++(*complete);
return true; return true;
} catch (...) { } else {
exception_.Catch(std::current_exception());
--remaining_; --remaining_;
complete_q->Push(-1UL); complete_q->Push(-1UL);
return false; return false;
...@@ -194,6 +220,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -194,6 +220,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
complete_q->Push(complete); complete_q->Push(complete);
}); });
} }
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
atomic_op_deps_ = prepare_pool_.enqueue([&] { atomic_op_deps_ = prepare_pool_.enqueue([&] {
auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>; auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>;
...@@ -206,6 +233,44 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { ...@@ -206,6 +233,44 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
} }
const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; } const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
traced_ops_.emplace_back(op);
}
}
void FastThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops);
exception_.ReThrow();
}
void FastThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_.IsCaught()) {
return;
}
RunOpSync(op);
}
}
void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
} catch (...) {
exception_.Catch(std::current_exception());
}
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -60,6 +60,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -60,6 +60,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
::ThreadPool pool_; ::ThreadPool pool_;
::ThreadPool prepare_pool_; ::ThreadPool prepare_pool_;
std::vector<OpHandleBase *> traced_ops_;
bool RunOp(OpHandleBase *op, bool RunOp(OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q, const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
size_t *complete); size_t *complete);
...@@ -69,6 +71,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -69,6 +71,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
const std::shared_ptr<BlockingQueue<size_t>> &complete_q); const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
void PrepareAtomicOpDeps(); void PrepareAtomicOpDeps();
inline void RecordOps(OpHandleBase *op);
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>>
*fetched_vars,
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops);
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -19,10 +19,13 @@ namespace framework { ...@@ -19,10 +19,13 @@ namespace framework {
namespace details { namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {} SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) { void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops) {
if (fetch_ops->empty()) return; if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) { for (auto& op : *fetch_ops) {
PADDLE_ENFORCE_NOT_NULL(
dynamic_cast<FetchOpHandle*>(op),
"The input ops of ClearFetchOp function should be FetchOpHandle.");
for (auto& out_var : op->Node()->outputs) { for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var); graph->RemoveNode(out_var);
} }
......
...@@ -38,7 +38,7 @@ class SSAGraphExecutor { ...@@ -38,7 +38,7 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0; virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
}; };
void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops); void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops);
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -53,27 +53,40 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -53,27 +53,40 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare")); new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get(); std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get();
CopyOpDeps(); CopyOpDeps();
VLOG(10) << "ThreadedSSAGraphExecutor::Run"; VLOG(10) << "ThreadedSSAGraphExecutor::Run";
std::shared_ptr<BlockingQueue<VarHandleBase *>> ready_vars( std::shared_ptr<BlockingQueue<VarHandleBase *>> ready_vars(
new BlockingQueue<VarHandleBase *>); new BlockingQueue<VarHandleBase *>);
auto &pending_ops = op_deps->pending_ops_; auto &pending_ops = op_deps->pending_ops_;
auto &pending_vars = op_deps->pending_vars_; auto &pending_vars = op_deps->pending_vars_;
auto &ready_ops = op_deps->ready_ops_; auto &ready_ops = op_deps->ready_ops_;
size_t num_ops = op_deps->num_ops_;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule
// together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops;
// Step 2. Insert FetchOps // Step 2. Insert FetchOps
std::vector<FetchOpHandle *> fetch_ops; std::vector<OpHandleBase *> fetch_ops;
std::unordered_set<VarHandleBase *> fetch_dependencies; std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size()); FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops, InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops,
&pending_ops, &pending_vars, &fetch_data); &pending_ops, &pending_vars, &fetch_data);
exception_holder_.Clear();
event.reset(nullptr);
// Step 3. Execution
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
// If the num_threads is 1, we can record the order of operator's
// execution in the first iteration, and in subsequent iterations,
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_holder_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
} else {
traced_ops_.clear();
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) { auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) { for (auto *op : set) {
RunOp(ready_vars, op); RunOp(ready_vars, op);
...@@ -82,9 +95,7 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -82,9 +95,7 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
}; };
// Clean run context // Clean run context
run_op_futures_.clear(); run_op_futures_.clear();
exception_holder_.Clear();
event.reset(nullptr);
// Step 3. Execution
while (!pending_vars.empty()) { while (!pending_vars.empty()) {
// 1. Run All Ready ops // 1. Run All Ready ops
// Keep loop until all vars are ready. // Keep loop until all vars are ready.
...@@ -94,14 +105,11 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -94,14 +105,11 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
bool timeout; bool timeout;
auto cur_ready_vars = ready_vars->PopAll(1, &timeout); auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) { if (timeout) {
if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
for (auto &run_op_future : run_op_futures_) { for (auto &run_op_future : run_op_futures_) {
run_op_future.wait(); run_op_future.wait();
} }
ClearFetchOp(graph_, &fetch_ops); if (exception_holder_.IsCaught()) {
exception_holder_.ReThrow(); ExecutionFinal(&fetch_ops);
} else { } else {
continue; continue;
} }
...@@ -121,6 +129,8 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl( ...@@ -121,6 +129,8 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
} }
} }
PADDLE_ENFORCE(ready_ops.empty()); PADDLE_ENFORCE(ready_ops.empty());
}
// Wait FetchOps. // Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops); ClearFetchOp(graph_, &fetch_ops);
...@@ -137,7 +147,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -137,7 +147,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops, std::vector<OpHandleBase *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops, std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
...@@ -243,6 +253,9 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() { ...@@ -243,6 +253,9 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() {
InsertPendingOp(&pending_ops, op); InsertPendingOp(&pending_ops, op);
} }
} }
op_deps_->num_ops_ = ready_ops.size() + pending_ops.size();
PADDLE_ENFORCE_GT(op_deps_->num_ops_, 0, "The graph doesn't have operators.");
for (auto ready_var : ready_vars) { for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var); pending_vars.erase(ready_var);
for (auto *op : ready_var->PendingOps()) { for (auto *op : ready_var->PendingOps()) {
...@@ -264,6 +277,7 @@ void ThreadedSSAGraphExecutor::CopyOpDeps() { ...@@ -264,6 +277,7 @@ void ThreadedSSAGraphExecutor::CopyOpDeps() {
op_deps_->pending_vars_.end()); op_deps_->pending_vars_.end());
op_deps->ready_ops_.insert(op_deps_->ready_ops_.begin(), op_deps->ready_ops_.insert(op_deps_->ready_ops_.begin(),
op_deps_->ready_ops_.end()); op_deps_->ready_ops_.end());
op_deps->num_ops_ = op_deps_->num_ops_;
return std::unique_ptr<OpDependentData>(op_deps); return std::unique_ptr<OpDependentData>(op_deps);
}); });
} }
...@@ -272,25 +286,59 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -272,25 +286,59 @@ void ThreadedSSAGraphExecutor::RunOp(
const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q, const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op) { details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
RunOpSync(op);
try { try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
ready_var_q->Extend(op->Outputs()); ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << " Signal posted"; VLOG(10) << op << " " << op->Name() << " Signal posted";
} catch (...) { } catch (...) {
exception_holder_.Catch(std::current_exception()); exception_holder_.Catch(std::current_exception());
} }
}; };
if (pool_) { if (pool_) {
run_op_futures_.emplace_back(pool_->enqueue(op_run)); run_op_futures_.emplace_back(pool_->enqueue(op_run));
} else { } else {
op_run(); op_run();
} }
RecordOps(op);
}
void ThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_holder_.IsCaught()) {
return;
}
RunOpSync(op);
}
}
void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
}
void ThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_holder_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops);
exception_holder_.ReThrow();
}
void ThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
traced_ops_.emplace_back(op);
}
} }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -44,6 +44,7 @@ struct OpDependentData { ...@@ -44,6 +44,7 @@ struct OpDependentData {
std::unordered_map<OpHandleBase *, size_t> pending_ops_; std::unordered_map<OpHandleBase *, size_t> pending_ops_;
std::unordered_set<VarHandleBase *> pending_vars_; std::unordered_set<VarHandleBase *> pending_vars_;
std::unordered_set<OpHandleBase *> ready_ops_; std::unordered_set<OpHandleBase *> ready_ops_;
size_t num_ops_{0};
}; };
class ThreadedSSAGraphExecutor : public SSAGraphExecutor { class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
...@@ -80,6 +81,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -80,6 +81,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::list<std::future<void>> run_op_futures_; std::list<std::future<void>> run_op_futures_;
::ThreadPool prepare_pool_; ::ThreadPool prepare_pool_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
std::vector<OpHandleBase *> traced_ops_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops, void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const; OpHandleBase *op_instance) const;
...@@ -89,7 +91,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -89,7 +91,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
VarHandleBase *var) const; VarHandleBase *var) const;
void InsertFetchOps(const std::vector<std::string> &fetch_tensors, void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops, std::vector<OpHandleBase *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies, std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops, std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops, std::unordered_map<OpHandleBase *, size_t> *pending_ops,
...@@ -97,7 +99,16 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -97,7 +99,16 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList *fetch_data); FeedFetchList *fetch_data);
void PrepareOpDeps(); void PrepareOpDeps();
void CopyOpDeps(); void CopyOpDeps();
inline void RecordOps(OpHandleBase *op);
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
}; };
} // namespace details } // namespace details
......
...@@ -45,7 +45,8 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -45,7 +45,8 @@ class TestFetchAndFeed(unittest.TestCase):
def parallel_exe(self, def parallel_exe(self,
use_cuda, use_cuda,
run_parallel_exe, run_parallel_exe,
use_experimental_executor=False, use_faster_executor=False,
num_threads=4,
seed=1): seed=1):
main_program = fluid.Program() main_program = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
...@@ -72,7 +73,8 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -72,7 +73,8 @@ class TestFetchAndFeed(unittest.TestCase):
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
build_strategy.memory_optimize = False build_strategy.memory_optimize = False
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = use_experimental_executor exec_strategy.use_experimental_executor = use_faster_executor
exec_strategy.num_threads = num_threads
train_cp = compiler.CompiledProgram(main_program).with_data_parallel( train_cp = compiler.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
build_strategy=build_strategy, build_strategy=build_strategy,
...@@ -143,24 +145,25 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -143,24 +145,25 @@ class TestFetchAndFeed(unittest.TestCase):
if batch_id == 2: if batch_id == 2:
break break
def test_fetch_with_threaded_executor(self): def check_executor(self, use_faster_executor=False, num_threads=4):
if core.is_compiled_with_cuda():
self.parallel_exe(
use_cuda=True,
run_parallel_exe=self.run_parallel_exe_with_fetch)
self.parallel_exe(
use_cuda=False, run_parallel_exe=self.run_parallel_exe_with_fetch)
def test_fetch_with_fast_threaded_executor(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.parallel_exe( self.parallel_exe(
use_cuda=True, use_cuda=True,
run_parallel_exe=self.run_parallel_exe_with_fetch, run_parallel_exe=self.run_parallel_exe_with_fetch,
use_experimental_executor=True) use_faster_executor=use_faster_executor,
num_threads=num_threads)
self.parallel_exe( self.parallel_exe(
use_cuda=False, use_cuda=False,
run_parallel_exe=self.run_parallel_exe_with_fetch, run_parallel_exe=self.run_parallel_exe_with_fetch,
use_experimental_executor=True) use_faster_executor=use_faster_executor,
num_threads=num_threads)
def test_fetch(self):
for use_faster_executor in {True, False}:
self.check_executor(
use_faster_executor=use_faster_executor, num_threads=4)
self.check_executor(
use_faster_executor=use_faster_executor, num_threads=1)
def test_feed(self): def test_feed(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册