cache.cc 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/autotune/cache.h"
16

17
#include <iomanip>
18

19
#include "glog/logging.h"
20 21 22 23

namespace phi {
namespace autotune {

24 25 26 27
size_t TransposeKey(const std::vector<int64_t>& x_dims,
                    const std::vector<int32_t>& perm,
                    phi::DataType dtype) {
  const auto rank = perm.size();
28
  return GenKey(x_dims, perm, rank, static_cast<int>(dtype));
29 30
}

31 32 33 34 35 36 37 38 39 40
std::string AlgorithmTypeString(int64_t algo_type) {
  if (algo_type == static_cast<int64_t>(AlgorithmType::kConvForward)) {
    return "conv_forward";
  } else if (algo_type ==
             static_cast<int64_t>(AlgorithmType::kConvBackwardData)) {
    return "conv_backward_data";
  } else if (algo_type ==
             static_cast<int64_t>(AlgorithmType::kConvBackwardFilter)) {
    return "conv_backward_filter";
  }
41 42 43 44 45 46 47 48 49
#ifdef PADDLE_WITH_CUDNN_FRONTEND
  if (algo_type == static_cast<int64_t>(AlgorithmType::kConvForwardV8)) {
    return "conv_forward_v8";
  } else if (algo_type ==
             static_cast<int64_t>(AlgorithmType::kConvBackwardDataV8)) {
    return "conv_backward_data_v8";
  } else if (algo_type ==
             static_cast<int64_t>(AlgorithmType::kConvBackwardFilterV8)) {
    return "conv_backward_filter_v8";
50 51 52 53 54
  } else if (algo_type ==
             static_cast<int64_t>(AlgorithmType::kScaleBiasReluConvBNstats)) {
    return "scale_bias_relu_conv_bnstats";
  } else if (algo_type == static_cast<int64_t>(AlgorithmType::kBNFinalize)) {
    return "bn_finalize";
55 56
  }
#endif
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
  return std::to_string(algo_type);
}

void AutoTuneCache::UpdateStatus() {
  int64_t size = 0;
  int64_t cache_hits = 0;
  int64_t cache_misses = 0;
  int name_width = 24;
  std::cout.setf(std::ios::left);
  for (auto& v : auto_tune_map_) {
    VLOG(4) << "AlgoType: " << std::setfill(' ') << std::setw(name_width)
            << AlgorithmTypeString(v.first)
            << " Cache Size: " << v.second.Size()
            << " Hits: " << v.second.CacheHits()
            << " Misses: " << v.second.CacheMisses()
            << " Hit Rate: " << v.second.CacheHitRate();
    size += v.second.Size();
    cache_hits += v.second.CacheHits();
    cache_misses += v.second.CacheMisses();
  }
H
hong 已提交
77

78
  for (auto& v : conv_auto_tune_map_) {
H
hong 已提交
79 80 81 82 83 84 85 86 87 88 89
    VLOG(4) << "AlgoType: " << std::setfill(' ') << std::setw(name_width)
            << AlgorithmTypeString(v.first)
            << " Cache Size: " << v.second.Size()
            << " Hits: " << v.second.CacheHits()
            << " Misses: " << v.second.CacheMisses()
            << " Hit Rate: " << v.second.CacheHitRate();
    size += v.second.Size();
    cache_hits += v.second.CacheHits();
    cache_misses += v.second.CacheMisses();
  }

90 91 92 93 94 95 96 97 98 99 100 101 102 103
#ifdef PADDLE_WITH_CUDNN_FRONTEND
  for (auto& v : cudnn_v8_auto_tune_map_) {
    VLOG(4) << "AlgoType: " << std::setfill(' ') << std::setw(name_width)
            << AlgorithmTypeString(v.first)
            << " Cache Size: " << v.second.Size()
            << " Hits: " << v.second.CacheHits()
            << " Misses: " << v.second.CacheMisses()
            << " Hit Rate: " << v.second.CacheHitRate();
    size += v.second.Size();
    cache_hits += v.second.CacheHits();
    cache_misses += v.second.CacheMisses();
  }
#endif

104 105 106 107 108
  total_size_ = size;
  total_cache_hits_ = cache_hits;
  total_cache_misses_ = cache_misses;
}

109 110
}  // namespace autotune
}  // namespace phi