// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" DECLARE_int32(search_cache_max_number); inline void HashCombine(std::size_t* seed) {} // combine hash value // https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x template inline void HashCombine(std::size_t* seed, const T& v, Rest... rest) { std::hash hasher; *seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); *seed *= 0x00000100000001B3; HashCombine(seed, rest...); } // custom specialization of std::hash can be injected in namespace std // ref: https://en.cppreference.com/w/cpp/utility/hash namespace std { template struct hash> { std::size_t operator()(std::vector const& vec) const noexcept { std::size_t seed = 0xcbf29ce484222325; for (auto val : vec) { HashCombine(&seed, val); } return seed; } }; } // namespace std namespace phi { namespace autotune { struct ConvAutoTuneResult { ConvAutoTuneResult() {} ConvAutoTuneResult(int64_t a, size_t size, bool search) : algo(a), workspace_size(size), exhaustive_search(search) {} int64_t algo; size_t workspace_size = 0; bool exhaustive_search = false; }; template size_t GetKey(Args&&... args) { size_t seed = 0; HashCombine(&seed, std::forward(args)...); return seed; } struct ConvCacheKey { ConvCacheKey() {} ConvCacheKey(const std::vector& arg_x_dims, const std::vector& arg_w_dims, const std::vector& arg_strides, const std::vector& arg_paddings, const std::vector& arg_dilations, phi::DataType arg_dtype, int arg_groups, int64_t arg_data_layout) : x_dims(arg_x_dims), w_dims(arg_w_dims), strides(arg_strides), paddings(arg_paddings), dilations(arg_dilations), dtype(arg_dtype), groups(arg_groups), data_layout(arg_data_layout) {} size_t hash_value() const { return GetKey(x_dims, w_dims, strides, paddings, dilations, static_cast(dtype), groups, data_layout); } std::vector x_dims; std::vector w_dims; std::vector strides; std::vector paddings; std::vector dilations; phi::DataType dtype; int groups; int64_t data_layout; }; struct ConvCacheKeyHash { size_t operator()(const ConvCacheKey& cache) const { return cache.hash_value(); } }; struct ConvCacheKeyEqual { size_t operator()(const ConvCacheKey& first, const ConvCacheKey& second) const { if (first.x_dims != second.x_dims) return false; if (first.w_dims != second.w_dims) return false; if (first.strides != second.strides) return false; if (first.paddings != second.paddings) return false; if (first.dilations != second.dilations) return false; if (first.dtype != second.dtype) return false; if (first.groups != second.groups) return false; if (first.data_layout != second.data_layout) return false; return true; } }; class CudnnAlgorithmsCacheMap { public: CudnnAlgorithmsCacheMap() : cache_mutex_(new std::mutex()) { hash_.clear(); } ConvAutoTuneResult Get(const ConvCacheKey& key) { std::lock_guard lock(*cache_mutex_); PADDLE_ENFORCE_NE( hash_.find(key), hash_.end(), phi::errors::PreconditionNotMet("The key does not exist.")); return hash_[key]; } bool Find(const ConvCacheKey& key) { bool ret = false; std::lock_guard lock(*cache_mutex_); if (hash_.find(key) != hash_.end()) { cache_hits_++; ret = true; } else { cache_misses_++; } return ret; } void Clean() { std::lock_guard lock(*cache_mutex_); hash_.clear(); cache_hits_ = 0; cache_misses_ = 0; } void Set(const ConvCacheKey& key, ConvAutoTuneResult algo) { std::lock_guard lock(*cache_mutex_); if (hash_.size() > static_cast(FLAGS_search_cache_max_number)) { hash_.clear(); } hash_[key] = algo; } int64_t CacheMisses() const { return cache_misses_; } int64_t CacheHits() const { return cache_hits_; } float CacheHitRate() const { int64_t num_accesses = cache_hits_ + cache_misses_; float cache_hit_rate = 0.; if (num_accesses != 0) { cache_hit_rate = static_cast(cache_hits_) / static_cast(num_accesses); } return cache_hit_rate; } int64_t Size() const { return hash_.size(); } private: std::unordered_map hash_; std::shared_ptr cache_mutex_; int64_t cache_hits_{0}; int64_t cache_misses_{0}; }; size_t TransposeKey(const std::vector& x_dims, const std::vector& perm, phi::DataType dtype); template class AlgorithmsCache { public: AlgorithmsCache() : cache_mutex_(new std::mutex()) { hash_.clear(); } AlgorithmT Get(const size_t& key) { std::lock_guard lock(*cache_mutex_); PADDLE_ENFORCE_NE( hash_.find(key), hash_.end(), phi::errors::PreconditionNotMet("The key does not exist.")); return hash_[key]; } bool Find(const size_t& key) { bool ret = false; std::lock_guard lock(*cache_mutex_); if (hash_.find(key) != hash_.end()) { cache_hits_++; ret = true; } else { cache_misses_++; } return ret; } void Clean() { std::lock_guard lock(*cache_mutex_); hash_.clear(); cache_hits_ = 0; cache_misses_ = 0; } void Set(const size_t& key, AlgorithmT algo) { std::lock_guard lock(*cache_mutex_); hash_[key] = algo; } int64_t CacheMisses() const { return cache_misses_; } int64_t CacheHits() const { return cache_hits_; } float CacheHitRate() const { int64_t num_accesses = cache_hits_ + cache_misses_; float cache_hit_rate = 0.; if (num_accesses != 0) { cache_hit_rate = static_cast(cache_hits_) / static_cast(num_accesses); } return cache_hit_rate; } int64_t Size() const { return hash_.size(); } private: std::unordered_map hash_; std::shared_ptr cache_mutex_; int64_t cache_hits_{0}; int64_t cache_misses_{0}; }; enum class AlgorithmType { kConvForward = 1, kConvBackwardData = 2, kConvBackwardFilter = 3, kTranspose = 4, kAlgorithmCount = 5 }; // AlgorithmsConfigKey -> AlgorithmsID // (todo. hong) use cudnnConvolutionFwdAlgo_t using AlgorithmsCacheMap = AlgorithmsCache; // AlgorithmType -> AlgorithmsCache using AlgorithmsTypeMap = std::unordered_map; using CudnnAlgorithmsTypeMap = std::unordered_map; class AutoTuneCache { public: static AutoTuneCache& Instance() { static AutoTuneCache autotune_cache; return autotune_cache; } AlgorithmsCacheMap& Get(const AlgorithmType& algo_type) { return auto_tune_map_[static_cast(algo_type)]; } CudnnAlgorithmsCacheMap& GetConv(const AlgorithmType& algo_type) { return cudnn_auto_tune_map_[static_cast(algo_type)]; } AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); } void Clean() { for (auto& v : auto_tune_map_) { v.second.Clean(); } for (auto& v : cudnn_auto_tune_map_) { v.second.Clean(); } } void UpdateStatus(); // The number of total config cached int64_t Size() const { return total_size_; } int64_t CacheHits() const { return total_cache_hits_; } int64_t CacheMisses() const { return total_cache_misses_; } float CacheHitRate() const { float total_cache_hit_rate = 0.; int64_t total_num_accesses = total_cache_hits_ + total_cache_misses_; if (total_num_accesses != 0) { total_cache_hit_rate = static_cast(total_cache_hits_) / static_cast(total_num_accesses); } return total_cache_hit_rate; } private: AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) { for (int i = 1; i < static_cast(AlgorithmType::kAlgorithmCount); ++i) { Register(static_cast(i)); } } void Register(const AlgorithmType& algo_type) { std::lock_guard lock(*autotune_cache_mutex_); if (algo_type == AlgorithmType::kConvForward || algo_type == AlgorithmType::kConvBackwardData || algo_type == AlgorithmType::kConvBackwardFilter) { int64_t key = static_cast(algo_type); if (auto_tune_map_.find(key) == auto_tune_map_.end()) { CudnnAlgorithmsCacheMap cache; cudnn_auto_tune_map_[key] = cache; } } else { int64_t key = static_cast(algo_type); if (auto_tune_map_.find(key) == auto_tune_map_.end()) { AlgorithmsCacheMap cache; auto_tune_map_[key] = cache; } } } AlgorithmsTypeMap auto_tune_map_; CudnnAlgorithmsTypeMap cudnn_auto_tune_map_; std::shared_ptr autotune_cache_mutex_; int64_t total_cache_hits_{0}; int64_t total_cache_misses_{0}; int64_t total_size_{0}; }; } // namespace autotune } // namespace phi