diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index d7478b18dba0616fdc995866d8892c7c052a0e35..ce3c5dd2fe5622c9e68f4c1f2ee5c2bb20c1766b 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -30,6 +30,7 @@ #include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/platform/profiler.h" +#include "paddle/phi/kernels/autotune/switch_autotune.h" #include "paddle/phi/kernels/funcs/math_function.h" DECLARE_bool(sort_sum_gradient); @@ -645,6 +646,8 @@ void BasicEngine::Execute() { Clear(); VLOG(1) << "Backward op number: " << op_num; + + phi::autotune::AutoTuneStatus::Instance().Update(); } void BasicEngine::Clear() { diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 982aa52913d630fa97400294db38a2423792c48a..96d86ee1a3100457000d410679c2ddff1be1a815 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -168,6 +168,8 @@ limitations under the License. */ #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/phi/api/ext/op_meta_info.h" +#include "paddle/phi/kernels/autotune/cache.h" +#include "paddle/phi/kernels/autotune/switch_autotune.h" #include "pybind11/stl.h" DECLARE_bool(use_mkldnn); @@ -4419,6 +4421,34 @@ All parameter, weight, gradient are variables in Paddle. .def("is_pattern_enabled", &platform::ipu::IpuStrategy::IsPatternEnabled); #endif + m.def("enable_autotune", [] { + return phi::autotune::AutoTuneStatus::Instance().EnableAutoTune(); + }); + + m.def("disable_autotune", [] { + return phi::autotune::AutoTuneStatus::Instance().DisableAutoTune(); + }); + + m.def("autotune_range", [](int64_t start, int64_t stop) { + return phi::autotune::AutoTuneStatus::Instance().SetAutoTuneRange(start, + stop); + }); + + m.def("update_autotune_status", + [] { return phi::autotune::AutoTuneStatus::Instance().Update(); }); + + m.def("autotune_status", [] { + phi::autotune::AutoTuneCache::Instance().UpdateStatus(); + py::dict res; + res["use_autotune"] = + phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); + res["step_id"] = phi::autotune::AutoTuneStatus::Instance().StepID(); + res["cache_size"] = phi::autotune::AutoTuneCache::Instance().Size(); + res["cache_hit_rate"] = + phi::autotune::AutoTuneCache::Instance().CacheHitRate(); + return res; + }); + BindFleetWrapper(&m); BindIO(&m); diff --git a/paddle/phi/kernels/autotune/CMakeLists.txt b/paddle/phi/kernels/autotune/CMakeLists.txt index a3fb9a06fe6718dc1ff0da9398dae72e5230e057..db094d85bf3fd09e351723b72153f7004bc1eb7d 100644 --- a/paddle/phi/kernels/autotune/CMakeLists.txt +++ b/paddle/phi/kernels/autotune/CMakeLists.txt @@ -6,4 +6,6 @@ elseif (WITH_ROCM) hip_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest) endif() -cc_test(cache_test SRCS cache_test.cc DEPS gtest) +cc_library(cache SRCS cache.cc DEPS) + +cc_test(cache_test SRCS cache_test.cc DEPS gtest cache) diff --git a/paddle/phi/kernels/autotune/cache.cc b/paddle/phi/kernels/autotune/cache.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf68e2010151b28b6a0afab5078520ec8df4a31e --- /dev/null +++ b/paddle/phi/kernels/autotune/cache.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/autotune/cache.h" + +namespace phi { +namespace autotune { + +// Define the cache key of operator +size_t ConvKey(const std::vector& x_dims, + const std::vector& w_dims, + const std::vector& strides, + const std::vector& paddings, + const std::vector& dilations, + phi::DataType dtype) { + return GetKey(x_dims, + w_dims, + strides, + paddings, + dilations, + static_cast(dtype)); +} + +} // namespace autotune +} // namespace phi diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index 990843e58f7f2334bfb487214064afaf5fe96c44..d492e7c151f916173b6e952619412c81de23c3ff 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -64,14 +64,7 @@ size_t ConvKey(const std::vector& x_dims, const std::vector& strides, const std::vector& paddings, const std::vector& dilations, - phi::DataType dtype) { - return GetKey(x_dims, - w_dims, - strides, - paddings, - dilations, - static_cast(dtype)); -} + phi::DataType dtype); template class AlgorithmsCache { @@ -104,14 +97,21 @@ class AlgorithmsCache { hash_[key] = algo; } + int64_t CacheMisses() const { return cache_misses_; } + + int64_t CacheHits() const { return cache_hits_; } + float CacheHitRate() const { int64_t num_accesses = cache_hits_ + cache_misses_; - float cache_hit_rate = - static_cast(cache_hits_) / static_cast(num_accesses); + float cache_hit_rate = 0.; + if (num_accesses != 0) { + cache_hit_rate = + static_cast(cache_hits_) / static_cast(num_accesses); + } return cache_hit_rate; } - int64_t Size() { return hash_.size(); } + int64_t Size() const { return hash_.size(); } private: std::unordered_map hash_; @@ -142,20 +142,58 @@ class AutoTuneCache { return auto_tune_map_[algo_type]; } - // The number of total config cached - int64_t Size() { - int64_t total = 0; + void Clean(float miss_rate) { + std::lock_guard lock(*autotune_cache_mutex_); + // Set a small tolerance to avoid performance degradation + // due to large cache size under dynamic shape. + if (miss_rate > 0.01) { + auto_tune_map_.clear(); + } + } + + void UpdateStatus() { + int64_t size = 0; + int64_t cache_hits = 0; + int64_t cache_misses = 0; for (auto& v : auto_tune_map_) { - VLOG(3) << v.first << " " << v.second.Size(); - total += v.second.Size(); + VLOG(4) << "AlgoType: " << v.first << " Cache Size: " << v.second.Size() + << " Hits: " << v.second.CacheHits() + << " Misses: " << v.second.CacheMisses() + << " Hit Rate: " << v.second.CacheHitRate(); + size += v.second.Size(); + cache_hits += v.second.CacheHits(); + cache_misses += v.second.CacheMisses(); } - return total; + total_size_ = size; + total_cache_hits_ = cache_hits; + total_cache_misses_ = cache_misses; + } + + // The number of total config cached + int64_t Size() const { return total_size_; } + + int64_t CacheHits() const { return total_cache_hits_; } + + int64_t CacheMisses() const { return total_cache_misses_; } + + float CacheHitRate() const { + float total_cache_hit_rate = 0.; + int64_t total_num_accesses = total_cache_hits_ + total_cache_misses_; + if (total_num_accesses != 0) { + total_cache_hit_rate = static_cast(total_cache_hits_) / + static_cast(total_num_accesses); + } + + return total_cache_hit_rate; } private: AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {} AlgorithmsTypeMap auto_tune_map_; std::shared_ptr autotune_cache_mutex_; + int64_t total_cache_hits_ = 0; + int64_t total_cache_misses_ = 0; + int64_t total_size_ = 0; }; } // namespace autotune diff --git a/paddle/phi/kernels/autotune/cache_test.cc b/paddle/phi/kernels/autotune/cache_test.cc index 9fcd9b796d0ae17922821270da1470f9d75fca49..92ba411624fc0b1b6bd1237ae8f9b446032249b8 100644 --- a/paddle/phi/kernels/autotune/cache_test.cc +++ b/paddle/phi/kernels/autotune/cache_test.cc @@ -46,8 +46,15 @@ TEST(AlgosCache, AlgosCache) { EXPECT_EQ(cache.Find(key), false); cache.Set(key, ConvAlgos::CuDNNKernel_1); EXPECT_EQ(cache.Size(), 2); - EXPECT_EQ(autotune_cache.Size(), 2); + EXPECT_EQ(cache.CacheHits(), 1); + EXPECT_EQ(cache.CacheMisses(), 2); float cache_hit_rate = static_cast(1) / static_cast(3); EXPECT_LT(std::abs(cache_hit_rate - cache.CacheHitRate()), 1e-5); + + autotune_cache.UpdateStatus(); + EXPECT_EQ(autotune_cache.Size(), 2); + EXPECT_EQ(autotune_cache.CacheHits(), 1); + EXPECT_EQ(autotune_cache.CacheMisses(), 2); + EXPECT_LT(std::abs(cache_hit_rate - autotune_cache.CacheHitRate()), 1e-5); } diff --git a/paddle/phi/kernels/autotune/switch_autotune.h b/paddle/phi/kernels/autotune/switch_autotune.h new file mode 100644 index 0000000000000000000000000000000000000000..2f9621ed2079e497f389efd1dd74f46a255fe842 --- /dev/null +++ b/paddle/phi/kernels/autotune/switch_autotune.h @@ -0,0 +1,130 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "glog/logging.h" +#include "paddle/phi/kernels/autotune/cache.h" + +namespace phi { +namespace autotune { + +class AutoTuneStatus { + public: + static AutoTuneStatus& Instance() { + static AutoTuneStatus switch_autotune; + return switch_autotune; + } + + bool UseAutoTune() { return use_autotune_; } + + // EnableAutoTune and DisableAutoTune Should be used for debug only. + void EnableAutoTune() { + use_autotune_ = true; + Init(); + } + + void DisableAutoTune() { + use_autotune_ = false; + Init(); + } + + void Update() { + current_steps_id_ += 1; + + if (!use_autotune_ && !update_use_autotune_) { + return; + } + + if (current_steps_id_ < start_step_id_) { + use_autotune_ = false; + } else if (current_steps_id_ >= start_step_id_ && + current_steps_id_ < stop_step_id_) { + use_autotune_ = true; + AutoTuneCache::Instance().UpdateStatus(); + step_hit_rates_.push_back(StepHitRate()); + VLOG(3) << "Step ID " << current_steps_id_ + << ", Accumulative Cache Hit Rate: " + << AutoTuneCache::Instance().CacheHitRate() + << ", Cache Size: " << AutoTuneCache::Instance().Size() + << ", Current Step Hit Rate: " << StepHitRate(); + } else if (current_steps_id_ == stop_step_id_) { + use_autotune_ = false; + update_use_autotune_ = false; + // clean cache according miss rate + float miss_rate = static_cast(1) - RecentHitRate(); + AutoTuneCache::Instance().Clean(miss_rate); + VLOG(3) << "Recent Miss Rate: " << miss_rate; + } + } + + int64_t StepID() { return current_steps_id_; } + + float RecentHitRate() { + int recent_step_nums = std::ceil(step_hit_rates_.size() * 0.3); + float sum = std::accumulate(step_hit_rates_.rbegin(), + step_hit_rates_.rbegin() + recent_step_nums, + 0.0); + float mean = sum / recent_step_nums; + return mean; + } + + // Hit Rate of Current Step + float StepHitRate() { + int64_t current_hits = AutoTuneCache::Instance().CacheHits(); + int64_t current_misses = AutoTuneCache::Instance().CacheMisses(); + int64_t step_hits_ = current_hits - previous_hits_; + int64_t step_misses_ = current_misses - previous_misses_; + float step_hit_rate = 0.; + int64_t step_num_accesses = step_hits_ + step_misses_; + if (step_num_accesses != 0) { + step_hit_rate = static_cast(step_hits_) / + static_cast(step_num_accesses); + } + previous_hits_ = current_hits; + previous_misses_ = current_misses; + return step_hit_rate; + } + + void SetAutoTuneRange(int64_t start, int64_t stop) { + start_step_id_ = start; + stop_step_id_ = stop; + } + + private: + AutoTuneStatus() = default; + + void Init() { + update_use_autotune_ = use_autotune_; + current_steps_id_ = -1; + previous_hits_ = 0; + previous_misses_ = 0; + step_hit_rates_.clear(); + AutoTuneCache::Instance().Clean(1.0); + } + + int64_t start_step_id_ = 0; + int64_t stop_step_id_ = 10; + int64_t current_steps_id_ = -1; + bool use_autotune_ = false; + bool update_use_autotune_ = false; + int64_t previous_hits_ = 0; + int64_t previous_misses_ = 0; + std::vector step_hit_rates_; +}; + +} // namespace autotune +} // namespace phi diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 935f7b53eba5741e1896725d2d8c20c7eeb1a4e8..2232c34e63bd003115f8d48b6840ff49d0b10ace 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1276,7 +1276,7 @@ class Executor(object): """ try: - return self._run_impl( + res = self._run_impl( program=program, feed=feed, fetch_list=fetch_list, @@ -1287,6 +1287,8 @@ class Executor(object): use_program_cache=use_program_cache, use_prune=use_prune, return_merged=return_merged) + core.update_autotune_status() + return res except Exception as e: six.reraise(*sys.exc_info()) diff --git a/python/paddle/fluid/tests/unittests/test_switch_autotune.py b/python/paddle/fluid/tests/unittests/test_switch_autotune.py new file mode 100644 index 0000000000000000000000000000000000000000..08cf120a0366ebe5e0d9eed0794f2578cfd1dd1e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_switch_autotune.py @@ -0,0 +1,147 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest +import numpy + + +class SimpleNet(paddle.nn.Layer): + def __init__(self): + super(SimpleNet, self).__init__() + self.conv = paddle.nn.Conv2D(1, 2, (3, 3)) + + def forward(self, image, label=None): + return self.conv(image) + + +def train_dygraph(net, data): + out = net(data) + loss = paddle.mean(out) + adam = paddle.optimizer.Adam(parameters=net.parameters()) + out.backward() + adam.step() + adam.clear_grad() + + +def static_program(net, data): + out = net(data) + loss = paddle.mean(out) + adam = paddle.optimizer.Adam() + adam.minimize(loss) + return loss + + +class TestAutoTune(unittest.TestCase): + def test_autotune(self): + paddle.fluid.core.disable_autotune() + status = paddle.fluid.core.autotune_status() + self.assertEqual(status["use_autotune"], False) + + paddle.fluid.core.enable_autotune() + status = paddle.fluid.core.autotune_status() + self.assertEqual(status["use_autotune"], True) + + def check_status(self, expected_res): + status = paddle.fluid.core.autotune_status() + for key in status.keys(): + self.assertEqual(status[key], expected_res[key]) + + +class TestDygraphAutoTuneStatus(TestAutoTune): + def run_program(self, enable_autotune): + if enable_autotune: + paddle.fluid.core.enable_autotune() + else: + paddle.fluid.core.disable_autotune() + paddle.fluid.core.autotune_range(1, 2) + x_var = paddle.uniform((1, 1, 8, 8), dtype='float32', min=-1., max=1.) + net = SimpleNet() + for i in range(3): + train_dygraph(net, x_var) + if i >= 1 and i < 2: + expected_res = { + "step_id": i, + "use_autotune": enable_autotune, + "cache_size": 0, + "cache_hit_rate": 0 + } + self.check_status(expected_res) + else: + expected_res = { + "step_id": i, + "use_autotune": False, + "cache_size": 0, + "cache_hit_rate": 0 + } + self.check_status(expected_res) + + def test_enable_autotune(self): + self.run_program(enable_autotune=True) + + def test_disable_autotune(self): + self.run_program(enable_autotune=False) + + +class TestStaticAutoTuneStatus(TestAutoTune): + def run_program(self, enable_autotune): + paddle.enable_static() + if enable_autotune: + paddle.fluid.core.enable_autotune() + else: + paddle.fluid.core.disable_autotune() + paddle.fluid.core.autotune_range(1, 2) + + data_shape = [1, 1, 8, 8] + data = paddle.static.data(name='X', shape=data_shape, dtype='float32') + net = SimpleNet() + loss = static_program(net, data) + place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + x = numpy.random.random(size=data_shape).astype('float32') + + for i in range(3): + exe.run(feed={'X': x}, fetch_list=[loss]) + status = paddle.fluid.core.autotune_status() + # In static mode, the startup_program will run at first. + # The expected step_id will be increased by 1. + if i >= 0 and i < 1: + expected_res = { + "step_id": i + 1, + "use_autotune": enable_autotune, + "cache_size": 0, + "cache_hit_rate": 0 + } + self.check_status(expected_res) + else: + expected_res = { + "step_id": i + 1, + "use_autotune": False, + "cache_size": 0, + "cache_hit_rate": 0 + } + self.check_status(expected_res) + paddle.disable_static() + + def test_enable_autotune(self): + self.run_program(enable_autotune=True) + + def test_disable_autotune(self): + self.run_program(enable_autotune=False) + + +if __name__ == '__main__': + unittest.main()