conv_cudnn_op.cu.cc 33.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
武毅 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
武毅 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
武毅 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
武毅 已提交
14

Y
Yi Wang 已提交
15 16 17
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
Q
qingqing01 已提交
18
#include "paddle/fluid/operators/conv_cudnn_helper.h"
19
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
Y
Yi Wang 已提交
20 21 22
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cudnn_helper.h"
23
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
K
Kexin Zhao 已提交
24
#include "paddle/fluid/platform/float16.h"
25
#include "paddle/fluid/platform/profiler.h"
武毅 已提交
26

Y
Yu Yang 已提交
27
DEFINE_bool(cudnn_deterministic, false,
C
chengduoZH 已提交
28 29
            "Whether allow using an autotuning algorithm for convolution "
            "operator. The autotuning algorithm may be non-deterministic. If "
Y
Yu Yang 已提交
30
            "true, the algorithm is deterministic.");
31 32
DEFINE_uint64(conv_workspace_size_limit,
              paddle::platform::kDefaultConvWorkspaceSizeLimitMB,
33 34 35 36
              "cuDNN convolution workspace limit in MB unit.");
DEFINE_bool(cudnn_exhaustive_search, false,
            "Whether enable exhaustive search for cuDNN convolution or "
            "not, defalut is False.");
C
chengduoZH 已提交
37

武毅 已提交
38 39 40 41 42 43 44 45
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout;
K
update  
Kexin Zhao 已提交
46 47
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
48
using framework::AlgorithmsCache;
武毅 已提交
49

Q
qingqing01 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
static inline void GetNCDHW(const framework::DDim& dims,
                            const DataLayout& layout, int* N, int* C, int* D,
                            int* H, int* W) {
  *N = dims[0];
  *C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1];
  int i = layout == DataLayout::kNCHW ? 0 : 1;
  if (dims.size() == 5) {
    *D = dims[2 - i];
    *H = dims[3 - i];
    *W = dims[4 - i];
  } else {
    *D = 1;
    *H = dims[2 - i];
    *W = dims[3 - i];
  }
}

武毅 已提交
67
template <typename T>
68
class CUDNNConvOpKernel : public framework::OpKernel<T> {
武毅 已提交
69 70
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
71
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
武毅 已提交
72
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
D
dzhwinter 已提交
73
                   "It must use CUDAPlace.");
武毅 已提交
74 75 76 77 78 79 80 81
    auto* input = ctx.Input<Tensor>("Input");
    auto* filter = ctx.Input<Tensor>("Filter");
    auto* output = ctx.Output<Tensor>("Output");

    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
    int groups = ctx.Attr<int>("groups");
Q
qiaolongfei 已提交
82 83
    int64_t user_workspace_size =
        static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
84 85
    bool exhaustive_search =
        FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
武毅 已提交
86 87 88 89 90 91 92 93 94 95 96

    const T* input_data = input->data<T>();
    const T* filter_data = filter->data<T>();
    T* output_data = output->mutable_data<T>(ctx.GetPlace());

    // ------------------- cudnn descriptors ---------------------
    ScopedTensorDescriptor input_desc;
    ScopedTensorDescriptor output_desc;
    ScopedFilterDescriptor filter_desc;
    ScopedConvolutionDescriptor conv_desc;
    DataLayout layout = DataLayout::kNCHW;
武毅 已提交
97 98 99 100 101 102 103
    if (input->dims().size() == 5) {
      layout = DataLayout::kNCDHW;
    }

    cudnnConvolutionDescriptor_t cudnn_conv_desc =
        conv_desc.descriptor<T>(paddings, strides, dilations);

武毅 已提交
104
#if CUDNN_VERSION_MIN(7, 0, 1)
武毅 已提交
105 106 107
    // cudnn 7 can support groups, no need to do it mannually
    // FIXME(typhoonzero): find a better way to disable groups
    // rather than setting it to 1.
W
Wu Yi 已提交
108
    CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
武毅 已提交
109 110 111
        cudnn_conv_desc, groups));
    groups = 1;
#endif
武毅 已提交
112

C
chengduoZH 已提交
113 114 115 116 117 118
    cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
        layout, framework::vectorize2int(input->dims()), groups);
    cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
        layout, framework::vectorize2int(output->dims()), groups);
    cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
        layout, framework::vectorize2int(filter->dims()), groups);
武毅 已提交
119

Q
qingqing01 已提交
120 121 122 123
    int i_n, i_c, i_d, i_h, i_w;
    GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
    int o_n, o_c, o_d, o_h, o_w;
    GetNCDHW(output->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w);
武毅 已提交
124

Q
qingqing01 已提交
125 126
    int group_offset_in = i_c / groups * i_h * i_w * i_d;
    int group_offset_out = o_c / groups * o_h * o_w * o_d;
武毅 已提交
127 128 129
    int group_offset_filter = filter->numel() / groups;
    // ------------------- cudnn conv workspace ---------------------
    size_t workspace_size_in_bytes;  // final workspace to allocate.
130
    size_t workspace_size_limit = 0;
131 132
    if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
      int64_t max_user_size =
133
          std::min(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
134 135
                   user_workspace_size);
      workspace_size_limit = max_user_size * 1024 * 1024;
武毅 已提交
136
    }
137

武毅 已提交
138 139
    // ------------------- cudnn conv algorithm ---------------------
    cudnnConvolutionFwdAlgo_t algo;
140
    bool half_float = false;
141

142 143 144 145 146 147 148 149 150
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
    // Tensor core is supported since the volta GPU and
    // is only enabled when input and filter data are float16
    if (dev_ctx.GetComputeCapability() >= 70 &&
        std::type_index(typeid(T)) ==
            std::type_index(typeid(platform::float16))) {
      CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
          cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
      // Currently tensor core is only enabled using this algo
K
Kexin Zhao 已提交
151
      algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
152
      half_float = true;
M
minqiyang 已提交
153
      VLOG(5) << "use cudnn_tensor_op_math";
K
Kexin Zhao 已提交
154
    } else {
155 156
      CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
          cudnn_conv_desc, CUDNN_DEFAULT_MATH));
M
minqiyang 已提交
157
      VLOG(5) << "NOT use cudnn_tensor_op_math";
K
Kexin Zhao 已提交
158
    }
159
#endif
K
Kexin Zhao 已提交
160

161 162
    auto handle = dev_ctx.cudnn_handle();
    auto workspace_handle = dev_ctx.cudnn_workspace_handle();
163 164
    auto x_dims = framework::vectorize(input->dims());
    auto f_dims = framework::vectorize(filter->dims());
Q
qingqing01 已提交
165 166 167

    // TODO(dangqingqing) simplify the following code by SearchAlgorithm in
    // conv_cudnn_helper.h
168
    if ((!exhaustive_search) && (!half_float)) {
169 170 171 172 173 174 175 176 177 178 179
#if CUDNN_VERSION >= 7001
      using perf_t = cudnnConvolutionFwdAlgoPerf_t;
      int perf_count;
      int best_algo_idx = 0;
      std::unique_ptr<perf_t[]> perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]);
      CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
          handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
          cudnn_output_desc, kNUM_CUDNN_FWD_ALGS, &perf_count,
          perf_results.get()));
      algo = (perf_results.get())[best_algo_idx].algo;
#else
180 181 182 183
      CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
          handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
          cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
          workspace_size_limit, &algo));
184 185
#endif

186 187
      VLOG(3) << "cuDNN forward algo " << algo;
    } else if (exhaustive_search && (!half_float)) {
188 189
      AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
          ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(0);
190

191
      algo = algo_cache.GetAlgorithm(
192 193 194 195
          x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
            int returned_algo_count;
            std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
                fwd_perf_stat;
196

197 198 199 200 201 202 203 204 205 206
            auto cudnn_find_func = [&](void* cudnn_workspace) {
              CUDNN_ENFORCE(
                  platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
                      handle, cudnn_input_desc, input_data, cudnn_filter_desc,
                      filter_data, cudnn_conv_desc, cudnn_output_desc,
                      output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
                      fwd_perf_stat.data(), cudnn_workspace,
                      workspace_size_limit));
            };
            workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221

            VLOG(3) << "Perf result: (algo: stat, time, memory)";
            for (int i = 0; i < returned_algo_count; ++i) {
              const auto& stat = fwd_perf_stat[i];
              VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
                      << " " << stat.memory;
            }
            return fwd_perf_stat[0].algo;
          });
      VLOG(3) << "choose algo " << algo;
    } else {
      PADDLE_ENFORCE(half_float,
                     "cuDNN exhaustive search doesn't support half float.");
    }

武毅 已提交
222
    // get workspace size able to allocate
W
Wu Yi 已提交
223
    CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
武毅 已提交
224 225
        handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
        cudnn_output_desc, algo, &workspace_size_in_bytes));
K
Kexin Zhao 已提交
226 227 228 229 230
    // It is possible for float16 on Volta GPU to allocate more memory than
    // the limit because the algo is overrided to use tensor core.
    PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
                      "workspace_size to be allocated exceeds the limit");

231
    // Allocate on GPU memory
232 233 234 235 236 237 238
    Tensor cudnn_workspace =
        ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
            framework::make_ddim(
                {static_cast<int64_t>(workspace_size_in_bytes)}),
            dev_ctx);
    void* cudnn_workspace_ptr =
        static_cast<void*>(cudnn_workspace.data<int8_t>());
Z
Zeng Jinle 已提交
239 240 241
    VLOG(2) << "Cudnn workspace size fwd: "
            << static_cast<double>(workspace_size_in_bytes) / (1 << 20)
            << " MB";
武毅 已提交
242
    // ------------------- cudnn conv forward ---------------------
K
update  
Kexin Zhao 已提交
243
    ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
武毅 已提交
244
    for (int i = 0; i < groups; i++) {
245 246 247 248 249
      CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
          handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
          cudnn_filter_desc, filter_data + i * group_offset_filter,
          cudnn_conv_desc, algo, cudnn_workspace_ptr, workspace_size_in_bytes,
          &beta, cudnn_output_desc, output_data + i * group_offset_out));
武毅 已提交
250 251 252 253 254
    }
  }
};

template <typename T>
255
class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
武毅 已提交
256 257
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
258
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
武毅 已提交
259
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
D
dzhwinter 已提交
260
                   "It must use CUDAPlace.");
武毅 已提交
261 262 263 264 265 266 267 268 269 270 271 272 273 274
    auto input = ctx.Input<Tensor>("Input");
    auto filter = ctx.Input<Tensor>("Filter");
    auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
    auto input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
    auto filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));

    const T* input_data = input->data<T>();
    const T* output_grad_data = output_grad->data<T>();
    const T* filter_data = filter->data<T>();

    std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
    std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
    int groups = ctx.Attr<int>("groups");
Q
qiaolongfei 已提交
275 276
    int64_t user_workspace_size =
        static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
277 278 279 280 281 282 283
    bool exhaustive_search =
        FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
    if (exhaustive_search && FLAGS_cudnn_deterministic) {
      PADDLE_THROW(
          "Cann't set exhaustive_search True and "
          "FLAGS_cudnn_deterministic True at same time.");
    }
武毅 已提交
284 285 286 287 288 289 290 291 292

    // ------------------- cudnn descriptors ---------------------
    ScopedTensorDescriptor input_desc;
    ScopedTensorDescriptor output_grad_desc;

    ScopedFilterDescriptor filter_desc;
    ScopedFilterDescriptor filter_grad_desc;
    ScopedConvolutionDescriptor conv_desc;
    DataLayout layout = DataLayout::kNCHW;
武毅 已提交
293 294 295 296 297 298 299
    if (input->dims().size() == 5) {
      layout = DataLayout::kNCDHW;
    }

    cudnnConvolutionDescriptor_t cudnn_conv_desc =
        conv_desc.descriptor<T>(paddings, strides, dilations);

武毅 已提交
300
#if CUDNN_VERSION_MIN(7, 0, 1)
武毅 已提交
301 302 303
    // cudnn 7 can support groups, no need to do it mannually
    // FIXME(typhoonzero): find a better way to disable groups
    // rather than setting it to 1.
W
Wu Yi 已提交
304
    CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
武毅 已提交
305 306 307
        cudnn_conv_desc, groups));
    groups = 1;
#endif
武毅 已提交
308

C
chengduoZH 已提交
309 310
    cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
        layout, framework::vectorize2int(input->dims()), groups);
武毅 已提交
311
    cudnnTensorDescriptor_t cudnn_output_grad_desc =
C
chengduoZH 已提交
312 313 314 315
        output_grad_desc.descriptor<T>(
            layout, framework::vectorize2int(output_grad->dims()), groups);
    cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
        layout, framework::vectorize2int(filter->dims()), groups);
武毅 已提交
316

317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
    // Enable Tensor Core for cudnn backward
    if (dev_ctx.GetComputeCapability() >= 70 &&
        std::type_index(typeid(T)) ==
            std::type_index(typeid(platform::float16))) {
      CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
          cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
      VLOG(5) << "use cudnn_tensor_op_math for backward";
    } else {
      CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
          cudnn_conv_desc, CUDNN_DEFAULT_MATH));
      VLOG(5) << "NOT use cudnn_tensor_op_math for backward";
    }
#endif

Q
qingqing01 已提交
332 333 334 335 336
    int i_n, i_c, i_d, i_h, i_w;
    GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
    int o_n, o_c, o_d, o_h, o_w;
    GetNCDHW(output_grad->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h,
             &o_w);
武毅 已提交
337

Q
qingqing01 已提交
338 339
    int group_offset_in = i_c / groups * i_h * i_w * i_d;
    int group_offset_out = o_c / groups * o_h * o_w * o_d;
武毅 已提交
340 341 342 343 344
    int group_offset_filter = filter->numel() / groups;
    // ------------------- cudnn backward algorithm ---------------------
    cudnnConvolutionBwdDataAlgo_t data_algo;
    cudnnConvolutionBwdFilterAlgo_t filter_algo;
    size_t workspace_size_in_bytes = 0, tmp_size = 0;
345
    size_t workspace_size_limit = 0;
346 347
    if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
      int64_t max_user_size =
348
          std::min(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
349 350
                   user_workspace_size);
      workspace_size_limit = max_user_size * 1024 * 1024;
武毅 已提交
351 352
    }

353 354 355 356 357 358 359 360 361 362 363
    Tensor cudnn_workspace;
    void* cudnn_workspace_ptr = nullptr;
    if ((input_data || filter_data) && exhaustive_search) {
      cudnn_workspace =
          ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
              framework::make_ddim(
                  {static_cast<int64_t>(workspace_size_limit)}),
              dev_ctx);
      cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
    }

Q
qingqing01 已提交
364 365
    // TODO(dangqingqing) simplify the following code by SearchAlgorithm in
    // conv_cudnn_helper.h
366 367
    auto x_dims = framework::vectorize(input->dims());
    auto f_dims = framework::vectorize(filter->dims());
Q
QI JUN 已提交
368
    auto handle = dev_ctx.cudnn_handle();
武毅 已提交
369
    if (input_grad) {
370 371
      T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
      if (exhaustive_search) {
372 373 374 375 376
        AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>& data_algo_cache =
            ctx.GetKernelConfig<AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>(
                0);

        data_algo = data_algo_cache.GetAlgorithm(
377 378 379 380 381
            x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
              int returned_algo_count;
              std::array<cudnnConvolutionBwdDataAlgoPerf_t,
                         kNUM_CUDNN_BWD_DATA_ALGS>
                  data_perf_stat;
382 383 384 385 386 387 388 389 390

              CUDNN_ENFORCE(platform::dynload::
                                cudnnFindConvolutionBackwardDataAlgorithmEx(
                                    handle, cudnn_filter_desc, filter_data,
                                    cudnn_output_grad_desc, output_grad_data,
                                    cudnn_conv_desc, cudnn_input_desc,
                                    input_grad_data, kNUM_CUDNN_BWD_DATA_ALGS,
                                    &returned_algo_count, data_perf_stat.data(),
                                    cudnn_workspace_ptr, workspace_size_limit));
391 392 393 394 395 396 397 398 399 400 401 402 403

              VLOG(3) << "Perf result: (algo: stat, time, memory)";
              for (int i = 0; i < returned_algo_count; ++i) {
                const auto& stat = data_perf_stat[i];
                VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
                        << " " << stat.memory;
              }
              return data_perf_stat[0].algo;
            });
        VLOG(3) << "cuDNN backward data algo " << data_algo;
      } else if (FLAGS_cudnn_deterministic) {
        data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
      } else {
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
#if CUDNN_VERSION >= 7001
        using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
        int perf_count;
        int best_algo_idx = 0;
        std::unique_ptr<perf_t[]> perf_results(
            new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]);
        CUDNN_ENFORCE(
            platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7(
                handle, cudnn_filter_desc,
                // dyDesc: Handle to the previously initialized input
                // differential
                // tensor descriptor.
                cudnn_output_grad_desc, cudnn_conv_desc,
                // dxDesc: Handle to the previously initialized output tensor
                // descriptor.
                cudnn_input_desc, kNUM_CUDNN_BWD_DATA_ALGS, &perf_count,
                perf_results.get()));
        data_algo = (perf_results.get())[best_algo_idx].algo;
        int stride_dim = input->dims().size() - 2;
        bool blacklist =
            std::any_of(strides.begin(), strides.begin() + stride_dim,
                        [=](int n) { return n != 1; });
        if (blacklist && (static_cast<cudnnConvolutionBwdDataAlgo_t>(
                              perf_results[best_algo_idx].algo) ==
                              CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
                          static_cast<cudnnConvolutionBwdDataAlgo_t>(
                              perf_results[best_algo_idx].algo) ==
                              CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
          data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
        }
#else
W
Wu Yi 已提交
435
        CUDNN_ENFORCE(
C
chengduoZH 已提交
436 437 438 439 440 441 442 443 444 445 446
            platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
                handle, cudnn_filter_desc,
                // dyDesc: Handle to the previously initialized input
                // differential
                // tensor descriptor.
                cudnn_output_grad_desc, cudnn_conv_desc,
                // dxDesc: Handle to the previously initialized output tensor
                // descriptor.
                cudnn_input_desc,
                CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
                workspace_size_limit, &data_algo));
447
#endif
C
chengduoZH 已提交
448
      }
W
Wu Yi 已提交
449
      CUDNN_ENFORCE(
武毅 已提交
450 451
          platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
              handle, cudnn_filter_desc, cudnn_output_grad_desc,
武毅 已提交
452
              cudnn_conv_desc, cudnn_input_desc, data_algo, &tmp_size));
武毅 已提交
453 454 455 456
      workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
    }

    if (filter_grad) {
457 458
      T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
      if (exhaustive_search) {
459 460 461 462 463
        AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>& f_algo_cache =
            ctx.GetKernelConfig<
                AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>(1);

        filter_algo = f_algo_cache.GetAlgorithm(
464 465 466 467 468
            x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
              int returned_algo_count;
              std::array<cudnnConvolutionBwdFilterAlgoPerf_t,
                         kNUM_CUDNN_BWD_FILTER_ALGS>
                  filter_perf_stat;
469 470 471 472 473 474 475 476 477 478

              CUDNN_ENFORCE(
                  platform::dynload::
                      cudnnFindConvolutionBackwardFilterAlgorithmEx(
                          handle, cudnn_input_desc, input_data,
                          cudnn_output_grad_desc, output_grad_data,
                          cudnn_conv_desc, cudnn_filter_desc, filter_grad_data,
                          kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count,
                          filter_perf_stat.data(), cudnn_workspace_ptr,
                          workspace_size_limit));
479 480 481 482 483 484
              return filter_perf_stat[0].algo;
            });
        VLOG(3) << "cuDNN backward filter algo " << filter_algo;
      } else if (FLAGS_cudnn_deterministic) {
        filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
      } else {
485 486 487 488 489 490 491 492 493 494 495 496 497 498
#if CUDNN_VERSION >= 7001
        using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
        int perf_count;
        int best_algo_idx = 0;
        std::unique_ptr<perf_t[]> perf_results(
            new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]);

        CUDNN_ENFORCE(
            platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7(
                handle, cudnn_input_desc, cudnn_output_grad_desc,
                cudnn_conv_desc, cudnn_filter_desc, kNUM_CUDNN_BWD_FILTER_ALGS,
                &perf_count, perf_results.get()));
        filter_algo = (perf_results.get())[best_algo_idx].algo;
#else
W
Wu Yi 已提交
499
        CUDNN_ENFORCE(
C
chengduoZH 已提交
500 501 502 503 504
            platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
                handle, cudnn_input_desc, cudnn_output_grad_desc,
                cudnn_conv_desc, cudnn_filter_desc,
                CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
                workspace_size_limit, &filter_algo));
505
#endif
C
chengduoZH 已提交
506
      }
W
Wu Yi 已提交
507
      CUDNN_ENFORCE(
武毅 已提交
508 509 510 511 512
          platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
              handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
              cudnn_filter_desc, filter_algo, &tmp_size));
      workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
    }
513

514 515 516 517 518 519 520 521
    // ------------------- cudnn conv workspace ---------------------
    if (!cudnn_workspace_ptr) {
      cudnn_workspace =
          ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
              framework::make_ddim(
                  {static_cast<int64_t>(workspace_size_in_bytes)}),
              dev_ctx);
      cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
Z
Zeng Jinle 已提交
522 523 524
      VLOG(2) << "Cudnn workspace size bwd: "
              << static_cast<double>(workspace_size_in_bytes) / (1 << 20)
              << " MB";
525 526
    }

武毅 已提交
527
    // ------------------- cudnn conv backward data ---------------------
K
update  
Kexin Zhao 已提交
528
    ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
武毅 已提交
529 530
    if (input_grad) {
      T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
C
chengduoZH 已提交
531 532
      // Because beta is zero, it is unnecessary to reset input_grad.

武毅 已提交
533
      for (int i = 0; i < groups; i++) {
534 535 536 537 538 539
        CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
            handle, &alpha, cudnn_filter_desc,
            filter_data + i * group_offset_filter, cudnn_output_grad_desc,
            output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
            cudnn_workspace_ptr, workspace_size_in_bytes, &beta,
            cudnn_input_desc, input_grad_data + i * group_offset_in));
武毅 已提交
540 541 542 543 544
      }
    }
    // ------------------- cudnn conv backward filter ---------------------
    if (filter_grad) {
      T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
C
chengduoZH 已提交
545
      // Because beta is zero, it is unnecessary to reset filter_grad.
武毅 已提交
546
      for (int i = 0; i < groups; i++) {
547 548 549 550 551 552
        CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
            handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
            cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
            cudnn_conv_desc, filter_algo, cudnn_workspace_ptr,
            workspace_size_in_bytes, &beta, cudnn_filter_desc,
            filter_grad_data + i * group_offset_filter));
武毅 已提交
553 554 555 556 557
      }
    }
  }
};

Q
qingqing01 已提交
558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763
/*
 * Inputs:  I, W, dO, ddI, ddW
 * Outputs: ddO, dW, dI
 * ddo = conv(ddI, W) + conv(I, ddW)
 * dW = conv_bp_filter(ddI, dO)
 * dI = conv_bp_data(ddW, dO)
 */
template <typename T>
class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                   "It must use CUDAPlace.");
    auto X = ctx.Input<Tensor>("Input");
    auto W = ctx.Input<Tensor>("Filter");
    auto dO = ctx.Input<Tensor>("DOutput");
    auto ddX = ctx.Input<Tensor>("DDInput");
    auto ddW = ctx.Input<Tensor>("DDFilter");

    auto ddO = ctx.Output<Tensor>("DDOutput");
    auto dW = ctx.Output<Tensor>("DFilter");
    auto dX = ctx.Output<Tensor>("DInput");

    const T* x = X->data<T>();
    const T* dy = dO->data<T>();
    const T* w = W->data<T>();

    const T* ddx = nullptr;
    const T* ddw = nullptr;
    T *dw, *dx, *ddy;
    dw = dx = ddy = nullptr;

    const std::vector<int>& strides = ctx.Attr<std::vector<int>>("strides");
    const std::vector<int>& paddings = ctx.Attr<std::vector<int>>("paddings");
    const std::vector<int>& dilations = ctx.Attr<std::vector<int>>("dilations");
    int groups = ctx.Attr<int>("groups");
    bool exhaustive_search =
        FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
    bool deterministic = FLAGS_cudnn_deterministic;
    if (exhaustive_search && deterministic) {
      PADDLE_THROW(
          "Cann't set exhaustive_search True and "
          "FLAGS_cudnn_deterministic True at same time.");
    }

    int iwo_group = groups;
    int c_group = 1;
#if CUDNN_VERSION_MIN(7, 0, 1)
    iwo_group = 1;
    c_group = groups;
#endif
    auto dtype = platform::CudnnDataType<T>::type;

    auto handle = dev_ctx.cudnn_handle();

    ConvArgs args1{ddX, W, ddO, strides, paddings, dilations};
    ConvArgs args2{X, ddW, ddO, strides, paddings, dilations};
    ConvArgs args3{ddX, dW, dO, strides, paddings, dilations};
    ConvArgs args4{dX, ddW, dO, strides, paddings, dilations};

    cudnnConvolutionFwdAlgo_t fwd_algo1 =
        static_cast<cudnnConvolutionFwdAlgo_t>(0);
    cudnnConvolutionFwdAlgo_t fwd_algo2 =
        static_cast<cudnnConvolutionFwdAlgo_t>(0);
    cudnnConvolutionBwdDataAlgo_t data_algo =
        static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
    cudnnConvolutionBwdFilterAlgo_t filter_algo =
        static_cast<cudnnConvolutionBwdFilterAlgo_t>(0);

    auto layout = GetCudnnTensorFormat(DataLayout::kNCHW);

    // ddo = conv(ddI, W) + conv(I, ddW)
    size_t workspace_size = 0;
    if (ddO) {
      ddy = ddO->mutable_data<T>(ctx.GetPlace());
      args1.handle = handle;
      args1.idesc.set(*ddX, iwo_group);
      args1.wdesc.set(*W, layout, iwo_group);
      args1.odesc.set(*ddO, iwo_group);
      args1.cdesc.set(dtype, paddings, strides, dilations, c_group);

      using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
      fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, 0, ctx);
      workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1);

      if (ddW) {
        ddw = ddW->data<T>();
        args2.handle = handle;
        args2.idesc.set(*X, iwo_group);
        args2.wdesc.set(*ddW, layout, iwo_group);
        args2.odesc.set(*ddO, iwo_group);
        args2.cdesc.set(dtype, paddings, strides, dilations, c_group);

        using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
        fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, 0, ctx);
        workspace_size = std::max(workspace_size,
                                  search2::GetWorkspaceSize(args2, fwd_algo2));
      }
    }

    if (dW) {
      dw = dW->mutable_data<T>(ctx.GetPlace());
      args3.handle = handle;
      args3.idesc.set(*ddX, iwo_group);
      args3.wdesc.set(*dW, layout, iwo_group);
      args3.odesc.set(*dO, iwo_group);
      args3.cdesc.set(dtype, paddings, strides, dilations, c_group);

      using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
      filter_algo =
          search3::Find<T>(args3, exhaustive_search, deterministic, 1, ctx);
      workspace_size = std::max(workspace_size,
                                search3::GetWorkspaceSize(args3, filter_algo));
    }

    if (ddW && dX) {
      dx = dX->mutable_data<T>(ctx.GetPlace());
      args4.handle = handle;
      args4.idesc.set(*dX, iwo_group);
      args4.wdesc.set(*ddW, layout, iwo_group);
      args4.odesc.set(*dO, iwo_group);
      args4.cdesc.set(dtype, paddings, strides, dilations, c_group);

      using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
      data_algo =
          search4::Find<T>(args4, exhaustive_search, deterministic, 2, ctx);
      workspace_size =
          std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo));
    }

    int i_n, i_c, i_d, i_h, i_w;
    GetNCDHW(X->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
    int o_n, o_c, o_d, o_h, o_w;
    GetNCDHW(dO->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w);

    int group_offset_in = i_c / groups * i_h * i_w * i_d;
    int group_offset_out = o_c / groups * o_h * o_w * o_d;
    int group_offset_filter = W->numel() / groups;

    ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
    auto wkspace_handle = dev_ctx.cudnn_workspace_handle();

    if (ddO) {
      ddx = ddX->data<T>();
      for (int i = 0; i < groups; i++) {
        wkspace_handle.RunFunc(
            [&](void* workspace_ptr) {
              CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
                  handle, &alpha, args1.idesc.desc(), ddx + i * group_offset_in,
                  args1.wdesc.desc(), w + i * group_offset_filter,
                  args1.cdesc.desc(), fwd_algo1, workspace_ptr, workspace_size,
                  &beta, args1.odesc.desc(), ddy + i * group_offset_out));
            },
            workspace_size);
      }
      if (ddW) {
        for (int i = 0; i < groups; i++) {
          wkspace_handle.RunFunc(
              [&](void* workspace_ptr) {
                CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
                    handle, &alpha, args2.idesc.desc(), x + i * group_offset_in,
                    args2.wdesc.desc(), ddw + i * group_offset_filter,
                    args2.cdesc.desc(), fwd_algo2, workspace_ptr,
                    workspace_size, &alpha, args2.odesc.desc(),
                    ddy + i * group_offset_out));
              },
              workspace_size);
        }
      }
    }

    if (dW) {
      ddx = ddX->data<T>();
      for (int i = 0; i < groups; i++) {
        wkspace_handle.RunFunc(
            [&](void* workspace_ptr) {
              CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
                  handle, &alpha, args3.idesc.desc(), ddx + i * group_offset_in,
                  args3.odesc.desc(), dy + i * group_offset_out,
                  args3.cdesc.desc(), filter_algo, workspace_ptr,
                  workspace_size, &beta, args3.wdesc.desc(),
                  dw + i * group_offset_filter));
            },
            workspace_size);
      }
    }

    if (dX && ddW) {
      ddw = ddW->data<T>();
      for (int i = 0; i < groups; i++) {
        wkspace_handle.RunFunc(
            [&](void* workspace_ptr) {
              CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
                  handle, &alpha, args4.wdesc.desc(),
                  ddw + i * group_offset_filter, args4.odesc.desc(),
                  dy + i * group_offset_out, args4.cdesc.desc(), data_algo,
                  workspace_ptr, workspace_size, &beta, args4.idesc.desc(),
                  dx + i * group_offset_in));
            },
            workspace_size);
      }
    }
  }
};

武毅 已提交
764 765 766
}  // namespace operators
}  // namespace paddle

K
Kexin Zhao 已提交
767 768
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(conv2d, CUDNN, plat::CUDAPlace,
769
                   paddle::operators::CUDNNConvOpKernel<float>,
K
Kexin Zhao 已提交
770
                   paddle::operators::CUDNNConvOpKernel<double>,
K
Kexin Zhao 已提交
771
                   paddle::operators::CUDNNConvOpKernel<plat::float16>);
K
Kexin Zhao 已提交
772
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
773
                   paddle::operators::CUDNNConvGradOpKernel<float>,
C
chengduo 已提交
774 775
                   paddle::operators::CUDNNConvGradOpKernel<double>,
                   paddle::operators::CUDNNConvGradOpKernel<plat::float16>);
Q
qingqing01 已提交
776 777 778 779 780
REGISTER_OP_KERNEL(
    conv2d_grad_grad, CUDNN, plat::CUDAPlace,
    paddle::operators::CUDNNConvDoubleGradOpKernel<float>,
    paddle::operators::CUDNNConvDoubleGradOpKernel<double>,
    paddle::operators::CUDNNConvDoubleGradOpKernel<plat::float16>);
781

K
Kexin Zhao 已提交
782
REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
783
                   paddle::operators::CUDNNConvOpKernel<float>,
K
Kexin Zhao 已提交
784 785
                   paddle::operators::CUDNNConvOpKernel<double>,
                   paddle::operators::CUDNNConvOpKernel<plat::float16>);
K
Kexin Zhao 已提交
786
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace,
787
                   paddle::operators::CUDNNConvGradOpKernel<float>,
788
                   paddle::operators::CUDNNConvGradOpKernel<double>);