未验证 提交 8b853b30 编写于 作者: W wangchaochaohu 提交者: GitHub

fix the number of perf algo for conv cudnn in exhaustive mode (#28694)

上级 8c0ea4bf
...@@ -429,7 +429,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -429,7 +429,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
x_dims, w_dims, args.s, args.p, args.d, 0, x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() { static_cast<int64_t>(args.cudnn_dtype), [&]() {
int returned_algo_count; int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat; std::array<perf_t, kNUM_CUDNN_BWD_DATA_ALGS> perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
...@@ -561,7 +561,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -561,7 +561,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
x_dims, w_dims, args.s, args.p, args.d, 0, x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() { static_cast<int64_t>(args.cudnn_dtype), [&]() {
int returned_algo_count; int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat; std::array<perf_t, kNUM_CUDNN_BWD_FILTER_ALGS> perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload:: platform::dynload::
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册