提交 ad73b331 编写于 作者: Y Yu Yang 提交者: qingqing01

Eagerly drop local scope in iteration (#9838)

* Eagerly drop local scope in iteration

* Correct create var

* Fix typo

* Debug
上级 8d4d6eae
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -33,7 +35,7 @@ void ComputationOpHandle::RunImpl() { ...@@ -33,7 +35,7 @@ void ComputationOpHandle::RunImpl() {
} }
} }
op_->Run(*scope_->FindVar("@TMP_SCOPE@")->Get<Scope *>(), place_); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
} }
std::string ComputationOpHandle::Name() const { return op_->Type(); } std::string ComputationOpHandle::Name() const { return op_->Type(); }
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#include "paddle/fluid/framework/details/fetch_op_handle.h" #include "paddle/fluid/framework/details/fetch_op_handle.h"
#include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -57,7 +60,10 @@ void FetchOpHandle::RunImpl() { ...@@ -57,7 +60,10 @@ void FetchOpHandle::RunImpl() {
for (size_t i = 0; i < scopes.size(); ++i) { for (size_t i = 0; i < scopes.size(); ++i) {
auto &scope = scopes[i]; auto &scope = scopes[i];
auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>(); auto &t = scope->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(var_name)
->Get<framework::LoDTensor>();
if (platform::is_gpu_place(var->place_)) { if (platform::is_gpu_place(var->place_)) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]); TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]);
......
...@@ -24,6 +24,8 @@ namespace paddle { ...@@ -24,6 +24,8 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
class OpHandleBase { class OpHandleBase {
private: private:
DISABLE_COPY_AND_ASSIGN(OpHandleBase); DISABLE_COPY_AND_ASSIGN(OpHandleBase);
......
...@@ -15,13 +15,15 @@ ...@@ -15,13 +15,15 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h" #include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
class SSAGraphExecutor { class SSAGraphExecutor {
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor); DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
......
...@@ -136,12 +136,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -136,12 +136,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
ready_ops.clear(); ready_ops.clear();
}; };
// Create local scopes.
for (auto &scope : local_scopes_) {
auto &local_scope = scope->NewScope();
*scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>() = &local_scope;
}
// Step 3. Execution // Step 3. Execution
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) { while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
// 1. Run All Ready ops // 1. Run All Ready ops
...@@ -189,34 +183,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -189,34 +183,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
PADDLE_ENFORCE(ready_ops.empty()); PADDLE_ENFORCE(ready_ops.empty());
PADDLE_ENFORCE(delayed_ops.empty()); PADDLE_ENFORCE(delayed_ops.empty());
PADDLE_ENFORCE(blocked_by_delayed_ops.empty()); PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
++computation_count_;
auto sync_computation = [&] {
computation_count_ = 0;
// Wait All computational streams
for (auto p : this->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : local_scopes_) {
scope->DropKids();
}
};
// Wait FetchOps. // Wait FetchOps.
if (!fetch_ops.empty()) { if (!fetch_ops.empty()) {
fetch_ops.clear(); fetch_ops.clear();
sync_computation();
}
if (computation_count_ == max_async_computation) {
sync_computation();
}
// NOTE: the temp scope can be dropped lazily if needed.
// Drop tmp scopes;
for (auto &scope : local_scopes_) {
auto &kid = *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>();
kid = nullptr;
} }
return fetch_data; return fetch_data;
......
...@@ -99,9 +99,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -99,9 +99,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unique_ptr<platform::EnforceNotMet> exception_; std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_; std::atomic<int> running_ops_;
bool allow_op_delay_; bool allow_op_delay_;
size_t computation_count_{0};
size_t max_async_computation{100};
}; };
} // namespace details } // namespace details
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include <string> #include <string>
#include <tuple>
#include <vector> #include <vector>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -41,6 +42,8 @@ class ParallelExecutorPrivate { ...@@ -41,6 +42,8 @@ class ParallelExecutorPrivate {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
#endif #endif
std::vector<std::tuple<std::string, proto::VarType::Type, bool>> var_types_;
}; };
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() { std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
...@@ -97,14 +100,9 @@ ParallelExecutor::ParallelExecutor( ...@@ -97,14 +100,9 @@ ParallelExecutor::ParallelExecutor(
allow_op_delay)); allow_op_delay));
// Step 3. Create vars in each scope; // Step 3. Create vars in each scope;
for (auto *scope : member_->local_scopes_) { for (auto *var : main_program.Block(0).AllVars()) {
for (auto *var : main_program.Block(0).AllVars()) { member_->var_types_.emplace_back(var->Name(), var->GetType(),
if (scope->FindVar(var->Name()) != nullptr) { var->Persistable());
continue;
}
InitializeVariable(scope->Var(var->Name()), var->GetType());
}
} }
} }
...@@ -163,9 +161,42 @@ void ParallelExecutor::Run( ...@@ -163,9 +161,42 @@ void ParallelExecutor::Run(
const std::unordered_map<std::string, LoDTensor> &feed_tensors) { const std::unordered_map<std::string, LoDTensor> &feed_tensors) {
platform::RecordBlock b(0); platform::RecordBlock b(0);
SplitTensorToPlaces(feed_tensors); SplitTensorToPlaces(feed_tensors);
// Create local scopes.
for (auto &scope : member_->local_scopes_) {
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &name_type_pair : member_->var_types_) {
if (scope->FindVar(std::get<0>(name_type_pair)) != nullptr) {
continue;
}
if (std::get<2>(name_type_pair)) { // Persistable
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
std::get<1>(name_type_pair));
} else {
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
std::get<1>(name_type_pair));
}
}
}
auto fetch_data = member_->executor_->Run(fetch_tensors); auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() = *member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data; fetch_data;
// Wait All computational streams
for (auto p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : member_->local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
local_scope = nullptr;
}
} }
void ParallelExecutor::SplitTensorToPlaces( void ParallelExecutor::SplitTensorToPlaces(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册