diff --git a/paddle/phi/kernels/autotune/CMakeLists.txt b/paddle/phi/kernels/autotune/CMakeLists.txt index c7bb30d2d767cfc712fc19152f35bb406a89eac9..d1579f66e29e064a90dde7ecb4eb6eb9cb62da45 100644 --- a/paddle/phi/kernels/autotune/CMakeLists.txt +++ b/paddle/phi/kernels/autotune/CMakeLists.txt @@ -3,3 +3,5 @@ if (WITH_GPU) elseif (WITH_ROCM) hip_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest) endif() + +cc_test(cache_test SRCS cache_test.cc DEPS gtest) diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h new file mode 100644 index 0000000000000000000000000000000000000000..c5b068c28994d6a37dbf616af03d29a90ac14d9d --- /dev/null +++ b/paddle/phi/kernels/autotune/cache.h @@ -0,0 +1,122 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "glog/logging.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" + +inline void HashCombine(std::size_t* seed) {} + +// combine hash value +// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x +template +inline void HashCombine(std::size_t* seed, const T& v, Rest... rest) { + std::hash hasher; + *seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); + HashCombine(seed, rest...); +} + +// custom specialization of std::hash can be injected in namespace std +// ref: https://en.cppreference.com/w/cpp/utility/hash +namespace std { +template +struct hash> { + std::size_t operator()(std::vector const& vec) const noexcept { + std::size_t seed = 0; + for (auto val : vec) { + HashCombine(&seed, val); + } + return seed; + } +}; +} // namespace std + +namespace phi { +namespace autotune { + +template +class AlgorithmsCache { + public: + AlgorithmsCache() { hash_.clear(); } + + template + size_t GetKey(Args&&... args) { + size_t seed = 0; + HashCombine(&seed, std::forward(args)...); + return seed; + } + + AlgorithmT Get(size_t key) { + std::lock_guard lock(cache_mutex_); + PADDLE_ENFORCE_NE( + hash_.find(key), + hash_.end(), + phi::errors::PreconditionNotMet("The key does not exist.")); + return hash_[key]; + } + + bool Find(size_t key) { + bool ret = false; + std::lock_guard lock(cache_mutex_); + if (hash_.find(key) != hash_.end()) { + cache_hits_++; + ret = true; + } else { + cache_misses_++; + } + return ret; + } + + void Set(size_t key, AlgorithmT algo) { + std::lock_guard lock(cache_mutex_); + hash_[key] = algo; + } + + float CacheHitRate() const { + int64_t num_accesses = cache_hits_ + cache_misses_; + float cache_hit_rate = + static_cast(cache_hits_) / static_cast(num_accesses); + return cache_hit_rate; + } + + // Define the cache key of operator + size_t ConvKey(const std::vector& x_dims, + const std::vector& w_dims, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + phi::DataType dtype) { + return GetKey(x_dims, + w_dims, + strides, + paddings, + dilations, + static_cast(dtype)); + } + + private: + std::unordered_map hash_; + std::mutex cache_mutex_; + int64_t cache_hits_ = 0; + int64_t cache_misses_ = 0; +}; + +} // namespace autotune +} // namespace phi diff --git a/paddle/phi/kernels/autotune/cache_test.cc b/paddle/phi/kernels/autotune/cache_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b08a6cfc14ae887b1543aab219706df5997eb9d3 --- /dev/null +++ b/paddle/phi/kernels/autotune/cache_test.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/autotune/cache.h" +#include +#include +#include +#include "glog/logging.h" + +void Algo() { VLOG(3) << "algo test"; } + +TEST(AlgosCache, AlgosCache) { + phi::autotune::AlgorithmsCache> cache; + std::vector x_shape = {4, 224, 224, 3}; + std::vector w_shape = {32, 3, 3, 3}; + std::vector paddings = {0, 0}; + std::vector strides = {2, 2}; + std::vector dilations = {1, 1}; + phi::DataType dtype = paddle::experimental::CppTypeToDataType::Type(); + + auto key = + cache.ConvKey(x_shape, w_shape, paddings, strides, dilations, dtype); + EXPECT_EQ(cache.Find(key), false); + cache.Set(key, Algo); + EXPECT_EQ(cache.Find(key), true); + auto algo = cache.Get(key); + algo(); + + x_shape = {4, 128, 128, 3}; + key = cache.ConvKey(x_shape, w_shape, paddings, strides, dilations, dtype); + EXPECT_EQ(cache.Find(key), false); + float cache_hit_rate = static_cast(1) / static_cast(3); + EXPECT_LT(std::abs(cache_hit_rate - cache.CacheHitRate()), 1e-5); +}