未验证 提交 35acfeda 编写于 作者: L limingshu 提交者: GitHub

Change cuDNN Conv kernel for auto tune feature (#41313)

* change cudnn helper for auto-tune

* Add FLAGS_use_autotune to set the global status of autotune and change the order of choosing algorithm.

* Fix the bug in calculating and printing current step cache hit rate.

* Improve the autotune cache and fix unittest.

* Change the key from AlgorithmType to int64_t.

* Fix unittest for cpu-only env.

* change ChooseAlgoByWorkspace for heuristic mode
Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
上级 10114859
...@@ -15,7 +15,7 @@ if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) ...@@ -15,7 +15,7 @@ if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_subdirectory(pylayer) add_subdirectory(pylayer)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator) cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator)
add_dependencies(grad_tensor_holder eager_final_state_codegen) add_dependencies(grad_tensor_holder eager_final_state_codegen)
cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info) cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info switch_autotune)
endif() endif()
cc_library(grad_node_info SRCS grad_node_info.cc DEPS phi_api phi_tensor) cc_library(grad_node_info SRCS grad_node_info.cc DEPS phi_api phi_tensor)
......
...@@ -9,8 +9,8 @@ cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_f ...@@ -9,8 +9,8 @@ cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_f
add_subdirectory(jit) add_subdirectory(jit)
cc_library(amp SRCS amp_auto_cast.cc DEPS layer var_helper) cc_library(amp SRCS amp_auto_cast.cc DEPS layer var_helper)
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector var_helper) cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector var_helper)
cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator) cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator switch_autotune)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator) cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator switch_autotune)
cc_library(imperative_profiler SRCS profiler.cc DEPS flags) cc_library(imperative_profiler SRCS profiler.cc DEPS flags)
if(NOT WIN32) if(NOT WIN32)
if(WITH_NCCL OR WITH_RCCL) if(WITH_NCCL OR WITH_RCCL)
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/conv_search_cache.h" #include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/autotune/cache.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -41,12 +42,22 @@ struct SearchAlgorithm {}; ...@@ -41,12 +42,22 @@ struct SearchAlgorithm {};
// As the container of searchAlgorithm::Find() result. // As the container of searchAlgorithm::Find() result.
template <typename AlgoT> template <typename AlgoT>
struct SearchResult { struct SearchResult {
public: SearchResult() {}
explicit SearchResult(AlgoT a) : algo(a) {}
AlgoT algo = static_cast<AlgoT>(0); AlgoT algo = static_cast<AlgoT>(0);
float time = -1.f; float time = -1.f;
size_t workspace_size = 0; size_t workspace_size = 0;
}; };
template <typename T>
static std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
out << "[";
for (auto const& tmp : v) out << tmp << ",";
out << "]";
return out;
}
// As the container of conv relevant descriptors. // As the container of conv relevant descriptors.
template <typename HandleT, typename DataT> template <typename HandleT, typename DataT>
struct ConvArgsBase { struct ConvArgsBase {
...@@ -68,6 +79,17 @@ struct ConvArgsBase { ...@@ -68,6 +79,17 @@ struct ConvArgsBase {
const framework::Tensor* o, const std::vector<int> s, const framework::Tensor* o, const std::vector<int> s,
const std::vector<int> p, const std::vector<int> d, DataT dtype) const std::vector<int> p, const std::vector<int> d, DataT dtype)
: x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {} : x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {}
template <typename T>
size_t GetCacheKey() const {
auto x_shape = phi::vectorize(x->dims());
auto w_shape = phi::vectorize(w->dims());
VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape
<< ", strides=" << s << ", paddings=" << p << ", dilations=" << d;
return phi::autotune::ConvKey(
x_shape, w_shape, p, s, d,
paddle::experimental::CppTypeToDataType<T>::Type());
}
}; };
static inline void GetNCDHW(const framework::DDim& dims, static inline void GetNCDHW(const framework::DDim& dims,
...@@ -87,13 +109,5 @@ static inline void GetNCDHW(const framework::DDim& dims, ...@@ -87,13 +109,5 @@ static inline void GetNCDHW(const framework::DDim& dims,
} }
} }
template <typename T>
static std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
out << "[";
for (auto const& tmp : v) out << tmp << ",";
out << "]";
return out;
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -774,3 +774,12 @@ DEFINE_bool(enable_ins_parser_file, false, ...@@ -774,3 +774,12 @@ DEFINE_bool(enable_ins_parser_file, false,
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait");
#endif #endif
/**
* Autotune related FLAG
* Name: FLAGS_use_autotune
* Since Version: 2.3.0
* Value Range: bool, default=false
* Example:
*/
PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune.");
...@@ -4469,7 +4469,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -4469,7 +4469,7 @@ All parameter, weight, gradient are variables in Paddle.
return phi::autotune::AutoTuneStatus::Instance().DisableAutoTune(); return phi::autotune::AutoTuneStatus::Instance().DisableAutoTune();
}); });
m.def("autotune_range", [](int64_t start, int64_t stop) { m.def("set_autotune_range", [](int64_t start, int64_t stop) {
return phi::autotune::AutoTuneStatus::Instance().SetAutoTuneRange(start, return phi::autotune::AutoTuneStatus::Instance().SetAutoTuneRange(start,
stop); stop);
}); });
...@@ -4478,10 +4478,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -4478,10 +4478,8 @@ All parameter, weight, gradient are variables in Paddle.
[] { return phi::autotune::AutoTuneStatus::Instance().Update(); }); [] { return phi::autotune::AutoTuneStatus::Instance().Update(); });
m.def("autotune_status", [] { m.def("autotune_status", [] {
phi::autotune::AutoTuneCache::Instance().UpdateStatus();
py::dict res; py::dict res;
res["use_autotune"] = phi::autotune::AutoTuneCache::Instance().UpdateStatus();
phi::autotune::AutoTuneStatus::Instance().UseAutoTune();
res["step_id"] = phi::autotune::AutoTuneStatus::Instance().StepID(); res["step_id"] = phi::autotune::AutoTuneStatus::Instance().StepID();
res["cache_size"] = phi::autotune::AutoTuneCache::Instance().Size(); res["cache_size"] = phi::autotune::AutoTuneCache::Instance().Size();
res["cache_hit_rate"] = res["cache_hit_rate"] =
......
...@@ -6,12 +6,15 @@ file(APPEND ${kernel_declare_file} "#include \"paddle/phi/core/kernel_registry.h ...@@ -6,12 +6,15 @@ file(APPEND ${kernel_declare_file} "#include \"paddle/phi/core/kernel_registry.h
# phi functors and functions called by kernels # phi functors and functions called by kernels
add_subdirectory(funcs) add_subdirectory(funcs)
# kernel autotune
add_subdirectory(autotune)
# phi depends all phi kernel targets # phi depends all phi kernel targets
set_property(GLOBAL PROPERTY PHI_KERNELS "") set_property(GLOBAL PROPERTY PHI_KERNELS "")
# [ 1. Common kernel compilation dependencies ] # [ 1. Common kernel compilation dependencies ]
set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel) set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor ) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor)
# remove this dep after removing fluid deps on tensor creation # remove this dep after removing fluid deps on tensor creation
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
...@@ -27,12 +30,16 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel) ...@@ -27,12 +30,16 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used. # Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies. # These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here. # In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS cross_entropy_kernel adam_kernel adamw_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel set(AUTOTUNE_KERNELS conv_kernel conv_grad_kernel conv_grad_grad_kernel conv_transpose_kernel conv_transpose_grad_kernel)
set(MANUAL_BUILD_KERNELS ${AUTOTUNE_KERNELS} cross_entropy_kernel adam_kernel adamw_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel
gumbel_softmax_kernel gumbel_softmax_grad_kernel hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel gumbel_softmax_kernel gumbel_softmax_grad_kernel hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel
triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel rnn_kernel rnn_grad_kernel warpctc_kernel warpctc_grad_kernel) triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel rnn_kernel rnn_grad_kernel warpctc_kernel warpctc_grad_kernel)
foreach(src ${AUTOTUNE_KERNELS})
kernel_library(${src} DEPS ${COMMON_KERNEL_DEPS} switch_autotune)
endforeach()
kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper) kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper)
kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel) kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel)
kernel_library(cross_entropy_kernel DEPS ${COMMON_KERNEL_DEPS} softmax cross_entropy) kernel_library(cross_entropy_kernel DEPS ${COMMON_KERNEL_DEPS} softmax cross_entropy)
...@@ -75,6 +82,3 @@ add_subdirectory(selected_rows) ...@@ -75,6 +82,3 @@ add_subdirectory(selected_rows)
copy_if_different(${kernel_declare_file} ${kernel_declare_file_final}) copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
# For strings kernels # For strings kernels
add_subdirectory(strings) add_subdirectory(strings)
# 5. kernel autotune
add_subdirectory(autotune)
if (WITH_GPU) if (WITH_GPU)
nv_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest) nv_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest)
nv_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest) nv_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest)
elseif (WITH_ROCM) elseif (WITH_ROCM)
hip_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest) hip_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest)
hip_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest) hip_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest)
endif() endif()
cc_library(cache SRCS cache.cc DEPS boost) cc_library(cache SRCS cache.cc DEPS boost)
cc_library(switch_autotune SRCS switch_autotune.cc DEPS cache flags)
cc_test(cache_test SRCS cache_test.cc DEPS gtest cache) cc_test(cache_test SRCS cache_test.cc DEPS gtest cache)
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/autotune/cache.h" #include "paddle/phi/kernels/autotune/cache.h"
#include <iomanip>
#include "glog/logging.h"
namespace phi { namespace phi {
namespace autotune { namespace autotune {
...@@ -32,5 +34,40 @@ size_t ConvKey(const std::vector<int64_t>& x_dims, ...@@ -32,5 +34,40 @@ size_t ConvKey(const std::vector<int64_t>& x_dims,
static_cast<int64_t>(dtype)); static_cast<int64_t>(dtype));
} }
std::string AlgorithmTypeString(int64_t algo_type) {
if (algo_type == static_cast<int64_t>(AlgorithmType::kConvForward)) {
return "conv_forward";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kConvBackwardData)) {
return "conv_backward_data";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kConvBackwardFilter)) {
return "conv_backward_filter";
}
return std::to_string(algo_type);
}
void AutoTuneCache::UpdateStatus() {
int64_t size = 0;
int64_t cache_hits = 0;
int64_t cache_misses = 0;
int name_width = 24;
std::cout.setf(std::ios::left);
for (auto& v : auto_tune_map_) {
VLOG(4) << "AlgoType: " << std::setfill(' ') << std::setw(name_width)
<< AlgorithmTypeString(v.first)
<< " Cache Size: " << v.second.Size()
<< " Hits: " << v.second.CacheHits()
<< " Misses: " << v.second.CacheMisses()
<< " Hit Rate: " << v.second.CacheHitRate();
size += v.second.Size();
cache_hits += v.second.CacheHits();
cache_misses += v.second.CacheMisses();
}
total_size_ = size;
total_cache_hits_ = cache_hits;
total_cache_misses_ = cache_misses;
}
} // namespace autotune } // namespace autotune
} // namespace phi } // namespace phi
...@@ -13,11 +13,12 @@ ...@@ -13,11 +13,12 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <mutex> #include <mutex>
#include <numeric>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "glog/logging.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h" #include "paddle/phi/core/errors.h"
...@@ -92,6 +93,13 @@ class AlgorithmsCache { ...@@ -92,6 +93,13 @@ class AlgorithmsCache {
return ret; return ret;
} }
void Clean() {
std::lock_guard<std::mutex> lock(*cache_mutex_);
hash_.clear();
cache_hits_ = 0;
cache_misses_ = 0;
}
void Set(size_t key, AlgorithmT algo) { void Set(size_t key, AlgorithmT algo) {
std::lock_guard<std::mutex> lock(*cache_mutex_); std::lock_guard<std::mutex> lock(*cache_mutex_);
hash_[key] = algo; hash_[key] = algo;
...@@ -116,15 +124,22 @@ class AlgorithmsCache { ...@@ -116,15 +124,22 @@ class AlgorithmsCache {
private: private:
std::unordered_map<size_t, AlgorithmT> hash_; std::unordered_map<size_t, AlgorithmT> hash_;
std::shared_ptr<std::mutex> cache_mutex_; std::shared_ptr<std::mutex> cache_mutex_;
int64_t cache_hits_ = 0;
int64_t cache_misses_ = 0; int64_t cache_hits_{0};
int64_t cache_misses_{0};
};
enum class AlgorithmType {
kConvForward = 1,
kConvBackwardData = 2,
kConvBackwardFilter = 3,
kAlgorithmCount = 4
}; };
// AlgorithmsConfigKey -> AlgorithmsID // AlgorithmsConfigKey -> AlgorithmsID
using AlgorithmsConfigKeyMap = AlgorithmsCache<int64_t>; using AlgorithmsCacheMap = AlgorithmsCache<int64_t>;
// AlgorithmsType -> AlgorithmsCache // AlgorithmType -> AlgorithmsCache
using AlgorithmsTypeMap = using AlgorithmsTypeMap = std::unordered_map<int64_t, AlgorithmsCacheMap>;
std::unordered_map<std::string, AlgorithmsConfigKeyMap>;
class AutoTuneCache { class AutoTuneCache {
public: public:
...@@ -133,42 +148,30 @@ class AutoTuneCache { ...@@ -133,42 +148,30 @@ class AutoTuneCache {
return autotune_cache; return autotune_cache;
} }
AlgorithmsConfigKeyMap& RegisterOrGet(const std::string& algo_type) { AlgorithmsCacheMap& Get(const AlgorithmType& algo_type) {
std::lock_guard<std::mutex> lock(*autotune_cache_mutex_); return auto_tune_map_[static_cast<int64_t>(algo_type)];
if (auto_tune_map_.find(algo_type) == auto_tune_map_.end()) {
AlgorithmsConfigKeyMap cache;
auto_tune_map_[algo_type] = cache;
}
return auto_tune_map_[algo_type];
} }
void Clean(float miss_rate) { AlgorithmsCacheMap& GetConvForward() {
std::lock_guard<std::mutex> lock(*autotune_cache_mutex_); return Get(AlgorithmType::kConvForward);
// Set a small tolerance to avoid performance degradation }
// due to large cache size under dynamic shape.
if (miss_rate > 0.01) { AlgorithmsCacheMap& GetConvBackwardData() {
auto_tune_map_.clear(); return Get(AlgorithmType::kConvBackwardData);
} }
AlgorithmsCacheMap& GetConvBackwardFilter() {
return Get(AlgorithmType::kConvBackwardFilter);
} }
void UpdateStatus() { void Clean() {
int64_t size = 0;
int64_t cache_hits = 0;
int64_t cache_misses = 0;
for (auto& v : auto_tune_map_) { for (auto& v : auto_tune_map_) {
VLOG(4) << "AlgoType: " << v.first << " Cache Size: " << v.second.Size() v.second.Clean();
<< " Hits: " << v.second.CacheHits()
<< " Misses: " << v.second.CacheMisses()
<< " Hit Rate: " << v.second.CacheHitRate();
size += v.second.Size();
cache_hits += v.second.CacheHits();
cache_misses += v.second.CacheMisses();
} }
total_size_ = size;
total_cache_hits_ = cache_hits;
total_cache_misses_ = cache_misses;
} }
void UpdateStatus();
// The number of total config cached // The number of total config cached
int64_t Size() const { return total_size_; } int64_t Size() const { return total_size_; }
...@@ -183,17 +186,30 @@ class AutoTuneCache { ...@@ -183,17 +186,30 @@ class AutoTuneCache {
total_cache_hit_rate = static_cast<float>(total_cache_hits_) / total_cache_hit_rate = static_cast<float>(total_cache_hits_) /
static_cast<float>(total_num_accesses); static_cast<float>(total_num_accesses);
} }
return total_cache_hit_rate; return total_cache_hit_rate;
} }
private: private:
AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {} AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {
for (int i = 1; i < static_cast<int>(AlgorithmType::kAlgorithmCount); ++i) {
Register(static_cast<AlgorithmType>(i));
}
}
void Register(const AlgorithmType& algo_type) {
std::lock_guard<std::mutex> lock(*autotune_cache_mutex_);
int64_t key = static_cast<int64_t>(algo_type);
if (auto_tune_map_.find(key) == auto_tune_map_.end()) {
AlgorithmsCacheMap cache;
auto_tune_map_[key] = cache;
}
}
AlgorithmsTypeMap auto_tune_map_; AlgorithmsTypeMap auto_tune_map_;
std::shared_ptr<std::mutex> autotune_cache_mutex_; std::shared_ptr<std::mutex> autotune_cache_mutex_;
int64_t total_cache_hits_ = 0; int64_t total_cache_hits_{0};
int64_t total_cache_misses_ = 0; int64_t total_cache_misses_{0};
int64_t total_size_ = 0; int64_t total_size_{0};
}; };
} // namespace autotune } // namespace autotune
......
...@@ -22,7 +22,7 @@ enum ConvAlgos { GEMMKernel = 0, CuDNNKernel_1 = 1, CuDNNKernel_2 = 2 }; ...@@ -22,7 +22,7 @@ enum ConvAlgos { GEMMKernel = 0, CuDNNKernel_1 = 1, CuDNNKernel_2 = 2 };
TEST(AlgosCache, AlgosCache) { TEST(AlgosCache, AlgosCache) {
auto autotune_cache = phi::autotune::AutoTuneCache::Instance(); auto autotune_cache = phi::autotune::AutoTuneCache::Instance();
auto& cache = autotune_cache.RegisterOrGet("conv_fw"); auto& cache = autotune_cache.GetConvForward();
std::vector<int64_t> x_shape = {4, 224, 224, 3}; std::vector<int64_t> x_shape = {4, 224, 224, 3};
std::vector<int64_t> w_shape = {32, 3, 3, 3}; std::vector<int64_t> w_shape = {32, 3, 3, 3};
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/autotune/switch_autotune.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
DECLARE_bool(use_autotune);
namespace phi {
namespace autotune {
void AutoTuneStatus::EnableAutoTune() {
FLAGS_use_autotune = true;
Init();
}
void AutoTuneStatus::DisableAutoTune() {
FLAGS_use_autotune = false;
Init();
}
void AutoTuneStatus::Update() {
current_steps_id_ += 1;
if (!FLAGS_use_autotune) {
return;
}
// This fuction is called when each iter finished.
if (current_steps_id_ + 1 < start_step_id_) {
use_autotune_ = false;
} else if (current_steps_id_ + 1 >= start_step_id_ &&
current_steps_id_ + 1 < stop_step_id_) {
use_autotune_ = true;
AutoTuneCache::Instance().UpdateStatus();
step_hit_rates_.push_back(StepHitRate());
VLOG(3) << "Step ID: " << current_steps_id_
<< ", Accumulative Cache Hit Rate: "
<< static_cast<int>(AutoTuneCache::Instance().CacheHitRate() * 100)
<< "%, Cache Size: " << AutoTuneCache::Instance().Size()
<< ", Current Step Hit Rate: "
<< static_cast<int>(StepHitRate() * 100) << "%";
} else {
use_autotune_ = false;
// Set a small tolerance to avoid performance degradation
// due to large cache size under dynamic shape.
// TODO(limingshu): Currently works for conv op only, this
// method shall be opimized when more ops involved in.
// float miss_rate = static_cast<float>(1) - RecentHitRate();
// if (current_steps_id_ == stop_step_id_) {
// AutoTuneCache::Instance().Clean(miss_rate);
// }
if (VLOG_IS_ON(4)) {
AutoTuneCache::Instance().UpdateStatus();
VLOG(4) << "Step ID: " << current_steps_id_ << ", Current Step Hit Rate: "
<< static_cast<int>(StepHitRate() * 100) << "%";
}
}
}
} // namespace autotune
} // namespace phi
...@@ -13,10 +13,8 @@ ...@@ -13,10 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <cmath> #include <cmath>
#include <mutex>
#include <numeric>
#include "glog/logging.h"
#include "paddle/phi/kernels/autotune/cache.h" #include "paddle/phi/kernels/autotune/cache.h"
namespace phi { namespace phi {
...@@ -31,45 +29,11 @@ class AutoTuneStatus { ...@@ -31,45 +29,11 @@ class AutoTuneStatus {
bool UseAutoTune() { return use_autotune_; } bool UseAutoTune() { return use_autotune_; }
// EnableAutoTune and DisableAutoTune Should be used for debug only. // EnableAutoTune and DisableAutoTune should be used for debug only.
void EnableAutoTune() { void EnableAutoTune();
use_autotune_ = true; void DisableAutoTune();
Init();
}
void DisableAutoTune() {
use_autotune_ = false;
Init();
}
void Update() { void Update();
current_steps_id_ += 1;
if (!use_autotune_ && !update_use_autotune_) {
return;
}
if (current_steps_id_ < start_step_id_) {
use_autotune_ = false;
} else if (current_steps_id_ >= start_step_id_ &&
current_steps_id_ < stop_step_id_) {
use_autotune_ = true;
AutoTuneCache::Instance().UpdateStatus();
step_hit_rates_.push_back(StepHitRate());
VLOG(3) << "Step ID " << current_steps_id_
<< ", Accumulative Cache Hit Rate: "
<< AutoTuneCache::Instance().CacheHitRate()
<< ", Cache Size: " << AutoTuneCache::Instance().Size()
<< ", Current Step Hit Rate: " << StepHitRate();
} else if (current_steps_id_ == stop_step_id_) {
use_autotune_ = false;
update_use_autotune_ = false;
// clean cache according miss rate
float miss_rate = static_cast<float>(1) - RecentHitRate();
AutoTuneCache::Instance().Clean(miss_rate);
VLOG(3) << "Recent Miss Rate: " << miss_rate;
}
}
int64_t StepID() { return current_steps_id_; } int64_t StepID() { return current_steps_id_; }
...@@ -84,19 +48,25 @@ class AutoTuneStatus { ...@@ -84,19 +48,25 @@ class AutoTuneStatus {
// Hit Rate of Current Step // Hit Rate of Current Step
float StepHitRate() { float StepHitRate() {
int64_t current_hits = AutoTuneCache::Instance().CacheHits(); static int64_t last_step_id = -2;
int64_t current_misses = AutoTuneCache::Instance().CacheMisses();
int64_t step_hits_ = current_hits - previous_hits_; if (last_step_id != current_steps_id_) {
int64_t step_misses_ = current_misses - previous_misses_; int64_t current_hits = AutoTuneCache::Instance().CacheHits();
float step_hit_rate = 0.; int64_t current_misses = AutoTuneCache::Instance().CacheMisses();
int64_t step_num_accesses = step_hits_ + step_misses_; int64_t step_hits_ = current_hits - previous_hits_;
if (step_num_accesses != 0) { int64_t step_misses_ = current_misses - previous_misses_;
step_hit_rate = static_cast<float>(step_hits_) / float step_hit_rate = 0.;
static_cast<float>(step_num_accesses); int64_t step_num_accesses = step_hits_ + step_misses_;
if (step_num_accesses != 0) {
step_hit_rate = static_cast<float>(step_hits_) /
static_cast<float>(step_num_accesses);
}
previous_hits_ = current_hits;
previous_misses_ = current_misses;
current_step_hit_rate_ = step_hit_rate;
last_step_id = current_steps_id_;
} }
previous_hits_ = current_hits; return current_step_hit_rate_;
previous_misses_ = current_misses;
return step_hit_rate;
} }
void SetAutoTuneRange(int64_t start, int64_t stop) { void SetAutoTuneRange(int64_t start, int64_t stop) {
...@@ -108,21 +78,21 @@ class AutoTuneStatus { ...@@ -108,21 +78,21 @@ class AutoTuneStatus {
AutoTuneStatus() = default; AutoTuneStatus() = default;
void Init() { void Init() {
update_use_autotune_ = use_autotune_; use_autotune_ = false;
current_steps_id_ = -1; current_steps_id_ = -1;
previous_hits_ = 0; previous_hits_ = 0;
previous_misses_ = 0; previous_misses_ = 0;
step_hit_rates_.clear(); step_hit_rates_.clear();
AutoTuneCache::Instance().Clean(1.0); AutoTuneCache::Instance().Clean();
} }
int64_t start_step_id_ = 0; bool use_autotune_{false};
int64_t stop_step_id_ = 10; int64_t start_step_id_{1};
int64_t current_steps_id_ = -1; int64_t stop_step_id_{10};
bool use_autotune_ = false; int64_t current_steps_id_{-1};
bool update_use_autotune_ = false; int64_t previous_hits_{0};
int64_t previous_hits_ = 0; int64_t previous_misses_{0};
int64_t previous_misses_ = 0; float current_step_hit_rate_{0.f};
std::vector<float> step_hit_rates_; std::vector<float> step_hit_rates_;
}; };
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import paddle import paddle
import unittest import unittest
import numpy import numpy as np
class SimpleNet(paddle.nn.Layer): class SimpleNet(paddle.nn.Layer):
...@@ -27,6 +27,7 @@ class SimpleNet(paddle.nn.Layer): ...@@ -27,6 +27,7 @@ class SimpleNet(paddle.nn.Layer):
def train_dygraph(net, data): def train_dygraph(net, data):
data.stop_gradient = False
out = net(data) out = net(data)
loss = paddle.mean(out) loss = paddle.mean(out)
adam = paddle.optimizer.Adam(parameters=net.parameters()) adam = paddle.optimizer.Adam(parameters=net.parameters())
...@@ -36,6 +37,7 @@ def train_dygraph(net, data): ...@@ -36,6 +37,7 @@ def train_dygraph(net, data):
def static_program(net, data): def static_program(net, data):
data.stop_gradient = False
out = net(data) out = net(data)
loss = paddle.mean(out) loss = paddle.mean(out)
adam = paddle.optimizer.Adam() adam = paddle.optimizer.Adam()
...@@ -43,60 +45,64 @@ def static_program(net, data): ...@@ -43,60 +45,64 @@ def static_program(net, data):
return loss return loss
def set_flags(enable_autotune):
if paddle.is_compiled_with_cuda():
if enable_autotune:
paddle.set_flags({'FLAGS_conv_workspace_size_limit': -1})
paddle.set_flags({'FLAGS_cudnn_exhaustive_search': 1})
else:
paddle.set_flags({'FLAGS_conv_workspace_size_limit': 512})
paddle.set_flags({'FLAGS_cudnn_exhaustive_search': 0})
class TestAutoTune(unittest.TestCase): class TestAutoTune(unittest.TestCase):
def set_flags(self, enable_autotune):
if paddle.is_compiled_with_cuda():
if enable_autotune:
paddle.set_flags({'FLAGS_conv_workspace_size_limit': -1})
else:
paddle.set_flags({'FLAGS_conv_workspace_size_limit': 512})
def get_flags(self, name):
res = paddle.get_flags(name)
return res[name]
def get_expected_res(self, step_id, enable_autotune):
expected_res = {
"step_id": step_id,
"cache_size": 0,
"cache_hit_rate": 0
}
if paddle.is_compiled_with_cuda():
# Total 3 * num_iters cache accesses, only iter 2 hits the cache.
if enable_autotune and step_id >= 1:
expected_res["cache_size"] = 3
if enable_autotune and step_id == 2:
expected_res["cache_hit_rate"] = np.round(
float(3) / float(9), 5)
return expected_res
def test_autotune(self): def test_autotune(self):
paddle.fluid.core.disable_autotune() paddle.fluid.core.disable_autotune()
status = paddle.fluid.core.autotune_status() self.assertEqual(self.get_flags("FLAGS_use_autotune"), False)
self.assertEqual(status["use_autotune"], False)
paddle.fluid.core.enable_autotune() paddle.fluid.core.enable_autotune()
status = paddle.fluid.core.autotune_status() self.assertEqual(self.get_flags("FLAGS_use_autotune"), True)
self.assertEqual(status["use_autotune"], True)
def check_status(self, expected_res): def check_status(self, expected_res):
status = paddle.fluid.core.autotune_status() status = paddle.fluid.core.autotune_status()
for key in status.keys(): for key in status.keys():
self.assertEqual(status[key], expected_res[key]) if key == "cache_hit_rate":
v = np.round(status[key], 5)
else:
v = status[key]
self.assertEqual(v, expected_res[key])
class TestDygraphAutoTuneStatus(TestAutoTune): class TestDygraphAutoTuneStatus(TestAutoTune):
def run_program(self, enable_autotune): def run_program(self, enable_autotune):
set_flags(enable_autotune) self.set_flags(enable_autotune)
if enable_autotune: if enable_autotune:
paddle.fluid.core.enable_autotune() paddle.fluid.core.enable_autotune()
else: else:
paddle.fluid.core.disable_autotune() paddle.fluid.core.disable_autotune()
paddle.fluid.core.autotune_range(1, 2) paddle.fluid.core.set_autotune_range(1, 2)
x_var = paddle.uniform((1, 1, 8, 8), dtype='float32', min=-1., max=1.) x_var = paddle.uniform((1, 1, 8, 8), dtype='float32', min=-1., max=1.)
net = SimpleNet() net = SimpleNet()
for i in range(3): for i in range(3):
train_dygraph(net, x_var) train_dygraph(net, x_var)
if i >= 1 and i < 2: expected_res = self.get_expected_res(i, enable_autotune)
expected_res = { self.check_status(expected_res)
"step_id": i,
"use_autotune": enable_autotune,
"cache_size": 0,
"cache_hit_rate": 0
}
self.check_status(expected_res)
else:
expected_res = {
"step_id": i,
"use_autotune": False,
"cache_size": 0,
"cache_hit_rate": 0
}
self.check_status(expected_res)
def func_enable_autotune(self): def func_enable_autotune(self):
self.run_program(enable_autotune=True) self.run_program(enable_autotune=True)
...@@ -118,60 +124,45 @@ class TestDygraphAutoTuneStatus(TestAutoTune): ...@@ -118,60 +124,45 @@ class TestDygraphAutoTuneStatus(TestAutoTune):
class TestStaticAutoTuneStatus(TestAutoTune): class TestStaticAutoTuneStatus(TestAutoTune):
def run_program(self, enable_autotune): def run_program(self, enable_autotune):
paddle.enable_static() paddle.enable_static()
set_flags(enable_autotune)
if enable_autotune:
paddle.fluid.core.enable_autotune()
else:
paddle.fluid.core.disable_autotune()
paddle.fluid.core.autotune_range(1, 2)
data_shape = [1, 1, 8, 8] data_shape = [1, 1, 8, 8]
data = paddle.static.data(name='X', shape=data_shape, dtype='float32') main_program = paddle.static.Program()
net = SimpleNet() startup_program = paddle.static.Program()
loss = static_program(net, data) with paddle.static.program_guard(main_program, startup_program):
data = paddle.static.data(
name='X', shape=data_shape, dtype='float32')
net = SimpleNet()
loss = static_program(net, data)
place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda( place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda(
) else paddle.CPUPlace() ) else paddle.CPUPlace()
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program()) exe.run(startup_program)
x = numpy.random.random(size=data_shape).astype('float32') x = np.random.random(size=data_shape).astype('float32')
self.set_flags(enable_autotune)
if enable_autotune:
paddle.fluid.core.enable_autotune()
else:
paddle.fluid.core.disable_autotune()
paddle.fluid.core.set_autotune_range(1, 2)
for i in range(3): for i in range(3):
exe.run(feed={'X': x}, fetch_list=[loss]) exe.run(program=main_program, feed={'X': x}, fetch_list=[loss])
status = paddle.fluid.core.autotune_status() status = paddle.fluid.core.autotune_status()
# In static mode, the startup_program will run at first. expected_res = self.get_expected_res(i, enable_autotune)
# The expected step_id will be increased by 1. self.check_status(expected_res)
if i >= 0 and i < 1:
expected_res = {
"step_id": i + 1,
"use_autotune": enable_autotune,
"cache_size": 0,
"cache_hit_rate": 0
}
self.check_status(expected_res)
else:
expected_res = {
"step_id": i + 1,
"use_autotune": False,
"cache_size": 0,
"cache_hit_rate": 0
}
self.check_status(expected_res)
paddle.disable_static() paddle.disable_static()
def func_enable_autotune(self): def func_enable_autotune(self):
self.run_program(enable_autotune=True) self.run_program(enable_autotune=True)
def test_enable_autotune(self): def test_enable_autotune(self):
with paddle.fluid.framework._test_eager_guard():
self.func_enable_autotune()
self.func_enable_autotune() self.func_enable_autotune()
def func_disable_autotune(self): def func_disable_autotune(self):
self.run_program(enable_autotune=False) self.run_program(enable_autotune=False)
def test_disable_autotune(self): def test_disable_autotune(self):
with paddle.fluid.framework._test_eager_guard():
self.func_disable_autotune()
self.func_disable_autotune() self.func_disable_autotune()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册