cache.h 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
16

17
#include <algorithm>
18
#include <numeric>
19

20
#include "paddle/phi/common/data_type.h"
21
#include "paddle/phi/kernels/autotune/cache_base.h"
22 23 24
#ifdef PADDLE_WITH_CUDNN_FRONTEND
#include "paddle/phi/kernels/autotune/cache_cudnn_frontend.h"
#endif
25 26 27
namespace phi {
namespace autotune {

28 29 30 31
struct ConvAutoTuneResult {
  ConvAutoTuneResult() {}
  ConvAutoTuneResult(int64_t a, size_t size, bool search)
      : algo(a), workspace_size(size), exhaustive_search(search) {}
H
hong 已提交
32 33 34

  int64_t algo;
  size_t workspace_size = 0;
35
  bool exhaustive_search = false;
H
hong 已提交
36 37
};

38 39 40 41
size_t TransposeKey(const std::vector<int64_t>& x_dims,
                    const std::vector<int32_t>& perm,
                    phi::DataType dtype);

42 43 44 45
enum class AlgorithmType {
  kConvForward = 1,
  kConvBackwardData = 2,
  kConvBackwardFilter = 3,
46
  kTranspose = 4,
47
#ifdef PADDLE_WITH_CUDNN_FRONTEND
48 49 50
  kConvForwardV8 = 5,
  kConvBackwardDataV8 = 6,
  kConvBackwardFilterV8 = 7,
51 52
  kAlgorithmCount = 8
#else
53
  kAlgorithmCount = 5
54
#endif
55 56
};

57
// AlgorithmsConfigKey -> AlgorithmsID
H
hong 已提交
58
// (todo. hong) use cudnnConvolutionFwdAlgo_t
59
using AlgorithmsCacheMap = AlgorithmsCache<size_t, int64_t>;
60 61
// AlgorithmType -> AlgorithmsCache
using AlgorithmsTypeMap = std::unordered_map<int64_t, AlgorithmsCacheMap>;
62 63 64
using ConvAlgorithmsCacheMap = ConvAlgorithmsCache<ConvAutoTuneResult>;
using ConvAlgorithmsTypeMap =
    std::unordered_map<int64_t, ConvAlgorithmsCacheMap>;
65 66 67 68
#ifdef PADDLE_WITH_CUDNN_FRONTEND
using CudnnV8AlgorithmsTypeMap =
    std::unordered_map<int64_t, CudnnFrontendPlanCache>;
#endif
69 70 71 72 73 74 75
class AutoTuneCache {
 public:
  static AutoTuneCache& Instance() {
    static AutoTuneCache autotune_cache;
    return autotune_cache;
  }

76 77
  AlgorithmsCacheMap& Get(const AlgorithmType& algo_type) {
    return auto_tune_map_[static_cast<int64_t>(algo_type)];
78 79
  }

80 81
  ConvAlgorithmsCacheMap& GetConv(const AlgorithmType& algo_type) {
    return conv_auto_tune_map_[static_cast<int64_t>(algo_type)];
Y
Yiqun Liu 已提交
82 83
  }

84 85 86 87 88 89
#ifdef PADDLE_WITH_CUDNN_FRONTEND
  CudnnFrontendPlanCache& GetConvV8(const AlgorithmType& algo_type) {
    return cudnn_v8_auto_tune_map_[static_cast<int64_t>(algo_type)];
  }
#endif

90 91
  AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); }

92
  void Clean() {
93
    for (auto& v : auto_tune_map_) {
94
      v.second.Clean();
95
    }
H
hong 已提交
96

97
    for (auto& v : conv_auto_tune_map_) {
H
hong 已提交
98 99
      v.second.Clean();
    }
100 101 102 103 104 105

#ifdef PADDLE_WITH_CUDNN_FRONTEND
    for (auto& v : cudnn_v8_auto_tune_map_) {
      v.second.Clean();
    }
#endif
106 107
  }

108 109
  void UpdateStatus();

110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
  // The number of total config cached
  int64_t Size() const { return total_size_; }

  int64_t CacheHits() const { return total_cache_hits_; }

  int64_t CacheMisses() const { return total_cache_misses_; }

  float CacheHitRate() const {
    float total_cache_hit_rate = 0.;
    int64_t total_num_accesses = total_cache_hits_ + total_cache_misses_;
    if (total_num_accesses != 0) {
      total_cache_hit_rate = static_cast<float>(total_cache_hits_) /
                             static_cast<float>(total_num_accesses);
    }
    return total_cache_hit_rate;
125 126 127
  }

 private:
128 129 130 131 132 133 134 135
  AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {
    for (int i = 1; i < static_cast<int>(AlgorithmType::kAlgorithmCount); ++i) {
      Register(static_cast<AlgorithmType>(i));
    }
  }

  void Register(const AlgorithmType& algo_type) {
    std::lock_guard<std::mutex> lock(*autotune_cache_mutex_);
H
hong 已提交
136 137 138 139 140
    if (algo_type == AlgorithmType::kConvForward ||
        algo_type == AlgorithmType::kConvBackwardData ||
        algo_type == AlgorithmType::kConvBackwardFilter) {
      int64_t key = static_cast<int64_t>(algo_type);
      if (auto_tune_map_.find(key) == auto_tune_map_.end()) {
141 142
        ConvAlgorithmsCacheMap cache;
        conv_auto_tune_map_[key] = cache;
H
hong 已提交
143
      }
144 145 146 147 148 149 150 151 152 153
#ifdef PADDLE_WITH_CUDNN_FRONTEND
    } else if (algo_type == AlgorithmType::kConvForwardV8 ||
               algo_type == AlgorithmType::kConvBackwardDataV8 ||
               algo_type == AlgorithmType::kConvBackwardFilterV8) {
      int64_t key = static_cast<int64_t>(algo_type);
      if (cudnn_v8_auto_tune_map_.find(key) == cudnn_v8_auto_tune_map_.end()) {
        CudnnFrontendPlanCache cache;
        cudnn_v8_auto_tune_map_[key] = cache;
      }
#endif
H
hong 已提交
154 155 156 157 158 159
    } else {
      int64_t key = static_cast<int64_t>(algo_type);
      if (auto_tune_map_.find(key) == auto_tune_map_.end()) {
        AlgorithmsCacheMap cache;
        auto_tune_map_[key] = cache;
      }
160 161 162
    }
  }

163
  AlgorithmsTypeMap auto_tune_map_;
164
  ConvAlgorithmsTypeMap conv_auto_tune_map_;
165 166 167
#ifdef PADDLE_WITH_CUDNN_FRONTEND
  CudnnV8AlgorithmsTypeMap cudnn_v8_auto_tune_map_;
#endif
168
  std::shared_ptr<std::mutex> autotune_cache_mutex_;
169 170 171
  int64_t total_cache_hits_{0};
  int64_t total_cache_misses_{0};
  int64_t total_size_{0};
172 173
};

174 175
}  // namespace autotune
}  // namespace phi