未验证 提交 67c8dade 编写于 作者: C chengduo 提交者: GitHub

Add Event in ScopeBuffer Executor (#17667)

* add event for fast executor and add threads for scopebuffer executor
test=develop
上级 bba57cdd
......@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
......@@ -50,6 +51,8 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
FeedFetchList FastThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
VLOG(3) << "enter FastThreadedSSAGraphExecutor Run";
std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("FastThreadedSSAGraphExecutorPrepare"));
std::unique_ptr<std::unordered_map<OpHandleBase *, std::atomic<int>>>
op_deps = atomic_op_deps_.get();
PrepareAtomicOpDeps();
......@@ -64,7 +67,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
InsertFetchOps(fetch_tensors, &fetches, &fetched_vars, op_deps.get(),
&fetch_ops, &ready_fetch_ops);
event.reset(nullptr);
if (strategy_.num_threads_ == 1 && traced_ops_.size() == num_ops) {
// If the num_threads is 1, we can record the order of operator's
// execution in the first iteration, and in subsequent iterations,
......
......@@ -36,26 +36,10 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
if (drop_scope_counter_ == 0) {
// Create local scopes.
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
platform::RecordEvent e("InitLocalExeScopes");
PrepareLocalExeScopes();
}
std::vector<framework::LoDTensor> fetch_data;
std::exception_ptr eptr = nullptr;
try {
......@@ -64,9 +48,7 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
eptr = std::current_exception();
}
platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun");
++drop_scope_counter_;
if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
DropLocalExeScopes();
}
......@@ -78,11 +60,11 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
}
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
platform::RecordEvent drop_scope_event("DropLocalExeScopes");
drop_scope_counter_ = 0;
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
......@@ -91,6 +73,26 @@ void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
}
}
void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() {
// Create local scopes.
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(kLocalExecScopeName)->GetMutable<Scope *>() = &local_scope;
for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
}
bool ScopeBufferedSSAGraphExecutor::NeedCreateLocalExeScope() {
return drop_scope_counter_ == 0;
}
......
......@@ -13,7 +13,8 @@
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <list>
#include <memory>
#include <string>
#include <vector>
......@@ -51,6 +52,8 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
bool NeedCreateLocalExeScope();
void PrepareLocalExeScopes();
private:
size_t drop_scope_counter_{0};
ExecutionStrategy strategy_;
......
......@@ -586,6 +586,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
platform::RecordBlock b(0);
if (member_->HasGarbageCollectors()) {
platform::RecordEvent event("PrepareGarbageCollectors");
member_->ResetRuntimeReferenceCount(fetch_tensors, fetched_var_name);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册