conv_cudnn_helper.h 27.7 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/fluid/operators/conv_base_helper.h"
18
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
19 20
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
21

Q
qingqing01 已提交
22 23 24
namespace paddle {
namespace operators {

25
using ConvArgs = ConvArgsBase<cudnnHandle_t, cudnnDataType_t>;
26 27

template <typename DeviceContext, typename T, size_t D>
H
hong 已提交
28
static void RemovePaddingSlice(const phi::GPUContext& context,
29 30 31
                               const Tensor* input, Tensor* out,
                               const std::vector<int>& starts,
                               const std::vector<int>& axes) {
H
hong 已提交
32
  auto& place = *context.eigen_device();
33 34
  auto in_dims = input->dims();
  auto new_out_dims = out->dims();
35 36
  auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
  auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
37 38 39 40 41 42
  for (size_t i = 0; i < D; ++i) {
    offsets[i] = 0;
    extents[i] = new_out_dims[i];
  }

  for (size_t i = 0; i < axes.size(); ++i) {
43
    int start = starts[i];
44 45 46 47 48 49
    if (start < 0) {
      start = (start + in_dims[axes[i]]);
    }
    start = std::max(start, 0);
    offsets[axes[i]] = start;
  }
50

51 52 53 54 55 56
  auto in_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *input);
  auto out_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *out, new_out_dims);
57 58 59

  phi::funcs::EigenSlice<std::decay_t<decltype(place)>, T, D>::Eval(
      place, out_t, in_t, offsets, extents);
60 61
}

62 63
static inline double ToMegaBytes(size_t bytes) {
  return static_cast<double>(bytes) / (1 << 20);
64 65
}

66 67
static inline bool UseFixedWorkspace() {
  return FLAGS_conv_workspace_size_limit >= 0;
68 69
}

70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
static size_t CaclWorkspaceLimitInBytes(const phi::GPUContext& ctx) {
  if (!UseFixedWorkspace()) {
    int device_id = platform::GetCurrentDeviceId();
    int64_t allocated = memory::StatGetCurrentValue("Allocated", device_id);
    int64_t reserved = memory::StatGetCurrentValue("Reserved", device_id);
    int64_t availble = platform::GpuAvailableMemToAlloc();
    int64_t cur_workspace_size = ctx.cudnn_workspace_handle().WorkspaceSize();
    VLOG(3) << "[memory] allocated=" << ToMegaBytes(allocated)
            << " MB, reserved=" << ToMegaBytes(reserved)
            << " MB, available_to_alloc=" << ToMegaBytes(availble)
            << " MB, current_workspace_size=" << ToMegaBytes(cur_workspace_size)
            << " MB.";
    return std::max(std::max(availble, cur_workspace_size),
                    reserved - allocated);
  } else {
    return FLAGS_conv_workspace_size_limit * 1024 * 1024;
86 87 88
  }
}

89 90 91 92 93 94 95 96 97 98 99 100 101 102
template <typename PerfT>
std::string GetPerfResultString(std::string prefix,
                                const std::vector<PerfT>& perf_results,
                                int actual_algo_count, size_t workspace_limit) {
  std::ostringstream out;
  out << prefix << " (workspace limit=" << ToMegaBytes(workspace_limit)
      << " MB):\n";
  for (int i = 0; i < actual_algo_count; ++i) {
    const auto& result = perf_results[i];
    auto math_type_str = (result.mathType == CUDNN_TENSOR_OP_MATH) ? "T" : "F";
    out << "  algo=" << result.algo << ": tensor_core=" << math_type_str
        << ", time=" << result.time
        << " ms, memory=" << ToMegaBytes(result.memory)
        << " MB, status=" << result.status << "\n";
103
  }
104 105
  return out.str();
}
106

107 108 109 110 111 112
template <typename PerfT, typename AlgoT>
void ChooseAlgoByWorkspace(const std::vector<PerfT>& perf_results,
                           size_t workspace_limit,
                           SearchResult<AlgoT>* algo_result) {
  for (size_t i = 0; i < perf_results.size(); ++i) {
    auto result = perf_results[i];
113
    if (result.status == CUDNN_STATUS_SUCCESS &&
114 115 116 117 118 119 120 121 122
        result.memory < workspace_limit) {
      algo_result->algo = result.algo;
      algo_result->time = result.time;
      algo_result->workspace_size = result.memory;
      VLOG(3) << "  algo=" << result.algo << ", time=" << result.time
              << " ms, memory=" << ToMegaBytes(result.memory)
              << " MB (limit=" << ToMegaBytes(workspace_limit)
              << " MB), status=" << result.status;
      return;
123 124
    }
  }
125 126
  VLOG(3) << "Can not find an algorithm that requires memory < "
          << ToMegaBytes(workspace_limit) << " MB";
127 128
}

H
hong 已提交
129
static void SetConvMathType(const phi::GPUContext& ctx, cudnnDataType_t dtype,
130 131
                            const platform::ConvolutionDescriptor& cdesc) {
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
132
  if (ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
133
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
134 135 136 137
        cdesc.desc(), CUDNN_TENSOR_OP_MATH));
    VLOG(5) << "use cudnn_tensor_op_math";
#if CUDA_VERSION >= 11000
#if CUDNN_VERSION_MIN(8, 1, 0)
138
  } else if (ctx.GetComputeCapability() >= 80 && dtype == CUDNN_DATA_BFLOAT16) {
139
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
140 141 142
        cdesc.desc(), CUDNN_TENSOR_OP_MATH));
#endif  // CUDNN_VERSION_MIN(8, 1, 0)
  } else if (dtype == CUDNN_DATA_FLOAT && !cdesc.allow_tf32_) {
143
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
144 145 146
        cdesc.desc(), CUDNN_FMA_MATH));
#endif  // CUDA_VERSION >= 11000
  } else {
147
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
148 149 150 151 152 153
        cdesc.desc(), CUDNN_DEFAULT_MATH));
    VLOG(5) << "NOT use cudnn_tensor_op_math";
  }
#endif
}

Q
qingqing01 已提交
154 155
template <>
struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
156 157
  using PerfT = cudnnConvolutionFwdAlgoPerf_t;
  using AlgoT = cudnnConvolutionFwdAlgo_t;
Q
qingqing01 已提交
158 159

  template <typename T>
160 161 162 163
  static SearchResult<AlgoT> Find(const ConvArgs& args, bool exhaustive_search,
                                  bool deterministic,
                                  const phi::GPUContext& ctx) {
    SearchResult<AlgoT> result;
Q
qingqing01 已提交
164
    auto dtype = platform::CudnnDataType<T>::type;
165
    size_t workspace_size_limit = CaclWorkspaceLimitInBytes(ctx);
166
    SetConvMathType(ctx, dtype, args.cdesc);
167

168
    if (!exhaustive_search && !deterministic) {
169
#if CUDNN_VERSION >= 7001
170
      int actual_perf_count;
171
      int best_algo_idx = 0;
172
      std::vector<PerfT> perf_results(kNUM_CUDNN_FWD_ALGS);
173
      PADDLE_ENFORCE_GPU_SUCCESS(
174 175 176
          platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
              args.handle, args.idesc.desc(), args.wdesc.desc(),
              args.cdesc.desc(), args.odesc.desc(), kNUM_CUDNN_FWD_ALGS,
177 178 179
              &actual_perf_count, perf_results.data()));
      result.algo = perf_results[best_algo_idx].algo;
      result.workspace_size = perf_results[best_algo_idx].memory;
180

181
      if (result.workspace_size > workspace_size_limit) {
182
#if CUDNN_VERSION >= 8000
183
        // cudnnGetConvolutionForwardAlgorithm is removed in CUDNN-8
184 185
        ChooseAlgoByWorkspace<PerfT, AlgoT>(perf_results, workspace_size_limit,
                                            &result);
186
#else
187 188 189
        VLOG(3) << "Fallback to non-v7 method to find conv algorithm "
                   "becasue the workspace size request("
                << result.workspace_size << ") exceeds the limit("
190
                << workspace_size_limit << ")";
191
        PADDLE_ENFORCE_GPU_SUCCESS(
192 193 194 195
            platform::dynload::cudnnGetConvolutionForwardAlgorithm(
                args.handle, args.idesc.desc(), args.wdesc.desc(),
                args.cdesc.desc(), args.odesc.desc(),
                CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
196
                workspace_size_limit, &(result.algo)));
197
#endif
198 199
      }
#else
200
      PADDLE_ENFORCE_GPU_SUCCESS(
201 202 203 204
          platform::dynload::cudnnGetConvolutionForwardAlgorithm(
              args.handle, args.idesc.desc(), args.wdesc.desc(),
              args.cdesc.desc(), args.odesc.desc(),
              CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
205
              workspace_size_limit, &(result.algo)));
206
#endif
207
    } else if (deterministic) {
208
      result.algo = static_cast<AlgoT>(1);
Q
qingqing01 已提交
209
    } else {
210
      auto workspace_handle = ctx.cudnn_workspace_handle();
211 212
      auto x_dims = phi::vectorize(args.x->dims());
      auto w_dims = phi::vectorize(args.w->dims());
213 214 215
      VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
               << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
               << args.s << ", args.p" << args.p << ", args.d" << args.d;
216

217 218 219 220
      AlgorithmsCache<AlgoT>& algo_cache =
          *(framework::ConvSearchCache::Instance().GetForward());

      result.algo = algo_cache.GetAlgorithm(
221 222
          x_dims, w_dims, args.s, args.p, args.d, 0,
          static_cast<int64_t>(args.cudnn_dtype), [&]() {
Q
qingqing01 已提交
223
            int returned_algo_count;
224 225 226 227 228
            std::vector<PerfT> perf_results(kNUM_CUDNN_FWD_ALGS);
            size_t max_workspace_size =
                FindMaxWorkspaceSize(args, workspace_size_limit);
            VLOG(4) << "max_workspace_size=" << ToMegaBytes(max_workspace_size)
                    << " MB";
Q
qingqing01 已提交
229 230

            auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
231
              PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
232 233 234 235 236
                  platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
                      args.handle, args.idesc.desc(), args.x->data<T>(),
                      args.wdesc.desc(), args.w->data<T>(), args.cdesc.desc(),
                      args.odesc.desc(), const_cast<T*>(args.o->data<T>()),
                      kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
237 238
                      perf_results.data(), cudnn_workspace_ptr,
                      max_workspace_size));
Q
qingqing01 已提交
239
            };
240 241 242 243 244 245 246 247
            workspace_handle.RunFuncSync(cudnn_find_func, max_workspace_size,
                                         UseFixedWorkspace());

            VLOG(4) << GetPerfResultString<PerfT>(
                "[Exhaustive Search] FwdAlgo Perf result", perf_results,
                returned_algo_count, workspace_size_limit);
            result.time = perf_results[0].time;
            return perf_results[0].algo;
Q
qingqing01 已提交
248 249
          });
    }
250 251 252 253 254
    VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search
            << ", deterministic=" << deterministic
            << ", choose algo=" << result.algo << ", workspace="
            << ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB";
    return result;
Q
qingqing01 已提交
255 256
  }

257 258
  static size_t GetWorkspaceSize(const ConvArgs& args,
                                 cudnnConvolutionFwdAlgo_t algo) {
Q
qingqing01 已提交
259
    size_t workspace_size = 0;
260
    PADDLE_ENFORCE_GPU_SUCCESS(
261 262 263
        platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
            args.handle, args.idesc.desc(), args.wdesc.desc(),
            args.cdesc.desc(), args.odesc.desc(), algo, &workspace_size));
Q
qingqing01 已提交
264 265
    return workspace_size;
  }
266 267 268 269 270 271 272 273 274 275 276 277 278

 private:
  static size_t FindMaxWorkspaceSize(const ConvArgs& args,
                                     size_t workspace_size_limit) {
    if (!UseFixedWorkspace()) {
      size_t max_workspace_size = 0;
      for (size_t algo = 0; algo < kNUM_CUDNN_FWD_ALGS; ++algo) {
        size_t workspace_size = 0;
        auto status =
            platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
                args.handle, args.idesc.desc(), args.wdesc.desc(),
                args.cdesc.desc(), args.odesc.desc(),
                static_cast<cudnnConvolutionFwdAlgo_t>(algo), &workspace_size);
279 280
        if (status == CUDNN_STATUS_SUCCESS &&
            workspace_size <= workspace_size_limit) {
281 282 283
          max_workspace_size = std::max(workspace_size, max_workspace_size);
        }
      }
284
      return max_workspace_size;
285 286 287 288
    } else {
      return workspace_size_limit;
    }
  }
Q
qingqing01 已提交
289 290 291 292
};

template <>
struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
293 294
  using PerfT = cudnnConvolutionBwdDataAlgoPerf_t;
  using AlgoT = cudnnConvolutionBwdDataAlgo_t;
Q
qingqing01 已提交
295 296

  template <typename T>
297 298 299 300
  static SearchResult<AlgoT> Find(const ConvArgs& args, bool exhaustive_search,
                                  bool deterministic,
                                  const phi::GPUContext& ctx) {
    SearchResult<AlgoT> result;
Q
qingqing01 已提交
301
    auto dtype = platform::CudnnDataType<T>::type;
302
    size_t workspace_size_limit = CaclWorkspaceLimitInBytes(ctx);
303
    SetConvMathType(ctx, dtype, args.cdesc);
304

305
    if (!exhaustive_search && !deterministic) {
306
#if CUDNN_VERSION >= 7001
307
      int actual_perf_count;
308
      int best_algo_idx = 0;
309
      std::vector<PerfT> perf_results(kNUM_CUDNN_BWD_DATA_ALGS);
310
      PADDLE_ENFORCE_GPU_SUCCESS(
311 312 313
          platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7(
              args.handle, args.wdesc.desc(), args.odesc.desc(),
              args.cdesc.desc(), args.idesc.desc(), kNUM_CUDNN_BWD_DATA_ALGS,
314 315
              &actual_perf_count, perf_results.data()));
      result.algo = perf_results[best_algo_idx].algo;
316 317 318 319 320

#if CUDNN_VERSION < 7500
      int stride_dim = args.x->dims().size() - 2;
      bool blacklist = std::any_of(args.s.begin(), args.s.begin() + stride_dim,
                                   [=](int n) { return n != 1; });
321
      if (blacklist && (perf_results[best_algo_idx].algo ==
322
                            CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
323
                        perf_results[best_algo_idx].algo ==
324
                            CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
325
        result.algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
326 327
      }
#endif
328 329
      result.workspace_size = GetWorkspaceSize(args, result.algo);
      if (result.workspace_size > workspace_size_limit) {
330
#if CUDNN_VERSION >= 8000
331
        // cudnnGetConvolutionBackwardDataAlgorithm is removed in CUDNN-8
332 333
        ChooseAlgoByWorkspace<PerfT, AlgoT>(perf_results, workspace_size_limit,
                                            &result);
334 335 336
#else
        VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
                   "the workspace size request("
337
                << result.workspace_size << ") exceeds the limit("
338
                << workspace_size_limit << ")";
339
        PADDLE_ENFORCE_GPU_SUCCESS(
340 341 342 343
            platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
                args.handle, args.wdesc.desc(), args.odesc.desc(),
                args.cdesc.desc(), args.idesc.desc(),
                CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
344
                workspace_size_limit, &(result.algo)));
345
#endif
346 347
      }
#else
348
      PADDLE_ENFORCE_GPU_SUCCESS(
349 350 351 352
          platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
              args.handle, args.wdesc.desc(), args.odesc.desc(),
              args.cdesc.desc(), args.idesc.desc(),
              CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
353
              workspace_size_limit, &(result.algo)));
354
#endif
Q
qingqing01 已提交
355
    } else if (deterministic) {
356
      result.algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
Q
qingqing01 已提交
357
    } else {
358
      auto workspace_handle = ctx.cudnn_workspace_handle();
359 360
      auto x_dims = phi::vectorize(args.x->dims());
      auto w_dims = phi::vectorize(args.w->dims());
361 362 363
      VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t"
               << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
               << args.s << ", args.p" << args.p << ", args.d" << args.d;
364

365 366 367
      AlgorithmsCache<AlgoT>& algo_cache =
          *(framework::ConvSearchCache::Instance().GetBackwardData());
      result.algo = algo_cache.GetAlgorithm(
368 369
          x_dims, w_dims, args.s, args.p, args.d, 0,
          static_cast<int64_t>(args.cudnn_dtype), [&]() {
Q
qingqing01 已提交
370
            int returned_algo_count;
371 372 373 374 375
            std::vector<PerfT> perf_results(kNUM_CUDNN_BWD_DATA_ALGS);
            size_t max_workspace_size =
                FindMaxWorkspaceSize(args, workspace_size_limit);
            VLOG(3) << "max_workspace_size=" << ToMegaBytes(max_workspace_size)
                    << " MB";
Q
qingqing01 已提交
376 377

            auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
378
              PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
379 380 381 382 383 384 385
                  platform::dynload::
                      cudnnFindConvolutionBackwardDataAlgorithmEx(
                          args.handle, args.wdesc.desc(), args.w->data<T>(),
                          args.odesc.desc(), args.o->data<T>(),
                          args.cdesc.desc(), args.idesc.desc(),
                          const_cast<T*>(args.x->data<T>()),
                          kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count,
386 387
                          perf_results.data(), cudnn_workspace_ptr,
                          max_workspace_size));
Q
qingqing01 已提交
388
            };
389 390 391 392 393 394 395 396
            workspace_handle.RunFuncSync(cudnn_find_func, max_workspace_size,
                                         UseFixedWorkspace());

            VLOG(3) << GetPerfResultString<PerfT>(
                "[Exhaustive Search] BwdDataAlgo Perf result", perf_results,
                returned_algo_count, workspace_size_limit);
            result.time = perf_results[0].time;
            return perf_results[0].algo;
Q
qingqing01 已提交
397 398
          });
    }
399 400 401 402 403
    VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search
            << ", deterministic=" << deterministic
            << ", choose algo=" << result.algo << ", workspace="
            << ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB";
    return result;
Q
qingqing01 已提交
404 405
  }

406 407
  static size_t GetWorkspaceSize(const ConvArgs& args,
                                 cudnnConvolutionBwdDataAlgo_t algo) {
Q
qingqing01 已提交
408
    size_t workspace_size = 0;
409
    PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
410
        platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
411 412
            args.handle, args.wdesc.desc(), args.odesc.desc(),
            args.cdesc.desc(), args.idesc.desc(), algo, &workspace_size));
Q
qingqing01 已提交
413 414
    return workspace_size;
  }
415 416 417 418 419 420 421 422 423 424 425 426 427 428

 private:
  static size_t FindMaxWorkspaceSize(const ConvArgs& args,
                                     size_t workspace_size_limit) {
    if (!UseFixedWorkspace()) {
      size_t max_workspace_size = 0;
      for (size_t algo = 0; algo < kNUM_CUDNN_BWD_DATA_ALGS; ++algo) {
        size_t workspace_size = 0;
        auto status =
            platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
                args.handle, args.wdesc.desc(), args.odesc.desc(),
                args.cdesc.desc(), args.idesc.desc(),
                static_cast<cudnnConvolutionBwdDataAlgo_t>(algo),
                &workspace_size);
429 430
        if (status == CUDNN_STATUS_SUCCESS &&
            workspace_size <= workspace_size_limit) {
431 432 433
          max_workspace_size = std::max(workspace_size, max_workspace_size);
        }
      }
434
      return max_workspace_size;
435 436 437 438
    } else {
      return workspace_size_limit;
    }
  }
Q
qingqing01 已提交
439 440 441 442
};

template <>
struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
443 444
  using PerfT = cudnnConvolutionBwdFilterAlgoPerf_t;
  using AlgoT = cudnnConvolutionBwdFilterAlgo_t;
Q
qingqing01 已提交
445 446

  template <typename T>
447 448 449
  static SearchResult<AlgoT> Find(const ConvArgs& args, bool exhaustive_search,
                                  bool deterministic,
                                  const phi::GPUContext& ctx) {
450
    platform::CUDAGraphCaptureModeGuard guard;
451
    SearchResult<AlgoT> result;
Q
qingqing01 已提交
452
    auto dtype = platform::CudnnDataType<T>::type;
453
    size_t workspace_size_limit = CaclWorkspaceLimitInBytes(ctx);
454
    SetConvMathType(ctx, dtype, args.cdesc);
Q
qingqing01 已提交
455

456
    if (!exhaustive_search && !deterministic) {
457
#if CUDNN_VERSION >= 7001
458
      int actual_perf_count;
459
      int best_algo_idx = 0;
460
      std::vector<PerfT> perf_results(kNUM_CUDNN_BWD_FILTER_ALGS);
461
      PADDLE_ENFORCE_GPU_SUCCESS(
462 463 464
          platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7(
              args.handle, args.idesc.desc(), args.odesc.desc(),
              args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS,
465 466 467
              &actual_perf_count, perf_results.data()));
      result.algo = perf_results[best_algo_idx].algo;
      result.workspace_size = perf_results[best_algo_idx].memory;
468

469
      if (result.workspace_size > workspace_size_limit) {
470 471
#if CUDNN_VERSION >= 8000
        // cudnnGetConvolutionBackwardFilterAlgorithm is removed in CUDNN-8
472 473
        ChooseAlgoByWorkspace<PerfT, AlgoT>(perf_results, workspace_size_limit,
                                            &result);
474 475 476
#else
        VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
                   "the workspace size request("
477
                << result.workspace_size << ") exceeds the limit("
478
                << workspace_size_limit << ")";
479
        PADDLE_ENFORCE_GPU_SUCCESS(
480 481 482 483
            platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
                args.handle, args.idesc.desc(), args.odesc.desc(),
                args.cdesc.desc(), args.wdesc.desc(),
                CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
484
                workspace_size_limit, &(result.algo)));
485
#endif
486 487
      }
#else
488
      PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
489 490 491 492
          platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
              args.handle, args.idesc.desc(), args.odesc.desc(),
              args.cdesc.desc(), args.wdesc.desc(),
              CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
493
              workspace_size_limit, &(result.algo)));
494
#endif
Q
qingqing01 已提交
495
    } else if (deterministic) {
496
      result.algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
Q
qingqing01 已提交
497
    } else {
498
      auto workspace_handle = ctx.cudnn_workspace_handle();
499 500
      auto x_dims = phi::vectorize(args.x->dims());
      auto w_dims = phi::vectorize(args.w->dims());
501 502 503
      VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
               << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
               << args.s << ", args.p" << args.p << ", args.d" << args.d;
504 505 506 507

      AlgorithmsCache<AlgoT>& algo_cache =
          *(framework::ConvSearchCache::Instance().GetBackwardFilter());

508
      if (dtype != CUDNN_DATA_HALF) {
509
        result.algo = algo_cache.GetAlgorithm(
510 511 512
            x_dims, w_dims, args.s, args.p, args.d, 0,
            static_cast<int64_t>(args.cudnn_dtype), [&]() {
              int returned_algo_count;
513 514 515 516 517 518
              std::vector<PerfT> perf_results(kNUM_CUDNN_BWD_FILTER_ALGS);
              size_t max_workspace_size =
                  FindMaxWorkspaceSize(args, workspace_size_limit);
              VLOG(3) << "max_workspace_size="
                      << ToMegaBytes(max_workspace_size) << " MB";

519
              auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
520
                PADDLE_ENFORCE_GPU_SUCCESS(
521 522 523 524 525 526 527
                    platform::dynload::
                        cudnnFindConvolutionBackwardFilterAlgorithmEx(
                            args.handle, args.idesc.desc(), args.x->data<T>(),
                            args.odesc.desc(), args.o->data<T>(),
                            args.cdesc.desc(), args.wdesc.desc(),
                            const_cast<T*>(args.w->data<T>()),
                            kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count,
528 529
                            perf_results.data(), cudnn_workspace_ptr,
                            max_workspace_size));
530
              };
531 532 533 534 535 536 537 538
              workspace_handle.RunFuncSync(cudnn_find_func, max_workspace_size,
                                           UseFixedWorkspace());

              VLOG(3) << GetPerfResultString<PerfT>(
                  "[Exhaustive Search] BwdFilterAlgo Perf result", perf_results,
                  returned_algo_count, workspace_size_limit);
              result.time = perf_results[0].time;
              return perf_results[0].algo;
539 540
            });
      } else {
541
        result.algo = algo_cache.GetAlgorithm(
542 543
            x_dims, w_dims, args.s, args.p, args.d, 0,
            static_cast<int64_t>(args.cudnn_dtype), [&]() {
544
              SearchResult<AlgoT> algo_result;
545
              int actual_algos = 0;
546 547
              std::vector<PerfT> perf_results(kNUM_CUDNN_BWD_FILTER_ALGS);

548
              PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
549
                  platform::dynload::
550 551
                      cudnnFindConvolutionBackwardFilterAlgorithm(
                          args.handle, args.idesc.desc(), args.odesc.desc(),
Q
qingqing01 已提交
552
                          args.cdesc.desc(), args.wdesc.desc(),
553 554 555
                          perf_results.size(), &actual_algos,
                          perf_results.data()));
              perf_results.resize(actual_algos);
556 557 558
              ChooseAlgo(perf_results, workspace_size_limit, &algo_result);
              result.time = algo_result.time;
              return algo_result.algo;
559 560
            });
      }
Q
qingqing01 已提交
561
    }
562 563 564 565 566
    VLOG(3) << "[cuDNN Convoltion] exhaustive_search=" << exhaustive_search
            << ", deterministic=" << deterministic
            << ", choose algo=" << result.algo << ", workspace="
            << ToMegaBytes(GetWorkspaceSize(args, result.algo)) << " MB";
    return result;
Q
qingqing01 已提交
567 568
  }

569 570
  static size_t GetWorkspaceSize(const ConvArgs& args,
                                 cudnnConvolutionBwdFilterAlgo_t algo) {
571
    platform::CUDAGraphCaptureModeGuard guard;
Q
qingqing01 已提交
572
    size_t workspace_size = 0;
573
    PADDLE_ENFORCE_GPU_SUCCESS(
Q
qingqing01 已提交
574 575 576 577 578
        platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
            args.handle, args.idesc.desc(), args.odesc.desc(),
            args.cdesc.desc(), args.wdesc.desc(), algo, &workspace_size));
    return workspace_size;
  }
579 580 581 582 583 584 585 586 587 588 589 590 591 592

 private:
  static size_t FindMaxWorkspaceSize(const ConvArgs& args,
                                     size_t workspace_size_limit) {
    if (!UseFixedWorkspace()) {
      size_t max_workspace_size = 0;
      for (size_t algo = 0; algo < kNUM_CUDNN_BWD_FILTER_ALGS; ++algo) {
        size_t workspace_size = 0;
        auto status =
            platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
                args.handle, args.idesc.desc(), args.odesc.desc(),
                args.cdesc.desc(), args.wdesc.desc(),
                static_cast<cudnnConvolutionBwdFilterAlgo_t>(algo),
                &workspace_size);
593 594
        if (status == CUDNN_STATUS_SUCCESS &&
            workspace_size <= workspace_size_limit) {
595 596 597
          max_workspace_size = std::max(workspace_size, max_workspace_size);
        }
      }
598
      return max_workspace_size;
599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642
    } else {
      return workspace_size_limit;
    }
  }

  static void ChooseAlgo(const std::vector<PerfT>& perf_results,
                         size_t workspace_limit,
                         SearchResult<AlgoT>* algo_result) {
    VLOG(3) << GetPerfResultString<PerfT>(
        "[Exhaustive Search] BwdFilterAlgo Perf result", perf_results,
        perf_results.size(), workspace_limit);

    for (size_t i = 0; i != perf_results.size(); ++i) {
      const auto& result = perf_results[i];
      if (result.status == CUDNN_STATUS_SUCCESS &&
          (result.memory <= workspace_limit)) {
        if ((result.mathType == CUDNN_TENSOR_OP_MATH) &&
            (i != perf_results.size() - 1)) {
          const auto& next_result = perf_results[i + 1];
          if (next_result.status == CUDNN_STATUS_SUCCESS &&
              next_result.algo == result.algo &&
              next_result.memory == result.memory &&
              next_result.mathType != CUDNN_TENSOR_OP_MATH &&
              next_result.time < 1.01 * result.time) {
            // Skip over this result- it's not really a Tensor Core algo.
            // Because it is only 1% performance difference.
            // Prefer to choose the next equivalent non-Tensor Core algo.
            continue;
          }
        }
        algo_result->algo = result.algo;
        algo_result->time = result.time;
        auto math_type_str = "0";
        if (result.mathType == CUDNN_TENSOR_OP_MATH) {
          math_type_str = "1";
        }
        VLOG(3) << "    choose algo: " << result.algo
                << ", TC: " << math_type_str << ", time: " << result.time
                << " ms, wksp = " << result.memory
                << ", status = " << result.status;
        break;
      }
    }
  }
Q
qingqing01 已提交
643 644 645 646
};

}  // namespace operators
}  // namespace paddle