提交 7deef68a 编写于 作者: Z zhangting2020

use exhaustive_search for float16

上级 8ebcf948
...@@ -443,7 +443,8 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -443,7 +443,8 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
bool deterministic, bool deterministic,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF); // bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
bool exhaustive = exhaustive_search;
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024; size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0; size_t workspace_size = 0;
bool has_got_workspace_size = true; bool has_got_workspace_size = true;
...@@ -539,7 +540,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -539,7 +540,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:" VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s" << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d; << args.s << ", args.p" << args.p << ", args.d" << args.d;
/*
algo = algo_cache.GetAlgorithm( algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0, x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() { static_cast<int64_t>(args.cudnn_dtype), [&]() {
...@@ -567,7 +568,34 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -567,7 +568,34 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
} }
return perf_stat[0].algo; return perf_stat[0].algo;
}); });
*/
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
algo_t sel_algo;
auto max_bwd_filt_algos = MaxBackwardFilterAlgos(args.handle);
std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(
max_bwd_filt_algos);
int actual_bwd_filter_algos = 0;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnFindConvolutionBackwardFilterAlgorithm(
args.handle, args.idesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.wdesc.desc(),
bwd_filt_results.size(), &actual_bwd_filter_algos,
bwd_filt_results.data()));
bwd_filt_results.resize(actual_bwd_filter_algos);
AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
cudnnConvolutionBwdFilterAlgo_t>(
bwd_filt_results, "backprop-to-filter", -1,
workspace_size_limit, &sel_algo, deterministic);
workspace_size = GetWorkspaceSize(args, sel_algo);
if (workspace_size > workspace_size_limit) {
workspace_size = workspace_size_limit;
}
return sel_algo;
});
} }
VLOG(3) << "choose algo " << algo; VLOG(3) << "choose algo " << algo;
return algo; return algo;
} }
......
...@@ -1582,7 +1582,7 @@ def conv2d(input, ...@@ -1582,7 +1582,7 @@ def conv2d(input,
'use_mkldnn': False, 'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False, 'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": padding_algorithm, "padding_algorithm": padding_algorithm,
"data_format": data_format, "data_format": data_format
}) })
if data_format == 'NCHW': if data_format == 'NCHW':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册