提交 c63bce8a 编写于 作者: Z zhangting2020

tune algo only when dtype is float16

上级 62eab2dc
...@@ -91,7 +91,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) { ...@@ -91,7 +91,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
return out; return out;
} }
inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) { inline int MaxBwdFilterAlgos(cudnnHandle_t cudnn_handle) {
int max_algos = 0; int max_algos = 0;
#if CUDNN_VERSION_MIN(7, 0, 1) #if CUDNN_VERSION_MIN(7, 0, 1)
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
...@@ -102,38 +102,23 @@ inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) { ...@@ -102,38 +102,23 @@ inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
} }
template <typename PerfType, typename AlgoType> template <typename PerfType, typename AlgoType>
void AlgoFinalSelect(const std::vector<PerfType>& perf_results, void ChooseAlgo(const std::vector<PerfType>& perf_results,
std::string kernel_name, int32_t algo_preference, size_t workspace_byte, AlgoType* algo) {
size_t workspace_byte, VLOG(3) << "=========BwdFilterAlgo Perf result=========";
cudnnConvolutionBwdFilterAlgo_t* algo,
bool deterministic) {
// Determine the fastest acceptable algo that matches the algo_preference (-1
// = any),
// regardless of mathType.
VLOG(3) << "=========Full results of algo=========" << kernel_name << ":";
for (const auto& result : perf_results) { for (const auto& result : perf_results) {
auto math_type_str = "-"; auto math_type_str = "0";
if (result.mathType == CUDNN_TENSOR_OP_MATH) { if (result.mathType == CUDNN_TENSOR_OP_MATH) {
math_type_str = "+"; math_type_str = "1";
} }
VLOG(3) << " algo: " << result.algo << ", TC: " << math_type_str
VLOG(3) << " algo: " << result.algo << ", TC" << math_type_str
<< ", time: " << result.time << " ms" << ", time: " << result.time << " ms"
<< ", wksp = " << result.memory << ", status = " << result.status; << ", wksp = " << result.memory << ", status = " << result.status;
} }
for (decltype(perf_results.size()) i = 0; i != perf_results.size(); ++i) { for (size_t i = 0; i != perf_results.size(); ++i) {
const auto& result = perf_results[i]; const auto& result = perf_results[i];
bool algo_is_tensor_core = false;
algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH;
bool algo_exclusion = 0;
if (result.status == CUDNN_STATUS_SUCCESS && if (result.status == CUDNN_STATUS_SUCCESS &&
(!deterministic || (result.memory <= workspace_byte)) {
result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) &&
(result.memory <= workspace_byte) &&
(algo_preference == -1 || algo_preference == result.algo) &&
!algo_exclusion) {
if ((result.mathType == CUDNN_TENSOR_OP_MATH) && if ((result.mathType == CUDNN_TENSOR_OP_MATH) &&
(i != perf_results.size() - 1)) { (i != perf_results.size() - 1)) {
const auto& next_result = perf_results[i + 1]; const auto& next_result = perf_results[i + 1];
...@@ -143,16 +128,17 @@ void AlgoFinalSelect(const std::vector<PerfType>& perf_results, ...@@ -143,16 +128,17 @@ void AlgoFinalSelect(const std::vector<PerfType>& perf_results,
next_result.mathType != CUDNN_TENSOR_OP_MATH && next_result.mathType != CUDNN_TENSOR_OP_MATH &&
next_result.time < 1.01 * result.time) { next_result.time < 1.01 * result.time) {
// Skip over this result- it's not really a Tensor Core algo. // Skip over this result- it's not really a Tensor Core algo.
// Prefer instead the next equivalent non-Tensor Core algo. // Because it is only 1% performance difference.
// Prefer to choose the next equivalent non-Tensor Core algo.
continue; continue;
} }
} }
*algo = result.algo; *algo = result.algo;
auto math_type_str = "-"; auto math_type_str = "0";
if (result.mathType == CUDNN_TENSOR_OP_MATH) { if (result.mathType == CUDNN_TENSOR_OP_MATH) {
math_type_str = "+"; math_type_str = "1";
} }
VLOG(3) << " choose algo: " << result.algo << ", TC" << math_type_str VLOG(3) << " choose algo: " << result.algo << ", TC: " << math_type_str
<< ", time: " << result.time << " ms" << ", time: " << result.time << " ms"
<< ", wksp = " << result.memory << ", status = " << result.status; << ", wksp = " << result.memory << ", status = " << result.status;
return; return;
...@@ -443,8 +429,6 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -443,8 +429,6 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
bool deterministic, bool deterministic,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
// bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
bool exhaustive = exhaustive_search;
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0; size_t workspace_size = 0;
bool has_got_workspace_size = true; bool has_got_workspace_size = true;
...@@ -465,9 +449,8 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -465,9 +449,8 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
#endif #endif
algo_t algo; algo_t algo;
if (!exhaustive && !deterministic) { if (!exhaustive_search && !deterministic) {
#if CUDNN_VERSION >= 7001 #if CUDNN_VERSION >= 7001
VLOG(3) << "=====Not exhaustive=====";
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t; using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
int perf_count; int perf_count;
int best_algo_idx = 0; int best_algo_idx = 0;
...@@ -494,7 +477,6 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -494,7 +477,6 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
} else if (deterministic) { } else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else { } else {
VLOG(3) << "=======exhaustive=======: " << exhaustive;
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
...@@ -507,7 +489,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -507,7 +489,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:" VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d; << args.s << ", args.p" << args.p << ", args.d" << args.d;
/* if (dtype != CUDNN_DATA_HALF) {
algo = algo_cache.GetAlgorithm( algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0, x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() { static_cast<int64_t>(args.cudnn_dtype), [&]() {
...@@ -525,9 +507,11 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -525,9 +507,11 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
perf_stat.data(), cudnn_workspace_ptr, perf_stat.data(), cudnn_workspace_ptr,
workspace_size_limit)); workspace_size_limit));
}; };
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); workspace_handle.RunFuncSync(cudnn_find_func,
workspace_size_limit);
VLOG(3) << "BwdFilterAlgo Perf result: (algo: stat, time, memory)"; VLOG(3)
<< "BwdFilterAlgo Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) { for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = perf_stat[i]; const auto& stat = perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
...@@ -535,34 +519,28 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -535,34 +519,28 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
} }
return perf_stat[0].algo; return perf_stat[0].algo;
}); });
*/ } else {
auto max_algos = MaxBwdFilterAlgos(args.handle);
algo = algo_cache.GetAlgorithm( algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0, x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() { static_cast<int64_t>(args.cudnn_dtype), [&]() {
algo_t sel_algo; algo_t chosen_algo;
auto max_bwd_filt_algos = MaxBackwardFilterAlgos(args.handle); std::vector<perf_t> perf_results(max_algos);
std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results( int actual_algos = 0;
max_bwd_filt_algos);
int actual_bwd_filter_algos = 0;
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnFindConvolutionBackwardFilterAlgorithm( platform::dynload::
cudnnFindConvolutionBackwardFilterAlgorithm(
args.handle, args.idesc.desc(), args.odesc.desc(), args.handle, args.idesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.wdesc.desc(), args.cdesc.desc(), args.wdesc.desc(),
bwd_filt_results.size(), &actual_bwd_filter_algos, perf_results.size(), &actual_algos,
bwd_filt_results.data())); perf_results.data()));
bwd_filt_results.resize(actual_bwd_filter_algos); perf_results.resize(actual_algos);
AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t, ChooseAlgo<perf_t, algo_t>(perf_results, workspace_size_limit,
cudnnConvolutionBwdFilterAlgo_t>( &chosen_algo);
bwd_filt_results, "backprop-to-filter", -1, return chosen_algo;
workspace_size_limit, &sel_algo, deterministic);
workspace_size = GetWorkspaceSize(args, sel_algo);
if (workspace_size > workspace_size_limit) {
workspace_size = workspace_size_limit;
}
return sel_algo;
}); });
} }
}
VLOG(3) << "choose algo " << algo; VLOG(3) << "choose algo " << algo;
return algo; return algo;
} }
......
...@@ -336,11 +336,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -336,11 +336,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
bool exhaustive_search = bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search"); FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
VLOG(3) << "=====exhaustive_search====: " << exhaustive_search;
VLOG(3) << "====FLAGS_cudnn_exhaustive_search====: "
<< FLAGS_cudnn_exhaustive_search;
VLOG(3) << "====Attr: exhaustive_search====: "
<< ctx.Attr<bool>("exhaustive_search");
bool deterministic = FLAGS_cudnn_deterministic; bool deterministic = FLAGS_cudnn_deterministic;
if (exhaustive_search && deterministic) { if (exhaustive_search && deterministic) {
PADDLE_THROW( PADDLE_THROW(
......
...@@ -185,7 +185,8 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) ...@@ -185,7 +185,8 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
__macro(cudnnCTCLoss); \ __macro(cudnnCTCLoss); \
__macro(cudnnGetConvolutionBackwardDataAlgorithm_v7); \ __macro(cudnnGetConvolutionBackwardDataAlgorithm_v7); \
__macro(cudnnGetConvolutionBackwardFilterAlgorithm_v7); \ __macro(cudnnGetConvolutionBackwardFilterAlgorithm_v7); \
__macro(cudnnGetConvolutionForwardAlgorithm_v7); __macro(cudnnGetConvolutionForwardAlgorithm_v7); \
__macro(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount);
CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif #endif
...@@ -195,8 +196,7 @@ CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) ...@@ -195,8 +196,7 @@ CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
__macro(cudnnBatchNormalizationForwardTrainingEx); \ __macro(cudnnBatchNormalizationForwardTrainingEx); \
__macro(cudnnGetBatchNormalizationBackwardExWorkspaceSize); \ __macro(cudnnGetBatchNormalizationBackwardExWorkspaceSize); \
__macro(cudnnBatchNormalizationBackwardEx); \ __macro(cudnnBatchNormalizationBackwardEx); \
__macro(cudnnGetBatchNormalizationTrainingExReserveSpaceSize); \ __macro(cudnnGetBatchNormalizationTrainingExReserveSpaceSize);
__macro(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount);
CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册