eager_deletion_pass.cc 3.7 KB
Newer Older
S
sneaxiy 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <queue>
#include <string>
#include <vector>

#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/details/eager_deletion_pass.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"

namespace paddle {
namespace framework {
namespace details {

std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
    std::unique_ptr<ir::Graph> graph) const {
  auto &ref_cnts =
S
sneaxiy 已提交
32
      Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
S
sneaxiy 已提交
33 34 35 36 37 38
  PADDLE_ENFORCE(ref_cnts.empty(),
                 "kRuntimeReferenceCount should be initialized here!");

  const auto &vars = graph->Get<GraphVars>(kGraphVars);
  ref_cnts.resize(vars.size());

S
fix bug  
sneaxiy 已提交
39 40
  const auto &last_live_ops =
      Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
S
sneaxiy 已提交
41
  const auto &gcs = Get<GarbageCollectorMap>(kGarbageCollector);
S
sneaxiy 已提交
42
  const auto &places = Get<std::vector<platform::Place>>(kAllPlaces);
S
sneaxiy 已提交
43

S
sneaxiy 已提交
44 45
  // a reverse map of last_live_ops
  //   i.e., last op --> variable names which can be deleted.
S
fix bug  
sneaxiy 已提交
46 47 48
  std::unordered_map<ComputationOpHandle *, std::unordered_set<std::string>>
      op_vars_map;

S
sneaxiy 已提交
49 50 51
  for (auto &var_ops_map : last_live_ops) {
    for (auto &var_ops_pair : var_ops_map) {
      const std::string &var_name = var_ops_pair.first;
S
fix bug  
sneaxiy 已提交
52 53
      for (auto *op : var_ops_pair.second) {
        op_vars_map[op].insert(var_name);
S
sneaxiy 已提交
54 55 56
      }
    }
  }
S
fix bug  
sneaxiy 已提交
57 58 59 60 61 62 63 64

  for (auto &pair : op_vars_map) {
    auto *op = pair.first;
    auto &var_names = pair.second;

    auto *eager_deletion_node =
        graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
    auto *eager_deletion_op = new EagerDeletionOpHandle(
S
sneaxiy 已提交
65 66
        eager_deletion_node, op->GetScope(), op->GetPlace(), var_names,
        gcs.at(places[op->GetScopeIdx()]).get(),
S
fix bug  
sneaxiy 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
        &(ref_cnts[op->GetScopeIdx()]));

    auto it = std::find_if(
        op->Outputs().begin(), op->Outputs().end(), [](VarHandleBase *var) {
          return dynamic_cast<DummyVarHandle *>(var) != nullptr;
        });

    if (it != op->Outputs().end()) {
      eager_deletion_op->AddInput(*it);
    } else {
      auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
      graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
      op->AddOutput(dep_var);
      eager_deletion_op->AddInput(dep_var);
    }

    auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
    graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
    eager_deletion_op->AddOutput(dummy_leaf);
  }

  VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
S
sneaxiy 已提交
89 90 91 92 93 94 95 96 97
  return graph;
}

}  // namespace details
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(eager_deletion_pass,
              paddle::framework::details::EagerDeletionPass)
S
sneaxiy 已提交
98
    .RequirePassAttr(paddle::framework::details::kRuntimeReferenceCount)
S
sneaxiy 已提交
99
    .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars)
S
sneaxiy 已提交
100
    .RequirePassAttr(paddle::framework::details::kAllPlaces)
S
sneaxiy 已提交
101
    .RequirePassAttr(paddle::framework::details::kGarbageCollector);