executor_cache.h 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <memory>
19
#include <sstream>
20 21
#include <string>
#include <unordered_map>
22
#include <unordered_set>
23 24 25
#include <utility>
#include <vector>

T
Thunderbrook 已提交
26
#include "paddle/fluid/framework/op_proto_maker.h"
27
#include "paddle/fluid/framework/parallel_executor.h"
28 29
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/macros.h"
T
Thunderbrook 已提交
30
#include "paddle/fluid/string/string_helper.h"
31

32 33 34 35 36
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"

37 38
namespace paddle {
namespace framework {
39 40 41
namespace ir {
class Graph;
}
42

43 44
class InterpreterCore;

45 46 47
namespace details {
void AppendSkipDeletionVars(const std::vector<std::string>& append_vars,
                            std::vector<std::string>* all_vars);
48

49
void ParseSafeEagerDeletionSkipVars(
50 51
    const ProgramDesc& program,
    int64_t forward_op_nums,
52 53 54
    const std::vector<std::string>& output_var_names,
    std::vector<std::string>* skip_eager_delete_vars);

55 56 57
void AppendSkipDeletionVars(const std::vector<std::string>& append_vars,
                            std::set<std::string>* all_vars);

58 59
// TODO(Aurelius84) : Need remove skip_no_need_buffer after cinn fix this
// problem.
60
std::set<std::string> ParseSafeEagerDeletionSkipVarsSet(
61
    const ProgramDesc& backward_program, bool skip_no_need_buffer = false);
62

63
}  // namespace details
64 65

class ExecutorInfo {
66
 public:
67 68 69
  struct CacheValue {
    std::shared_ptr<ParallelExecutor> executor_{nullptr};
    std::shared_ptr<ir::Graph> graph_{nullptr};
70

71
    std::vector<std::string> skip_eager_delete_vars_;
72
  };
73

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
  bool IsAvailable(bool is_grad) {
    const auto& executor =
        is_grad ? backward_info_.executor_ : forward_info_.executor_;
    return executor != nullptr;
  }

  CacheValue& GetMutable(bool is_grad) {
    return is_grad ? backward_info_ : forward_info_;
  }

 private:
  CacheValue forward_info_;
  CacheValue backward_info_;
};

class ExecutorInfoCache {
 public:
91 92
  static ExecutorInfoCache& Instance();

93 94 95 96 97 98 99
  const BuildStrategy& GetBuildStrategy(int64_t program_id) {
    // If not found, insert build_strategy with default value.
    return strategy_map_[program_id];
  }

  void SetBuildStrategy(int64_t program_id,
                        const BuildStrategy& build_strategy) {
100
    PADDLE_ENFORCE_EQ(
101 102
        strategy_map_.count(program_id),
        0,
103 104 105
        platform::errors::PreconditionNotMet(
            "program_id: %s already exist in ExecutorInfoCache", program_id));
    strategy_map_[program_id] = build_strategy;
106 107
  }

108 109 110
  bool Has(int64_t program_id, bool is_grad) {
    return info_map_.find(program_id) != info_map_.end() &&
           info_map_[program_id].IsAvailable(is_grad);
111 112
  }

113 114
  ExecutorInfo::CacheValue& GetMutable(int64_t program_id, bool is_grad) {
    return info_map_[program_id].GetMutable(is_grad);
115 116
  }

117 118
  void UpdateSkipEagerDeleteVars(int64_t program_id,
                                 bool is_grad,
119 120 121 122 123 124 125 126 127
                                 const std::vector<std::string>& skip_vars) {
    auto& cached_value = GetMutable(program_id, is_grad);
    cached_value.skip_eager_delete_vars_ = std::move(skip_vars);
  }

  std::vector<std::string>& SkipEagerDeleteVars(int64_t program_id,
                                                bool is_grad) {
    auto& cached_value = GetMutable(program_id, is_grad);
    return cached_value.skip_eager_delete_vars_;
128 129
  }

130
  size_t Size() const { return info_map_.size(); }
131

132 133 134 135 136 137 138
  void Finalize() {
    // NOTE(Aurelius84): DO NOT perform finalize in destructor
    // to avoid problems caused by destructor order of static
    // object.
    info_map_.clear();
    strategy_map_.clear();
  }
139

140
 private:
141 142
  std::unordered_map<int64_t, ExecutorInfo> info_map_;
  std::unordered_map<int64_t, BuildStrategy> strategy_map_;
143 144
};

145 146 147
using CacheInfo =
    std::pair<std::shared_ptr<ParallelExecutor>, bool /*is_new_created*/>;

148 149 150
using PEAndGraphPair =
    std::pair<std::shared_ptr<ParallelExecutor>, std::shared_ptr<ir::Graph>>;

151 152
CacheInfo GetExecutorInfoFromCache(const ProgramDesc& program_desc,
                                   const platform::Place& place,
153 154 155 156
                                   int64_t start_op_index,
                                   int64_t end_op_index,
                                   bool is_grad,
                                   int64_t program_id,
157
                                   framework::Scope* scope);
158

159 160 161 162 163 164
PEAndGraphPair CreateFixOrderExecutorInfo(const ProgramDesc& program_desc,
                                          const platform::Place& place,
                                          int64_t start_op_index,
                                          int64_t end_op_index,
                                          framework::Scope* scope);

165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
class InterpreterCoreInfo {
 public:
  struct CacheValue {
    std::shared_ptr<InterpreterCore> core_{nullptr};
    std::set<std::string> skip_eager_delete_vars_;
  };

  bool IsAvailable(bool is_grad) {
    const auto& core = is_grad ? backward_info_.core_ : forward_info_.core_;
    return core != nullptr;
  }

  CacheValue& GetMutable(bool is_grad) {
    return is_grad ? backward_info_ : forward_info_;
  }

 private:
  CacheValue forward_info_;
  CacheValue backward_info_;
};

class InterpreterCoreInfoCache {
 public:
  static InterpreterCoreInfoCache& Instance();

190 191 192
  bool Has(int64_t program_id, const framework::Scope* scope, bool is_grad) {
    int64_t scope_i = reinterpret_cast<std::uintptr_t>(scope);
    program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2);
193 194 195 196 197
    return info_map_.find(program_id) != info_map_.end() &&
           info_map_[program_id].IsAvailable(is_grad);
  }

  InterpreterCoreInfo::CacheValue& GetMutable(int64_t program_id,
198
                                              const framework::Scope* scope,
199
                                              bool is_grad) {
200 201
    int64_t scope_i = reinterpret_cast<std::uintptr_t>(scope);
    program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2);
202 203 204 205
    return info_map_[program_id].GetMutable(is_grad);
  }

  void UpdateSkipEagerDeleteVars(int64_t program_id,
206
                                 const framework::Scope* scope,
207 208
                                 bool is_grad,
                                 const std::set<std::string>& skip_vars) {
209
    auto& cached_value = GetMutable(program_id, scope, is_grad);
210 211 212 213
    cached_value.skip_eager_delete_vars_ = std::move(skip_vars);
  }

  std::set<std::string>& GetSkipEagerDeleteVars(int64_t program_id,
214
                                                const framework::Scope* scope,
215
                                                bool is_grad) {
216
    auto& cached_value = GetMutable(program_id, scope, is_grad);
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
    return cached_value.skip_eager_delete_vars_;
  }

  size_t Size() const { return info_map_.size(); }

  void Finalize() {
    // NOTE(Aurelius84): DO NOT perform finalize in destructor
    // to avoid problems caused by destructor order of static
    // object.
    info_map_.clear();
  }

 private:
  std::unordered_map<int64_t, InterpreterCoreInfo> info_map_;
};

233
std::shared_ptr<InterpreterCore> CreateProgramInterpreterCoreInfoToCache(
234 235 236 237 238 239
    const ProgramDesc& program_desc,
    const platform::Place& place,
    bool is_grad,
    int64_t program_id,
    framework::Scope* scope);

240 241 242 243 244 245 246 247 248 249 250
std::shared_ptr<InterpreterCore> CreateNewIRInterpreterCoreInfoToCache(
    std::unique_ptr<::ir::Program> ir_prog,
    const platform::Place& place,
    bool is_grad,
    int64_t program_id,
    framework::Scope* scope);

std::unique_ptr<::ir::Program> ConstructFowardIrProgram(
    const paddle::framework::BlockDesc* forward_global_block,
    const paddle::framework::BlockDesc* backward_global_block,
    const std::vector<std::string> output_names,
251 252
    const std::vector<paddle::Tensor>& x,
    const std::vector<paddle::Tensor>& params);
253 254 255 256 257

std::unique_ptr<::ir::Program> ConstructBackwardIrProgram(
    const paddle::framework::BlockDesc* backward_global_block,
    const std::vector<paddle::Tensor>& out_grad,
    const std::vector<paddle::Tensor*>& x_grad,
258 259
    const std::vector<paddle::Tensor*>& params_grad,
    const paddle::framework::Scope* scope);
260

261 262
}  // namespace framework
}  // namespace paddle