未验证 提交 e336dc86 编写于 作者: C chengduo 提交者: GitHub

[Speed] Refine the Executor when the num_thread=1 (#17405)

Refine the Executor when the num_thread=1
上级 30e178fa
......@@ -43,7 +43,7 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
bootstrap_ops_.emplace_back(op);
}
}
PADDLE_ENFORCE_GT(op_deps_.size(), 0, "The graph doesn't have operators.");
PrepareAtomicOpDeps();
}
......@@ -52,26 +52,85 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
op_deps = atomic_op_deps_.get();
PrepareAtomicOpDeps();
size_t num_ops = op_deps->size();
paddle::framework::FeedFetchList fetches;
fetches.resize(fetch_tensors.size());
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
std::vector<FetchOpHandle *> fetch_ops;
std::vector<OpHandleBase *> fetch_ops;
std::vector<OpHandleBase *> ready_fetch_ops;
exception_.Clear();
InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
&fetch_ops, &ready_fetch_ops);
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
// If the num_threads is 1, we can record the order of operator's
// execution in the first iteration, and in subsequent iterations,
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
} else {
traced_ops_.clear();
remaining_ = 0;
auto complete_q = std::make_shared<BlockingQueue<size_t>>();
for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, complete_q);
}
for (auto op : ready_fetch_ops) {
RunOpAsync(op_deps.get(), op, complete_q);
}
size_t num_complete = 0;
while (num_complete != op_deps->size()) {
size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) {
int remaining = 0;
while (true) {
remaining = remaining_;
if (remaining == 0) {
break;
}
for (int i = 0; i < remaining; ++i) {
complete_q->Pop();
}
}
if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
}
num_complete += num_comp;
}
}
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
return fetches;
}
void FastThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>> *fetched_vars,
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops) {
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &var_map : graph_->Get<GraphVars>(kGraphVars)) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(*it->second.rbegin());
(*fetched_vars)[fetch_var_name].push_back(*it->second.rbegin());
}
}
}
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto fetched_var_it = fetched_vars.find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars.end(),
auto &var_name = fetch_tensors.at(i);
auto fetched_var_it = fetched_vars->find(var_name);
PADDLE_ENFORCE(fetched_var_it != fetched_vars->end(),
"Cannot find fetched variable(%s).(Perhaps the main_program "
"is not set to ParallelExecutor)",
var_name);
......@@ -80,8 +139,8 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, &fetches, i, &local_scopes_);
fetch_ops.emplace_back(op);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
......@@ -94,55 +153,22 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep;
if (dep == 0) {
ready_fetch_ops.emplace_back(op);
}
}
size_t num_complete = 0;
remaining_ = 0;
auto complete_q = std::make_shared<BlockingQueue<size_t>>();
for (auto op : bootstrap_ops_) {
RunOpAsync(op_deps.get(), op, complete_q);
}
for (auto op : ready_fetch_ops) {
RunOpAsync(op_deps.get(), op, complete_q);
}
while (num_complete != op_deps->size()) {
size_t num_comp = complete_q->Pop();
if (num_comp == -1UL) {
int remaining = 0;
while (true) {
remaining = remaining_;
if (remaining == 0) {
break;
}
for (int i = 0; i < remaining; ++i) {
complete_q->Pop();
}
}
if (exception_.IsCaught()) {
ClearFetchOp(graph_, &fetch_ops);
exception_.ReThrow();
}
ready_fetch_ops->emplace_back(op);
}
num_complete += num_comp;
}
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
return fetches;
}
bool FastThreadedSSAGraphExecutor::RunOp(
OpHandleBase *op, const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
size_t *complete) {
try {
RunOpSync(op);
if (LIKELY(!exception_.IsCaught())) {
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
RecordOps(op);
}
++(*complete);
return true;
} catch (...) {
exception_.Catch(std::current_exception());
} else {
--remaining_;
complete_q->Push(-1UL);
return false;
......@@ -194,6 +220,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
complete_q->Push(complete);
});
}
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
atomic_op_deps_ = prepare_pool_.enqueue([&] {
auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>;
......@@ -206,6 +233,44 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
}
const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
traced_ops_.emplace_back(op);
}
}
void FastThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops);
exception_.ReThrow();
}
void FastThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_.IsCaught()) {
return;
}
RunOpSync(op);
}
}
void FastThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
} catch (...) {
exception_.Catch(std::current_exception());
}
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -60,6 +60,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
::ThreadPool pool_;
::ThreadPool prepare_pool_;
std::vector<OpHandleBase *> traced_ops_;
bool RunOp(OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q,
size_t *complete);
......@@ -69,6 +71,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
void PrepareAtomicOpDeps();
inline void RecordOps(OpHandleBase *op);
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FeedFetchList *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>>
*fetched_vars,
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops);
};
} // namespace details
} // namespace framework
......
......@@ -19,10 +19,13 @@ namespace framework {
namespace details {
SSAGraphExecutor::~SSAGraphExecutor() {}
void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops) {
void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops) {
if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) {
PADDLE_ENFORCE_NOT_NULL(
dynamic_cast<FetchOpHandle*>(op),
"The input ops of ClearFetchOp function should be FetchOpHandle.");
for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var);
}
......
......@@ -38,7 +38,7 @@ class SSAGraphExecutor {
virtual FeedFetchList Run(const std::vector<std::string>& fetch_tensors) = 0;
};
void ClearFetchOp(ir::Graph* graph, std::vector<FetchOpHandle*>* fetch_ops);
void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops);
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -53,74 +53,84 @@ inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
std::unique_ptr<OpDependentData> op_deps = op_deps_futures_.get();
CopyOpDeps();
VLOG(10) << "ThreadedSSAGraphExecutor::Run";
std::shared_ptr<BlockingQueue<VarHandleBase *>> ready_vars(
new BlockingQueue<VarHandleBase *>);
auto &pending_ops = op_deps->pending_ops_;
auto &pending_vars = op_deps->pending_vars_;
auto &ready_ops = op_deps->ready_ops_;
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
// streams from multiple GPUs, it's faster to buffer them and schedule
// together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available.
std::unordered_set<OpHandleBase *> delayed_ops;
size_t num_ops = op_deps->num_ops_;
// Step 2. Insert FetchOps
std::vector<FetchOpHandle *> fetch_ops;
std::vector<OpHandleBase *> fetch_ops;
std::unordered_set<VarHandleBase *> fetch_dependencies;
FeedFetchList fetch_data(fetch_tensors.size());
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &ready_ops,
&pending_ops, &pending_vars, &fetch_data);
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
RunOp(ready_vars, op);
}
set.clear();
};
// Clean run context
run_op_futures_.clear();
exception_holder_.Clear();
event.reset(nullptr);
// Step 3. Execution
while (!pending_vars.empty()) {
// 1. Run All Ready ops
// Keep loop until all vars are ready.
run_all_ops(ready_ops);
// 2. Find ready variable
bool timeout;
auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) {
if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
// If the num_threads is 1, we can record the order of operator's
// execution in the first iteration, and in subsequent iterations,
// run the recorded operators directly. This strategy could make the
// execution faster.
VLOG(3) << "Run the traced ops.";
RunTracedOps(traced_ops_);
RunTracedOps(fetch_ops);
if (exception_holder_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
} else {
traced_ops_.clear();
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
for (auto *op : set) {
RunOp(ready_vars, op);
}
set.clear();
};
// Clean run context
run_op_futures_.clear();
while (!pending_vars.empty()) {
// 1. Run All Ready ops
// Keep loop until all vars are ready.
run_all_ops(ready_ops);
// 2. Find ready variable
bool timeout;
auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) {
for (auto &run_op_future : run_op_futures_) {
run_op_future.wait();
}
ClearFetchOp(graph_, &fetch_ops);
exception_holder_.ReThrow();
} else {
continue;
if (exception_holder_.IsCaught()) {
ExecutionFinal(&fetch_ops);
} else {
continue;
}
}
}
// 3. Remove the dependency of ready_var.
// Find the ready_ops after the ready_var.
for (auto ready_var : cur_ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->PendingOps()) {
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
ready_ops.insert(op);
// 3. Remove the dependency of ready_var.
// Find the ready_ops after the ready_var.
for (auto ready_var : cur_ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->PendingOps()) {
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
ready_ops.insert(op);
}
}
}
}
PADDLE_ENFORCE(ready_ops.empty());
}
PADDLE_ENFORCE(ready_ops.empty());
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
......@@ -137,7 +147,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops,
std::vector<OpHandleBase *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
......@@ -243,6 +253,9 @@ void ThreadedSSAGraphExecutor::PrepareOpDeps() {
InsertPendingOp(&pending_ops, op);
}
}
op_deps_->num_ops_ = ready_ops.size() + pending_ops.size();
PADDLE_ENFORCE_GT(op_deps_->num_ops_, 0, "The graph doesn't have operators.");
for (auto ready_var : ready_vars) {
pending_vars.erase(ready_var);
for (auto *op : ready_var->PendingOps()) {
......@@ -264,6 +277,7 @@ void ThreadedSSAGraphExecutor::CopyOpDeps() {
op_deps_->pending_vars_.end());
op_deps->ready_ops_.insert(op_deps_->ready_ops_.begin(),
op_deps_->ready_ops_.end());
op_deps->num_ops_ = op_deps_->num_ops_;
return std::unique_ptr<OpDependentData>(op_deps);
});
}
......@@ -272,25 +286,59 @@ void ThreadedSSAGraphExecutor::RunOp(
const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] {
RunOpSync(op);
try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << " Signal posted";
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
};
if (pool_) {
run_op_futures_.emplace_back(pool_->enqueue(op_run));
} else {
op_run();
}
RecordOps(op);
}
void ThreadedSSAGraphExecutor::RunTracedOps(
const std::vector<OpHandleBase *> &traced_ops) {
for (auto &op : traced_ops) {
if (exception_holder_.IsCaught()) {
return;
}
RunOpSync(op);
}
}
void ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
try {
if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
}
void ThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
VLOG(3) << "caught exception " << exception_holder_.Type() << ", rethrow it";
ClearFetchOp(graph_, fetch_ops);
exception_holder_.ReThrow();
}
void ThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
traced_ops_.emplace_back(op);
}
}
} // namespace details
} // namespace framework
......
......@@ -44,6 +44,7 @@ struct OpDependentData {
std::unordered_map<OpHandleBase *, size_t> pending_ops_;
std::unordered_set<VarHandleBase *> pending_vars_;
std::unordered_set<OpHandleBase *> ready_ops_;
size_t num_ops_{0};
};
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
......@@ -80,6 +81,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::list<std::future<void>> run_op_futures_;
::ThreadPool prepare_pool_;
std::unique_ptr<::ThreadPool> pool_;
std::vector<OpHandleBase *> traced_ops_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const;
......@@ -89,7 +91,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
VarHandleBase *var) const;
void InsertFetchOps(const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops,
std::vector<OpHandleBase *> *fetch_ops,
std::unordered_set<VarHandleBase *> *fetch_dependencies,
std::unordered_set<OpHandleBase *> *ready_ops,
std::unordered_map<OpHandleBase *, size_t> *pending_ops,
......@@ -97,7 +99,16 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
FeedFetchList *fetch_data);
void PrepareOpDeps();
void CopyOpDeps();
inline void RecordOps(OpHandleBase *op);
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
inline void RunOpSync(OpHandleBase *op);
void RunTracedOps(const std::vector<OpHandleBase *> &traced_ops);
};
} // namespace details
......
......@@ -45,7 +45,8 @@ class TestFetchAndFeed(unittest.TestCase):
def parallel_exe(self,
use_cuda,
run_parallel_exe,
use_experimental_executor=False,
use_faster_executor=False,
num_threads=4,
seed=1):
main_program = fluid.Program()
startup = fluid.Program()
......@@ -72,7 +73,8 @@ class TestFetchAndFeed(unittest.TestCase):
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = use_experimental_executor
exec_strategy.use_experimental_executor = use_faster_executor
exec_strategy.num_threads = num_threads
train_cp = compiler.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
......@@ -143,24 +145,25 @@ class TestFetchAndFeed(unittest.TestCase):
if batch_id == 2:
break
def test_fetch_with_threaded_executor(self):
if core.is_compiled_with_cuda():
self.parallel_exe(
use_cuda=True,
run_parallel_exe=self.run_parallel_exe_with_fetch)
self.parallel_exe(
use_cuda=False, run_parallel_exe=self.run_parallel_exe_with_fetch)
def test_fetch_with_fast_threaded_executor(self):
def check_executor(self, use_faster_executor=False, num_threads=4):
if core.is_compiled_with_cuda():
self.parallel_exe(
use_cuda=True,
run_parallel_exe=self.run_parallel_exe_with_fetch,
use_experimental_executor=True)
use_faster_executor=use_faster_executor,
num_threads=num_threads)
self.parallel_exe(
use_cuda=False,
run_parallel_exe=self.run_parallel_exe_with_fetch,
use_experimental_executor=True)
use_faster_executor=use_faster_executor,
num_threads=num_threads)
def test_fetch(self):
for use_faster_executor in {True, False}:
self.check_executor(
use_faster_executor=use_faster_executor, num_threads=4)
self.check_executor(
use_faster_executor=use_faster_executor, num_threads=1)
def test_feed(self):
if core.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册