未验证 提交 1cd7e68b 编写于 作者: H hong 提交者: GitHub

optimize conv algo cache (#41891)

* optimizer conv alog speed

* code polish

* remove useless code

* fix compile error

* fix cpu compile error

* not use cudnn alog t

* add search cache max number

* polish code

* fix cache test bug

* add groups data format to conv args

* fix cache test bug

* fix cudnn_deterministic bug

* fix test switch auto tune bug

* fix test swith autotune bug;

* fix conv cache bug

* fix cache test error

* fix cache test bug

* fix windows mac compile error

* fix workspace search error

* update cudnn cache

* fix cache test bug; test=develop

* fix autotune swith test error

* polish code

* oplish code
上级 f2f3f6e7
...@@ -44,7 +44,13 @@ struct SearchAlgorithm {}; ...@@ -44,7 +44,13 @@ struct SearchAlgorithm {};
template <typename AlgoT> template <typename AlgoT>
struct SearchResult { struct SearchResult {
SearchResult() {} SearchResult() {}
explicit SearchResult(const phi::autotune::DnnNode& node)
: algo(static_cast<AlgoT>(node.algo)),
workspace_size(node.workspace_size) {}
explicit SearchResult(AlgoT a) : algo(a) {} explicit SearchResult(AlgoT a) : algo(a) {}
explicit SearchResult(AlgoT a, float t, size_t size)
: algo(a), time(t), workspace_size(size) {}
AlgoT algo = static_cast<AlgoT>(0); AlgoT algo = static_cast<AlgoT>(0);
float time = -1.f; float time = -1.f;
...@@ -76,28 +82,50 @@ struct ConvArgsBase { ...@@ -76,28 +82,50 @@ struct ConvArgsBase {
// dilations // dilations
std::vector<int> d; std::vector<int> d;
// groups
int group;
// data foramt
DataLayout data_layout;
ConvArgsBase(const framework::Tensor* x, ConvArgsBase(const framework::Tensor* x,
const framework::Tensor* w, const framework::Tensor* w,
const framework::Tensor* o, const framework::Tensor* o,
const std::vector<int> s, const std::vector<int> s,
const std::vector<int> p, const std::vector<int> p,
const std::vector<int> d, const std::vector<int> d,
DataT dtype) DataT dtype,
: x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {} int g,
DataLayout layout)
: x(x),
w(w),
o(o),
s(s),
p(p),
d(d),
cudnn_dtype(dtype),
group(g),
data_layout(layout) {}
template <typename T> template <typename T>
size_t GetCacheKey() const { phi::autotune::ConvCacheKey Convert2ConvCacheKey() const {
auto x_shape = phi::vectorize(x->dims()); auto x_shape = phi::vectorize(x->dims());
auto w_shape = phi::vectorize(w->dims()); auto w_shape = phi::vectorize(w->dims());
VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape
<< ", strides=" << s << ", paddings=" << p << ", dilations=" << d; << ", strides=" << s << ", paddings=" << p << ", dilations=" << d
return phi::autotune::ConvKey( << ",data= " << paddle::experimental::CppTypeToDataType<T>::Type()
<< ", group=" << group
<< ", data layout=" << static_cast<int64_t>(data_layout);
return phi::autotune::ConvCacheKey(
x_shape, x_shape,
w_shape, w_shape,
p, p,
s, s,
d, d,
paddle::experimental::CppTypeToDataType<T>::Type()); paddle::experimental::CppTypeToDataType<T>::Type(),
group,
static_cast<int64_t>(data_layout));
} }
}; };
......
...@@ -191,32 +191,36 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -191,32 +191,36 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
SetConvMathType(ctx, dtype, args.cdesc); SetConvMathType(ctx, dtype, args.cdesc);
if (deterministic) { if (deterministic) {
result = FindAlgoDeterministic(); result = FindAlgoDeterministic(args);
} else { } else {
// 1. Once turning on exhaustive FLAGS, always get exhaustive_search. // 1. Once turning on exhaustive FLAGS, always get exhaustive_search.
// 2. Once turning on auto-tune, runn heuristic search(default) before // 2. Once turning on auto-tune, runn heuristic search(default) before
// auto-tune process, run exhaustive_search during mentioned process. // auto-tune process, run exhaustive_search during mentioned process.
// 3. After auto-tune process, run cached algorithm if cached, run // 3. After auto-tune process, run cached algorithm if cached, run
// default mode for the rest. // default mode for the rest.
size_t key = args.GetCacheKey<T>(); auto key = args.Convert2ConvCacheKey<T>();
auto& cache = phi::autotune::AutoTuneCache::Instance().GetConvForward(); auto& cache = phi::autotune::AutoTuneCache::Instance().GetConvForward();
if (cache.Find(key)) { if (cache.Find(key)) {
result.algo = static_cast<AlgoT>(cache.Get(key)); auto t = cache.Get(key);
result.algo = static_cast<AlgoT>(t.algo);
result.workspace_size = t.workspace_size;
} else { } else {
bool use_autotune = bool use_autotune =
phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); phi::autotune::AutoTuneStatus::Instance().UseAutoTune();
if (exhaustive_search || use_autotune) { if (exhaustive_search || use_autotune) {
result = FindAlgoExhaustiveSearch<T>(args, ctx); result = FindAlgoExhaustiveSearch<T>(args, ctx);
cache.Set(key, static_cast<int64_t>(result.algo));
} else { } else {
result = FindAlgoHeuristic(args, ctx); result = FindAlgoHeuristic(args, ctx);
} }
phi::autotune::DnnNode node(static_cast<int64_t>(result.algo),
result.workspace_size);
cache.Set(key, node);
} }
} }
VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search
<< ", deterministic=" << deterministic << ", deterministic=" << deterministic
<< ", choose algo=" << result.algo << ", workspace=" << ", choose algo=" << result.algo
<< ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB"; << ", workspace=" << ToMegaBytes(result.workspace_size) << " MB";
return result; return result;
} }
...@@ -236,8 +240,9 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -236,8 +240,9 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
} }
private: private:
static SearchResult<AlgoT> FindAlgoDeterministic() { static SearchResult<AlgoT> FindAlgoDeterministic(const ConvArgs& args) {
return SearchResult<AlgoT>(static_cast<AlgoT>(1)); auto workspace_size = GetWorkspaceSize(args, static_cast<AlgoT>(1));
return SearchResult<AlgoT>(static_cast<AlgoT>(1), -1.0, workspace_size);
} }
// Heuristic search mode, calling the cudnnGetXxxAlgorithm. // Heuristic search mode, calling the cudnnGetXxxAlgorithm.
...@@ -298,6 +303,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -298,6 +303,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
workspace_size_limit, workspace_size_limit,
&(result.algo))); &(result.algo)));
#endif #endif
result.workspace_size = GetWorkspaceSize(args, result.algo);
return result; return result;
} }
...@@ -343,6 +349,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -343,6 +349,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
ChooseAlgoByWorkspace<PerfT, AlgoT>( ChooseAlgoByWorkspace<PerfT, AlgoT>(
perf_results, workspace_size_limit, &result); perf_results, workspace_size_limit, &result);
result.workspace_size = GetWorkspaceSize(args, result.algo);
return result; return result;
} }
...@@ -394,33 +401,37 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -394,33 +401,37 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
SetConvMathType(ctx, dtype, args.cdesc); SetConvMathType(ctx, dtype, args.cdesc);
if (deterministic) { if (deterministic) {
result = FindAlgoDeterministic(); result = FindAlgoDeterministic(args);
} else { } else {
// 1. Once turning on exhaustive FLAGS, always get exhaustive_search. // 1. Once turning on exhaustive FLAGS, always get exhaustive_search.
// 2. Once turning on auto-tune, runn heuristic search(default) before // 2. Once turning on auto-tune, runn heuristic search(default) before
// auto-tune process, run exhaustive_search during mentioned process. // auto-tune process, run exhaustive_search during mentioned process.
// 3. After auto-tune process, run cached algorithm if cached, run // 3. After auto-tune process, run cached algorithm if cached, run
// default mode for the rest. // default mode for the rest.
size_t key = args.GetCacheKey<T>(); auto key = args.Convert2ConvCacheKey<T>();
auto& cache = auto& cache =
phi::autotune::AutoTuneCache::Instance().GetConvBackwardData(); phi::autotune::AutoTuneCache::Instance().GetConvBackwardData();
if (cache.Find(key)) { if (cache.Find(key)) {
result.algo = static_cast<AlgoT>(cache.Get(key)); auto t = cache.Get(key);
result.algo = static_cast<AlgoT>(t.algo);
result.workspace_size = t.workspace_size;
} else { } else {
bool use_autotune = bool use_autotune =
phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); phi::autotune::AutoTuneStatus::Instance().UseAutoTune();
if (exhaustive_search || use_autotune) { if (exhaustive_search || use_autotune) {
result = FindAlgoExhaustiveSearch<T>(args, ctx); result = FindAlgoExhaustiveSearch<T>(args, ctx);
cache.Set(key, static_cast<int64_t>(result.algo));
} else { } else {
result = FindAlgoHeuristic(args, ctx); result = FindAlgoHeuristic(args, ctx);
} }
phi::autotune::DnnNode node(static_cast<int64_t>(result.algo),
result.workspace_size);
cache.Set(key, node);
} }
} }
VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search
<< ", deterministic=" << deterministic << ", deterministic=" << deterministic
<< ", choose algo=" << result.algo << ", workspace=" << ", choose algo=" << result.algo
<< ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB"; << ", workspace=" << ToMegaBytes(result.workspace_size) << " MB";
return result; return result;
} }
...@@ -440,8 +451,11 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -440,8 +451,11 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
} }
private: private:
static SearchResult<AlgoT> FindAlgoDeterministic() { static SearchResult<AlgoT> FindAlgoDeterministic(const ConvArgs& args) {
return SearchResult<AlgoT>(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1); auto workspace_size =
GetWorkspaceSize(args, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1);
return SearchResult<AlgoT>(
CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, -1.0, workspace_size);
} }
static SearchResult<AlgoT> FindAlgoHeuristic(const ConvArgs& args, static SearchResult<AlgoT> FindAlgoHeuristic(const ConvArgs& args,
...@@ -513,7 +527,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -513,7 +527,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
workspace_size_limit, workspace_size_limit,
&(result.algo))); &(result.algo)));
#endif #endif
result.workspace_size = GetWorkspaceSize(args, result.algo);
return result; return result;
} }
...@@ -559,6 +573,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -559,6 +573,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
ChooseAlgoByWorkspace<PerfT, AlgoT>( ChooseAlgoByWorkspace<PerfT, AlgoT>(
perf_results, workspace_size_limit, &result); perf_results, workspace_size_limit, &result);
result.workspace_size = GetWorkspaceSize(args, result.algo);
return result; return result;
} }
...@@ -609,33 +624,37 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -609,33 +624,37 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
SetConvMathType(ctx, dtype, args.cdesc); SetConvMathType(ctx, dtype, args.cdesc);
if (deterministic) { if (deterministic) {
result = FindAlgoDeterministic(); result = FindAlgoDeterministic(args);
} else { } else {
// 1. Once turning on exhaustive FLAGS, always get exhaustive_search. // 1. Once turning on exhaustive FLAGS, always get exhaustive_search.
// 2. Once turning on auto-tune, runn heuristic search(default) before // 2. Once turning on auto-tune, runn heuristic search(default) before
// auto-tune process, run exhaustive_search during mentioned process. // auto-tune process, run exhaustive_search during mentioned process.
// 3. After auto-tune process, run cached algorithm if cached, run // 3. After auto-tune process, run cached algorithm if cached, run
// default mode for the rest. // default mode for the rest.
size_t key = args.GetCacheKey<T>(); auto key = args.Convert2ConvCacheKey<T>();
auto& cache = auto& cache =
phi::autotune::AutoTuneCache::Instance().GetConvBackwardFilter(); phi::autotune::AutoTuneCache::Instance().GetConvBackwardFilter();
if (cache.Find(key)) { if (cache.Find(key)) {
result.algo = static_cast<AlgoT>(cache.Get(key)); auto t = cache.Get(key);
result.algo = static_cast<AlgoT>(t.algo);
result.workspace_size = t.workspace_size;
} else { } else {
bool use_autotune = bool use_autotune =
phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); phi::autotune::AutoTuneStatus::Instance().UseAutoTune();
if (exhaustive_search || use_autotune) { if (exhaustive_search || use_autotune) {
result = FindAlgoExhaustiveSearch<T>(args, ctx); result = FindAlgoExhaustiveSearch<T>(args, ctx);
cache.Set(key, static_cast<int64_t>(result.algo));
} else { } else {
result = FindAlgoHeuristic(args, ctx); result = FindAlgoHeuristic(args, ctx);
} }
phi::autotune::DnnNode node(static_cast<int64_t>(result.algo),
result.workspace_size);
cache.Set(key, node);
} }
} }
VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search
<< ", deterministic=" << deterministic << ", deterministic=" << deterministic
<< ", choose algo=" << result.algo << ", workspace=" << ", choose algo=" << result.algo
<< ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB"; << ", workspace=" << ToMegaBytes(result.workspace_size) << " MB";
return result; return result;
} }
...@@ -656,8 +675,11 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -656,8 +675,11 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
} }
private: private:
static SearchResult<AlgoT> FindAlgoDeterministic() { static SearchResult<AlgoT> FindAlgoDeterministic(const ConvArgs& args) {
return SearchResult<AlgoT>(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1); auto workspace_size =
GetWorkspaceSize(args, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1);
return SearchResult<AlgoT>(
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, -1.0, workspace_size);
} }
static SearchResult<AlgoT> FindAlgoHeuristic(const ConvArgs& args, static SearchResult<AlgoT> FindAlgoHeuristic(const ConvArgs& args,
...@@ -718,6 +740,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -718,6 +740,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
&(result.algo))); &(result.algo)));
#endif #endif
result.workspace_size = GetWorkspaceSize(args, result.algo);
return result; return result;
} }
...@@ -786,6 +809,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -786,6 +809,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
ChooseAlgo(perf_results, workspace_size_limit, &result); ChooseAlgo(perf_results, workspace_size_limit, &result);
} }
result.workspace_size = GetWorkspaceSize(args, result.algo);
return result; return result;
} }
......
...@@ -984,6 +984,17 @@ PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait"); ...@@ -984,6 +984,17 @@ PADDLE_DEFINE_EXPORTED_bool(nccl_blocking_wait, false, "nccl blocking wait");
*/ */
PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune."); PADDLE_DEFINE_EXPORTED_bool(use_autotune, false, "Whether enable autotune.");
/**
* Conv Search cache max number related FLAG
* Name: FLAGS_search_cache_max_number
* Since Version: 2.3.0
* Value Range: int32, default=1000000
* Example:
*/
PADDLE_DEFINE_EXPORTED_int32(search_cache_max_number,
1000000,
"search_cache_max_number.");
/** /**
* Preformance related FLAG * Preformance related FLAG
* Name: einsum_opt * Name: einsum_opt
......
...@@ -21,21 +21,6 @@ ...@@ -21,21 +21,6 @@
namespace phi { namespace phi {
namespace autotune { namespace autotune {
// Define the cache key of operator
size_t ConvKey(const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& w_dims,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& dilations,
phi::DataType dtype) {
return GetKey(x_dims,
w_dims,
strides,
paddings,
dilations,
static_cast<int64_t>(dtype));
}
size_t TransposeKey(const std::vector<int64_t>& x_dims, size_t TransposeKey(const std::vector<int64_t>& x_dims,
const std::vector<int32_t>& perm, const std::vector<int32_t>& perm,
phi::DataType dtype) { phi::DataType dtype) {
...@@ -73,6 +58,19 @@ void AutoTuneCache::UpdateStatus() { ...@@ -73,6 +58,19 @@ void AutoTuneCache::UpdateStatus() {
cache_hits += v.second.CacheHits(); cache_hits += v.second.CacheHits();
cache_misses += v.second.CacheMisses(); cache_misses += v.second.CacheMisses();
} }
for (auto& v : cudnn_auto_tune_map_) {
VLOG(4) << "AlgoType: " << std::setfill(' ') << std::setw(name_width)
<< AlgorithmTypeString(v.first)
<< " Cache Size: " << v.second.Size()
<< " Hits: " << v.second.CacheHits()
<< " Misses: " << v.second.CacheMisses()
<< " Hit Rate: " << v.second.CacheHitRate();
size += v.second.Size();
cache_hits += v.second.CacheHits();
cache_misses += v.second.CacheMisses();
}
total_size_ = size; total_size_ = size;
total_cache_hits_ = cache_hits; total_cache_hits_ = cache_hits;
total_cache_misses_ = cache_misses; total_cache_misses_ = cache_misses;
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h" #include "paddle/phi/core/errors.h"
DECLARE_int32(search_cache_max_number);
inline void HashCombine(std::size_t* seed) {} inline void HashCombine(std::size_t* seed) {}
// combine hash value // combine hash value
...@@ -32,6 +34,7 @@ template <typename T, typename... Rest> ...@@ -32,6 +34,7 @@ template <typename T, typename... Rest>
inline void HashCombine(std::size_t* seed, const T& v, Rest... rest) { inline void HashCombine(std::size_t* seed, const T& v, Rest... rest) {
std::hash<T> hasher; std::hash<T> hasher;
*seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); *seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
*seed *= 0x00000100000001B3;
HashCombine(seed, rest...); HashCombine(seed, rest...);
} }
...@@ -41,7 +44,7 @@ namespace std { ...@@ -41,7 +44,7 @@ namespace std {
template <typename T> template <typename T>
struct hash<std::vector<T>> { struct hash<std::vector<T>> {
std::size_t operator()(std::vector<T> const& vec) const noexcept { std::size_t operator()(std::vector<T> const& vec) const noexcept {
std::size_t seed = 0; std::size_t seed = 0xcbf29ce484222325;
for (auto val : vec) { for (auto val : vec) {
HashCombine(&seed, val); HashCombine(&seed, val);
} }
...@@ -53,6 +56,14 @@ struct hash<std::vector<T>> { ...@@ -53,6 +56,14 @@ struct hash<std::vector<T>> {
namespace phi { namespace phi {
namespace autotune { namespace autotune {
struct DnnNode {
DnnNode() {}
explicit DnnNode(int64_t a, size_t size) : algo(a), workspace_size(size) {}
int64_t algo;
size_t workspace_size = 0;
};
template <typename... Args> template <typename... Args>
size_t GetKey(Args&&... args) { size_t GetKey(Args&&... args) {
size_t seed = 0; size_t seed = 0;
...@@ -60,13 +71,130 @@ size_t GetKey(Args&&... args) { ...@@ -60,13 +71,130 @@ size_t GetKey(Args&&... args) {
return seed; return seed;
} }
// Define the cache key of operator struct ConvCacheKey {
size_t ConvKey(const std::vector<int64_t>& x_dims, ConvCacheKey() {}
const std::vector<int64_t>& w_dims, explicit ConvCacheKey(const std::vector<int64_t>& x_dims,
const std::vector<int>& strides, const std::vector<int64_t>& w_dims,
const std::vector<int>& paddings, const std::vector<int>& strides,
const std::vector<int>& dilations, const std::vector<int>& paddings,
phi::DataType dtype); const std::vector<int>& dilations,
phi::DataType dtype,
int groups,
int64_t data_layout)
: x_dims_(x_dims),
w_dims_(w_dims),
strides_(strides),
paddings_(paddings),
dilations_(dilations),
dtype_(dtype),
groups_(groups),
data_layout_(data_layout) {}
size_t hash_value() const {
return GetKey(x_dims_,
w_dims_,
strides_,
paddings_,
dilations_,
static_cast<int64_t>(dtype_),
groups_,
data_layout_);
}
std::vector<int64_t> x_dims_;
std::vector<int64_t> w_dims_;
std::vector<int> strides_;
std::vector<int> paddings_;
std::vector<int> dilations_;
phi::DataType dtype_;
int groups_;
int64_t data_layout_;
};
struct ConvCacheKeyHash {
size_t operator()(const ConvCacheKey& cache) const {
return cache.hash_value();
}
};
struct ConvCacheKeyEqual {
size_t operator()(const ConvCacheKey& first,
const ConvCacheKey& second) const {
if (first.x_dims_ != second.x_dims_) return false;
if (first.w_dims_ != second.w_dims_) return false;
if (first.strides_ != second.strides_) return false;
if (first.paddings_ != second.paddings_) return false;
if (first.dilations_ != second.dilations_) return false;
if (first.dtype_ != second.dtype_) return false;
if (first.groups_ != second.groups_) return false;
if (first.data_layout_ != second.data_layout_) return false;
return true;
}
};
class CudnnAlgorithmsCacheMap {
public:
CudnnAlgorithmsCacheMap() : cache_mutex_(new std::mutex()) { hash_.clear(); }
DnnNode Get(const ConvCacheKey& key) {
std::lock_guard<std::mutex> lock(*cache_mutex_);
PADDLE_ENFORCE_NE(
hash_.find(key),
hash_.end(),
phi::errors::PreconditionNotMet("The key does not exist."));
return hash_[key];
}
bool Find(const ConvCacheKey& key) {
bool ret = false;
std::lock_guard<std::mutex> lock(*cache_mutex_);
if (hash_.find(key) != hash_.end()) {
cache_hits_++;
ret = true;
} else {
cache_misses_++;
}
return ret;
}
void Clean() {
std::lock_guard<std::mutex> lock(*cache_mutex_);
hash_.clear();
cache_hits_ = 0;
cache_misses_ = 0;
}
void Set(const ConvCacheKey& key, DnnNode algo) {
std::lock_guard<std::mutex> lock(*cache_mutex_);
if (hash_.size() > static_cast<size_t>(FLAGS_search_cache_max_number)) {
hash_.clear();
}
hash_[key] = algo;
}
int64_t CacheMisses() const { return cache_misses_; }
int64_t CacheHits() const { return cache_hits_; }
float CacheHitRate() const {
int64_t num_accesses = cache_hits_ + cache_misses_;
float cache_hit_rate = 0.;
if (num_accesses != 0) {
cache_hit_rate =
static_cast<float>(cache_hits_) / static_cast<float>(num_accesses);
}
return cache_hit_rate;
}
int64_t Size() const { return hash_.size(); }
private:
std::unordered_map<ConvCacheKey, DnnNode, ConvCacheKeyHash, ConvCacheKeyEqual>
hash_;
std::shared_ptr<std::mutex> cache_mutex_;
int64_t cache_hits_{0};
int64_t cache_misses_{0};
};
size_t TransposeKey(const std::vector<int64_t>& x_dims, size_t TransposeKey(const std::vector<int64_t>& x_dims,
const std::vector<int32_t>& perm, const std::vector<int32_t>& perm,
...@@ -77,7 +205,7 @@ class AlgorithmsCache { ...@@ -77,7 +205,7 @@ class AlgorithmsCache {
public: public:
AlgorithmsCache() : cache_mutex_(new std::mutex()) { hash_.clear(); } AlgorithmsCache() : cache_mutex_(new std::mutex()) { hash_.clear(); }
AlgorithmT Get(size_t key) { AlgorithmT Get(const size_t& key) {
std::lock_guard<std::mutex> lock(*cache_mutex_); std::lock_guard<std::mutex> lock(*cache_mutex_);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
hash_.find(key), hash_.find(key),
...@@ -86,7 +214,7 @@ class AlgorithmsCache { ...@@ -86,7 +214,7 @@ class AlgorithmsCache {
return hash_[key]; return hash_[key];
} }
bool Find(size_t key) { bool Find(const size_t& key) {
bool ret = false; bool ret = false;
std::lock_guard<std::mutex> lock(*cache_mutex_); std::lock_guard<std::mutex> lock(*cache_mutex_);
if (hash_.find(key) != hash_.end()) { if (hash_.find(key) != hash_.end()) {
...@@ -105,7 +233,7 @@ class AlgorithmsCache { ...@@ -105,7 +233,7 @@ class AlgorithmsCache {
cache_misses_ = 0; cache_misses_ = 0;
} }
void Set(size_t key, AlgorithmT algo) { void Set(const size_t& key, AlgorithmT algo) {
std::lock_guard<std::mutex> lock(*cache_mutex_); std::lock_guard<std::mutex> lock(*cache_mutex_);
hash_[key] = algo; hash_[key] = algo;
} }
...@@ -143,9 +271,12 @@ enum class AlgorithmType { ...@@ -143,9 +271,12 @@ enum class AlgorithmType {
}; };
// AlgorithmsConfigKey -> AlgorithmsID // AlgorithmsConfigKey -> AlgorithmsID
// (todo. hong) use cudnnConvolutionFwdAlgo_t
using AlgorithmsCacheMap = AlgorithmsCache<int64_t>; using AlgorithmsCacheMap = AlgorithmsCache<int64_t>;
// AlgorithmType -> AlgorithmsCache // AlgorithmType -> AlgorithmsCache
using AlgorithmsTypeMap = std::unordered_map<int64_t, AlgorithmsCacheMap>; using AlgorithmsTypeMap = std::unordered_map<int64_t, AlgorithmsCacheMap>;
using CudnnAlgorithmsTypeMap =
std::unordered_map<int64_t, CudnnAlgorithmsCacheMap>;
class AutoTuneCache { class AutoTuneCache {
public: public:
...@@ -158,16 +289,19 @@ class AutoTuneCache { ...@@ -158,16 +289,19 @@ class AutoTuneCache {
return auto_tune_map_[static_cast<int64_t>(algo_type)]; return auto_tune_map_[static_cast<int64_t>(algo_type)];
} }
AlgorithmsCacheMap& GetConvForward() { CudnnAlgorithmsCacheMap& GetConvForward() {
return Get(AlgorithmType::kConvForward); return cudnn_auto_tune_map_[static_cast<int64_t>(
AlgorithmType::kConvForward)];
} }
AlgorithmsCacheMap& GetConvBackwardData() { CudnnAlgorithmsCacheMap& GetConvBackwardData() {
return Get(AlgorithmType::kConvBackwardData); return cudnn_auto_tune_map_[static_cast<int64_t>(
AlgorithmType::kConvBackwardData)];
} }
AlgorithmsCacheMap& GetConvBackwardFilter() { CudnnAlgorithmsCacheMap& GetConvBackwardFilter() {
return Get(AlgorithmType::kConvBackwardFilter); return cudnn_auto_tune_map_[static_cast<int64_t>(
AlgorithmType::kConvBackwardFilter)];
} }
AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); } AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); }
...@@ -176,6 +310,10 @@ class AutoTuneCache { ...@@ -176,6 +310,10 @@ class AutoTuneCache {
for (auto& v : auto_tune_map_) { for (auto& v : auto_tune_map_) {
v.second.Clean(); v.second.Clean();
} }
for (auto& v : cudnn_auto_tune_map_) {
v.second.Clean();
}
} }
void UpdateStatus(); void UpdateStatus();
...@@ -206,14 +344,25 @@ class AutoTuneCache { ...@@ -206,14 +344,25 @@ class AutoTuneCache {
void Register(const AlgorithmType& algo_type) { void Register(const AlgorithmType& algo_type) {
std::lock_guard<std::mutex> lock(*autotune_cache_mutex_); std::lock_guard<std::mutex> lock(*autotune_cache_mutex_);
int64_t key = static_cast<int64_t>(algo_type); if (algo_type == AlgorithmType::kConvForward ||
if (auto_tune_map_.find(key) == auto_tune_map_.end()) { algo_type == AlgorithmType::kConvBackwardData ||
AlgorithmsCacheMap cache; algo_type == AlgorithmType::kConvBackwardFilter) {
auto_tune_map_[key] = cache; int64_t key = static_cast<int64_t>(algo_type);
if (auto_tune_map_.find(key) == auto_tune_map_.end()) {
CudnnAlgorithmsCacheMap cache;
cudnn_auto_tune_map_[key] = cache;
}
} else {
int64_t key = static_cast<int64_t>(algo_type);
if (auto_tune_map_.find(key) == auto_tune_map_.end()) {
AlgorithmsCacheMap cache;
auto_tune_map_[key] = cache;
}
} }
} }
AlgorithmsTypeMap auto_tune_map_; AlgorithmsTypeMap auto_tune_map_;
CudnnAlgorithmsTypeMap cudnn_auto_tune_map_;
std::shared_ptr<std::mutex> autotune_cache_mutex_; std::shared_ptr<std::mutex> autotune_cache_mutex_;
int64_t total_cache_hits_{0}; int64_t total_cache_hits_{0};
int64_t total_cache_misses_{0}; int64_t total_cache_misses_{0};
......
...@@ -34,20 +34,23 @@ TEST(AlgosCache, AlgosCache) { ...@@ -34,20 +34,23 @@ TEST(AlgosCache, AlgosCache) {
std::vector<int> dilations = {1, 1}; std::vector<int> dilations = {1, 1};
phi::DataType dtype = paddle::experimental::CppTypeToDataType<float>::Type(); phi::DataType dtype = paddle::experimental::CppTypeToDataType<float>::Type();
auto key = phi::autotune::ConvKey( phi::autotune::ConvCacheKey key(
x_shape, w_shape, paddings, strides, dilations, dtype); x_shape, w_shape, paddings, strides, dilations, dtype, 0, 0);
EXPECT_EQ(cache.Find(key), false); EXPECT_EQ(cache.Find(key), false);
cache.Set(key, ConvAlgos::GEMMKernel); phi::autotune::DnnNode node(static_cast<int64_t>(ConvAlgos::GEMMKernel), 0);
cache.Set(key, node);
EXPECT_EQ(cache.Size(), 1); EXPECT_EQ(cache.Size(), 1);
EXPECT_EQ(cache.Find(key), true); EXPECT_EQ(cache.Find(key), true);
auto algo = cache.Get(key); auto algo = cache.Get(key);
EXPECT_EQ(algo, ConvAlgos::GEMMKernel); EXPECT_EQ(algo.algo, ConvAlgos::GEMMKernel);
x_shape = {4, 128, 128, 3}; x_shape = {4, 128, 128, 3};
key = phi::autotune::ConvKey( phi::autotune::ConvCacheKey key1(
x_shape, w_shape, paddings, strides, dilations, dtype); x_shape, w_shape, paddings, strides, dilations, dtype, 0, 1);
EXPECT_EQ(cache.Find(key), false); EXPECT_EQ(cache.Find(key1), false);
cache.Set(key, ConvAlgos::CuDNNKernel_1); phi::autotune::DnnNode node1(static_cast<int64_t>(ConvAlgos::CuDNNKernel_1),
0);
cache.Set(key1, node1);
EXPECT_EQ(cache.Size(), 2); EXPECT_EQ(cache.Size(), 2);
EXPECT_EQ(cache.CacheHits(), 1); EXPECT_EQ(cache.CacheHits(), 1);
EXPECT_EQ(cache.CacheMisses(), 2); EXPECT_EQ(cache.CacheMisses(), 2);
......
...@@ -254,6 +254,8 @@ void ConvCudnnGradGradKernel( ...@@ -254,6 +254,8 @@ void ConvCudnnGradGradKernel(
auto dtype = paddle::platform::CudnnDataType<T>::type; auto dtype = paddle::platform::CudnnDataType<T>::type;
auto handle = ctx.cudnn_handle(); auto handle = ctx.cudnn_handle();
auto layout = paddle::platform::GetCudnnTensorFormat(
paddle::platform::DataLayout::kNCHW);
paddle::operators::ConvArgs args1{&transformed_ddX, paddle::operators::ConvArgs args1{&transformed_ddX,
W, W,
...@@ -261,28 +263,36 @@ void ConvCudnnGradGradKernel( ...@@ -261,28 +263,36 @@ void ConvCudnnGradGradKernel(
strides, strides,
padding_common, padding_common,
dilations, dilations,
dtype}; dtype,
groups,
paddle::platform::DataLayout::kNCHW};
paddle::operators::ConvArgs args2{&transformed_X, paddle::operators::ConvArgs args2{&transformed_X,
ddW, ddW,
&transformed_ddO_channel, &transformed_ddO_channel,
strides, strides,
padding_common, padding_common,
dilations, dilations,
dtype}; dtype,
groups,
paddle::platform::DataLayout::kNCHW};
paddle::operators::ConvArgs args3{&transformed_ddX, paddle::operators::ConvArgs args3{&transformed_ddX,
dW, dW,
&transformed_dO_channel, &transformed_dO_channel,
strides, strides,
padding_common, padding_common,
dilations, dilations,
dtype}; dtype,
groups,
paddle::platform::DataLayout::kNCHW};
paddle::operators::ConvArgs args4{&transformed_dX, paddle::operators::ConvArgs args4{&transformed_dX,
ddW, ddW,
&transformed_dO_channel, &transformed_dO_channel,
strides, strides,
padding_common, padding_common,
dilations, dilations,
dtype}; dtype,
groups,
paddle::platform::DataLayout::kNCHW};
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result1; paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result1;
...@@ -298,9 +308,6 @@ void ConvCudnnGradGradKernel( ...@@ -298,9 +308,6 @@ void ConvCudnnGradGradKernel(
filter_result; filter_result;
#endif #endif
auto layout = paddle::platform::GetCudnnTensorFormat(
paddle::platform::DataLayout::kNCHW);
// ddo = conv(ddI, W) + conv(I, ddW) // ddo = conv(ddI, W) + conv(I, ddW)
size_t workspace_size = 0; size_t workspace_size = 0;
......
...@@ -251,27 +251,33 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -251,27 +251,33 @@ void ConvCudnnGradKernel(const Context& ctx,
T* input_grad_data = nullptr; T* input_grad_data = nullptr;
T* transformed_input_grad_data = nullptr; T* transformed_input_grad_data = nullptr;
paddle::platform::DataLayout layout =
compute_format == paddle::platform::DataLayout::kNHWC
? paddle::platform::DataLayout::kNHWC
: paddle::platform::DataLayout::kNCHW;
paddle::operators::ConvArgs args1{&transformed_input_grad, paddle::operators::ConvArgs args1{&transformed_input_grad,
&transformed_filter_channel, &transformed_filter_channel,
&transformed_output_grad_channel, &transformed_output_grad_channel,
strides, strides,
padding_common, padding_common,
dilations, dilations,
dtype}; dtype,
groups,
layout};
paddle::operators::ConvArgs args2{&transformed_input, paddle::operators::ConvArgs args2{&transformed_input,
&transformed_filter_grad_channel, &transformed_filter_grad_channel,
&transformed_output_grad_channel, &transformed_output_grad_channel,
strides, strides,
padding_common, padding_common,
dilations, dilations,
dtype}; dtype,
groups,
layout};
auto handle = ctx.cudnn_handle(); auto handle = ctx.cudnn_handle();
// TODO(phlrain): replace paddle::platform::DataLaytout to phi::DataLayout // TODO(phlrain): replace paddle::platform::DataLaytout to phi::DataLayout
paddle::platform::DataLayout layout =
compute_format == paddle::platform::DataLayout::kNHWC
? paddle::platform::DataLayout::kNHWC
: paddle::platform::DataLayout::kNCHW;
if (transformed_input.dims().size() == 5) { if (transformed_input.dims().size() == 5) {
layout = compute_format == paddle::platform::DataLayout::kNHWC layout = compute_format == paddle::platform::DataLayout::kNHWC
? paddle::platform::DataLayout::kNDHWC ? paddle::platform::DataLayout::kNDHWC
...@@ -368,8 +374,7 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -368,8 +374,7 @@ void ConvCudnnGradKernel(const Context& ctx,
using search1 = using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_result = search1::Find<T>(args1, exhaustive_search, deterministic, ctx); bwd_result = search1::Find<T>(args1, exhaustive_search, deterministic, ctx);
workspace_size_d = std::max( workspace_size_d = std::max(workspace_size_d, bwd_result.workspace_size);
workspace_size_d, search1::GetWorkspaceSize(args1, bwd_result.algo));
#endif #endif
} }
...@@ -400,8 +405,7 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -400,8 +405,7 @@ void ConvCudnnGradKernel(const Context& ctx,
search2::Find<T>(args2, exhaustive_search, deterministic, ctx); search2::Find<T>(args2, exhaustive_search, deterministic, ctx);
VLOG(3) << "filter algo: " << filter_result.algo << ", time " VLOG(3) << "filter algo: " << filter_result.algo << ", time "
<< filter_result.time; << filter_result.time;
workspace_size_w = std::max( workspace_size_w = std::max(workspace_size_w, filter_result.workspace_size);
workspace_size_w, search2::GetWorkspaceSize(args2, filter_result.algo));
#endif #endif
} }
......
...@@ -213,7 +213,9 @@ void ConvCudnnKernel(const Context& ctx, ...@@ -213,7 +213,9 @@ void ConvCudnnKernel(const Context& ctx,
strides, strides,
padding_common, padding_common,
dilations, dilations,
dtype}; dtype,
groups,
compute_format};
auto handle = ctx.cudnn_handle(); auto handle = ctx.cudnn_handle();
auto workspace_handle = ctx.cudnn_workspace_handle(); auto workspace_handle = ctx.cudnn_workspace_handle();
...@@ -314,7 +316,7 @@ void ConvCudnnKernel(const Context& ctx, ...@@ -314,7 +316,7 @@ void ConvCudnnKernel(const Context& ctx,
using search = using search =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result = search::Find<T>(args, exhaustive_search, deterministic, ctx); fwd_result = search::Find<T>(args, exhaustive_search, deterministic, ctx);
workspace_size = search::GetWorkspaceSize(args, fwd_result.algo); workspace_size = fwd_result.workspace_size;
#endif #endif
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1) #if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1)
......
...@@ -179,14 +179,18 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, ...@@ -179,14 +179,18 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
strides, strides,
padding_common, padding_common,
dilations_, dilations_,
dtype}; dtype,
groups,
layout};
paddle::operators::ConvArgs args2{&transformed_dout, paddle::operators::ConvArgs args2{&transformed_dout,
&filter, &filter,
&x_transpose, &x_transpose,
strides, strides,
padding_common, padding_common,
dilations_, dilations_,
dtype}; dtype,
groups,
layout};
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result; paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result;
...@@ -625,6 +629,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -625,6 +629,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
auto dtype = paddle::platform::CudnnDataType<T>::type; auto dtype = paddle::platform::CudnnDataType<T>::type;
auto handle = ctx.cudnn_handle(); auto handle = ctx.cudnn_handle();
auto layout = paddle::platform::GetCudnnTensorFormat(GPUDNNDataLayout::kNCHW);
paddle::operators::ConvArgs args1{&transformed_ddout_channel, paddle::operators::ConvArgs args1{&transformed_ddout_channel,
&filter, &filter,
...@@ -632,14 +637,18 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -632,14 +637,18 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
strides, strides,
padding_common, padding_common,
dilations_, dilations_,
dtype}; dtype,
groups,
GPUDNNDataLayout::kNCHW};
paddle::operators::ConvArgs args2{&transformed_ddout_channel, paddle::operators::ConvArgs args2{&transformed_ddout_channel,
&ddfilter, &ddfilter,
&transformed_x, &transformed_x,
strides, strides,
padding_common, padding_common,
dilations_, dilations_,
dtype}; dtype,
groups,
GPUDNNDataLayout::kNCHW};
paddle::operators::ConvArgs args3{&transformed_dout, paddle::operators::ConvArgs args3{&transformed_dout,
dfilter, dfilter,
...@@ -647,14 +656,18 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -647,14 +656,18 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
strides, strides,
padding_common, padding_common,
dilations_, dilations_,
dtype}; dtype,
groups,
GPUDNNDataLayout::kNCHW};
paddle::operators::ConvArgs args4{&transformed_dout, paddle::operators::ConvArgs args4{&transformed_dout,
&ddfilter, &ddfilter,
&transformed_dx_channel, &transformed_dx_channel,
strides, strides,
padding_common, padding_common,
dilations_, dilations_,
dtype}; dtype,
groups,
GPUDNNDataLayout::kNCHW};
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> bwd_result1; paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> bwd_result1;
paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> bwd_result2; paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> bwd_result2;
...@@ -669,8 +682,6 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -669,8 +682,6 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result; paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result;
#endif #endif
auto layout = paddle::platform::GetCudnnTensorFormat(GPUDNNDataLayout::kNCHW);
// ddo = conv(ddI, filter) + conv(I, ddfilter) // ddo = conv(ddI, filter) + conv(I, ddfilter)
size_t workspace_size = 0; size_t workspace_size = 0;
......
...@@ -205,7 +205,9 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx, ...@@ -205,7 +205,9 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx,
strides, strides,
padding_common, padding_common,
dilations_, dilations_,
dtype}; dtype,
groups,
data_layout};
args.handle = handle; args.handle = handle;
args.idesc.set(transformed_out, iwo_groups); args.idesc.set(transformed_out, iwo_groups);
args.wdesc.set(filter, layout_tensor, iwo_groups); args.wdesc.set(filter, layout_tensor, iwo_groups);
......
...@@ -71,11 +71,8 @@ class TestAutoTune(unittest.TestCase): ...@@ -71,11 +71,8 @@ class TestAutoTune(unittest.TestCase):
} }
if paddle.is_compiled_with_cuda(): if paddle.is_compiled_with_cuda():
# Total 3 * num_iters cache accesses, only iter 2 hits the cache. # Total 3 * num_iters cache accesses, only iter 2 hits the cache.
if enable_autotune and step_id >= 1: expected_res["cache_size"] = 3
expected_res["cache_size"] = 3 expected_res["cache_hit_rate"] = (step_id + 0.0) / (step_id + 1.0)
if enable_autotune and step_id == 2:
expected_res["cache_hit_rate"] = np.round(
float(3) / float(9), 5)
return expected_res return expected_res
def test_autotune(self): def test_autotune(self):
...@@ -91,11 +88,11 @@ class TestAutoTune(unittest.TestCase): ...@@ -91,11 +88,11 @@ class TestAutoTune(unittest.TestCase):
def check_status(self, expected_res): def check_status(self, expected_res):
status = paddle.fluid.core.autotune_status() status = paddle.fluid.core.autotune_status()
for key in status.keys(): for key in status.keys():
v = status[key]
if key == "cache_hit_rate": if key == "cache_hit_rate":
v = np.round(status[key], 5) self.assertTrue(np.allclose(v, expected_res[key]))
else: else:
v = status[key] self.assertEqual(v, expected_res[key])
self.assertEqual(v, expected_res[key])
class TestDygraphAutoTuneStatus(TestAutoTune): class TestDygraphAutoTuneStatus(TestAutoTune):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册