未验证 提交 2265d091 编写于 作者: C chengduo 提交者: GitHub

Fix threaded executor bug (#16508)

* fix threaded executor bug
test=develop

* change the order of class member
test=develop

* Fix Travis CI
test=develop
上级 f7f5044b
...@@ -31,9 +31,10 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -31,9 +31,10 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
local_scopes_(local_scopes), local_scopes_(local_scopes),
places_(places), places_(places),
graph_(graph), graph_(graph),
fetch_ctxs_(places),
pool_(strategy.num_threads_), pool_(strategy.num_threads_),
prepare_pool_(1), // add one more thread for generate op_deps // add one more thread for generate op_deps
fetch_ctxs_(places) { prepare_pool_(1) {
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
op_deps_.emplace(op, dep); op_deps_.emplace(op, dep);
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#pragma once #pragma once
#include <ThreadPool.h> #include <ThreadPool.h>
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h" #include "paddle/fluid/framework/details/exception_holder.h"
...@@ -37,6 +39,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -37,6 +39,8 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
const ir::Graph &Graph() const override; const ir::Graph &Graph() const override;
private: private:
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
// be destroyed first.
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
...@@ -45,21 +49,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -45,21 +49,22 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unordered_map<OpHandleBase *, int> op_deps_; std::unordered_map<OpHandleBase *, int> op_deps_;
std::vector<OpHandleBase *> bootstrap_ops_; std::vector<OpHandleBase *> bootstrap_ops_;
::ThreadPool pool_;
::ThreadPool prepare_pool_;
platform::DeviceContextPool fetch_ctxs_; platform::DeviceContextPool fetch_ctxs_;
std::atomic<int> remaining_; std::atomic<int> remaining_;
std::future<
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
atomic_op_deps_;
ExceptionHolder exception_;
::ThreadPool pool_;
::ThreadPool prepare_pool_;
void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, void RunOpAsync(std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op, OpHandleBase *op,
const std::shared_ptr<BlockingQueue<size_t>> &complete_q); const std::shared_ptr<BlockingQueue<size_t>> &complete_q);
void PrepareAtomicOpDeps(); void PrepareAtomicOpDeps();
std::future<
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>>
atomic_op_deps_;
ExceptionHolder exception_;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -24,13 +24,13 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ...@@ -24,13 +24,13 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph) const std::vector<platform::Place> &places, ir::Graph *graph)
: graph_(graph), : graph_(graph),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr),
prepare_pool_(1),
local_scopes_(local_scopes), local_scopes_(local_scopes),
places_(places), places_(places),
fetch_ctxs_(places), fetch_ctxs_(places),
strategy_(strategy) { strategy_(strategy),
prepare_pool_(1),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr) {
PrepareOpDeps(); PrepareOpDeps();
CopyOpDeps(); CopyOpDeps();
} }
......
...@@ -63,13 +63,20 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -63,13 +63,20 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op); details::OpHandleBase *op);
private: private:
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
// be destroyed first.
ir::Graph *graph_; ir::Graph *graph_;
std::unique_ptr<::ThreadPool> pool_;
::ThreadPool prepare_pool_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_; platform::DeviceContextPool fetch_ctxs_;
ExceptionHolder exception_holder_; ExceptionHolder exception_holder_;
std::unique_ptr<OpDependentData> op_deps_;
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
ExecutionStrategy strategy_;
// use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_;
::ThreadPool prepare_pool_;
std::unique_ptr<::ThreadPool> pool_;
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops, void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
OpHandleBase *op_instance) const; OpHandleBase *op_instance) const;
...@@ -88,14 +95,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -88,14 +95,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
void PrepareOpDeps(); void PrepareOpDeps();
void CopyOpDeps(); void CopyOpDeps();
private:
std::future<std::unique_ptr<OpDependentData>> op_deps_futures_;
ExecutionStrategy strategy_;
std::unique_ptr<OpDependentData> op_deps_;
// use std::list because clear(), push_back, and for_each are O(1)
std::list<std::future<void>> run_op_futures_;
}; };
} // namespace details } // namespace details
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册