conv_cudnn_helper.h 23.1 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <algorithm>
18
#include <array>
19
#include <memory>
Z
zhangting2020 已提交
20
#include <string>
Q
qingqing01 已提交
21
#include <vector>
22
#include "paddle/fluid/framework/conv_search_cache.h"
Q
qingqing01 已提交
23 24 25 26 27 28
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_desc.h"
namespace paddle {
namespace operators {

29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
using Tensor = framework::Tensor;
using DataLayout = platform::DataLayout;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
using framework::AlgorithmsCache;
static inline void GetNCDHW(const framework::DDim& dims,
                            const DataLayout& layout, int* N, int* C, int* D,
                            int* H, int* W) {
  *N = dims[0];
  *C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1];
  int i = layout == DataLayout::kNCHW ? 0 : 1;
  if (dims.size() == 5) {
    *D = dims[2 - i];
    *H = dims[3 - i];
    *W = dims[4 - i];
  } else {
    *D = 1;
    *H = dims[2 - i];
    *W = dims[3 - i];
  }
}

template <typename DeviceContext, typename T, size_t D>
static void RemovePaddingSlice(const framework::ExecutionContext& context,
                               const Tensor* input, Tensor* out,
                               const std::vector<int>& starts,
                               const std::vector<int>& axes) {
  auto& place =
      *context.template device_context<DeviceContext>().eigen_device();
  auto in_dims = input->dims();
  auto new_out_dims = out->dims();
  auto offsets = Eigen::array<int, D>();
  auto extents = Eigen::array<int, D>();
  for (size_t i = 0; i < D; ++i) {
    offsets[i] = 0;
    extents[i] = new_out_dims[i];
  }

  int start;
  for (size_t i = 0; i < axes.size(); ++i) {
    start = starts[i];
    if (start < 0) {
      start = (start + in_dims[axes[i]]);
    }
    start = std::max(start, 0);
    offsets[axes[i]] = start;
  }
  auto in_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *input);

  auto out_t =
      framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
          *out, new_out_dims);
  out_t.device(place) = in_t.slice(offsets, extents);
}

86 87 88 89 90 91 92 93
template <typename T>
std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
  out << "[";
  for (auto const& tmp : v) out << tmp << ",";
  out << "]";
  return out;
}

Z
zhangting2020 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
inline int MaxBackwardFilterAlgos(cudnnHandle_t cudnn_handle) {
  int max_algos = 0;
#if CUDNN_VERSION_MIN(7, 0, 1)
  PADDLE_ENFORCE_CUDA_SUCCESS(
      platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithmMaxCount(
          cudnn_handle, &max_algos));
#endif
  return max_algos;
}

template <typename PerfType, typename AlgoType>
void AlgoFinalSelect(const std::vector<PerfType>& perf_results,
                     std::string kernel_name, int32_t algo_preference,
                     size_t workspace_byte,
                     cudnnConvolutionBwdFilterAlgo_t* algo,
                     bool deterministic) {
  // Determine the fastest acceptable algo that matches the algo_preference (-1
  // = any),
  // regardless of mathType.

  VLOG(3) << "=========Full results of algo=========" << kernel_name << ":";
  for (const auto& result : perf_results) {
    auto math_type_str = "-";
    if (result.mathType == CUDNN_TENSOR_OP_MATH) {
      math_type_str = "+";
    }

    VLOG(3) << "    algo: " << result.algo << ", TC" << math_type_str
            << ", time: " << result.time << " ms"
            << ", wksp = " << result.memory << ", status = " << result.status;
  }

  for (decltype(perf_results.size()) i = 0; i != perf_results.size(); ++i) {
    const auto& result = perf_results[i];
    bool algo_is_tensor_core = false;
    algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH;
    bool algo_exclusion = 0;
    if (result.status == CUDNN_STATUS_SUCCESS &&
        (!deterministic ||
         result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) &&
        (result.memory <= workspace_byte) &&
        (algo_preference == -1 || algo_preference == result.algo) &&
        !algo_exclusion) {
      if ((result.mathType == CUDNN_TENSOR_OP_MATH) &&
          (i != perf_results.size() - 1)) {
        const auto& next_result = perf_results[i + 1];
        if (next_result.status == CUDNN_STATUS_SUCCESS &&
            next_result.algo == result.algo &&
            next_result.memory == result.memory &&
            next_result.mathType != CUDNN_TENSOR_OP_MATH &&
            next_result.time < 1.01 * result.time) {
          // Skip over this result- it's not really a Tensor Core algo.
          // Prefer instead the next equivalent non-Tensor Core algo.
          continue;
        }
      }
      *algo = result.algo;
      auto math_type_str = "-";
      if (result.mathType == CUDNN_TENSOR_OP_MATH) {
        math_type_str = "+";
      }
      VLOG(3) << "    choose algo: " << result.algo << ", TC" << math_type_str
              << ", time: " << result.time << " ms"
              << ", wksp = " << result.memory << ", status = " << result.status;
      return;
    }
  }
}

163
using framework::ConvSearchCache;
Q
qingqing01 已提交
164 165 166 167 168 169 170

struct ConvArgs {
  cudnnHandle_t handle;
  platform::TensorDescriptor idesc, odesc;
  platform::FilterDescriptor wdesc;
  platform::ConvolutionDescriptor cdesc;
  const framework::Tensor *x, *w, *o;
171
  cudnnDataType_t cudnn_dtype;
Q
qingqing01 已提交
172 173 174 175 176 177 178 179 180 181

  // strides
  std::vector<int> s;
  // paddings
  std::vector<int> p;
  // dilations
  std::vector<int> d;

  ConvArgs(const framework::Tensor* x, const framework::Tensor* w,
           const framework::Tensor* o, const std::vector<int> s,
182 183 184
           const std::vector<int> p, const std::vector<int> d,
           cudnnDataType_t dtype)
      : x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {}
Q
qingqing01 已提交
185 186 187 188 189 190 191 192 193 194 195 196
};

template <typename perf_t>
struct SearchAlgorithm {};

template <>
struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
  using perf_t = cudnnConvolutionFwdAlgoPerf_t;
  using algo_t = cudnnConvolutionFwdAlgo_t;

  template <typename T>
  static algo_t Find(const ConvArgs& args, bool exhaustive_search,
197
                     bool deterministic,
Q
qingqing01 已提交
198 199
                     const framework::ExecutionContext& ctx) {
    auto dtype = platform::CudnnDataType<T>::type;
200
    bool has_got_workspace_size = true;
Q
qingqing01 已提交
201 202
    bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
    size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
203
    size_t workspace_size = 0;
Q
qingqing01 已提交
204
    algo_t algo;
205 206 207 208

#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
209 210 211
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
                                                         CUDNN_TENSOR_OP_MATH));
212 213
      VLOG(5) << "use cudnn_tensor_op_math";
    } else {
214 215 216
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
                                                         CUDNN_DEFAULT_MATH));
217 218 219 220
      VLOG(5) << "NOT use cudnn_tensor_op_math";
    }
#endif

221
    if (!exhaustive && !deterministic) {
222 223 224 225
#if CUDNN_VERSION >= 7001
      int perf_count;
      int best_algo_idx = 0;
      std::unique_ptr<perf_t[]> perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]);
226 227 228 229 230
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
              args.handle, args.idesc.desc(), args.wdesc.desc(),
              args.cdesc.desc(), args.odesc.desc(), kNUM_CUDNN_FWD_ALGS,
              &perf_count, perf_results.get()));
231 232 233 234
      algo = (perf_results.get())[best_algo_idx].algo;
      workspace_size = GetWorkspaceSize(args, algo);

      if (workspace_size > workspace_size_limit) {
235
        workspace_size_limit = workspace_size;
236 237
      }
#else
238 239 240 241 242 243
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnGetConvolutionForwardAlgorithm(
              args.handle, args.idesc.desc(), args.wdesc.desc(),
              args.cdesc.desc(), args.odesc.desc(),
              CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
              workspace_size_limit, &algo));
244
#endif
Q
qingqing01 已提交
245
      VLOG(3) << "choose algo " << algo;
246 247
    } else if (deterministic) {
      algo = static_cast<cudnnConvolutionFwdAlgo_t>(1);
Q
qingqing01 已提交
248 249 250 251 252
    } else {
      auto& dev_ctx =
          ctx.template device_context<platform::CUDADeviceContext>();
      auto workspace_handle = dev_ctx.cudnn_workspace_handle();

253 254
      auto& temp = ctx.cuda_device_context();
      AlgorithmsCache<algo_t>& algo_cache =
255
          *(framework::ConvSearchCache::Instance().GetForward());
256

Q
qingqing01 已提交
257 258 259
      auto x_dims = framework::vectorize(args.x->dims());
      auto w_dims = framework::vectorize(args.w->dims());

260 261 262
      VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
               << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
               << args.s << ", args.p" << args.p << ", args.d" << args.d;
263

Q
qingqing01 已提交
264
      algo = algo_cache.GetAlgorithm(
265 266
          x_dims, w_dims, args.s, args.p, args.d, 0,
          static_cast<int64_t>(args.cudnn_dtype), [&]() {
Q
qingqing01 已提交
267 268 269 270
            int returned_algo_count;
            std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;

            auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
271
              PADDLE_ENFORCE_CUDA_SUCCESS(
Q
qingqing01 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
                  platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
                      args.handle, args.idesc.desc(), args.x->data<T>(),
                      args.wdesc.desc(), args.w->data<T>(), args.cdesc.desc(),
                      args.odesc.desc(), const_cast<T*>(args.o->data<T>()),
                      kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
                      perf_stat.data(), cudnn_workspace_ptr,
                      workspace_size_limit));
            };
            workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);

            VLOG(3) << "FwdAlgo Perf result: (algo: stat, time, memory)";
            for (int i = 0; i < returned_algo_count; ++i) {
              const auto& stat = perf_stat[i];
              VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
                      << " " << stat.memory;
            }
            return perf_stat[0].algo;
          });
    }
    VLOG(3) << "choose algo " << algo;
    return algo;
  }

  static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
    size_t workspace_size = 0;
297 298 299 300
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
            args.handle, args.idesc.desc(), args.wdesc.desc(),
            args.cdesc.desc(), args.odesc.desc(), algo, &workspace_size));
Q
qingqing01 已提交
301 302 303 304 305 306 307 308 309 310 311
    return workspace_size;
  }
};

template <>
struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
  using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
  using algo_t = cudnnConvolutionBwdDataAlgo_t;

  template <typename T>
  static algo_t Find(const ConvArgs& args, bool exhaustive_search,
312
                     bool deterministic,
Q
qingqing01 已提交
313 314 315 316
                     const framework::ExecutionContext& ctx) {
    auto dtype = platform::CudnnDataType<T>::type;
    bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
    size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
317 318
    size_t workspace_size = 0;
    bool has_got_workspace_size = true;
Q
qingqing01 已提交
319
    algo_t algo;
320 321 322 323

#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
324 325 326
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
                                                         CUDNN_TENSOR_OP_MATH));
327 328
      VLOG(5) << "use cudnn_tensor_op_math";
    } else {
329 330 331
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
                                                         CUDNN_DEFAULT_MATH));
332 333 334 335
      VLOG(5) << "NOT use cudnn_tensor_op_math";
    }
#endif

Q
qingqing01 已提交
336
    if (!exhaustive && !deterministic) {
337 338 339 340 341
#if CUDNN_VERSION >= 7001
      int perf_count;
      int best_algo_idx = 0;
      std::unique_ptr<perf_t[]> perf_results(
          new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]);
342
      PADDLE_ENFORCE_CUDA_SUCCESS(
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
          platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7(
              args.handle, args.wdesc.desc(), args.odesc.desc(),
              args.cdesc.desc(), args.idesc.desc(), kNUM_CUDNN_BWD_DATA_ALGS,
              &perf_count, perf_results.get()));
      algo = (perf_results.get())[best_algo_idx].algo;

#if CUDNN_VERSION < 7500
      int stride_dim = args.x->dims().size() - 2;
      bool blacklist = std::any_of(args.s.begin(), args.s.begin() + stride_dim,
                                   [=](int n) { return n != 1; });
      if (blacklist && (static_cast<cudnnConvolutionBwdDataAlgo_t>(
                            perf_results[best_algo_idx].algo) ==
                            CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
                        static_cast<cudnnConvolutionBwdDataAlgo_t>(
                            perf_results[best_algo_idx].algo) ==
                            CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
        algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
      }
#endif
      workspace_size = GetWorkspaceSize(args, algo);
      if (workspace_size > workspace_size_limit) {
364
        workspace_size_limit = workspace_size;
365 366 367
        has_got_workspace_size = false;
      }
#else
368 369 370 371 372 373
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
              args.handle, args.wdesc.desc(), args.odesc.desc(),
              args.cdesc.desc(), args.idesc.desc(),
              CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
              workspace_size_limit, &algo));
374
#endif
Q
qingqing01 已提交
375 376 377 378 379 380 381
    } else if (deterministic) {
      return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
    } else {
      auto& dev_ctx =
          ctx.template device_context<platform::CUDADeviceContext>();
      auto workspace_handle = dev_ctx.cudnn_workspace_handle();

382
      AlgorithmsCache<algo_t>& algo_cache =
383
          *(framework::ConvSearchCache::Instance().GetBackwardData());
384

Q
qingqing01 已提交
385 386 387
      auto x_dims = framework::vectorize(args.x->dims());
      auto w_dims = framework::vectorize(args.w->dims());

388 389 390
      VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t"
               << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
               << args.s << ", args.p" << args.p << ", args.d" << args.d;
391

Q
qingqing01 已提交
392
      algo = algo_cache.GetAlgorithm(
393 394
          x_dims, w_dims, args.s, args.p, args.d, 0,
          static_cast<int64_t>(args.cudnn_dtype), [&]() {
Q
qingqing01 已提交
395 396 397 398
            int returned_algo_count;
            std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;

            auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
399
              PADDLE_ENFORCE_CUDA_SUCCESS(
Q
qingqing01 已提交
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
                  platform::dynload::
                      cudnnFindConvolutionBackwardDataAlgorithmEx(
                          args.handle, args.wdesc.desc(), args.w->data<T>(),
                          args.odesc.desc(), args.o->data<T>(),
                          args.cdesc.desc(), args.idesc.desc(),
                          const_cast<T*>(args.x->data<T>()),
                          kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count,
                          perf_stat.data(), cudnn_workspace_ptr,
                          workspace_size_limit));
            };
            workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);

            VLOG(3) << "BwdDataAlgo Perf result: (algo: stat, time, memory)";
            for (int i = 0; i < returned_algo_count; ++i) {
              const auto& stat = perf_stat[i];
              VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
                      << " " << stat.memory;
            }

            return perf_stat[0].algo;
          });
    }
    VLOG(3) << "choose algo " << algo;
    return algo;
  }

  static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
    size_t workspace_size = 0;
428
    PADDLE_ENFORCE_CUDA_SUCCESS(
Q
qingqing01 已提交
429
        platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
430 431
            args.handle, args.wdesc.desc(), args.odesc.desc(),
            args.cdesc.desc(), args.idesc.desc(), algo, &workspace_size));
Q
qingqing01 已提交
432 433 434 435 436 437 438 439 440 441 442
    return workspace_size;
  }
};

template <>
struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
  using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
  using algo_t = cudnnConvolutionBwdFilterAlgo_t;

  template <typename T>
  static algo_t Find(const ConvArgs& args, bool exhaustive_search,
443
                     bool deterministic,
Q
qingqing01 已提交
444 445
                     const framework::ExecutionContext& ctx) {
    auto dtype = platform::CudnnDataType<T>::type;
446 447
    // bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
    bool exhaustive = exhaustive_search;
Q
qingqing01 已提交
448
    size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
449 450 451 452 453 454
    size_t workspace_size = 0;
    bool has_got_workspace_size = true;

#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
455 456 457
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
                                                         CUDNN_TENSOR_OP_MATH));
458 459
      VLOG(5) << "use cudnn_tensor_op_math";
    } else {
460 461 462
      PADDLE_ENFORCE_CUDA_SUCCESS(
          platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
                                                         CUDNN_DEFAULT_MATH));
463 464 465
      VLOG(5) << "NOT use cudnn_tensor_op_math";
    }
#endif
Q
qingqing01 已提交
466 467 468

    algo_t algo;
    if (!exhaustive && !deterministic) {
469
#if CUDNN_VERSION >= 7001
Z
zhangting2020 已提交
470
      VLOG(3) << "=====Not exhaustive=====";
471 472 473 474 475
      using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
      int perf_count;
      int best_algo_idx = 0;
      std::unique_ptr<perf_t[]> perf_results(
          new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]);
476
      PADDLE_ENFORCE_CUDA_SUCCESS(
477 478 479 480 481 482 483
          platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7(
              args.handle, args.idesc.desc(), args.odesc.desc(),
              args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS,
              &perf_count, perf_results.get()));
      algo = (perf_results.get())[best_algo_idx].algo;
      workspace_size = GetWorkspaceSize(args, algo);
      if (workspace_size > workspace_size_limit) {
484
        workspace_size = workspace_size_limit;
485 486
      }
#else
487
      PADDLE_ENFORCE_CUDA_SUCCESS(
Q
qingqing01 已提交
488 489 490 491 492
          platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
              args.handle, args.idesc.desc(), args.odesc.desc(),
              args.cdesc.desc(), args.wdesc.desc(),
              CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
              workspace_size_limit, &algo));
493
#endif
Q
qingqing01 已提交
494 495 496
    } else if (deterministic) {
      return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
    } else {
Z
zhangting2020 已提交
497
      VLOG(3) << "=======exhaustive=======: " << exhaustive;
Q
qingqing01 已提交
498 499 500
      auto& dev_ctx =
          ctx.template device_context<platform::CUDADeviceContext>();
      auto workspace_handle = dev_ctx.cudnn_workspace_handle();
501
      AlgorithmsCache<algo_t>& algo_cache =
502
          *(framework::ConvSearchCache::Instance().GetBackwardFilter());
Q
qingqing01 已提交
503 504 505 506

      auto x_dims = framework::vectorize(args.x->dims());
      auto w_dims = framework::vectorize(args.w->dims());

507 508 509
      VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
               << ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
               << args.s << ", args.p" << args.p << ", args.d" << args.d;
510
      /*
Q
qingqing01 已提交
511
      algo = algo_cache.GetAlgorithm(
512 513
          x_dims, w_dims, args.s, args.p, args.d, 0,
          static_cast<int64_t>(args.cudnn_dtype), [&]() {
Q
qingqing01 已提交
514 515 516
            int returned_algo_count;
            std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
            auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
517
              PADDLE_ENFORCE_CUDA_SUCCESS(
Q
qingqing01 已提交
518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
                  platform::dynload::
                      cudnnFindConvolutionBackwardFilterAlgorithmEx(
                          args.handle, args.idesc.desc(), args.x->data<T>(),
                          args.odesc.desc(), args.o->data<T>(),
                          args.cdesc.desc(), args.wdesc.desc(),
                          const_cast<T*>(args.w->data<T>()),
                          kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count,
                          perf_stat.data(), cudnn_workspace_ptr,
                          workspace_size_limit));
            };
            workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);

            VLOG(3) << "BwdFilterAlgo Perf result: (algo: stat, time, memory)";
            for (int i = 0; i < returned_algo_count; ++i) {
              const auto& stat = perf_stat[i];
              VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
                      << " " << stat.memory;
            }
            return perf_stat[0].algo;
          });
538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563
      */
      algo = algo_cache.GetAlgorithm(
          x_dims, w_dims, args.s, args.p, args.d, 0,
          static_cast<int64_t>(args.cudnn_dtype), [&]() {
            algo_t sel_algo;
            auto max_bwd_filt_algos = MaxBackwardFilterAlgos(args.handle);
            std::vector<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filt_results(
                max_bwd_filt_algos);
            int actual_bwd_filter_algos = 0;
            PADDLE_ENFORCE_CUDA_SUCCESS(
                platform::dynload::cudnnFindConvolutionBackwardFilterAlgorithm(
                    args.handle, args.idesc.desc(), args.odesc.desc(),
                    args.cdesc.desc(), args.wdesc.desc(),
                    bwd_filt_results.size(), &actual_bwd_filter_algos,
                    bwd_filt_results.data()));
            bwd_filt_results.resize(actual_bwd_filter_algos);
            AlgoFinalSelect<cudnnConvolutionBwdFilterAlgoPerf_t,
                            cudnnConvolutionBwdFilterAlgo_t>(
                bwd_filt_results, "backprop-to-filter", -1,
                workspace_size_limit, &sel_algo, deterministic);
            workspace_size = GetWorkspaceSize(args, sel_algo);
            if (workspace_size > workspace_size_limit) {
              workspace_size = workspace_size_limit;
            }
            return sel_algo;
          });
Q
qingqing01 已提交
564
    }
565

Q
qingqing01 已提交
566 567 568 569 570 571
    VLOG(3) << "choose algo " << algo;
    return algo;
  }

  static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
    size_t workspace_size = 0;
572
    PADDLE_ENFORCE_CUDA_SUCCESS(
Q
qingqing01 已提交
573 574 575 576 577 578 579 580 581
        platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
            args.handle, args.idesc.desc(), args.odesc.desc(),
            args.cdesc.desc(), args.wdesc.desc(), algo, &workspace_size));
    return workspace_size;
  }
};

}  // namespace operators
}  // namespace paddle