未验证 提交 7dfd3846 编写于 作者: Z Zhang Ting 提交者: GitHub

Implement AutotuneCache class for Kernel AutoTune (#41169)

上级 6744754f
...@@ -51,20 +51,35 @@ struct hash<std::vector<T>> { ...@@ -51,20 +51,35 @@ struct hash<std::vector<T>> {
namespace phi { namespace phi {
namespace autotune { namespace autotune {
template <typename... Args>
size_t GetKey(Args&&... args) {
size_t seed = 0;
HashCombine(&seed, std::forward<Args>(args)...);
return seed;
}
// Define the cache key of operator
size_t ConvKey(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& w_dims,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
phi::DataType dtype) {
return GetKey(x_dims,
w_dims,
strides,
paddings,
dilations,
static_cast<int64_t>(dtype));
}
template <typename AlgorithmT> template <typename AlgorithmT>
class AlgorithmsCache { class AlgorithmsCache {
public: public:
AlgorithmsCache() { hash_.clear(); } AlgorithmsCache() : cache_mutex_(new std::mutex()) { hash_.clear(); }
template <typename... Args>
size_t GetKey(Args&&... args) {
size_t seed = 0;
HashCombine(&seed, std::forward<Args>(args)...);
return seed;
}
AlgorithmT Get(size_t key) { AlgorithmT Get(size_t key) {
std::lock_guard<std::mutex> lock(cache_mutex_); std::lock_guard<std::mutex> lock(*cache_mutex_);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
hash_.find(key), hash_.find(key),
hash_.end(), hash_.end(),
...@@ -74,7 +89,7 @@ class AlgorithmsCache { ...@@ -74,7 +89,7 @@ class AlgorithmsCache {
bool Find(size_t key) { bool Find(size_t key) {
bool ret = false; bool ret = false;
std::lock_guard<std::mutex> lock(cache_mutex_); std::lock_guard<std::mutex> lock(*cache_mutex_);
if (hash_.find(key) != hash_.end()) { if (hash_.find(key) != hash_.end()) {
cache_hits_++; cache_hits_++;
ret = true; ret = true;
...@@ -85,7 +100,7 @@ class AlgorithmsCache { ...@@ -85,7 +100,7 @@ class AlgorithmsCache {
} }
void Set(size_t key, AlgorithmT algo) { void Set(size_t key, AlgorithmT algo) {
std::lock_guard<std::mutex> lock(cache_mutex_); std::lock_guard<std::mutex> lock(*cache_mutex_);
hash_[key] = algo; hash_[key] = algo;
} }
...@@ -96,27 +111,52 @@ class AlgorithmsCache { ...@@ -96,27 +111,52 @@ class AlgorithmsCache {
return cache_hit_rate; return cache_hit_rate;
} }
// Define the cache key of operator int64_t Size() { return hash_.size(); }
size_t ConvKey(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& w_dims,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
phi::DataType dtype) {
return GetKey(x_dims,
w_dims,
strides,
paddings,
dilations,
static_cast<int64_t>(dtype));
}
private: private:
std::unordered_map<size_t, AlgorithmT> hash_; std::unordered_map<size_t, AlgorithmT> hash_;
std::mutex cache_mutex_; std::shared_ptr<std::mutex> cache_mutex_;
int64_t cache_hits_ = 0; int64_t cache_hits_ = 0;
int64_t cache_misses_ = 0; int64_t cache_misses_ = 0;
}; };
// AlgorithmsConfigKey -> AlgorithmsID
using AlgorithmsConfigKeyMap = AlgorithmsCache<int64_t>;
// AlgorithmsType -> AlgorithmsCache
using AlgorithmsTypeMap =
std::unordered_map<std::string, AlgorithmsConfigKeyMap>;
class AutoTuneCache {
public:
static AutoTuneCache& Instance() {
static AutoTuneCache autotune_cache;
return autotune_cache;
}
AlgorithmsConfigKeyMap& RegisterOrGet(const std::string& algo_type) {
std::lock_guard<std::mutex> lock(*autotune_cache_mutex_);
if (auto_tune_map_.find(algo_type) == auto_tune_map_.end()) {
AlgorithmsConfigKeyMap cache;
auto_tune_map_[algo_type] = cache;
}
return auto_tune_map_[algo_type];
}
// The number of total config cached
int64_t Size() {
int64_t total = 0;
for (auto& v : auto_tune_map_) {
VLOG(3) << v.first << " " << v.second.Size();
total += v.second.Size();
}
return total;
}
private:
AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {}
AlgorithmsTypeMap auto_tune_map_;
std::shared_ptr<std::mutex> autotune_cache_mutex_;
};
} // namespace autotune } // namespace autotune
} // namespace phi } // namespace phi
...@@ -18,10 +18,12 @@ ...@@ -18,10 +18,12 @@
#include <functional> #include <functional>
#include "glog/logging.h" #include "glog/logging.h"
void Algo() { VLOG(3) << "algo test"; } enum ConvAlgos { GEMMKernel = 0, CuDNNKernel_1 = 1, CuDNNKernel_2 = 2 };
TEST(AlgosCache, AlgosCache) { TEST(AlgosCache, AlgosCache) {
phi::autotune::AlgorithmsCache<std::function<void()>> cache; auto autotune_cache = phi::autotune::AutoTuneCache::Instance();
auto& cache = autotune_cache.RegisterOrGet("conv_fw");
std::vector<int64_t> x_shape = {4, 224, 224, 3}; std::vector<int64_t> x_shape = {4, 224, 224, 3};
std::vector<int64_t> w_shape = {32, 3, 3, 3}; std::vector<int64_t> w_shape = {32, 3, 3, 3};
std::vector<int> paddings = {0, 0}; std::vector<int> paddings = {0, 0};
...@@ -29,17 +31,23 @@ TEST(AlgosCache, AlgosCache) { ...@@ -29,17 +31,23 @@ TEST(AlgosCache, AlgosCache) {
std::vector<int> dilations = {1, 1}; std::vector<int> dilations = {1, 1};
phi::DataType dtype = paddle::experimental::CppTypeToDataType<float>::Type(); phi::DataType dtype = paddle::experimental::CppTypeToDataType<float>::Type();
auto key = auto key = phi::autotune::ConvKey(
cache.ConvKey(x_shape, w_shape, paddings, strides, dilations, dtype); x_shape, w_shape, paddings, strides, dilations, dtype);
EXPECT_EQ(cache.Find(key), false); EXPECT_EQ(cache.Find(key), false);
cache.Set(key, Algo); cache.Set(key, ConvAlgos::GEMMKernel);
EXPECT_EQ(cache.Size(), 1);
EXPECT_EQ(cache.Find(key), true); EXPECT_EQ(cache.Find(key), true);
auto algo = cache.Get(key); auto algo = cache.Get(key);
algo(); EXPECT_EQ(algo, ConvAlgos::GEMMKernel);
x_shape = {4, 128, 128, 3}; x_shape = {4, 128, 128, 3};
key = cache.ConvKey(x_shape, w_shape, paddings, strides, dilations, dtype); key = phi::autotune::ConvKey(
x_shape, w_shape, paddings, strides, dilations, dtype);
EXPECT_EQ(cache.Find(key), false); EXPECT_EQ(cache.Find(key), false);
cache.Set(key, ConvAlgos::CuDNNKernel_1);
EXPECT_EQ(cache.Size(), 2);
EXPECT_EQ(autotune_cache.Size(), 2);
float cache_hit_rate = static_cast<float>(1) / static_cast<float>(3); float cache_hit_rate = static_cast<float>(1) / static_cast<float>(3);
EXPECT_LT(std::abs(cache_hit_rate - cache.CacheHitRate()), 1e-5); EXPECT_LT(std::abs(cache_hit_rate - cache.CacheHitRate()), 1e-5);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册