reference_count_pass.cc 4.2 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

S
sneaxiy 已提交
15
#include <queue>
S
sneaxiy 已提交
16 17 18 19
#include <string>
#include <vector>

#include "paddle/fluid/framework/details/computation_op_handle.h"
S
sneaxiy 已提交
20
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
S
sneaxiy 已提交
21 22
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/reference_count_pass.h"
S
sneaxiy 已提交
23
#include "paddle/fluid/framework/details/reference_count_pass_helper.h"
X
Xin Pan 已提交
24
#include "paddle/fluid/framework/ir/graph_helper.h"
S
sneaxiy 已提交
25 26 27 28 29

namespace paddle {
namespace framework {
namespace details {

S
sneaxiy 已提交
30 31 32 33 34
static ComputationOpHandle *FindNextComputationOpHandleOrReturnItself(
    OpHandleBase *op, size_t scope_idx) {
  std::queue<OpHandleBase *> q;
  std::unordered_set<OpHandleBase *> visited;
  q.push(op);
S
sneaxiy 已提交
35
  do {
S
sneaxiy 已提交
36 37 38 39 40 41 42 43 44 45
    auto *op = q.front();
    q.pop();
    auto *compute_op = dynamic_cast<ComputationOpHandle *>(op);
    if (compute_op != nullptr && compute_op->GetScopeIdx() == scope_idx) {
      return compute_op;
    }
    for (auto *out_var : op->Outputs()) {
      for (auto *pending_op : out_var->PendingOps()) {
        if (visited.count(pending_op)) continue;
        visited.insert(pending_op);
S
sneaxiy 已提交
46 47
      }
    }
S
sneaxiy 已提交
48
  } while (!q.empty());
S
sneaxiy 已提交
49 50 51
  return nullptr;
}

S
sneaxiy 已提交
52 53
std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
    std::unique_ptr<ir::Graph> graph) const {
S
sneaxiy 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
  auto &vars = graph->Get<GraphVars>(kGraphVars);
  auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
  auto &last_live_ops_of_vars =
      Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);

  last_live_ops_of_vars = std::vector<LastLiveOpsOfVars>(vars.size());
  ref_cnts = std::vector<ReferenceCountMap>(vars.size());

  for (size_t i = 0; i < vars.size(); ++i) {
    for (auto &name_var_pair : vars[i]) {
      if (name_var_pair.second.empty()) continue;
      auto *last_ver_var = name_var_pair.second.back();

      VarDesc *var_desc = nullptr;
      std::find_if(name_var_pair.second.rbegin(), name_var_pair.second.rend(),
                   [&](VarHandle *var_handle) -> bool {
                     var_desc = var_handle->Node()->Var();
                     return var_desc != nullptr;
                   });

      if (var_desc == nullptr || var_desc->Persistable()) {
S
sneaxiy 已提交
75
        continue;
S
sneaxiy 已提交
76 77 78 79
      }

      auto var_type = var_desc->Proto()->type().type();
      if (var_type != proto::VarType::LOD_TENSOR &&
S
sneaxiy 已提交
80 81
          var_type != proto::VarType::SELECTED_ROWS &&
          var_type != proto::VarType::LOD_TENSOR_ARRAY) {
S
sneaxiy 已提交
82
        continue;
S
sneaxiy 已提交
83 84
      }

S
sneaxiy 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
      std::unordered_set<ComputationOpHandle *> last_live_op;
      auto add_last_live_op = [&](OpHandleBase *op) {
        auto *compute_op = FindNextComputationOpHandleOrReturnItself(op, i);
        if (compute_op) {
          last_live_op.insert(compute_op);
        }
      };
      const std::string &var_name = name_var_pair.first;
      auto &pending_ops = last_ver_var->PendingOps();
      if (pending_ops.empty()) {
        auto *generated_op = last_ver_var->GeneratedOp();
        if (generated_op) {
          ref_cnts[i].emplace(var_name, 1);
          add_last_live_op(generated_op);
        }
      } else {
        ref_cnts[i].emplace(var_name, pending_ops.size());
        for (auto *pending_op : pending_ops) {
          add_last_live_op(pending_op);
S
sneaxiy 已提交
104
        }
S
sneaxiy 已提交
105 106
      }

S
sneaxiy 已提交
107
      last_live_ops_of_vars[i].emplace(var_name, std::move(last_live_op));
S
sneaxiy 已提交
108 109 110 111 112 113 114 115 116 117 118 119
    }
  }
  return graph;
}

}  // namespace details
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(reference_count_pass,
              paddle::framework::details::ReferenceCountPass)
    .RequirePassAttr(paddle::framework::details::kGlobalReferenceCount)
S
sneaxiy 已提交
120
    .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars);