memory_optimize_helper.h 4.0 KB
Newer Older
D
dzhwinter 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <algorithm>
#include <iostream>
#include <iterator>
#include <list>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ir/graph.h"

namespace paddle {
namespace framework {
namespace details {

constexpr char kFetchedVars[] = "fetched_vars";
constexpr char kGraphNodePool[] = "graph_node_pool";

// NOTE(dzh): Variable and the operators use the var.
// for early delete pass.
// Because analysis var pass build base on ir::Node, which maybe released
// or modified between passes, so we use OpDesc* to mark ops.
using GraphNodePool = std::vector<
    std::pair<std::string /*var node*/, std::unordered_set<OpDesc*> /* ops */>>;

// NOTE(dzh): by default, it sort node in ascend order(by node bytes size).
// in fluid, -1 means the batch_size is determined in runtime.
// the node batch_size equal -1 always ranking in the front than the node not.
// For example,
// node0[-1, 1] node1[-1, 1, 1], node2[1,1], node3[1,1024], ..
// O(1) insert, delete
D
dzhwinter 已提交
46
class OrderedNodeList {
D
dzhwinter 已提交
47 48 49 50 51 52 53 54 55
 public:
  using NodePair = std::pair<ir::Node*, std::unordered_set<ir::Node*>>;
  using Iter = typename std::list<NodePair>::iterator;
  using ConstIter = typename std::list<NodePair>::const_iterator;

  void Insert(ir::Node* var, ir::Node* op);

  void Erase(ir::Node* var);

D
dzhwinter 已提交
56 57
  void Erase(const std::string& var);

D
dzhwinter 已提交
58 59
  bool Has(ir::Node* var) { return mark_table_.count(var->Name()); }

D
dzhwinter 已提交
60 61
  bool Has(const std::string& var) { return mark_table_.count(var); }

D
dzhwinter 已提交
62 63 64 65 66 67 68 69 70 71 72 73
  ir::Node* NodeMatch(ir::Node* var) const;
  // map store non-const iterator, can not promise const
  int GetIndex(ir::Node* var);
  // pool all node to string
  std::string ToString() const;

  Iter begin() { return nodes_.begin(); }
  Iter end() { return nodes_.end(); }
  ConstIter begin() const { return nodes_.begin(); }
  ConstIter end() const { return nodes_.end(); }
  size_t size() const { return nodes_.size(); }

D
dzhwinter 已提交
74 75 76 77 78
  void Clear() {
    mark_table_.clear();
    nodes_.clear();
  }

D
dzhwinter 已提交
79 80 81 82 83 84 85
 private:
  // for searching.
  std::unordered_map<std::string, Iter> mark_table_;
  // node swap pairs. var -> ops dep var
  std::list<NodePair> nodes_;
};

D
dzhwinter 已提交
86 87 88
// valid a tensor can be reuse or not
bool NodeCanReused(ir::Node* node);

D
dzhwinter 已提交
89 90 91
// valid a tensor can be reuse or not.
bool NodeCanReused(const VarDesc& node);

D
dzhwinter 已提交
92 93 94
// check op has subblock or not
bool OpHasSubBlock(OpDesc* desc);

D
dzhwinter 已提交
95 96 97
// node memory size in bytes
size_t NodeSizeInBytes(ir::Node* n);

D
dzhwinter 已提交
98 99 100
// node memory size in bytes
size_t NodeSizeInBytes(const VarDesc&);

D
dzhwinter 已提交
101 102 103 104
std::string DebugString(ir::Node* var);

VarDesc* FindVarDescInBlock(ir::Node* n);

D
dzhwinter 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
template <typename Container, typename Callback>
class FilterVariableImpl {
 public:
  void operator()(const Container& nodes, Callback callback) {
    for (auto* node : nodes) {
      callback(node);
    }
  }
};

// filter var node for op->inputs/outputs
template <typename Callback>
class FilterVariableImpl<std::vector<ir::Node*>, Callback> {
 public:
  void operator()(const std::vector<ir::Node*>& nodes, Callback callback) {
    for (auto* var : nodes) {
      if (var->IsVar() && !var->IsCtrlVar()) {
        callback(var);
      }
    }
  }
};

template <typename Container, typename Callback>
void FilterVariables(const Container& nodes, Callback callback) {
  FilterVariableImpl<Container, Callback>()(nodes, callback);
}

D
dzhwinter 已提交
133 134 135
}  // namespace details
}  // namespace framework
}  // namespace paddle