未验证 提交 3bc4b850 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enable to record whether the conv algo is got by exhaustive search to fix...

Enable to record whether the conv algo is got by exhaustive search to fix autotune cache bug. (#47065)
上级 af4bdede
...@@ -40,9 +40,6 @@ using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType; ...@@ -40,9 +40,6 @@ using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
template <typename AlgoT> template <typename AlgoT>
struct SearchResult { struct SearchResult {
SearchResult() {} SearchResult() {}
explicit SearchResult(const phi::autotune::DnnNode& node)
: algo(static_cast<AlgoT>(node.algo)),
workspace_size(node.workspace_size) {}
explicit SearchResult(AlgoT a) : algo(a) {} explicit SearchResult(AlgoT a) : algo(a) {}
explicit SearchResult(AlgoT a, float t, size_t size) explicit SearchResult(AlgoT a, float t, size_t size)
...@@ -51,12 +48,21 @@ struct SearchResult { ...@@ -51,12 +48,21 @@ struct SearchResult {
AlgoT algo = static_cast<AlgoT>(0); AlgoT algo = static_cast<AlgoT>(0);
float time = -1.f; float time = -1.f;
size_t workspace_size = 0; size_t workspace_size = 0;
bool exhaustive_search = false;
}; };
template <typename T> template <typename T>
static std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) { static std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
out << "["; out << "[";
for (auto const& tmp : v) out << tmp << ","; bool is_first = true;
for (auto const& tmp : v) {
if (is_first) {
out << tmp;
is_first = false;
} else {
out << ", " << tmp;
}
}
out << "]"; out << "]";
return out; return out;
} }
...@@ -109,7 +115,7 @@ struct ConvArgsBase { ...@@ -109,7 +115,7 @@ struct ConvArgsBase {
auto w_shape = phi::vectorize(w->dims()); auto w_shape = phi::vectorize(w->dims());
VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape
<< ", strides=" << s << ", paddings=" << p << ", dilations=" << d << ", strides=" << s << ", paddings=" << p << ", dilations=" << d
<< ",data= " << paddle::experimental::CppTypeToDataType<T>::Type() << ", data=" << paddle::experimental::CppTypeToDataType<T>::Type()
<< ", group=" << group << ", group=" << group
<< ", data layout=" << static_cast<int64_t>(data_layout); << ", data layout=" << static_cast<int64_t>(data_layout);
......
...@@ -14,12 +14,11 @@ limitations under the License. */ ...@@ -14,12 +14,11 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/conv_base_helper.h" #include "paddle/fluid/operators/conv_base_helper.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h" #include "paddle/phi/kernels/autotune/switch_autotune.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace paddle { namespace paddle {
...@@ -53,10 +52,8 @@ static void RemovePaddingSlice(const phi::GPUContext& context, ...@@ -53,10 +52,8 @@ static void RemovePaddingSlice(const phi::GPUContext& context,
} }
auto in_t = auto in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From( phi::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(*input);
*input); auto out_t = phi::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
auto out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*out, new_out_dims); *out, new_out_dims);
phi::funcs::EigenSlice<std::decay_t<decltype(place)>, T, D>::Eval( phi::funcs::EigenSlice<std::decay_t<decltype(place)>, T, D>::Eval(
...@@ -161,6 +158,8 @@ struct SearchAlgorithmBase<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -161,6 +158,8 @@ struct SearchAlgorithmBase<cudnnConvolutionFwdAlgoPerf_t> {
constexpr static phi::autotune::AlgorithmType kAlgoType = constexpr static phi::autotune::AlgorithmType kAlgoType =
phi::autotune::AlgorithmType::kConvForward; phi::autotune::AlgorithmType::kConvForward;
static const std::string GetPerfName() { return "ConvForward"; }
static size_t GetWorkspaceSize(const ConvArgs& args, static size_t GetWorkspaceSize(const ConvArgs& args,
cudnnConvolutionFwdAlgo_t algo) { cudnnConvolutionFwdAlgo_t algo) {
size_t workspace_size = 0; size_t workspace_size = 0;
...@@ -334,6 +333,8 @@ struct SearchAlgorithmBase<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -334,6 +333,8 @@ struct SearchAlgorithmBase<cudnnConvolutionBwdDataAlgoPerf_t> {
constexpr static phi::autotune::AlgorithmType kAlgoType = constexpr static phi::autotune::AlgorithmType kAlgoType =
phi::autotune::AlgorithmType::kConvBackwardData; phi::autotune::AlgorithmType::kConvBackwardData;
static const std::string GetPerfName() { return "ConvBackwardData"; }
static size_t GetWorkspaceSize(const ConvArgs& args, static size_t GetWorkspaceSize(const ConvArgs& args,
cudnnConvolutionBwdDataAlgo_t algo) { cudnnConvolutionBwdDataAlgo_t algo) {
size_t workspace_size = 0; size_t workspace_size = 0;
...@@ -514,6 +515,8 @@ struct SearchAlgorithmBase<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -514,6 +515,8 @@ struct SearchAlgorithmBase<cudnnConvolutionBwdFilterAlgoPerf_t> {
constexpr static phi::autotune::AlgorithmType kAlgoType = constexpr static phi::autotune::AlgorithmType kAlgoType =
phi::autotune::AlgorithmType::kConvBackwardFilter; phi::autotune::AlgorithmType::kConvBackwardFilter;
static const std::string GetPerfName() { return "ConvBackwardFilter"; }
static size_t GetWorkspaceSize(const ConvArgs& args, static size_t GetWorkspaceSize(const ConvArgs& args,
cudnnConvolutionBwdFilterAlgo_t algo) { cudnnConvolutionBwdFilterAlgo_t algo) {
platform::CUDAGraphCaptureModeGuard guard; platform::CUDAGraphCaptureModeGuard guard;
...@@ -752,11 +755,13 @@ struct SearchAlgorithm : public SearchAlgorithmBase<PerfT> { ...@@ -752,11 +755,13 @@ struct SearchAlgorithm : public SearchAlgorithmBase<PerfT> {
using AlgoT = typename SearchAlgorithmBase<PerfT>::AlgoT; using AlgoT = typename SearchAlgorithmBase<PerfT>::AlgoT;
template <typename T> template <typename T>
static SearchResult<AlgoT> Find(const ConvArgs& args, static SearchResult<AlgoT> Find(const phi::GPUContext& ctx,
const ConvArgs& args,
bool exhaustive_search, bool exhaustive_search,
bool deterministic, bool deterministic,
const phi::GPUContext& ctx) { bool enable_autotune = true) {
SearchResult<AlgoT> result; SearchResult<AlgoT> result;
bool use_autotune = false;
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
SetConvMathType(ctx, dtype, args.cdesc); SetConvMathType(ctx, dtype, args.cdesc);
...@@ -764,33 +769,50 @@ struct SearchAlgorithm : public SearchAlgorithmBase<PerfT> { ...@@ -764,33 +769,50 @@ struct SearchAlgorithm : public SearchAlgorithmBase<PerfT> {
result = SearchAlgorithmBase<PerfT>::FindAlgoDeterministic(args); result = SearchAlgorithmBase<PerfT>::FindAlgoDeterministic(args);
} else { } else {
// 1. Once turning on exhaustive FLAGS, always get exhaustive_search. // 1. Once turning on exhaustive FLAGS, always get exhaustive_search.
// 2. Once turning on auto-tune, runn heuristic search(default) before // 2. Once turning on auto-tune, run heuristic (default) before
// auto-tune process, run exhaustive_search during mentioned process. // auto-tune process, run exhaustive_search during mentioned process.
// Auto tune is only enabled between specified range.
// 3. After auto-tune process, run cached algorithm if cached, run // 3. After auto-tune process, run cached algorithm if cached, run
// default mode for the rest. // default mode for the rest.
auto key = args.Convert2ConvCacheKey<T>(); auto key = args.Convert2ConvCacheKey<T>();
auto& cache = phi::autotune::AutoTuneCache::Instance().GetConv( auto& cache = phi::autotune::AutoTuneCache::Instance().GetConv(
SearchAlgorithmBase<PerfT>::kAlgoType); SearchAlgorithmBase<PerfT>::kAlgoType);
if (cache.Find(key)) { bool find_in_cache = cache.Find(key);
if (find_in_cache) {
auto t = cache.Get(key); auto t = cache.Get(key);
result.algo = static_cast<AlgoT>(t.algo); result.algo = static_cast<AlgoT>(t.algo);
result.workspace_size = t.workspace_size; result.workspace_size = t.workspace_size;
} else { result.exhaustive_search = t.exhaustive_search;
bool use_autotune = }
if (!result.exhaustive_search) {
bool need_update_cache = false;
// In conv2d_tranpose, enable_autotune is set to false because some
// algorithm picked by exhaustive search method produce wrong result.
use_autotune = enable_autotune &&
phi::autotune::AutoTuneStatus::Instance().UseAutoTune(); phi::autotune::AutoTuneStatus::Instance().UseAutoTune();
if (exhaustive_search || use_autotune) { if (exhaustive_search || use_autotune) {
// Once autotune is enabled, the autotuned result can rewrite the
// previous result in cache found by heuristic method.
result = result =
SearchAlgorithmBase<PerfT>::template FindAlgoExhaustiveSearch<T>( SearchAlgorithmBase<PerfT>::template FindAlgoExhaustiveSearch<T>(
args, ctx); args, ctx);
} else { need_update_cache = true;
} else if (!find_in_cache) {
result = SearchAlgorithmBase<PerfT>::FindAlgoHeuristic(args, ctx); result = SearchAlgorithmBase<PerfT>::FindAlgoHeuristic(args, ctx);
need_update_cache = true;
} }
phi::autotune::DnnNode node(static_cast<int64_t>(result.algo), if (need_update_cache) {
result.workspace_size); phi::autotune::ConvAutoTuneResult node(
static_cast<int64_t>(result.algo),
result.workspace_size,
exhaustive_search || use_autotune);
cache.Set(key, node); cache.Set(key, node);
} }
} }
VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search }
VLOG(3) << "[cuDNN " << SearchAlgorithmBase<PerfT>::GetPerfName()
<< "] exhaustive_search=" << exhaustive_search
<< ", use_autotune=" << use_autotune
<< ", deterministic=" << deterministic << ", deterministic=" << deterministic
<< ", choose algo=" << result.algo << ", choose algo=" << result.algo
<< ", workspace=" << ToMegaBytes(result.workspace_size) << " MB"; << ", workspace=" << ToMegaBytes(result.workspace_size) << " MB";
......
...@@ -56,12 +56,14 @@ struct hash<std::vector<T>> { ...@@ -56,12 +56,14 @@ struct hash<std::vector<T>> {
namespace phi { namespace phi {
namespace autotune { namespace autotune {
struct DnnNode { struct ConvAutoTuneResult {
DnnNode() {} ConvAutoTuneResult() {}
explicit DnnNode(int64_t a, size_t size) : algo(a), workspace_size(size) {} ConvAutoTuneResult(int64_t a, size_t size, bool search)
: algo(a), workspace_size(size), exhaustive_search(search) {}
int64_t algo; int64_t algo;
size_t workspace_size = 0; size_t workspace_size = 0;
bool exhaustive_search = false;
}; };
template <typename... Args> template <typename... Args>
...@@ -73,40 +75,41 @@ size_t GetKey(Args&&... args) { ...@@ -73,40 +75,41 @@ size_t GetKey(Args&&... args) {
struct ConvCacheKey { struct ConvCacheKey {
ConvCacheKey() {} ConvCacheKey() {}
explicit ConvCacheKey(const std::vector<int64_t>& x_dims, ConvCacheKey(const std::vector<int64_t>& arg_x_dims,
const std::vector<int64_t>& w_dims, const std::vector<int64_t>& arg_w_dims,
const std::vector<int>& strides, const std::vector<int>& arg_strides,
const std::vector<int>& paddings, const std::vector<int>& arg_paddings,
const std::vector<int>& dilations, const std::vector<int>& arg_dilations,
phi::DataType dtype, phi::DataType arg_dtype,
int groups, int arg_groups,
int64_t data_layout) int64_t arg_data_layout)
: x_dims_(x_dims), : x_dims(arg_x_dims),
w_dims_(w_dims), w_dims(arg_w_dims),
strides_(strides), strides(arg_strides),
paddings_(paddings), paddings(arg_paddings),
dilations_(dilations), dilations(arg_dilations),
dtype_(dtype), dtype(arg_dtype),
groups_(groups), groups(arg_groups),
data_layout_(data_layout) {} data_layout(arg_data_layout) {}
size_t hash_value() const { size_t hash_value() const {
return GetKey(x_dims_, return GetKey(x_dims,
w_dims_, w_dims,
strides_, strides,
paddings_, paddings,
dilations_, dilations,
static_cast<int64_t>(dtype_), static_cast<int64_t>(dtype),
groups_, groups,
data_layout_); data_layout);
} }
std::vector<int64_t> x_dims_;
std::vector<int64_t> w_dims_; std::vector<int64_t> x_dims;
std::vector<int> strides_; std::vector<int64_t> w_dims;
std::vector<int> paddings_; std::vector<int> strides;
std::vector<int> dilations_; std::vector<int> paddings;
phi::DataType dtype_; std::vector<int> dilations;
int groups_; phi::DataType dtype;
int64_t data_layout_; int groups;
int64_t data_layout;
}; };
struct ConvCacheKeyHash { struct ConvCacheKeyHash {
...@@ -118,14 +121,14 @@ struct ConvCacheKeyHash { ...@@ -118,14 +121,14 @@ struct ConvCacheKeyHash {
struct ConvCacheKeyEqual { struct ConvCacheKeyEqual {
size_t operator()(const ConvCacheKey& first, size_t operator()(const ConvCacheKey& first,
const ConvCacheKey& second) const { const ConvCacheKey& second) const {
if (first.x_dims_ != second.x_dims_) return false; if (first.x_dims != second.x_dims) return false;
if (first.w_dims_ != second.w_dims_) return false; if (first.w_dims != second.w_dims) return false;
if (first.strides_ != second.strides_) return false; if (first.strides != second.strides) return false;
if (first.paddings_ != second.paddings_) return false; if (first.paddings != second.paddings) return false;
if (first.dilations_ != second.dilations_) return false; if (first.dilations != second.dilations) return false;
if (first.dtype_ != second.dtype_) return false; if (first.dtype != second.dtype) return false;
if (first.groups_ != second.groups_) return false; if (first.groups != second.groups) return false;
if (first.data_layout_ != second.data_layout_) return false; if (first.data_layout != second.data_layout) return false;
return true; return true;
} }
...@@ -135,7 +138,7 @@ class CudnnAlgorithmsCacheMap { ...@@ -135,7 +138,7 @@ class CudnnAlgorithmsCacheMap {
public: public:
CudnnAlgorithmsCacheMap() : cache_mutex_(new std::mutex()) { hash_.clear(); } CudnnAlgorithmsCacheMap() : cache_mutex_(new std::mutex()) { hash_.clear(); }
DnnNode Get(const ConvCacheKey& key) { ConvAutoTuneResult Get(const ConvCacheKey& key) {
std::lock_guard<std::mutex> lock(*cache_mutex_); std::lock_guard<std::mutex> lock(*cache_mutex_);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
hash_.find(key), hash_.find(key),
...@@ -163,7 +166,7 @@ class CudnnAlgorithmsCacheMap { ...@@ -163,7 +166,7 @@ class CudnnAlgorithmsCacheMap {
cache_misses_ = 0; cache_misses_ = 0;
} }
void Set(const ConvCacheKey& key, DnnNode algo) { void Set(const ConvCacheKey& key, ConvAutoTuneResult algo) {
std::lock_guard<std::mutex> lock(*cache_mutex_); std::lock_guard<std::mutex> lock(*cache_mutex_);
if (hash_.size() > static_cast<size_t>(FLAGS_search_cache_max_number)) { if (hash_.size() > static_cast<size_t>(FLAGS_search_cache_max_number)) {
hash_.clear(); hash_.clear();
...@@ -188,7 +191,10 @@ class CudnnAlgorithmsCacheMap { ...@@ -188,7 +191,10 @@ class CudnnAlgorithmsCacheMap {
int64_t Size() const { return hash_.size(); } int64_t Size() const { return hash_.size(); }
private: private:
std::unordered_map<ConvCacheKey, DnnNode, ConvCacheKeyHash, ConvCacheKeyEqual> std::unordered_map<ConvCacheKey,
ConvAutoTuneResult,
ConvCacheKeyHash,
ConvCacheKeyEqual>
hash_; hash_;
std::shared_ptr<std::mutex> cache_mutex_; std::shared_ptr<std::mutex> cache_mutex_;
...@@ -293,21 +299,6 @@ class AutoTuneCache { ...@@ -293,21 +299,6 @@ class AutoTuneCache {
return cudnn_auto_tune_map_[static_cast<int64_t>(algo_type)]; return cudnn_auto_tune_map_[static_cast<int64_t>(algo_type)];
} }
CudnnAlgorithmsCacheMap& GetConvForward() {
return cudnn_auto_tune_map_[static_cast<int64_t>(
AlgorithmType::kConvForward)];
}
CudnnAlgorithmsCacheMap& GetConvBackwardData() {
return cudnn_auto_tune_map_[static_cast<int64_t>(
AlgorithmType::kConvBackwardData)];
}
CudnnAlgorithmsCacheMap& GetConvBackwardFilter() {
return cudnn_auto_tune_map_[static_cast<int64_t>(
AlgorithmType::kConvBackwardFilter)];
}
AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); } AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); }
void Clean() { void Clean() {
......
...@@ -25,7 +25,8 @@ enum ConvAlgos { GEMMKernel = 0, CuDNNKernel_1 = 1, CuDNNKernel_2 = 2 }; ...@@ -25,7 +25,8 @@ enum ConvAlgos { GEMMKernel = 0, CuDNNKernel_1 = 1, CuDNNKernel_2 = 2 };
TEST(AlgosCache, AlgosCache) { TEST(AlgosCache, AlgosCache) {
auto autotune_cache = phi::autotune::AutoTuneCache::Instance(); auto autotune_cache = phi::autotune::AutoTuneCache::Instance();
auto& cache = autotune_cache.GetConvForward(); auto& cache =
autotune_cache.GetConv(phi::autotune::AlgorithmType::kConvForward);
std::vector<int64_t> x_shape = {4, 224, 224, 3}; std::vector<int64_t> x_shape = {4, 224, 224, 3};
std::vector<int64_t> w_shape = {32, 3, 3, 3}; std::vector<int64_t> w_shape = {32, 3, 3, 3};
...@@ -37,7 +38,8 @@ TEST(AlgosCache, AlgosCache) { ...@@ -37,7 +38,8 @@ TEST(AlgosCache, AlgosCache) {
phi::autotune::ConvCacheKey key( phi::autotune::ConvCacheKey key(
x_shape, w_shape, paddings, strides, dilations, dtype, 0, 0); x_shape, w_shape, paddings, strides, dilations, dtype, 0, 0);
EXPECT_EQ(cache.Find(key), false); EXPECT_EQ(cache.Find(key), false);
phi::autotune::DnnNode node(static_cast<int64_t>(ConvAlgos::GEMMKernel), 0); phi::autotune::ConvAutoTuneResult node(
static_cast<int64_t>(ConvAlgos::GEMMKernel), 0, false);
cache.Set(key, node); cache.Set(key, node);
EXPECT_EQ(cache.Size(), 1); EXPECT_EQ(cache.Size(), 1);
EXPECT_EQ(cache.Find(key), true); EXPECT_EQ(cache.Find(key), true);
...@@ -48,8 +50,8 @@ TEST(AlgosCache, AlgosCache) { ...@@ -48,8 +50,8 @@ TEST(AlgosCache, AlgosCache) {
phi::autotune::ConvCacheKey key1( phi::autotune::ConvCacheKey key1(
x_shape, w_shape, paddings, strides, dilations, dtype, 0, 1); x_shape, w_shape, paddings, strides, dilations, dtype, 0, 1);
EXPECT_EQ(cache.Find(key1), false); EXPECT_EQ(cache.Find(key1), false);
phi::autotune::DnnNode node1(static_cast<int64_t>(ConvAlgos::CuDNNKernel_1), phi::autotune::ConvAutoTuneResult node1(
0); static_cast<int64_t>(ConvAlgos::CuDNNKernel_1), 0, false);
cache.Set(key1, node1); cache.Set(key1, node1);
EXPECT_EQ(cache.Size(), 2); EXPECT_EQ(cache.Size(), 2);
EXPECT_EQ(cache.CacheHits(), 1); EXPECT_EQ(cache.CacheHits(), 1);
......
...@@ -336,7 +336,7 @@ void ConvCudnnGradGradKernel( ...@@ -336,7 +336,7 @@ void ConvCudnnGradGradKernel(
#else #else
using search1 = using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result1 = search1::Find<T>(args1, exhaustive_search, false, ctx); fwd_result1 = search1::Find<T>(ctx, args1, exhaustive_search, false);
workspace_size = search1::GetWorkspaceSize(args1, fwd_result1.algo); workspace_size = search1::GetWorkspaceSize(args1, fwd_result1.algo);
#endif #endif
} }
...@@ -364,7 +364,7 @@ void ConvCudnnGradGradKernel( ...@@ -364,7 +364,7 @@ void ConvCudnnGradGradKernel(
#else #else
using search2 = using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result2 = search2::Find<T>(args2, exhaustive_search, false, ctx); fwd_result2 = search2::Find<T>(ctx, args2, exhaustive_search, false);
workspace_size = std::max( workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, fwd_result2.algo)); workspace_size, search2::GetWorkspaceSize(args2, fwd_result2.algo));
#endif #endif
...@@ -394,7 +394,7 @@ void ConvCudnnGradGradKernel( ...@@ -394,7 +394,7 @@ void ConvCudnnGradGradKernel(
using search3 = using search3 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_result = filter_result =
search3::Find<T>(args3, exhaustive_search, deterministic, ctx); search3::Find<T>(ctx, args3, exhaustive_search, deterministic);
workspace_size = std::max( workspace_size = std::max(
workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo)); workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo));
#endif #endif
...@@ -424,7 +424,7 @@ void ConvCudnnGradGradKernel( ...@@ -424,7 +424,7 @@ void ConvCudnnGradGradKernel(
using search4 = using search4 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_result = data_result =
search4::Find<T>(args4, exhaustive_search, deterministic, ctx); search4::Find<T>(ctx, args4, exhaustive_search, deterministic);
workspace_size = std::max( workspace_size = std::max(
workspace_size, search4::GetWorkspaceSize(args4, data_result.algo)); workspace_size, search4::GetWorkspaceSize(args4, data_result.algo));
#endif #endif
......
...@@ -373,7 +373,7 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -373,7 +373,7 @@ void ConvCudnnGradKernel(const Context& ctx,
#else #else
using search1 = using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_result = search1::Find<T>(args1, exhaustive_search, deterministic, ctx); bwd_result = search1::Find<T>(ctx, args1, exhaustive_search, deterministic);
workspace_size_d = std::max(workspace_size_d, bwd_result.workspace_size); workspace_size_d = std::max(workspace_size_d, bwd_result.workspace_size);
#endif #endif
} }
...@@ -402,7 +402,7 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -402,7 +402,7 @@ void ConvCudnnGradKernel(const Context& ctx,
using search2 = using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_result = filter_result =
search2::Find<T>(args2, exhaustive_search, deterministic, ctx); search2::Find<T>(ctx, args2, exhaustive_search, deterministic);
VLOG(3) << "filter algo: " << filter_result.algo << ", time " VLOG(3) << "filter algo: " << filter_result.algo << ", time "
<< filter_result.time; << filter_result.time;
workspace_size_w = std::max(workspace_size_w, filter_result.workspace_size); workspace_size_w = std::max(workspace_size_w, filter_result.workspace_size);
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
#endif #endif
#include "paddle/fluid/platform/cudnn_workspace_helper.h" #include "paddle/fluid/platform/cudnn_workspace_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
...@@ -56,8 +55,7 @@ void ConvCudnnKernel(const Context& ctx, ...@@ -56,8 +55,7 @@ void ConvCudnnKernel(const Context& ctx,
bool exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search_t; bool exhaustive_search = FLAGS_cudnn_exhaustive_search || exhaustive_search_t;
bool deterministic = FLAGS_cudnn_deterministic; bool deterministic = FLAGS_cudnn_deterministic;
auto exhaustive_deterministic = exhaustive_search && deterministic; PADDLE_ENFORCE_EQ(exhaustive_search && deterministic,
PADDLE_ENFORCE_EQ(exhaustive_deterministic,
false, false,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Cann't set exhaustive_search True and " "Cann't set exhaustive_search True and "
...@@ -315,7 +313,7 @@ void ConvCudnnKernel(const Context& ctx, ...@@ -315,7 +313,7 @@ void ConvCudnnKernel(const Context& ctx,
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result; paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result;
using search = using search =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result = search::Find<T>(args, exhaustive_search, deterministic, ctx); fwd_result = search::Find<T>(ctx, args, exhaustive_search, deterministic);
workspace_size = fwd_result.workspace_size; workspace_size = fwd_result.workspace_size;
#endif #endif
......
...@@ -230,7 +230,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, ...@@ -230,7 +230,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
#else #else
using search1 = using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result = search1::Find<T>(args1, false, deterministic, ctx); fwd_result = search1::Find<T>(ctx, args1, false, deterministic, false);
workspace_size = std::max( workspace_size = std::max(
workspace_size, search1::GetWorkspaceSize(args1, fwd_result.algo)); workspace_size, search1::GetWorkspaceSize(args1, fwd_result.algo));
#endif #endif
...@@ -257,7 +257,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, ...@@ -257,7 +257,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
#else #else
using search2 = using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_result = search2::Find<T>(args2, false, deterministic, ctx); filter_result = search2::Find<T>(ctx, args2, false, deterministic, false);
workspace_size = std::max( workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, filter_result.algo)); workspace_size, search2::GetWorkspaceSize(args2, filter_result.algo));
#endif #endif
...@@ -710,7 +710,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -710,7 +710,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
#else #else
using search1 = using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_result1 = search1::Find<T>(args1, false, deterministic, ctx); bwd_result1 = search1::Find<T>(ctx, args1, false, deterministic, false);
workspace_size = search1::GetWorkspaceSize(args1, bwd_result1.algo); workspace_size = search1::GetWorkspaceSize(args1, bwd_result1.algo);
#endif #endif
...@@ -734,7 +734,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -734,7 +734,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
#else #else
using search2 = using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_result2 = search2::Find<T>(args2, false, deterministic, ctx); bwd_result2 = search2::Find<T>(ctx, args2, false, deterministic, false);
workspace_size = std::max( workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, bwd_result2.algo)); workspace_size, search2::GetWorkspaceSize(args2, bwd_result2.algo));
#endif #endif
...@@ -761,7 +761,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -761,7 +761,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
#else #else
using search3 = using search3 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_result = search3::Find<T>(args3, false, deterministic, ctx); filter_result = search3::Find<T>(ctx, args3, false, deterministic, false);
workspace_size = std::max( workspace_size = std::max(
workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo)); workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo));
#endif #endif
...@@ -789,7 +789,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -789,7 +789,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
#else #else
using search4 = using search4 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_result = search4::Find<T>(args4, false, deterministic, ctx); fwd_result = search4::Find<T>(ctx, args4, false, deterministic, false);
workspace_size = std::max( workspace_size = std::max(
workspace_size, search4::GetWorkspaceSize(args4, fwd_result.algo)); workspace_size, search4::GetWorkspaceSize(args4, fwd_result.algo));
#endif #endif
......
...@@ -230,7 +230,7 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx, ...@@ -230,7 +230,7 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx,
paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result; paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result;
using search = using search =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_result = search::Find<T>(args, false, deterministic, ctx); bwd_result = search::Find<T>(ctx, args, false, deterministic, false);
workspace_size = workspace_size =
std::max(workspace_size, search::GetWorkspaceSize(args, bwd_result.algo)); std::max(workspace_size, search::GetWorkspaceSize(args, bwd_result.algo));
#endif #endif
......
...@@ -72,15 +72,20 @@ def check_speed_result(case_name, develop_data, pr_data, pr_result): ...@@ -72,15 +72,20 @@ def check_speed_result(case_name, develop_data, pr_data, pr_result):
""" """
pr_gpu_time = pr_data.get("gpu_time") pr_gpu_time = pr_data.get("gpu_time")
develop_gpu_time = develop_data.get("gpu_time") develop_gpu_time = develop_data.get("gpu_time")
if develop_gpu_time != 0.0:
gpu_time_diff = (pr_gpu_time - develop_gpu_time) / develop_gpu_time gpu_time_diff = (pr_gpu_time - develop_gpu_time) / develop_gpu_time
gpu_time_diff_str = "{:.5f}".format(gpu_time_diff * 100)
else:
gpu_time_diff = None
gpu_time_diff_str = ""
pr_total_time = pr_data.get("total") pr_total_time = pr_data.get("total")
develop_total_time = develop_data.get("total") develop_total_time = develop_data.get("total")
total_time_diff = (pr_total_time - develop_total_time) / develop_total_time total_time_diff = (pr_total_time - develop_total_time) / develop_total_time
logging.info("------ OP: %s ------" % case_name) logging.info("------ OP: %s ------" % case_name)
logging.info("GPU time change: %.5f%% (develop: %.7f -> PR: %.7f)" % logging.info("GPU time change: %s (develop: %.7f -> PR: %.7f)" %
(gpu_time_diff * 100, develop_gpu_time, pr_gpu_time)) (gpu_time_diff_str, develop_gpu_time, pr_gpu_time))
logging.info("Total time change: %.5f%% (develop: %.7f -> PR: %.7f)" % logging.info("Total time change: %.5f%% (develop: %.7f -> PR: %.7f)" %
(total_time_diff * 100, develop_total_time, pr_total_time)) (total_time_diff * 100, develop_total_time, pr_total_time))
logging.info("backward: %s" % pr_result.get("backward")) logging.info("backward: %s" % pr_result.get("backward"))
...@@ -196,7 +201,8 @@ if __name__ == "__main__": ...@@ -196,7 +201,8 @@ if __name__ == "__main__":
args.develop_logs_dir) args.develop_logs_dir)
check_path_exists(args.pr_logs_dir) check_path_exists(args.pr_logs_dir)
for log_file in os.listdir(args.pr_logs_dir): pr_log_files = os.listdir(args.pr_logs_dir)
for log_file in sorted(pr_log_files):
develop_result = develop_result_dict.get(log_file) develop_result = develop_result_dict.get(log_file)
pr_result = parse_log_file(os.path.join(args.pr_logs_dir, log_file)) pr_result = parse_log_file(os.path.join(args.pr_logs_dir, log_file))
if develop_result is None or pr_result is None: if develop_result is None or pr_result is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册