executor_cache.h 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <memory>
19
#include <sstream>
20 21
#include <string>
#include <unordered_map>
22
#include <unordered_set>
23 24 25
#include <utility>
#include <vector>

T
Thunderbrook 已提交
26
#include "paddle/fluid/framework/op_proto_maker.h"
27
#include "paddle/fluid/framework/parallel_executor.h"
28 29
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/macros.h"
T
Thunderbrook 已提交
30
#include "paddle/fluid/string/string_helper.h"
31 32 33

namespace paddle {
namespace framework {
34 35 36
namespace ir {
class Graph;
}
37

38 39 40
namespace details {
void AppendSkipDeletionVars(const std::vector<std::string>& append_vars,
                            std::vector<std::string>* all_vars);
41

42
void ParseSafeEagerDeletionSkipVars(
43 44
    const ProgramDesc& program,
    int64_t forward_op_nums,
45 46 47 48
    const std::vector<std::string>& output_var_names,
    std::vector<std::string>* skip_eager_delete_vars);

}  // namespace details
49 50

class ExecutorInfo {
51
 public:
52 53 54
  struct CacheValue {
    std::shared_ptr<ParallelExecutor> executor_{nullptr};
    std::shared_ptr<ir::Graph> graph_{nullptr};
55

56
    std::vector<std::string> skip_eager_delete_vars_;
57
  };
58

59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
  bool IsAvailable(bool is_grad) {
    const auto& executor =
        is_grad ? backward_info_.executor_ : forward_info_.executor_;
    return executor != nullptr;
  }

  CacheValue& GetMutable(bool is_grad) {
    return is_grad ? backward_info_ : forward_info_;
  }

 private:
  CacheValue forward_info_;
  CacheValue backward_info_;
};

class ExecutorInfoCache {
 public:
76 77
  static ExecutorInfoCache& Instance();

78 79 80 81 82 83 84
  const BuildStrategy& GetBuildStrategy(int64_t program_id) {
    // If not found, insert build_strategy with default value.
    return strategy_map_[program_id];
  }

  void SetBuildStrategy(int64_t program_id,
                        const BuildStrategy& build_strategy) {
85
    PADDLE_ENFORCE_EQ(
86 87
        strategy_map_.count(program_id),
        0,
88 89 90
        platform::errors::PreconditionNotMet(
            "program_id: %s already exist in ExecutorInfoCache", program_id));
    strategy_map_[program_id] = build_strategy;
91 92
  }

93 94 95
  bool Has(int64_t program_id, bool is_grad) {
    return info_map_.find(program_id) != info_map_.end() &&
           info_map_[program_id].IsAvailable(is_grad);
96 97
  }

98 99
  ExecutorInfo::CacheValue& GetMutable(int64_t program_id, bool is_grad) {
    return info_map_[program_id].GetMutable(is_grad);
100 101
  }

102 103
  void UpdateSkipEagerDeleteVars(int64_t program_id,
                                 bool is_grad,
104 105 106 107 108 109 110 111 112
                                 const std::vector<std::string>& skip_vars) {
    auto& cached_value = GetMutable(program_id, is_grad);
    cached_value.skip_eager_delete_vars_ = std::move(skip_vars);
  }

  std::vector<std::string>& SkipEagerDeleteVars(int64_t program_id,
                                                bool is_grad) {
    auto& cached_value = GetMutable(program_id, is_grad);
    return cached_value.skip_eager_delete_vars_;
113 114
  }

115
  size_t Size() const { return info_map_.size(); }
116

117 118 119 120 121 122 123
  void Finalize() {
    // NOTE(Aurelius84): DO NOT perform finalize in destructor
    // to avoid problems caused by destructor order of static
    // object.
    info_map_.clear();
    strategy_map_.clear();
  }
124

125
 private:
126 127
  std::unordered_map<int64_t, ExecutorInfo> info_map_;
  std::unordered_map<int64_t, BuildStrategy> strategy_map_;
128 129
};

130 131 132
using CacheInfo =
    std::pair<std::shared_ptr<ParallelExecutor>, bool /*is_new_created*/>;

133 134 135
using PEAndGraphPair =
    std::pair<std::shared_ptr<ParallelExecutor>, std::shared_ptr<ir::Graph>>;

136 137
CacheInfo GetExecutorInfoFromCache(const ProgramDesc& program_desc,
                                   const platform::Place& place,
138 139 140 141
                                   int64_t start_op_index,
                                   int64_t end_op_index,
                                   bool is_grad,
                                   int64_t program_id,
142
                                   framework::Scope* scope);
143

144 145 146 147 148 149
PEAndGraphPair CreateFixOrderExecutorInfo(const ProgramDesc& program_desc,
                                          const platform::Place& place,
                                          int64_t start_op_index,
                                          int64_t end_op_index,
                                          framework::Scope* scope);

150 151
}  // namespace framework
}  // namespace paddle