未验证 提交 01b688c0 编写于 作者: Z Zhang Ting 提交者: GitHub

Implement a common AlgorithmsCache for kernel auto-tune (#40793)

上级 54632b5c
......@@ -3,3 +3,5 @@ if (WITH_GPU)
elseif (WITH_ROCM)
hip_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest)
endif()
cc_test(cache_test SRCS cache_test.cc DEPS gtest)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <mutex>
#include <unordered_map>
#include <vector>
#include "glog/logging.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
inline void HashCombine(std::size_t* seed) {}
// combine hash value
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
template <typename T, typename... Rest>
inline void HashCombine(std::size_t* seed, const T& v, Rest... rest) {
std::hash<T> hasher;
*seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
HashCombine(seed, rest...);
}
// custom specialization of std::hash can be injected in namespace std
// ref: https://en.cppreference.com/w/cpp/utility/hash
namespace std {
template <typename T>
struct hash<std::vector<T>> {
std::size_t operator()(std::vector<T> const& vec) const noexcept {
std::size_t seed = 0;
for (auto val : vec) {
HashCombine(&seed, val);
}
return seed;
}
};
} // namespace std
namespace phi {
namespace autotune {
template <typename AlgorithmT>
class AlgorithmsCache {
public:
AlgorithmsCache() { hash_.clear(); }
template <typename... Args>
size_t GetKey(Args&&... args) {
size_t seed = 0;
HashCombine(&seed, std::forward<Args>(args)...);
return seed;
}
AlgorithmT Get(size_t key) {
std::lock_guard<std::mutex> lock(cache_mutex_);
PADDLE_ENFORCE_NE(
hash_.find(key),
hash_.end(),
phi::errors::PreconditionNotMet("The key does not exist."));
return hash_[key];
}
bool Find(size_t key) {
bool ret = false;
std::lock_guard<std::mutex> lock(cache_mutex_);
if (hash_.find(key) != hash_.end()) {
cache_hits_++;
ret = true;
} else {
cache_misses_++;
}
return ret;
}
void Set(size_t key, AlgorithmT algo) {
std::lock_guard<std::mutex> lock(cache_mutex_);
hash_[key] = algo;
}
float CacheHitRate() const {
int64_t num_accesses = cache_hits_ + cache_misses_;
float cache_hit_rate =
static_cast<float>(cache_hits_) / static_cast<float>(num_accesses);
return cache_hit_rate;
}
// Define the cache key of operator
size_t ConvKey(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& w_dims,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
phi::DataType dtype) {
return GetKey(x_dims,
w_dims,
strides,
paddings,
dilations,
static_cast<int64_t>(dtype));
}
private:
std::unordered_map<size_t, AlgorithmT> hash_;
std::mutex cache_mutex_;
int64_t cache_hits_ = 0;
int64_t cache_misses_ = 0;
};
} // namespace autotune
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/autotune/cache.h"
#include <gtest/gtest.h>
#include <cmath>
#include <functional>
#include "glog/logging.h"
void Algo() { VLOG(3) << "algo test"; }
TEST(AlgosCache, AlgosCache) {
phi::autotune::AlgorithmsCache<std::function<void()>> cache;
std::vector<int64_t> x_shape = {4, 224, 224, 3};
std::vector<int64_t> w_shape = {32, 3, 3, 3};
std::vector<int> paddings = {0, 0};
std::vector<int> strides = {2, 2};
std::vector<int> dilations = {1, 1};
phi::DataType dtype = paddle::experimental::CppTypeToDataType<float>::Type();
auto key =
cache.ConvKey(x_shape, w_shape, paddings, strides, dilations, dtype);
EXPECT_EQ(cache.Find(key), false);
cache.Set(key, Algo);
EXPECT_EQ(cache.Find(key), true);
auto algo = cache.Get(key);
algo();
x_shape = {4, 128, 128, 3};
key = cache.ConvKey(x_shape, w_shape, paddings, strides, dilations, dtype);
EXPECT_EQ(cache.Find(key), false);
float cache_hit_rate = static_cast<float>(1) / static_cast<float>(3);
EXPECT_LT(std::abs(cache_hit_rate - cache.CacheHitRate()), 1e-5);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册