scope_buffered_ssa_graph_executor.cc 3.4 KB
Newer Older
Y
yuyang18 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
16
#include <stdexcept>
Y
yuyang18 已提交
17 18
#include <string>
#include <vector>
W
Wang Guibao 已提交
19
#include "paddle/fluid/framework/variable_helper.h"
S
sneaxiy 已提交
20
#include "paddle/fluid/platform/profiler.h"
Y
yuyang18 已提交
21 22 23 24 25 26 27 28 29 30 31 32

namespace paddle {
namespace framework {
namespace details {
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
    ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
    std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
    std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
    : strategy_(std::move(strategy)),
      underlying_executor_(std::move(underlying_executor)),
      local_scopes_(std::move(local_scopes)),
      var_infos_(std::move(var_infos)),
S
sneaxiy 已提交
33 34
      places_(std::move(places)) {
  if (Graph().Has(details::kGarbageCollector)) {
S
sneaxiy 已提交
35
    gc_ = &(Graph().Get<GarbageCollectorMap>(details::kGarbageCollector));
S
sneaxiy 已提交
36 37
  }
}
Y
yuyang18 已提交
38

S
fix bug  
sneaxiy 已提交
39 40
void ScopeBufferedSSAGraphExecutor::WaitAllGarbageCollectors() {
  if (gc_) {
S
sneaxiy 已提交
41 42 43
    for (auto &gc_pair : *gc_) {
      gc_pair.second->Wait();
      gc_pair.second->Reset();
S
fix bug  
sneaxiy 已提交
44 45 46 47
    }
  }
}

Y
yuyang18 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
    const std::vector<std::string> &fetch_tensors) {
  if (drop_scope_counter_ == 0) {
    // Create local scopes.
    for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
      auto &scope = *it;
      Scope &local_scope = scope->NewScope();
      *scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
          &local_scope;

      for (auto &info : var_infos_) {
        if (scope->FindVar(info.name_) != nullptr) {
          continue;
        }

        if (info.persistable_) {  // Persistable
          InitializeVariable(scope->Var(info.name_), info.type_);
        } else {
          InitializeVariable(local_scope.Var(info.name_), info.type_);
        }
      }
    }
  }
71 72 73 74 75 76 77
  std::vector<framework::LoDTensor> fetch_data;
  std::exception_ptr eptr;
  try {
    fetch_data = underlying_executor_->Run(fetch_tensors);
  } catch (...) {
    eptr = std::current_exception();
  }
Y
yuyang18 已提交
78

79
  platform::RecordEvent e("ScopeBufferedSSAGraphExecutorAfterRun", nullptr);
Y
yuyang18 已提交
80
  drop_scope_counter_ += 1;
S
sneaxiy 已提交
81

Y
yuyang18 已提交
82 83 84 85
  if (!fetch_tensors.empty() ||
      drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
    drop_scope_counter_ = 0;
    // Wait All computational streams
S
fix bug  
sneaxiy 已提交
86 87
    for (auto &p : places_) {
      platform::DeviceContextPool::Instance().Get(p)->Wait();
Y
yuyang18 已提交
88
    }
S
fix bug  
sneaxiy 已提交
89
    WaitAllGarbageCollectors();
Y
yuyang18 已提交
90 91 92 93 94
    for (auto &scope : local_scopes_) {
      auto &local_scope =
          *scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
      scope->DeleteScope(local_scope);
    }
S
fix bug  
sneaxiy 已提交
95 96
  } else {
    WaitAllGarbageCollectors();
Y
yuyang18 已提交
97
  }
S
fix bug  
sneaxiy 已提交
98

99 100 101 102 103
  if (eptr) {
    std::rethrow_exception(eptr);
  } else {
    return fetch_data;
  }
Y
yuyang18 已提交
104 105 106 107
}
}  // namespace details
}  // namespace framework
}  // namespace paddle