“2eae3616006ef1e1ff440211c5bb8c4399089318”上不存在“develop/doc/howto/cluster/multi_cluster/k8s_aws_en.html”
未验证 提交 b0f8000e 编写于 作者: Z Zhang Ting 提交者: GitHub

Implement AutoTuneStatus class for Kernel Auto Tune (#41218)

* switch autotune

* implement AutoTuneCache

* implement AutoTuneCache class

* add pybind api

* add dygraph test

* support static mode and eager mode and improve unittests

* rename the SwitchAutoTune Class and improve tests

* improve AutoTuneStatus and reduce the cost of tests
上级 3b0e911c
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
DECLARE_bool(sort_sum_gradient); DECLARE_bool(sort_sum_gradient);
...@@ -645,6 +646,8 @@ void BasicEngine::Execute() { ...@@ -645,6 +646,8 @@ void BasicEngine::Execute() {
Clear(); Clear();
VLOG(1) << "Backward op number: " << op_num; VLOG(1) << "Backward op number: " << op_num;
phi::autotune::AutoTuneStatus::Instance().Update();
} }
void BasicEngine::Clear() { void BasicEngine::Clear() {
......
...@@ -168,6 +168,8 @@ limitations under the License. */ ...@@ -168,6 +168,8 @@ limitations under the License. */
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
...@@ -4419,6 +4421,34 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -4419,6 +4421,34 @@ All parameter, weight, gradient are variables in Paddle.
.def("is_pattern_enabled", &platform::ipu::IpuStrategy::IsPatternEnabled); .def("is_pattern_enabled", &platform::ipu::IpuStrategy::IsPatternEnabled);
#endif #endif
m.def("enable_autotune", [] {
return phi::autotune::AutoTuneStatus::Instance().EnableAutoTune();
});
m.def("disable_autotune", [] {
return phi::autotune::AutoTuneStatus::Instance().DisableAutoTune();
});
m.def("autotune_range", [](int64_t start, int64_t stop) {
return phi::autotune::AutoTuneStatus::Instance().SetAutoTuneRange(start,
stop);
});
m.def("update_autotune_status",
[] { return phi::autotune::AutoTuneStatus::Instance().Update(); });
m.def("autotune_status", [] {
phi::autotune::AutoTuneCache::Instance().UpdateStatus();
py::dict res;
res["use_autotune"] =
phi::autotune::AutoTuneStatus::Instance().UseAutoTune();
res["step_id"] = phi::autotune::AutoTuneStatus::Instance().StepID();
res["cache_size"] = phi::autotune::AutoTuneCache::Instance().Size();
res["cache_hit_rate"] =
phi::autotune::AutoTuneCache::Instance().CacheHitRate();
return res;
});
BindFleetWrapper(&m); BindFleetWrapper(&m);
BindIO(&m); BindIO(&m);
......
...@@ -6,4 +6,6 @@ elseif (WITH_ROCM) ...@@ -6,4 +6,6 @@ elseif (WITH_ROCM)
hip_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest) hip_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest)
endif() endif()
cc_test(cache_test SRCS cache_test.cc DEPS gtest) cc_library(cache SRCS cache.cc DEPS)
cc_test(cache_test SRCS cache_test.cc DEPS gtest cache)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/autotune/cache.h"
namespace phi {
namespace autotune {
// Define the cache key of operator
size_t ConvKey(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& w_dims,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
phi::DataType dtype) {
return GetKey(x_dims,
w_dims,
strides,
paddings,
dilations,
static_cast<int64_t>(dtype));
}
} // namespace autotune
} // namespace phi
...@@ -64,14 +64,7 @@ size_t ConvKey(const std::vector<int64_t>& x_dims, ...@@ -64,14 +64,7 @@ size_t ConvKey(const std::vector<int64_t>& x_dims,
const std::vector<int>& strides, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const std::vector<int>& dilations, const std::vector<int>& dilations,
phi::DataType dtype) { phi::DataType dtype);
return GetKey(x_dims,
w_dims,
strides,
paddings,
dilations,
static_cast<int64_t>(dtype));
}
template <typename AlgorithmT> template <typename AlgorithmT>
class AlgorithmsCache { class AlgorithmsCache {
...@@ -104,14 +97,21 @@ class AlgorithmsCache { ...@@ -104,14 +97,21 @@ class AlgorithmsCache {
hash_[key] = algo; hash_[key] = algo;
} }
int64_t CacheMisses() const { return cache_misses_; }
int64_t CacheHits() const { return cache_hits_; }
float CacheHitRate() const { float CacheHitRate() const {
int64_t num_accesses = cache_hits_ + cache_misses_; int64_t num_accesses = cache_hits_ + cache_misses_;
float cache_hit_rate = float cache_hit_rate = 0.;
static_cast<float>(cache_hits_) / static_cast<float>(num_accesses); if (num_accesses != 0) {
cache_hit_rate =
static_cast<float>(cache_hits_) / static_cast<float>(num_accesses);
}
return cache_hit_rate; return cache_hit_rate;
} }
int64_t Size() { return hash_.size(); } int64_t Size() const { return hash_.size(); }
private: private:
std::unordered_map<size_t, AlgorithmT> hash_; std::unordered_map<size_t, AlgorithmT> hash_;
...@@ -142,20 +142,58 @@ class AutoTuneCache { ...@@ -142,20 +142,58 @@ class AutoTuneCache {
return auto_tune_map_[algo_type]; return auto_tune_map_[algo_type];
} }
// The number of total config cached void Clean(float miss_rate) {
int64_t Size() { std::lock_guard<std::mutex> lock(*autotune_cache_mutex_);
int64_t total = 0; // Set a small tolerance to avoid performance degradation
// due to large cache size under dynamic shape.
if (miss_rate > 0.01) {
auto_tune_map_.clear();
}
}
void UpdateStatus() {
int64_t size = 0;
int64_t cache_hits = 0;
int64_t cache_misses = 0;
for (auto& v : auto_tune_map_) { for (auto& v : auto_tune_map_) {
VLOG(3) << v.first << " " << v.second.Size(); VLOG(4) << "AlgoType: " << v.first << " Cache Size: " << v.second.Size()
total += v.second.Size(); << " Hits: " << v.second.CacheHits()
<< " Misses: " << v.second.CacheMisses()
<< " Hit Rate: " << v.second.CacheHitRate();
size += v.second.Size();
cache_hits += v.second.CacheHits();
cache_misses += v.second.CacheMisses();
} }
return total; total_size_ = size;
total_cache_hits_ = cache_hits;
total_cache_misses_ = cache_misses;
}
// The number of total config cached
int64_t Size() const { return total_size_; }
int64_t CacheHits() const { return total_cache_hits_; }
int64_t CacheMisses() const { return total_cache_misses_; }
float CacheHitRate() const {
float total_cache_hit_rate = 0.;
int64_t total_num_accesses = total_cache_hits_ + total_cache_misses_;
if (total_num_accesses != 0) {
total_cache_hit_rate = static_cast<float>(total_cache_hits_) /
static_cast<float>(total_num_accesses);
}
return total_cache_hit_rate;
} }
private: private:
AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {} AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {}
AlgorithmsTypeMap auto_tune_map_; AlgorithmsTypeMap auto_tune_map_;
std::shared_ptr<std::mutex> autotune_cache_mutex_; std::shared_ptr<std::mutex> autotune_cache_mutex_;
int64_t total_cache_hits_ = 0;
int64_t total_cache_misses_ = 0;
int64_t total_size_ = 0;
}; };
} // namespace autotune } // namespace autotune
......
...@@ -46,8 +46,15 @@ TEST(AlgosCache, AlgosCache) { ...@@ -46,8 +46,15 @@ TEST(AlgosCache, AlgosCache) {
EXPECT_EQ(cache.Find(key), false); EXPECT_EQ(cache.Find(key), false);
cache.Set(key, ConvAlgos::CuDNNKernel_1); cache.Set(key, ConvAlgos::CuDNNKernel_1);
EXPECT_EQ(cache.Size(), 2); EXPECT_EQ(cache.Size(), 2);
EXPECT_EQ(autotune_cache.Size(), 2); EXPECT_EQ(cache.CacheHits(), 1);
EXPECT_EQ(cache.CacheMisses(), 2);
float cache_hit_rate = static_cast<float>(1) / static_cast<float>(3); float cache_hit_rate = static_cast<float>(1) / static_cast<float>(3);
EXPECT_LT(std::abs(cache_hit_rate - cache.CacheHitRate()), 1e-5); EXPECT_LT(std::abs(cache_hit_rate - cache.CacheHitRate()), 1e-5);
autotune_cache.UpdateStatus();
EXPECT_EQ(autotune_cache.Size(), 2);
EXPECT_EQ(autotune_cache.CacheHits(), 1);
EXPECT_EQ(autotune_cache.CacheMisses(), 2);
EXPECT_LT(std::abs(cache_hit_rate - autotune_cache.CacheHitRate()), 1e-5);
} }
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <mutex>
#include <numeric>
#include "glog/logging.h"
#include "paddle/phi/kernels/autotune/cache.h"
namespace phi {
namespace autotune {
class AutoTuneStatus {
public:
static AutoTuneStatus& Instance() {
static AutoTuneStatus switch_autotune;
return switch_autotune;
}
bool UseAutoTune() { return use_autotune_; }
// EnableAutoTune and DisableAutoTune Should be used for debug only.
void EnableAutoTune() {
use_autotune_ = true;
Init();
}
void DisableAutoTune() {
use_autotune_ = false;
Init();
}
void Update() {
current_steps_id_ += 1;
if (!use_autotune_ && !update_use_autotune_) {
return;
}
if (current_steps_id_ < start_step_id_) {
use_autotune_ = false;
} else if (current_steps_id_ >= start_step_id_ &&
current_steps_id_ < stop_step_id_) {
use_autotune_ = true;
AutoTuneCache::Instance().UpdateStatus();
step_hit_rates_.push_back(StepHitRate());
VLOG(3) << "Step ID " << current_steps_id_
<< ", Accumulative Cache Hit Rate: "
<< AutoTuneCache::Instance().CacheHitRate()
<< ", Cache Size: " << AutoTuneCache::Instance().Size()
<< ", Current Step Hit Rate: " << StepHitRate();
} else if (current_steps_id_ == stop_step_id_) {
use_autotune_ = false;
update_use_autotune_ = false;
// clean cache according miss rate
float miss_rate = static_cast<float>(1) - RecentHitRate();
AutoTuneCache::Instance().Clean(miss_rate);
VLOG(3) << "Recent Miss Rate: " << miss_rate;
}
}
int64_t StepID() { return current_steps_id_; }
float RecentHitRate() {
int recent_step_nums = std::ceil(step_hit_rates_.size() * 0.3);
float sum = std::accumulate(step_hit_rates_.rbegin(),
step_hit_rates_.rbegin() + recent_step_nums,
0.0);
float mean = sum / recent_step_nums;
return mean;
}
// Hit Rate of Current Step
float StepHitRate() {
int64_t current_hits = AutoTuneCache::Instance().CacheHits();
int64_t current_misses = AutoTuneCache::Instance().CacheMisses();
int64_t step_hits_ = current_hits - previous_hits_;
int64_t step_misses_ = current_misses - previous_misses_;
float step_hit_rate = 0.;
int64_t step_num_accesses = step_hits_ + step_misses_;
if (step_num_accesses != 0) {
step_hit_rate = static_cast<float>(step_hits_) /
static_cast<float>(step_num_accesses);
}
previous_hits_ = current_hits;
previous_misses_ = current_misses;
return step_hit_rate;
}
void SetAutoTuneRange(int64_t start, int64_t stop) {
start_step_id_ = start;
stop_step_id_ = stop;
}
private:
AutoTuneStatus() = default;
void Init() {
update_use_autotune_ = use_autotune_;
current_steps_id_ = -1;
previous_hits_ = 0;
previous_misses_ = 0;
step_hit_rates_.clear();
AutoTuneCache::Instance().Clean(1.0);
}
int64_t start_step_id_ = 0;
int64_t stop_step_id_ = 10;
int64_t current_steps_id_ = -1;
bool use_autotune_ = false;
bool update_use_autotune_ = false;
int64_t previous_hits_ = 0;
int64_t previous_misses_ = 0;
std::vector<float> step_hit_rates_;
};
} // namespace autotune
} // namespace phi
...@@ -1276,7 +1276,7 @@ class Executor(object): ...@@ -1276,7 +1276,7 @@ class Executor(object):
""" """
try: try:
return self._run_impl( res = self._run_impl(
program=program, program=program,
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
...@@ -1287,6 +1287,8 @@ class Executor(object): ...@@ -1287,6 +1287,8 @@ class Executor(object):
use_program_cache=use_program_cache, use_program_cache=use_program_cache,
use_prune=use_prune, use_prune=use_prune,
return_merged=return_merged) return_merged=return_merged)
core.update_autotune_status()
return res
except Exception as e: except Exception as e:
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
import numpy
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv = paddle.nn.Conv2D(1, 2, (3, 3))
def forward(self, image, label=None):
return self.conv(image)
def train_dygraph(net, data):
out = net(data)
loss = paddle.mean(out)
adam = paddle.optimizer.Adam(parameters=net.parameters())
out.backward()
adam.step()
adam.clear_grad()
def static_program(net, data):
out = net(data)
loss = paddle.mean(out)
adam = paddle.optimizer.Adam()
adam.minimize(loss)
return loss
class TestAutoTune(unittest.TestCase):
def test_autotune(self):
paddle.fluid.core.disable_autotune()
status = paddle.fluid.core.autotune_status()
self.assertEqual(status["use_autotune"], False)
paddle.fluid.core.enable_autotune()
status = paddle.fluid.core.autotune_status()
self.assertEqual(status["use_autotune"], True)
def check_status(self, expected_res):
status = paddle.fluid.core.autotune_status()
for key in status.keys():
self.assertEqual(status[key], expected_res[key])
class TestDygraphAutoTuneStatus(TestAutoTune):
def run_program(self, enable_autotune):
if enable_autotune:
paddle.fluid.core.enable_autotune()
else:
paddle.fluid.core.disable_autotune()
paddle.fluid.core.autotune_range(1, 2)
x_var = paddle.uniform((1, 1, 8, 8), dtype='float32', min=-1., max=1.)
net = SimpleNet()
for i in range(3):
train_dygraph(net, x_var)
if i >= 1 and i < 2:
expected_res = {
"step_id": i,
"use_autotune": enable_autotune,
"cache_size": 0,
"cache_hit_rate": 0
}
self.check_status(expected_res)
else:
expected_res = {
"step_id": i,
"use_autotune": False,
"cache_size": 0,
"cache_hit_rate": 0
}
self.check_status(expected_res)
def test_enable_autotune(self):
self.run_program(enable_autotune=True)
def test_disable_autotune(self):
self.run_program(enable_autotune=False)
class TestStaticAutoTuneStatus(TestAutoTune):
def run_program(self, enable_autotune):
paddle.enable_static()
if enable_autotune:
paddle.fluid.core.enable_autotune()
else:
paddle.fluid.core.disable_autotune()
paddle.fluid.core.autotune_range(1, 2)
data_shape = [1, 1, 8, 8]
data = paddle.static.data(name='X', shape=data_shape, dtype='float32')
net = SimpleNet()
loss = static_program(net, data)
place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
x = numpy.random.random(size=data_shape).astype('float32')
for i in range(3):
exe.run(feed={'X': x}, fetch_list=[loss])
status = paddle.fluid.core.autotune_status()
# In static mode, the startup_program will run at first.
# The expected step_id will be increased by 1.
if i >= 0 and i < 1:
expected_res = {
"step_id": i + 1,
"use_autotune": enable_autotune,
"cache_size": 0,
"cache_hit_rate": 0
}
self.check_status(expected_res)
else:
expected_res = {
"step_id": i + 1,
"use_autotune": False,
"cache_size": 0,
"cache_hit_rate": 0
}
self.check_status(expected_res)
paddle.disable_static()
def test_enable_autotune(self):
self.run_program(enable_autotune=True)
def test_disable_autotune(self):
self.run_program(enable_autotune=False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册