提交 d3ed070e 编写于 作者: S sneaxiy

test=develop

上级 fb6201e9
......@@ -64,8 +64,6 @@ ParallelExecutor::ParallelExecutor(
const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
size_t num_trainers, size_t trainer_id)
: member_(new ParallelExecutorPrivate(places)) {
is_alive_.test_and_set();
member_->global_scope_ = scope;
member_->use_cuda_ = exec_strategy.use_cuda_;
member_->use_all_reduce_ =
......@@ -248,15 +246,6 @@ void ParallelExecutor::BCastParamsToDevices(
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
// If ParallelExecutor has been destructed
// just return
if (!is_alive_.test_and_set()) return;
// If ParallelExecutor is running
if (is_running_.test_and_set()) {
PADDLE_THROW("The previous ParallelExecutor::Run() has not stopped");
}
platform::RecordBlock b(0);
#ifdef PADDLE_WITH_CUDA
if (!gcs_.empty()) {
......@@ -270,17 +259,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
}
#endif
try {
auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)
->GetMutable<FeedFetchList>() = fetch_data;
is_running_.clear();
} catch (...) {
is_running_.clear();
if (is_alive_.test_and_set()) {
std::rethrow_exception(std::current_exception());
}
}
auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data;
}
void ParallelExecutor::FeedTensorsIntoLocalScopes(
......@@ -318,7 +299,6 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
}
ParallelExecutor::~ParallelExecutor() {
is_alive_.clear();
if (member_->own_local_scope_) {
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
Scope *local_scope = member_->local_scopes_[i];
......@@ -328,10 +308,8 @@ ParallelExecutor::~ParallelExecutor() {
}
}
while (is_running_.test_and_set()) {
// wait unitl all threads have been stopped
}
// member_ must be destructed before gcs_ since the destructor of
// ReferenceCountOpHandle use raw pointers of gcs_ inside.
member_.reset();
}
......
......@@ -77,19 +77,6 @@ class ParallelExecutor {
std::unique_ptr<ParallelExecutorPrivate> member_;
// FIXME(zjl): HOT-FIX
// A flag to indicate whether ParallelExecutor is destructed.
// In Python side, when users interrupt the process manually, such as
// keyboard interrupt, ParallelExecutor may be destructed before Run() ends.
// Thus, disturbing exception messages would occur when interrupted.
// If is_alive_ is false, we would discard the last exception thrown by Run().
// Since std::atomic_flag is always lock-free and faster than
// std::atomic<bool>, we choose std::atomic_flag to be the flag here.
std::atomic_flag is_alive_ = ATOMIC_FLAG_INIT;
// A flag to indicate whether ParallelExecutor is running.
std::atomic_flag is_running_ = ATOMIC_FLAG_INIT;
#ifdef PADDLE_WITH_CUDA
// ref_cnts_ is only initialized when ParallelExecutor constructs, and then
// keeps unchanged
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册