executor_cache.h 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <memory>
19
#include <sstream>
20 21
#include <string>
#include <unordered_map>
22
#include <unordered_set>
23 24 25
#include <utility>
#include <vector>

T
Thunderbrook 已提交
26
#include "paddle/fluid/framework/op_proto_maker.h"
27
#include "paddle/fluid/framework/parallel_executor.h"
28 29
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/macros.h"
T
Thunderbrook 已提交
30
#include "paddle/fluid/string/string_helper.h"
31 32 33

namespace paddle {
namespace framework {
34 35 36
namespace ir {
class Graph;
}
37

38 39 40
namespace details {
void AppendSkipDeletionVars(const std::vector<std::string>& append_vars,
                            std::vector<std::string>* all_vars);
41

42 43 44 45 46 47
void ParseSafeEagerDeletionSkipVars(
    const ProgramDesc& program, int64_t forward_op_nums,
    const std::vector<std::string>& output_var_names,
    std::vector<std::string>* skip_eager_delete_vars);

}  // namespace details
48
class ExecutorInfoCache {
49
 public:
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
  struct CacheKey {
    CacheKey(const ProgramDesc* program_desc, const platform::Place& place,
             int64_t start_op_index, int64_t end_op_index, bool is_grad)
        : program_desc_(program_desc),
          place_(place),
          start_op_index_(start_op_index),
          end_op_index_(end_op_index),
          is_grad_(is_grad) {
      device_type_ = platform::Place2DeviceType(place);
      PADDLE_ENFORCE_NOT_NULL(program_desc_,
                              "program_desc should not be null.");
    }

    std::string DebugString() const {
      std::stringstream ss;

      ss << "\n CacheKey(program_desc: " << program_desc_;
      ss << ", start_op_index: " << start_op_index_;
      ss << ", end_op_index: " << end_op_index_;
      ss << ", is_grad: " << is_grad_;
      ss << ", device_type: " << device_type_ << ")";

      return ss.str();
    }

    const ProgramDesc* program_desc_;
    platform::Place place_;
    int64_t start_op_index_;
    int64_t end_op_index_;
    bool is_grad_;
    platform::DeviceType device_type_;
81 82
  };

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
  using KeyType = size_t;
  using ValueType =
      std::pair<std::shared_ptr<ParallelExecutor>, std::shared_ptr<ir::Graph>>;

  struct KeyHasher {
    size_t operator()(const CacheKey& key) const noexcept {
      size_t seed = 10;
      auto* prog_desc = key.program_desc_;
      /*
       * Note(Aurelius84): DO NOT use only ProgramDesc* to calculate hash value
       * because a new program will hold same pointer address after an older
       * program is destructed with a small probability. Add op size while
       * hashing because program may contains at least one block.
       */
      hash_combine(&seed, prog_desc);
      for (size_t i = 0; i < prog_desc->Size(); ++i) {
        hash_combine(&seed, &prog_desc->Block(i));
        hash_combine(&seed, prog_desc->Block(i).OpSize());
      }
      hash_combine(&seed, static_cast<int>(key.device_type_));
      hash_combine(&seed, key.start_op_index_);
      hash_combine(&seed, key.end_op_index_);
      hash_combine(&seed, key.is_grad_);
      VLOG(3) << "hash value is : " << seed
              << " of key:  " << key.DebugString();
      return seed;
    }

    template <typename T>
    void hash_combine(size_t* seed, const T& val) const {
      std::hash<T> hasher;
      (*seed) ^= hasher(val) + 0x9e3779b9 + ((*seed) << 6) + ((*seed >> 2));
    }
  };
117

118 119
  static ExecutorInfoCache& Instance();

120 121
  ValueType GetMutable(const CacheKey& key) {
    auto key_val = key_hash_func_(key);
122
    PADDLE_ENFORCE_EQ(
123 124 125 126
        Has(key_val), true,
        platform::errors::NotFound("%s doesn't exist in ExecutorInfoCache",
                                   key.DebugString()));
    return info_map_[key_val];
127 128
  }

129 130 131
  bool Has(const CacheKey& key) const {
    auto key_val = key_hash_func_(key);
    return Has(key_val);
132 133
  }

134 135
  bool Has(const KeyType& key) const {
    return info_map_.find(key) != info_map_.end();
136 137
  }

138 139 140 141 142 143 144
  void Insert(const CacheKey& key, ValueType value) {
    auto key_val = key_hash_func_(key);
    PADDLE_ENFORCE_EQ(
        Has(key_val), false,
        platform::errors::NotFound("%s has existed in ExecutorInfoCache",
                                   key.DebugString()));
    info_map_.insert({key_val, value});
145 146
  }

147
  size_t Size() const { return info_map_.size(); }
148

149
  void Finalize();
150

151
 private:
152 153 154 155 156
  ExecutorInfoCache() = default;
  DISABLE_COPY_AND_ASSIGN(ExecutorInfoCache);

  KeyHasher key_hash_func_;
  std::unordered_map<KeyType, ValueType> info_map_;
157 158
};

159 160 161
using CacheInfo =
    std::pair<std::shared_ptr<ParallelExecutor>, bool /*is_new_created*/>;

162
CacheInfo GetExecutorInfoFromCache(const ExecutorInfoCache::CacheKey& cache_key,
163
                                   framework::Scope* scope);
164 165 166

}  // namespace framework
}  // namespace paddle