未验证 提交 60e0d1aa 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #10023 from reyoung/feature/DtorOfPE

Correctly implement destructor of ParallelExecutor
...@@ -44,6 +44,7 @@ class ParallelExecutorPrivate { ...@@ -44,6 +44,7 @@ class ParallelExecutorPrivate {
#endif #endif
std::vector<std::tuple<std::string, proto::VarType::Type, bool>> var_types_; std::vector<std::tuple<std::string, proto::VarType::Type, bool>> var_types_;
bool own_local_scope;
}; };
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() { std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
...@@ -63,11 +64,13 @@ ParallelExecutor::ParallelExecutor( ...@@ -63,11 +64,13 @@ ParallelExecutor::ParallelExecutor(
// Step 1. Bcast the params to devs. // Step 1. Bcast the params to devs.
// Create local scopes // Create local scopes
if (local_scopes.empty()) { if (local_scopes.empty()) {
member_->own_local_scope = true;
member_->local_scopes_.emplace_back(member_->global_scope_); member_->local_scopes_.emplace_back(member_->global_scope_);
for (size_t i = 1; i < member_->places_.size(); ++i) { for (size_t i = 1; i < member_->places_.size(); ++i) {
member_->local_scopes_.emplace_back(&scope->NewScope()); member_->local_scopes_.emplace_back(&scope->NewScope());
} }
} else { } else {
member_->own_local_scope = false;
PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size()); PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size());
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
member_->local_scopes_.emplace_back(local_scopes[i]); member_->local_scopes_.emplace_back(local_scopes[i]);
...@@ -231,5 +234,13 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -231,5 +234,13 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
} }
} }
ParallelExecutor::~ParallelExecutor() {
if (member_->own_local_scope) {
for (size_t i = 1; i < member_->local_scopes_.size(); ++i) {
member_->global_scope_->DeleteScope(member_->local_scopes_[i]);
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -42,6 +42,8 @@ class ParallelExecutor { ...@@ -42,6 +42,8 @@ class ParallelExecutor {
const std::vector<Scope*>& local_scopes, const std::vector<Scope*>& local_scopes,
bool allow_op_delay); bool allow_op_delay);
~ParallelExecutor();
std::vector<Scope*>& GetLocalScopes(); std::vector<Scope*>& GetLocalScopes();
/** /**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册