executor_cache.h 4.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <functional>
#include <memory>
19
#include <sstream>
20 21
#include <string>
#include <unordered_map>
22
#include <unordered_set>
23 24 25
#include <utility>
#include <vector>

T
Thunderbrook 已提交
26
#include "paddle/fluid/framework/op_proto_maker.h"
27
#include "paddle/fluid/framework/parallel_executor.h"
28 29
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/macros.h"
T
Thunderbrook 已提交
30
#include "paddle/fluid/string/string_helper.h"
31 32 33

namespace paddle {
namespace framework {
34 35 36
namespace ir {
class Graph;
}
37

38 39 40
namespace details {
void AppendSkipDeletionVars(const std::vector<std::string>& append_vars,
                            std::vector<std::string>* all_vars);
41

42 43 44 45 46 47
void ParseSafeEagerDeletionSkipVars(
    const ProgramDesc& program, int64_t forward_op_nums,
    const std::vector<std::string>& output_var_names,
    std::vector<std::string>* skip_eager_delete_vars);

}  // namespace details
48 49

class ExecutorInfo {
50
 public:
51 52 53
  struct CacheValue {
    std::shared_ptr<ParallelExecutor> executor_{nullptr};
    std::shared_ptr<ir::Graph> graph_{nullptr};
54

55
    std::vector<std::string> skip_eager_delete_vars_;
56
  };
57

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
  bool IsAvailable(bool is_grad) {
    const auto& executor =
        is_grad ? backward_info_.executor_ : forward_info_.executor_;
    return executor != nullptr;
  }

  CacheValue& GetMutable(bool is_grad) {
    return is_grad ? backward_info_ : forward_info_;
  }

 private:
  CacheValue forward_info_;
  CacheValue backward_info_;
};

class ExecutorInfoCache {
 public:
75 76
  static ExecutorInfoCache& Instance();

77 78 79 80 81 82 83
  const BuildStrategy& GetBuildStrategy(int64_t program_id) {
    // If not found, insert build_strategy with default value.
    return strategy_map_[program_id];
  }

  void SetBuildStrategy(int64_t program_id,
                        const BuildStrategy& build_strategy) {
84
    PADDLE_ENFORCE_EQ(
85 86 87 88
        strategy_map_.count(program_id), 0,
        platform::errors::PreconditionNotMet(
            "program_id: %s already exist in ExecutorInfoCache", program_id));
    strategy_map_[program_id] = build_strategy;
89 90
  }

91 92 93
  bool Has(int64_t program_id, bool is_grad) {
    return info_map_.find(program_id) != info_map_.end() &&
           info_map_[program_id].IsAvailable(is_grad);
94 95
  }

96 97
  ExecutorInfo::CacheValue& GetMutable(int64_t program_id, bool is_grad) {
    return info_map_[program_id].GetMutable(is_grad);
98 99
  }

100 101 102 103 104 105 106 107 108 109
  void UpdateSkipEagerDeleteVars(int64_t program_id, bool is_grad,
                                 const std::vector<std::string>& skip_vars) {
    auto& cached_value = GetMutable(program_id, is_grad);
    cached_value.skip_eager_delete_vars_ = std::move(skip_vars);
  }

  std::vector<std::string>& SkipEagerDeleteVars(int64_t program_id,
                                                bool is_grad) {
    auto& cached_value = GetMutable(program_id, is_grad);
    return cached_value.skip_eager_delete_vars_;
110 111
  }

112
  size_t Size() const { return info_map_.size(); }
113

114 115 116 117 118 119 120
  void Finalize() {
    // NOTE(Aurelius84): DO NOT perform finalize in destructor
    // to avoid problems caused by destructor order of static
    // object.
    info_map_.clear();
    strategy_map_.clear();
  }
121

122
 private:
123 124
  std::unordered_map<int64_t, ExecutorInfo> info_map_;
  std::unordered_map<int64_t, BuildStrategy> strategy_map_;
125 126
};

127 128 129
using CacheInfo =
    std::pair<std::shared_ptr<ParallelExecutor>, bool /*is_new_created*/>;

130 131 132 133
CacheInfo GetExecutorInfoFromCache(const ProgramDesc& program_desc,
                                   const platform::Place& place,
                                   int64_t start_op_index, int64_t end_op_index,
                                   bool is_grad, int64_t program_id,
134
                                   framework::Scope* scope);
135 136 137

}  // namespace framework
}  // namespace paddle