// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/autotune/cache.h" #include #include #include #include "glog/logging.h" enum ConvAlgos { GEMMKernel = 0, CuDNNKernel_1 = 1, CuDNNKernel_2 = 2 }; TEST(AlgosCache, AlgosCache) { auto autotune_cache = phi::autotune::AutoTuneCache::Instance(); auto& cache = autotune_cache.GetConvForward(); std::vector x_shape = {4, 224, 224, 3}; std::vector w_shape = {32, 3, 3, 3}; std::vector paddings = {0, 0}; std::vector strides = {2, 2}; std::vector dilations = {1, 1}; phi::DataType dtype = paddle::experimental::CppTypeToDataType::Type(); phi::autotune::ConvCacheKey key( x_shape, w_shape, paddings, strides, dilations, dtype, 0, 0); EXPECT_EQ(cache.Find(key), false); phi::autotune::DnnNode node(static_cast(ConvAlgos::GEMMKernel), 0); cache.Set(key, node); EXPECT_EQ(cache.Size(), 1); EXPECT_EQ(cache.Find(key), true); auto algo = cache.Get(key); EXPECT_EQ(algo.algo, ConvAlgos::GEMMKernel); x_shape = {4, 128, 128, 3}; phi::autotune::ConvCacheKey key1( x_shape, w_shape, paddings, strides, dilations, dtype, 0, 1); EXPECT_EQ(cache.Find(key1), false); phi::autotune::DnnNode node1(static_cast(ConvAlgos::CuDNNKernel_1), 0); cache.Set(key1, node1); EXPECT_EQ(cache.Size(), 2); EXPECT_EQ(cache.CacheHits(), 1); EXPECT_EQ(cache.CacheMisses(), 2); float cache_hit_rate = static_cast(1) / static_cast(3); EXPECT_LT(std::abs(cache_hit_rate - cache.CacheHitRate()), 1e-5); autotune_cache.UpdateStatus(); EXPECT_EQ(autotune_cache.Size(), 2); EXPECT_EQ(autotune_cache.CacheHits(), 1); EXPECT_EQ(autotune_cache.CacheMisses(), 2); EXPECT_LT(std::abs(cache_hit_rate - autotune_cache.CacheHitRate()), 1e-5); }