scope_buffered_ssa_graph_executor.cc 3.5 KB
Newer Older
Y
yuyang18 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
16
#include <stdexcept>
Y
yuyang18 已提交
17
#include <string>
18
#include <utility>
Y
yuyang18 已提交
19
#include <vector>
W
Wang Guibao 已提交
20
#include "paddle/fluid/framework/variable_helper.h"
S
sneaxiy 已提交
21
#include "paddle/fluid/platform/profiler.h"
Y
yuyang18 已提交
22 23 24 25 26 27

namespace paddle {
namespace framework {
namespace details {
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
    ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
Y
Yancey1989 已提交
28
    std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
Y
yuyang18 已提交
29 30 31 32
    std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
    : strategy_(std::move(strategy)),
      underlying_executor_(std::move(underlying_executor)),
      local_scopes_(std::move(local_scopes)),
Y
Yancey1989 已提交
33
      var_infos_(std::move(var_infos)),
Y
yuyang18 已提交
34 35 36 37 38
      places_(std::move(places)) {}

FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
    const std::vector<std::string> &fetch_tensors) {
  if (drop_scope_counter_ == 0) {
39 40
    platform::RecordEvent e("InitLocalExeScopes");
    PrepareLocalExeScopes();
Y
yuyang18 已提交
41
  }
42

43
  std::vector<framework::LoDTensor> fetch_data;
Y
Yancey1989 已提交
44
  std::exception_ptr eptr = nullptr;
45 46 47 48 49
  try {
    fetch_data = underlying_executor_->Run(fetch_tensors);
  } catch (...) {
    eptr = std::current_exception();
  }
Y
yuyang18 已提交
50

51
  ++drop_scope_counter_;
M
minqiyang 已提交
52
  if (drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
53
    DropLocalExeScopes();
Y
yuyang18 已提交
54
  }
55 56 57 58 59
  if (eptr) {
    std::rethrow_exception(eptr);
  } else {
    return fetch_data;
  }
Y
yuyang18 已提交
60
}
61 62

void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
63
  platform::RecordEvent drop_scope_event("DropLocalExeScopes");
64 65 66 67
  drop_scope_counter_ = 0;
  for (auto p : places_) {
    platform::DeviceContextPool::Instance().Get(p)->Wait();
  }
C
chengduo 已提交
68

69
  for (auto &scope : local_scopes_) {
C
chengduo 已提交
70 71 72 73
    auto *local_scope_var = scope->FindLocalVar(details::kLocalExecScopeName);
    if (local_scope_var != nullptr) {
      auto &local_scope = *local_scope_var->GetMutable<Scope *>();
      scope->DeleteScope(local_scope);
74
      scope->EraseVars({std::string(details::kLocalExecScopeName)});
C
chengduo 已提交
75 76
      VLOG(3) << "Drop local execution scope: " << local_scope;
    }
77 78 79
  }
}

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() {
  // Create local scopes.
  for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
    auto &scope = *it;
    Scope &local_scope = scope->NewScope();
    *scope->Var(kLocalExecScopeName)->GetMutable<Scope *>() = &local_scope;

    for (auto &info : var_infos_) {
      if (scope->FindVar(info.name_) != nullptr) {
        continue;
      }
      if (info.persistable_) {  // Persistable
        InitializeVariable(scope->Var(info.name_), info.type_);
      } else {
        InitializeVariable(local_scope.Var(info.name_), info.type_);
      }
    }
  }
}

100 101 102 103
bool ScopeBufferedSSAGraphExecutor::NeedCreateLocalExeScope() {
  return drop_scope_counter_ == 0;
}

Y
yuyang18 已提交
104 105 106
}  // namespace details
}  // namespace framework
}  // namespace paddle