提交 ad73b331 编写于 作者: Y Yu Yang 提交者: qingqing01

Eagerly drop local scope in iteration (#9838)

* Eagerly drop local scope in iteration

* Correct create var

* Fix typo

* Debug
上级 8d4d6eae
......@@ -14,6 +14,8 @@
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include <string>
namespace paddle {
namespace framework {
namespace details {
......@@ -33,7 +35,7 @@ void ComputationOpHandle::RunImpl() {
}
}
op_->Run(*scope_->FindVar("@TMP_SCOPE@")->Get<Scope *>(), place_);
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
}
std::string ComputationOpHandle::Name() const { return op_->Type(); }
......
......@@ -14,6 +14,9 @@
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include <string>
#include <vector>
namespace paddle {
namespace framework {
namespace details {
......@@ -57,7 +60,10 @@ void FetchOpHandle::RunImpl() {
for (size_t i = 0; i < scopes.size(); ++i) {
auto &scope = scopes[i];
auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>();
auto &t = scope->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(var_name)
->Get<framework::LoDTensor>();
if (platform::is_gpu_place(var->place_)) {
#ifdef PADDLE_WITH_CUDA
TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]);
......
......@@ -24,6 +24,8 @@ namespace paddle {
namespace framework {
namespace details {
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
class OpHandleBase {
private:
DISABLE_COPY_AND_ASSIGN(OpHandleBase);
......
......@@ -15,13 +15,15 @@
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
namespace paddle {
namespace framework {
namespace details {
class SSAGraphExecutor {
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
......
......@@ -136,12 +136,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
ready_ops.clear();
};
// Create local scopes.
for (auto &scope : local_scopes_) {
auto &local_scope = scope->NewScope();
*scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>() = &local_scope;
}
// Step 3. Execution
while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
// 1. Run All Ready ops
......@@ -189,34 +183,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
PADDLE_ENFORCE(ready_ops.empty());
PADDLE_ENFORCE(delayed_ops.empty());
PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
++computation_count_;
auto sync_computation = [&] {
computation_count_ = 0;
// Wait All computational streams
for (auto p : this->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : local_scopes_) {
scope->DropKids();
}
};
// Wait FetchOps.
if (!fetch_ops.empty()) {
fetch_ops.clear();
sync_computation();
}
if (computation_count_ == max_async_computation) {
sync_computation();
}
// NOTE: the temp scope can be dropped lazily if needed.
// Drop tmp scopes;
for (auto &scope : local_scopes_) {
auto &kid = *scope->Var("@TMP_SCOPE@")->GetMutable<Scope *>();
kid = nullptr;
}
return fetch_data;
......
......@@ -99,9 +99,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unique_ptr<platform::EnforceNotMet> exception_;
std::atomic<int> running_ops_;
bool allow_op_delay_;
size_t computation_count_{0};
size_t max_async_computation{100};
};
} // namespace details
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h"
#include <string>
#include <tuple>
#include <vector>
#ifdef PADDLE_WITH_CUDA
......@@ -41,6 +42,8 @@ class ParallelExecutorPrivate {
#ifdef PADDLE_WITH_CUDA
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
#endif
std::vector<std::tuple<std::string, proto::VarType::Type, bool>> var_types_;
};
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
......@@ -97,14 +100,9 @@ ParallelExecutor::ParallelExecutor(
allow_op_delay));
// Step 3. Create vars in each scope;
for (auto *scope : member_->local_scopes_) {
for (auto *var : main_program.Block(0).AllVars()) {
if (scope->FindVar(var->Name()) != nullptr) {
continue;
}
InitializeVariable(scope->Var(var->Name()), var->GetType());
}
member_->var_types_.emplace_back(var->Name(), var->GetType(),
var->Persistable());
}
}
......@@ -163,9 +161,42 @@ void ParallelExecutor::Run(
const std::unordered_map<std::string, LoDTensor> &feed_tensors) {
platform::RecordBlock b(0);
SplitTensorToPlaces(feed_tensors);
// Create local scopes.
for (auto &scope : member_->local_scopes_) {
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &name_type_pair : member_->var_types_) {
if (scope->FindVar(std::get<0>(name_type_pair)) != nullptr) {
continue;
}
if (std::get<2>(name_type_pair)) { // Persistable
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
std::get<1>(name_type_pair));
} else {
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
std::get<1>(name_type_pair));
}
}
}
auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data;
// Wait All computational streams
for (auto p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : member_->local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
local_scope = nullptr;
}
}
void ParallelExecutor::SplitTensorToPlaces(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册