提交 7deef68a 编写于 作者: Z zhangting2020

use exhaustive_search for float16

上级 8ebcf948
......@@ -443,7 +443,8 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
bool deterministic,
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
// bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
bool exhaustive = exhaustive_search;
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
size_t workspace_size = 0;
bool has_got_workspace_size = true;
......@@ -539,7 +540,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;
/*
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
......@@ -567,7 +568,34 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
}
return perf_stat[0].algo;
});
*/
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
algo_t sel_algo;
auto max_bwd_filt_algos = MaxBackwardFilterAlgos(args.handle);
std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(
max_bwd_filt_algos);
int actual_bwd_filter_algos = 0;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnFindConvolutionBackwardFilterAlgorithm(
args.handle, args.idesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.wdesc.desc(),
bwd_filt_results.size(), &actual_bwd_filter_algos,
bwd_filt_results.data()));
bwd_filt_results.resize(actual_bwd_filter_algos);
AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
cudnnConvolutionBwdFilterAlgo_t>(
bwd_filt_results, "backprop-to-filter", -1,
workspace_size_limit, &sel_algo, deterministic);
workspace_size = GetWorkspaceSize(args, sel_algo);
if (workspace_size > workspace_size_limit) {
workspace_size = workspace_size_limit;
}
return sel_algo;
});
}
VLOG(3) << "choose algo " << algo;
return algo;
}
......
......@@ -1582,7 +1582,7 @@ def conv2d(input,
'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False,
"padding_algorithm": padding_algorithm,
"data_format": data_format,
"data_format": data_format
})
if data_format == 'NCHW':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册