未验证 提交 35acfeda 编写于 作者: L limingshu 提交者: GitHub

Change cuDNN Conv kernel for auto tune feature (#41313)

* change cudnn helper for auto-tune

* Add FLAGS_use_autotune to set the global status of autotune and change the order of choosing algorithm.

* Fix the bug in calculating and printing current step cache hit rate.

* Improve the autotune cache and fix unittest.

* Change the key from AlgorithmType to int64_t.

* Fix unittest for cpu-only env.

* change ChooseAlgoByWorkspace for heuristic mode
Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
上级 10114859
......@@ -15,7 +15,7 @@ if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_subdirectory(pylayer)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator)
add_dependencies(grad_tensor_holder eager_final_state_codegen)
cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info)
cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info switch_autotune)
endif()
cc_library(grad_node_info SRCS grad_node_info.cc DEPS phi_api phi_tensor)
......
......@@ -9,8 +9,8 @@ cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_f
add_subdirectory(jit)
cc_library(amp SRCS amp_auto_cast.cc DEPS layer var_helper)
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector var_helper)
cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator)
cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator switch_autotune)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator switch_autotune)
cc_library(imperative_profiler SRCS profiler.cc DEPS flags)
if(NOT WIN32)
if(WITH_NCCL OR WITH_RCCL)
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/autotune/cache.h"
namespace paddle {
namespace operators {
......@@ -41,12 +42,22 @@ struct SearchAlgorithm {};
// As the container of searchAlgorithm::Find() result.
template <typename AlgoT>
struct SearchResult {
public:
SearchResult() {}
explicit SearchResult(AlgoT a) : algo(a) {}
AlgoT algo = static_cast<AlgoT>(0);
float time = -1.f;
size_t workspace_size = 0;
};
template <typename T>
static std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
out << "[";
for (auto const& tmp : v) out << tmp << ",";
out << "]";
return out;
}
// As the container of conv relevant descriptors.
template <typename HandleT, typename DataT>
struct ConvArgsBase {
......@@ -68,6 +79,17 @@ struct ConvArgsBase {
const framework::Tensor* o, const std::vector<int> s,
const std::vector<int> p, const std::vector<int> d, DataT dtype)
: x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {}
template <typename T>
size_t GetCacheKey() const {
auto x_shape = phi::vectorize(x->dims());
auto w_shape = phi::vectorize(w->dims());
VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape
<< ", strides=" << s << ", paddings=" << p << ", dilations=" << d;
return phi::autotune::ConvKey(
x_shape, w_shape, p, s, d,
paddle::experimental::CppTypeToDataType<T>::Type());
}
};
static inline void GetNCDHW(const framework::DDim& dims,
......@@ -87,13 +109,5 @@ static inline void GetNCDHW(const framework::DDim& dims,
}
}
template <typename T>
static std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
out << "[";
for (auto const& tmp : v) out << tmp << ",";
out << "]";
return out;
}
} // namespace operators
} // namespace paddle
......@@ -774,3 +774,12 @@ DEFINE_bool(enable_ins_parser_file, false,
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait");
#endif
/**
* Autotune related FLAG
* Name: FLAGS_use_autotune
* Since Version: 2.3.0
* Value Range: bool, default=false
* Example:
*/
PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune.");
......@@ -4469,7 +4469,7 @@ All parameter, weight, gradient are variables in Paddle.
return phi::autotune::AutoTuneStatus::Instance().DisableAutoTune();
});
m.def("autotune_range", [](int64_t start, int64_t stop) {
m.def("set_autotune_range", [](int64_t start, int64_t stop) {
return phi::autotune::AutoTuneStatus::Instance().SetAutoTuneRange(start,
stop);
});
......@@ -4478,10 +4478,8 @@ All parameter, weight, gradient are variables in Paddle.
[] { return phi::autotune::AutoTuneStatus::Instance().Update(); });
m.def("autotune_status", [] {
phi::autotune::AutoTuneCache::Instance().UpdateStatus();
py::dict res;
res["use_autotune"] =
phi::autotune::AutoTuneStatus::Instance().UseAutoTune();
phi::autotune::AutoTuneCache::Instance().UpdateStatus();
res["step_id"] = phi::autotune::AutoTuneStatus::Instance().StepID();
res["cache_size"] = phi::autotune::AutoTuneCache::Instance().Size();
res["cache_hit_rate"] =
......
......@@ -6,12 +6,15 @@ file(APPEND ${kernel_declare_file} "#include \"paddle/phi/core/kernel_registry.h
# phi functors and functions called by kernels
add_subdirectory(funcs)
# kernel autotune
add_subdirectory(autotune)
# phi depends all phi kernel targets
set_property(GLOBAL PROPERTY PHI_KERNELS "")
# [ 1. Common kernel compilation dependencies ]
set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils custom_kernel)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor )
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col vol2col concat_and_split_functor selected_rows_functor)
# remove this dep after removing fluid deps on tensor creation
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_api_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
......@@ -27,12 +30,16 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS cross_entropy_kernel adam_kernel adamw_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel
set(AUTOTUNE_KERNELS conv_kernel conv_grad_kernel conv_grad_grad_kernel conv_transpose_kernel conv_transpose_grad_kernel)
set(MANUAL_BUILD_KERNELS ${AUTOTUNE_KERNELS} cross_entropy_kernel adam_kernel adamw_kernel deformable_conv_kernel deformable_conv_grad_kernel eigh_kernel
gumbel_softmax_kernel gumbel_softmax_grad_kernel hierarchical_sigmoid_kernel hierarchical_sigmoid_grad_kernel
matrix_power_kernel matrix_power_grad_kernel maxout_kernel maxout_grad_kernel pool_kernel
put_along_axis_kernel put_along_axis_grad_kernel segment_pool_kernel segment_pool_grad_kernel
softmax_kernel softmax_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel
triangular_solve_grad_kernel determinant_grad_kernel reduce_kernel rnn_kernel rnn_grad_kernel warpctc_kernel warpctc_grad_kernel)
foreach(src ${AUTOTUNE_KERNELS})
kernel_library(${src} DEPS ${COMMON_KERNEL_DEPS} switch_autotune)
endforeach()
kernel_library(adam_kernel DEPS gflags glog flags ${COMMON_KERNEL_DEPS} selected_rows_functor threadpool jit_kernel_helper)
kernel_library(adamw_kernel DEPS ${COMMON_KERNEL_DEPS} adam_kernel)
kernel_library(cross_entropy_kernel DEPS ${COMMON_KERNEL_DEPS} softmax cross_entropy)
......@@ -75,6 +82,3 @@ add_subdirectory(selected_rows)
copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
# For strings kernels
add_subdirectory(strings)
# 5. kernel autotune
add_subdirectory(autotune)
......@@ -7,5 +7,6 @@ elseif (WITH_ROCM)
endif()
cc_library(cache SRCS cache.cc DEPS boost)
cc_library(switch_autotune SRCS switch_autotune.cc DEPS cache flags)
cc_test(cache_test SRCS cache_test.cc DEPS gtest cache)
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/phi/kernels/autotune/cache.h"
#include <iomanip>
#include "glog/logging.h"
namespace phi {
namespace autotune {
......@@ -32,5 +34,40 @@ size_t ConvKey(const std::vector<int64_t>& x_dims,
static_cast<int64_t>(dtype));
}
std::string AlgorithmTypeString(int64_t algo_type) {
if (algo_type == static_cast<int64_t>(AlgorithmType::kConvForward)) {
return "conv_forward";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kConvBackwardData)) {
return "conv_backward_data";
} else if (algo_type ==
static_cast<int64_t>(AlgorithmType::kConvBackwardFilter)) {
return "conv_backward_filter";
}
return std::to_string(algo_type);
}
void AutoTuneCache::UpdateStatus() {
int64_t size = 0;
int64_t cache_hits = 0;
int64_t cache_misses = 0;
int name_width = 24;
std::cout.setf(std::ios::left);
for (auto& v : auto_tune_map_) {
VLOG(4) << "AlgoType: " << std::setfill(' ') << std::setw(name_width)
<< AlgorithmTypeString(v.first)
<< " Cache Size: " << v.second.Size()
<< " Hits: " << v.second.CacheHits()
<< " Misses: " << v.second.CacheMisses()
<< " Hit Rate: " << v.second.CacheHitRate();
size += v.second.Size();
cache_hits += v.second.CacheHits();
cache_misses += v.second.CacheMisses();
}
total_size_ = size;
total_cache_hits_ = cache_hits;
total_cache_misses_ = cache_misses;
}
} // namespace autotune
} // namespace phi
......@@ -13,11 +13,12 @@
// limitations under the License.
#pragma once
#include <algorithm>
#include <mutex>
#include <numeric>
#include <unordered_map>
#include <vector>
#include "glog/logging.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
......@@ -92,6 +93,13 @@ class AlgorithmsCache {
return ret;
}
void Clean() {
std::lock_guard<std::mutex> lock(*cache_mutex_);
hash_.clear();
cache_hits_ = 0;
cache_misses_ = 0;
}
void Set(size_t key, AlgorithmT algo) {
std::lock_guard<std::mutex> lock(*cache_mutex_);
hash_[key] = algo;
......@@ -116,15 +124,22 @@ class AlgorithmsCache {
private:
std::unordered_map<size_t, AlgorithmT> hash_;
std::shared_ptr<std::mutex> cache_mutex_;
int64_t cache_hits_ = 0;
int64_t cache_misses_ = 0;
int64_t cache_hits_{0};
int64_t cache_misses_{0};
};
enum class AlgorithmType {
kConvForward = 1,
kConvBackwardData = 2,
kConvBackwardFilter = 3,
kAlgorithmCount = 4
};
// AlgorithmsConfigKey -> AlgorithmsID
using AlgorithmsConfigKeyMap = AlgorithmsCache<int64_t>;
// AlgorithmsType -> AlgorithmsCache
using AlgorithmsTypeMap =
std::unordered_map<std::string, AlgorithmsConfigKeyMap>;
using AlgorithmsCacheMap = AlgorithmsCache<int64_t>;
// AlgorithmType -> AlgorithmsCache
using AlgorithmsTypeMap = std::unordered_map<int64_t, AlgorithmsCacheMap>;
class AutoTuneCache {
public:
......@@ -133,42 +148,30 @@ class AutoTuneCache {
return autotune_cache;
}
AlgorithmsConfigKeyMap& RegisterOrGet(const std::string& algo_type) {
std::lock_guard<std::mutex> lock(*autotune_cache_mutex_);
if (auto_tune_map_.find(algo_type) == auto_tune_map_.end()) {
AlgorithmsConfigKeyMap cache;
auto_tune_map_[algo_type] = cache;
AlgorithmsCacheMap& Get(const AlgorithmType& algo_type) {
return auto_tune_map_[static_cast<int64_t>(algo_type)];
}
return auto_tune_map_[algo_type];
AlgorithmsCacheMap& GetConvForward() {
return Get(AlgorithmType::kConvForward);
}
void Clean(float miss_rate) {
std::lock_guard<std::mutex> lock(*autotune_cache_mutex_);
// Set a small tolerance to avoid performance degradation
// due to large cache size under dynamic shape.
if (miss_rate > 0.01) {
auto_tune_map_.clear();
AlgorithmsCacheMap& GetConvBackwardData() {
return Get(AlgorithmType::kConvBackwardData);
}
AlgorithmsCacheMap& GetConvBackwardFilter() {
return Get(AlgorithmType::kConvBackwardFilter);
}
void UpdateStatus() {
int64_t size = 0;
int64_t cache_hits = 0;
int64_t cache_misses = 0;
void Clean() {
for (auto& v : auto_tune_map_) {
VLOG(4) << "AlgoType: " << v.first << " Cache Size: " << v.second.Size()
<< " Hits: " << v.second.CacheHits()
<< " Misses: " << v.second.CacheMisses()
<< " Hit Rate: " << v.second.CacheHitRate();
size += v.second.Size();
cache_hits += v.second.CacheHits();
cache_misses += v.second.CacheMisses();
v.second.Clean();
}
total_size_ = size;
total_cache_hits_ = cache_hits;
total_cache_misses_ = cache_misses;
}
void UpdateStatus();
// The number of total config cached
int64_t Size() const { return total_size_; }
......@@ -183,17 +186,30 @@ class AutoTuneCache {
total_cache_hit_rate = static_cast<float>(total_cache_hits_) /
static_cast<float>(total_num_accesses);
}
return total_cache_hit_rate;
}
private:
AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {}
AutoTuneCache() : autotune_cache_mutex_(new std::mutex()) {
for (int i = 1; i < static_cast<int>(AlgorithmType::kAlgorithmCount); ++i) {
Register(static_cast<AlgorithmType>(i));
}
}
void Register(const AlgorithmType& algo_type) {
std::lock_guard<std::mutex> lock(*autotune_cache_mutex_);
int64_t key = static_cast<int64_t>(algo_type);
if (auto_tune_map_.find(key) == auto_tune_map_.end()) {
AlgorithmsCacheMap cache;
auto_tune_map_[key] = cache;
}
}
AlgorithmsTypeMap auto_tune_map_;
std::shared_ptr<std::mutex> autotune_cache_mutex_;
int64_t total_cache_hits_ = 0;
int64_t total_cache_misses_ = 0;
int64_t total_size_ = 0;
int64_t total_cache_hits_{0};
int64_t total_cache_misses_{0};
int64_t total_size_{0};
};
} // namespace autotune
......
......@@ -22,7 +22,7 @@ enum ConvAlgos { GEMMKernel = 0, CuDNNKernel_1 = 1, CuDNNKernel_2 = 2 };
TEST(AlgosCache, AlgosCache) {
auto autotune_cache = phi::autotune::AutoTuneCache::Instance();
auto& cache = autotune_cache.RegisterOrGet("conv_fw");
auto& cache = autotune_cache.GetConvForward();
std::vector<int64_t> x_shape = {4, 224, 224, 3};
std::vector<int64_t> w_shape = {32, 3, 3, 3};
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/autotune/switch_autotune.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
DECLARE_bool(use_autotune);
namespace phi {
namespace autotune {
void AutoTuneStatus::EnableAutoTune() {
FLAGS_use_autotune = true;
Init();
}
void AutoTuneStatus::DisableAutoTune() {
FLAGS_use_autotune = false;
Init();
}
void AutoTuneStatus::Update() {
current_steps_id_ += 1;
if (!FLAGS_use_autotune) {
return;
}
// This fuction is called when each iter finished.
if (current_steps_id_ + 1 < start_step_id_) {
use_autotune_ = false;
} else if (current_steps_id_ + 1 >= start_step_id_ &&
current_steps_id_ + 1 < stop_step_id_) {
use_autotune_ = true;
AutoTuneCache::Instance().UpdateStatus();
step_hit_rates_.push_back(StepHitRate());
VLOG(3) << "Step ID: " << current_steps_id_
<< ", Accumulative Cache Hit Rate: "
<< static_cast<int>(AutoTuneCache::Instance().CacheHitRate() * 100)
<< "%, Cache Size: " << AutoTuneCache::Instance().Size()
<< ", Current Step Hit Rate: "
<< static_cast<int>(StepHitRate() * 100) << "%";
} else {
use_autotune_ = false;
// Set a small tolerance to avoid performance degradation
// due to large cache size under dynamic shape.
// TODO(limingshu): Currently works for conv op only, this
// method shall be opimized when more ops involved in.
// float miss_rate = static_cast<float>(1) - RecentHitRate();
// if (current_steps_id_ == stop_step_id_) {
// AutoTuneCache::Instance().Clean(miss_rate);
// }
if (VLOG_IS_ON(4)) {
AutoTuneCache::Instance().UpdateStatus();
VLOG(4) << "Step ID: " << current_steps_id_ << ", Current Step Hit Rate: "
<< static_cast<int>(StepHitRate() * 100) << "%";
}
}
}
} // namespace autotune
} // namespace phi
......@@ -13,10 +13,8 @@
// limitations under the License.
#pragma once
#include <cmath>
#include <mutex>
#include <numeric>
#include "glog/logging.h"
#include "paddle/phi/kernels/autotune/cache.h"
namespace phi {
......@@ -31,45 +29,11 @@ class AutoTuneStatus {
bool UseAutoTune() { return use_autotune_; }
// EnableAutoTune and DisableAutoTune Should be used for debug only.
void EnableAutoTune() {
use_autotune_ = true;
Init();
}
void DisableAutoTune() {
use_autotune_ = false;
Init();
}
void Update() {
current_steps_id_ += 1;
// EnableAutoTune and DisableAutoTune should be used for debug only.
void EnableAutoTune();
void DisableAutoTune();
if (!use_autotune_ && !update_use_autotune_) {
return;
}
if (current_steps_id_ < start_step_id_) {
use_autotune_ = false;
} else if (current_steps_id_ >= start_step_id_ &&
current_steps_id_ < stop_step_id_) {
use_autotune_ = true;
AutoTuneCache::Instance().UpdateStatus();
step_hit_rates_.push_back(StepHitRate());
VLOG(3) << "Step ID " << current_steps_id_
<< ", Accumulative Cache Hit Rate: "
<< AutoTuneCache::Instance().CacheHitRate()
<< ", Cache Size: " << AutoTuneCache::Instance().Size()
<< ", Current Step Hit Rate: " << StepHitRate();
} else if (current_steps_id_ == stop_step_id_) {
use_autotune_ = false;
update_use_autotune_ = false;
// clean cache according miss rate
float miss_rate = static_cast<float>(1) - RecentHitRate();
AutoTuneCache::Instance().Clean(miss_rate);
VLOG(3) << "Recent Miss Rate: " << miss_rate;
}
}
void Update();
int64_t StepID() { return current_steps_id_; }
......@@ -84,6 +48,9 @@ class AutoTuneStatus {
// Hit Rate of Current Step
float StepHitRate() {
static int64_t last_step_id = -2;
if (last_step_id != current_steps_id_) {
int64_t current_hits = AutoTuneCache::Instance().CacheHits();
int64_t current_misses = AutoTuneCache::Instance().CacheMisses();
int64_t step_hits_ = current_hits - previous_hits_;
......@@ -96,7 +63,10 @@ class AutoTuneStatus {
}
previous_hits_ = current_hits;
previous_misses_ = current_misses;
return step_hit_rate;
current_step_hit_rate_ = step_hit_rate;
last_step_id = current_steps_id_;
}
return current_step_hit_rate_;
}
void SetAutoTuneRange(int64_t start, int64_t stop) {
......@@ -108,21 +78,21 @@ class AutoTuneStatus {
AutoTuneStatus() = default;
void Init() {
update_use_autotune_ = use_autotune_;
use_autotune_ = false;
current_steps_id_ = -1;
previous_hits_ = 0;
previous_misses_ = 0;
step_hit_rates_.clear();
AutoTuneCache::Instance().Clean(1.0);
AutoTuneCache::Instance().Clean();
}
int64_t start_step_id_ = 0;
int64_t stop_step_id_ = 10;
int64_t current_steps_id_ = -1;
bool use_autotune_ = false;
bool update_use_autotune_ = false;
int64_t previous_hits_ = 0;
int64_t previous_misses_ = 0;
bool use_autotune_{false};
int64_t start_step_id_{1};
int64_t stop_step_id_{10};
int64_t current_steps_id_{-1};
int64_t previous_hits_{0};
int64_t previous_misses_{0};
float current_step_hit_rate_{0.f};
std::vector<float> step_hit_rates_;
};
......
......@@ -14,7 +14,7 @@
import paddle
import unittest
import numpy
import numpy as np
class SimpleNet(paddle.nn.Layer):
......@@ -27,6 +27,7 @@ class SimpleNet(paddle.nn.Layer):
def train_dygraph(net, data):
data.stop_gradient = False
out = net(data)
loss = paddle.mean(out)
adam = paddle.optimizer.Adam(parameters=net.parameters())
......@@ -36,6 +37,7 @@ def train_dygraph(net, data):
def static_program(net, data):
data.stop_gradient = False
out = net(data)
loss = paddle.mean(out)
adam = paddle.optimizer.Adam()
......@@ -43,59 +45,63 @@ def static_program(net, data):
return loss
def set_flags(enable_autotune):
class TestAutoTune(unittest.TestCase):
def set_flags(self, enable_autotune):
if paddle.is_compiled_with_cuda():
if enable_autotune:
paddle.set_flags({'FLAGS_conv_workspace_size_limit': -1})
paddle.set_flags({'FLAGS_cudnn_exhaustive_search': 1})
else:
paddle.set_flags({'FLAGS_conv_workspace_size_limit': 512})
paddle.set_flags({'FLAGS_cudnn_exhaustive_search': 0})
def get_flags(self, name):
res = paddle.get_flags(name)
return res[name]
def get_expected_res(self, step_id, enable_autotune):
expected_res = {
"step_id": step_id,
"cache_size": 0,
"cache_hit_rate": 0
}
if paddle.is_compiled_with_cuda():
# Total 3 * num_iters cache accesses, only iter 2 hits the cache.
if enable_autotune and step_id >= 1:
expected_res["cache_size"] = 3
if enable_autotune and step_id == 2:
expected_res["cache_hit_rate"] = np.round(
float(3) / float(9), 5)
return expected_res
class TestAutoTune(unittest.TestCase):
def test_autotune(self):
paddle.fluid.core.disable_autotune()
status = paddle.fluid.core.autotune_status()
self.assertEqual(status["use_autotune"], False)
self.assertEqual(self.get_flags("FLAGS_use_autotune"), False)
paddle.fluid.core.enable_autotune()
status = paddle.fluid.core.autotune_status()
self.assertEqual(status["use_autotune"], True)
self.assertEqual(self.get_flags("FLAGS_use_autotune"), True)
def check_status(self, expected_res):
status = paddle.fluid.core.autotune_status()
for key in status.keys():
self.assertEqual(status[key], expected_res[key])
if key == "cache_hit_rate":
v = np.round(status[key], 5)
else:
v = status[key]
self.assertEqual(v, expected_res[key])
class TestDygraphAutoTuneStatus(TestAutoTune):
def run_program(self, enable_autotune):
set_flags(enable_autotune)
self.set_flags(enable_autotune)
if enable_autotune:
paddle.fluid.core.enable_autotune()
else:
paddle.fluid.core.disable_autotune()
paddle.fluid.core.autotune_range(1, 2)
paddle.fluid.core.set_autotune_range(1, 2)
x_var = paddle.uniform((1, 1, 8, 8), dtype='float32', min=-1., max=1.)
net = SimpleNet()
for i in range(3):
train_dygraph(net, x_var)
if i >= 1 and i < 2:
expected_res = {
"step_id": i,
"use_autotune": enable_autotune,
"cache_size": 0,
"cache_hit_rate": 0
}
self.check_status(expected_res)
else:
expected_res = {
"step_id": i,
"use_autotune": False,
"cache_size": 0,
"cache_hit_rate": 0
}
expected_res = self.get_expected_res(i, enable_autotune)
self.check_status(expected_res)
def func_enable_autotune(self):
......@@ -118,43 +124,32 @@ class TestDygraphAutoTuneStatus(TestAutoTune):
class TestStaticAutoTuneStatus(TestAutoTune):
def run_program(self, enable_autotune):
paddle.enable_static()
set_flags(enable_autotune)
if enable_autotune:
paddle.fluid.core.enable_autotune()
else:
paddle.fluid.core.disable_autotune()
paddle.fluid.core.autotune_range(1, 2)
data_shape = [1, 1, 8, 8]
data = paddle.static.data(name='X', shape=data_shape, dtype='float32')
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
data = paddle.static.data(
name='X', shape=data_shape, dtype='float32')
net = SimpleNet()
loss = static_program(net, data)
place = paddle.CUDAPlace(0) if paddle.fluid.core.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
x = numpy.random.random(size=data_shape).astype('float32')
exe.run(startup_program)
x = np.random.random(size=data_shape).astype('float32')
self.set_flags(enable_autotune)
if enable_autotune:
paddle.fluid.core.enable_autotune()
else:
paddle.fluid.core.disable_autotune()
paddle.fluid.core.set_autotune_range(1, 2)
for i in range(3):
exe.run(feed={'X': x}, fetch_list=[loss])
exe.run(program=main_program, feed={'X': x}, fetch_list=[loss])
status = paddle.fluid.core.autotune_status()
# In static mode, the startup_program will run at first.
# The expected step_id will be increased by 1.
if i >= 0 and i < 1:
expected_res = {
"step_id": i + 1,
"use_autotune": enable_autotune,
"cache_size": 0,
"cache_hit_rate": 0
}
self.check_status(expected_res)
else:
expected_res = {
"step_id": i + 1,
"use_autotune": False,
"cache_size": 0,
"cache_hit_rate": 0
}
expected_res = self.get_expected_res(i, enable_autotune)
self.check_status(expected_res)
paddle.disable_static()
......@@ -162,16 +157,12 @@ class TestStaticAutoTuneStatus(TestAutoTune):
self.run_program(enable_autotune=True)
def test_enable_autotune(self):
with paddle.fluid.framework._test_eager_guard():
self.func_enable_autotune()
self.func_enable_autotune()
def func_disable_autotune(self):
self.run_program(enable_autotune=False)
def test_disable_autotune(self):
with paddle.fluid.framework._test_eager_guard():
self.func_disable_autotune()
self.func_disable_autotune()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册