未验证 提交 f0ee1592 编写于 作者: Z Zhang Ting 提交者: GitHub

enable exhaustive_search for forward and backward algos when dtype is float16 (#30959)

* enable exhaustive_search for input_grad when dtype is float16

* enable exhaustive_search for forward algos
上级 9b54fe41
...@@ -203,7 +203,6 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -203,7 +203,6 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
bool has_got_workspace_size = true; bool has_got_workspace_size = true;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0; size_t workspace_size = 0;
algo_t algo; algo_t algo;
...@@ -227,7 +226,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -227,7 +226,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
} }
#endif #endif
if (!exhaustive && !deterministic) { if (!exhaustive_search && !deterministic) {
#if CUDNN_VERSION >= 7001 #if CUDNN_VERSION >= 7001
int perf_count; int perf_count;
int best_algo_idx = 0; int best_algo_idx = 0;
...@@ -337,7 +336,6 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -337,7 +336,6 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
bool deterministic, bool deterministic,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0; size_t workspace_size = 0;
bool has_got_workspace_size = true; bool has_got_workspace_size = true;
...@@ -361,7 +359,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -361,7 +359,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
} }
#endif #endif
if (!exhaustive && !deterministic) { if (!exhaustive_search && !deterministic) {
#if CUDNN_VERSION >= 7001 #if CUDNN_VERSION >= 7001
int perf_count; int perf_count;
int best_algo_idx = 0; int best_algo_idx = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册