eager_deletion_pass.cc 3.5 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <queue>
#include <string>
#include <vector>

#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"

namespace paddle {
namespace framework {
namespace details {

std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
    std::unique_ptr<ir::Graph> graph) const {
S
fix bug  
sneaxiy 已提交
31
  const auto &vars = graph->Get<GraphVars>(kGraphVars);
S
sneaxiy 已提交
32 33

  auto &ref_cnts =
S
sneaxiy 已提交
34
      Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
S
fix bug  
sneaxiy 已提交
35 36
  const auto &last_live_ops =
      Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
S
sneaxiy 已提交
37 38
  auto &gcs = Get<GarbageCollectorMap>(kGarbageCollector);
  const auto &places = Get<std::vector<platform::Place>>(kAllPlaces);
S
sneaxiy 已提交
39 40 41

  ref_cnts = std::vector<AtomicReferenceCountMap>(vars.size());

S
fix bug  
sneaxiy 已提交
42 43 44
  std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>
      op_vars_map;

S
sneaxiy 已提交
45 46 47
  for (auto &var_ops_map : last_live_ops) {
    for (auto &var_ops_pair : var_ops_map) {
      const std::string &var_name = var_ops_pair.first;
S
fix bug  
sneaxiy 已提交
48 49
      for (auto *op : var_ops_pair.second) {
        op_vars_map[op].insert(var_name);
S
sneaxiy 已提交
50 51 52
      }
    }
  }
S
fix bug  
sneaxiy 已提交
53 54 55 56 57 58 59 60 61

  for (auto &pair : op_vars_map) {
    auto *op = pair.first;
    auto &var_names = pair.second;

    auto *eager_deletion_node =
        graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
    auto *eager_deletion_op = new EagerDeletionOpHandle(
        eager_deletion_node, op->GetScope(), op->GetPlace(),
S
sneaxiy 已提交
62
        std::move(var_names), gcs.at(places[op->GetScopeIdx()]).get(),
S
fix bug  
sneaxiy 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
        &(ref_cnts[op->GetScopeIdx()]));

    auto it = std::find_if(
        op->Outputs().begin(), op->Outputs().end(), [](VarHandleBase *var) {
          return dynamic_cast<DummyVarHandle *>(var) != nullptr;
        });

    if (it != op->Outputs().end()) {
      eager_deletion_op->AddInput(*it);
    } else {
      auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
      graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
      op->AddOutput(dep_var);
      eager_deletion_op->AddInput(dep_var);
    }

    auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
    graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
    eager_deletion_op->AddOutput(dummy_leaf);
  }

  VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
S
sneaxiy 已提交
85 86 87 88 89 90 91 92 93
  return graph;
}

}  // namespace details
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(eager_deletion_pass,
              paddle::framework::details::EagerDeletionPass)
S
sneaxiy 已提交
94
    .RequirePassAttr(paddle::framework::details::kRuntimeReferenceCount)
S
sneaxiy 已提交
95
    .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
S
sneaxiy 已提交
96
    .RequirePassAttr(paddle::framework::details::kAllPlaces)
S
sneaxiy 已提交
97
    .RequirePassAttr(paddle::framework::details::kGarbageCollector);