hl_cuda_cudnn.cc 44.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yu Yang 已提交
15
#include "hl_cuda_cudnn.h"
Z
zhangjinchao01 已提交
16
#include <cudnn.h>
L
liaogang 已提交
17
#include <gflags/gflags.h>
Z
zhangjinchao01 已提交
18
#include "hl_cuda_cudnn.ph"
Y
Yu Yang 已提交
19
#include "hl_thread.ph"
L
liaogang 已提交
20
#include "paddle/utils/DynamicLoader.h"
Y
Yu Yang 已提交
21
#include "paddle/utils/Logging.h"
22

23 24 25 26
DEFINE_int32(cudnn_conv_workspace_limit_in_mb,
             4096,
             "Specify cuDNN max workspace limit, in units MB, "
             "4096MB=4GB by default.");
Z
zhangjinchao01 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

namespace dynload {

std::once_flag cudnn_dso_flag;
void* cudnn_dso_handle = nullptr;

/**
 * The following macro definition can generate structs
 * (for each function) to dynamic load cudbnn routine
 * via operator overloading: operator ()
 *
 * note: default dynamic linked libs
 **/

#ifdef PADDLE_USE_DSO

43 44 45 46 47 48 49 50 51
#define DYNAMIC_LOAD_CUDNN_WRAP(__name)                                     \
  struct DynLoad__##__name {                                                \
    template <typename... Args>                                             \
    auto operator()(Args... args) -> decltype(__name(args...)) {            \
      using cudnn_func = decltype(__name(args...)) (*)(Args...);            \
      std::call_once(cudnn_dso_flag, GetCudnnDsoHandle, &cudnn_dso_handle); \
      void* p_##__name = dlsym(cudnn_dso_handle, #__name);                  \
      return reinterpret_cast<cudnn_func>(p_##__name)(args...);             \
    }                                                                       \
Z
zhangjinchao01 已提交
52 53 54 55
  } __name; /* struct DynLoad__##__name */

#else

56 57 58 59 60 61
#define DYNAMIC_LOAD_CUDNN_WRAP(__name)                          \
  struct DynLoad__##__name {                                     \
    template <typename... Args>                                  \
    auto operator()(Args... args) -> decltype(__name(args...)) { \
      return __name(args...);                                    \
    }                                                            \
Z
zhangjinchao01 已提交
62 63 64 65 66 67 68 69
  } __name; /* struct DynLoad__##__name */

#endif

/**
 * include all needed cudnn functions in HPPL
 * different cudnn version has different interfaces
 **/
70
// clang-format off
Z
zhangjinchao01 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
#define CUDNN_DNN_ROUTINE_EACH(__macro)                   \
  __macro(cudnnSetTensor4dDescriptor)                     \
  __macro(cudnnSetTensor4dDescriptorEx)                   \
  __macro(cudnnGetConvolutionNdForwardOutputDim)          \
  __macro(cudnnGetConvolutionForwardAlgorithm)            \
  __macro(cudnnCreateTensorDescriptor)                    \
  __macro(cudnnDestroyTensorDescriptor)                   \
  __macro(cudnnCreateFilterDescriptor)                    \
  __macro(cudnnSetFilter4dDescriptor)                     \
  __macro(cudnnSetPooling2dDescriptor)                    \
  __macro(cudnnDestroyFilterDescriptor)                   \
  __macro(cudnnCreateConvolutionDescriptor)               \
  __macro(cudnnCreatePoolingDescriptor)                   \
  __macro(cudnnDestroyPoolingDescriptor)                  \
  __macro(cudnnSetConvolution2dDescriptor)                \
  __macro(cudnnDestroyConvolutionDescriptor)              \
  __macro(cudnnCreate)                                    \
  __macro(cudnnDestroy)                                   \
  __macro(cudnnSetStream)                                 \
  __macro(cudnnActivationForward)                         \
  __macro(cudnnConvolutionForward)                        \
  __macro(cudnnConvolutionBackwardBias)                   \
  __macro(cudnnGetConvolutionForwardWorkspaceSize)        \
  __macro(cudnnTransformTensor)                           \
  __macro(cudnnPoolingForward)                            \
  __macro(cudnnPoolingBackward)                           \
  __macro(cudnnSoftmaxBackward)                           \
98 99 100
  __macro(cudnnSoftmaxForward)                            \
  __macro(cudnnGetVersion)                                \
  __macro(cudnnGetErrorString)
Z
zhangjinchao01 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP)

#define CUDNN_DNN_ROUTINE_EACH_R2(__macro)                \
  __macro(cudnnAddTensor)                                 \
  __macro(cudnnConvolutionBackwardData)                   \
  __macro(cudnnConvolutionBackwardFilter)
CUDNN_DNN_ROUTINE_EACH_R2(DYNAMIC_LOAD_CUDNN_WRAP)

// APIs available after R3:
#if CUDNN_VERSION >= 3000
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro)              \
  __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize)     \
  __macro(cudnnGetConvolutionBackwardDataAlgorithm)           \
  __macro(cudnnGetConvolutionBackwardFilterAlgorithm)         \
  __macro(cudnnGetConvolutionBackwardDataWorkspaceSize)
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
#endif


// APIs available after R4:
122
#if CUDNN_VERSION >= 4007
Z
zhangjinchao01 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro)             \
  __macro(cudnnBatchNormalizationForwardTraining)            \
  __macro(cudnnBatchNormalizationForwardInference)           \
  __macro(cudnnBatchNormalizationBackward)
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R4
#endif

// APIs in R5
#if CUDNN_VERSION >= 5000
#define CUDNN_DNN_ROUTINE_EACH_R5(__macro)                    \
  __macro(cudnnCreateActivationDescriptor)                    \
  __macro(cudnnSetActivationDescriptor)                       \
  __macro(cudnnGetActivationDescriptor)                       \
  __macro(cudnnDestroyActivationDescriptor)
CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_R5
#endif

#undef CUDNN_DNN_ROUTINE_EACH
143
// clang-format on
Z
zhangjinchao01 已提交
144 145 146
} /* namespace dynload */

/**
147
 * Check build-in cudnn function using glog and it **does not**
Z
zhangjinchao01 已提交
148 149
 * support << operator for more details error info.
 */
150 151 152 153 154
#define CHECK_CUDNN(cudnnFunc)                                         \
  do {                                                                 \
    cudnnStatus_t cudnnStat = cudnnFunc;                               \
    CHECK_EQ(CUDNN_STATUS_SUCCESS, cudnnStat)                          \
        << "Cudnn Error: " << dynload::cudnnGetErrorString(cudnnStat); \
155
  } while (0)
Z
zhangjinchao01 已提交
156 157 158 159

bool g_is_libcudnn_init = false;
int g_cudnn_lib_version = 0;

160 161
void hl_cudnn_desc_init(cudnnTensorDescriptor_t* cudnn_desc) {
  CHECK_CUDNN(dynload::cudnnCreateTensorDescriptor(cudnn_desc));
Z
zhangjinchao01 已提交
162 163
}

164 165 166 167 168 169 170 171 172 173 174 175 176
void hl_cudnn_init(cudnnHandle_t* cudnn_handle, cudaStream_t stream) {
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  size_t cudnn_dso_major = cudnn_dso_ver / 1000;
  size_t cudnn_cuh_major = CUDNN_VERSION / 1000;

  // Compare cudnn header version with that of cudnn.so.
  CHECK((cudnn_cuh_major < 4 && cudnn_dso_major < 4) ||
        (cudnn_cuh_major == cudnn_dso_major))
      << "[cudnn init] libcudnn v" << cudnn_dso_major << " with header v"
      << cudnn_cuh_major << " unmatched!\n"
      << "PaddlePaddle Requirement: "
      << "(header v[2-3] with libcudnn v[2-3]) Or "
      << "(header v4 with libcudnn v4) Or "
177 178
      << "(header v5 with libcudnn v5) Or"
      << "(header v6 with libcudnn v6).";
179

180
  CHECK(!(CUDNN_VERSION < 6000 && CUDNN_VERSION >= 5000 && CUDA_VERSION < 7050))
181 182
      << "cudnn v5 requires cuda version >= 7.5";

183 184 185
  CHECK(!(CUDNN_VERSION >= 6000 && CUDA_VERSION < 8000))
      << "cudnn v6 requires cuda version >= 8.0";

186 187 188 189 190
  CHECK_CUDNN(dynload::cudnnCreate(cudnn_handle));
  CHECK_CUDNN(dynload::cudnnSetStream(*cudnn_handle, stream));

  g_is_libcudnn_init = true;
  g_cudnn_lib_version = cudnn_dso_ver;
Z
zhangjinchao01 已提交
191 192
}

193
int hl_get_cudnn_lib_version() { return g_cudnn_lib_version; }
Z
zhangjinchao01 已提交
194 195 196 197 198 199 200 201 202 203

void hl_conv_workspace(hl_tensor_descriptor input,
                       hl_tensor_descriptor output,
                       hl_filter_descriptor filter,
                       hl_convolution_descriptor conv,
                       int* convFwdAlgo,
                       size_t* fwdLimitBytes,
                       int* convBwdDataAlgo,
                       size_t* bwdDataLimitBytes,
                       int* convBwdFilterAlgo,
204 205
                       size_t* bwdFilterLimitBytes,
                       bool useDilation) {
Z
zhangjinchao01 已提交
206 207
#if CUDNN_VERSION >= 4000

208 209 210 211 212 213 214 215 216
  CHECK_NOTNULL(input);
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(filter);
  CHECK_NOTNULL(conv);

  // Specify workspace limit directly
  size_t memoryLimitBytes =
      (1LL << 20) * FLAGS_cudnn_conv_workspace_limit_in_mb;

217 218 219
  // For dilation
  int algo = 0;

220 221 222 223 224
  // cudnn convolution forward configuration
  cudnnTensorDescriptor_t fwd_src_desc = GET_TENSOR_DESCRIPTOR(input);
  cudnnTensorDescriptor_t fwd_dest_desc = GET_TENSOR_DESCRIPTOR(output);
  cudnnFilterDescriptor_t fwd_filter_desc = GET_FILTER_DESCRIPTOR(filter);
  cudnnConvolutionDescriptor_t fwd_conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);
225 226 227 228 229 230 231 232 233 234 235 236
  // cudnn convolution backward data configuration
  cudnnFilterDescriptor_t bwd_data_filter_desc = GET_FILTER_DESCRIPTOR(filter);
  cudnnTensorDescriptor_t bwd_data_diff_desc = GET_TENSOR_DESCRIPTOR(output);
  cudnnTensorDescriptor_t bwd_data_grad_desc = GET_TENSOR_DESCRIPTOR(input);
  cudnnConvolutionDescriptor_t bwd_data_conv_desc =
      GET_CONVOLUTION_DESCRIPTOR(conv);
  // cudnn convolution backward filter configuration
  cudnnTensorDescriptor_t bwd_filter_src_desc = GET_TENSOR_DESCRIPTOR(input);
  cudnnTensorDescriptor_t bwd_filter_diff_desc = GET_TENSOR_DESCRIPTOR(output);
  cudnnConvolutionDescriptor_t bwd_filter_conv_desc =
      GET_CONVOLUTION_DESCRIPTOR(conv);
  cudnnFilterDescriptor_t bwd_filter_grad_desc = GET_FILTER_DESCRIPTOR(filter);
237

238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
  if (useDilation) {
    convFwdAlgo = &algo;
    convBwdDataAlgo = &algo;
    convBwdFilterAlgo = &algo;
  } else {
    CHECK_CUDNN(dynload::cudnnGetConvolutionForwardAlgorithm(
        t_resource.cudnn_handle,
        fwd_src_desc,
        fwd_filter_desc,
        fwd_conv_desc,
        fwd_dest_desc,
        CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
        memoryLimitBytes,
        reinterpret_cast<cudnnConvolutionFwdAlgo_t*>(convFwdAlgo)));
    CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardDataAlgorithm(
        t_resource.cudnn_handle,
        bwd_data_filter_desc,
        bwd_data_diff_desc,
        bwd_data_conv_desc,
        bwd_data_grad_desc,
        CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
        memoryLimitBytes,
        reinterpret_cast<cudnnConvolutionBwdDataAlgo_t*>(convBwdDataAlgo)));
    CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
        t_resource.cudnn_handle,
        bwd_filter_src_desc,
        bwd_filter_diff_desc,
        bwd_filter_conv_desc,
        bwd_filter_grad_desc,
        CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
        memoryLimitBytes,
        reinterpret_cast<cudnnConvolutionBwdFilterAlgo_t*>(convBwdFilterAlgo)));
  }
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297

  CHECK_CUDNN(dynload::cudnnGetConvolutionForwardWorkspaceSize(
      t_resource.cudnn_handle,
      fwd_src_desc,
      fwd_filter_desc,
      fwd_conv_desc,
      fwd_dest_desc,
      static_cast<cudnnConvolutionFwdAlgo_t>(*convFwdAlgo),
      fwdLimitBytes));

  CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
      t_resource.cudnn_handle,
      bwd_data_filter_desc,
      bwd_data_diff_desc,
      bwd_data_conv_desc,
      bwd_data_grad_desc,
      static_cast<cudnnConvolutionBwdDataAlgo_t>(*convBwdDataAlgo),
      bwdDataLimitBytes));

  CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
      t_resource.cudnn_handle,
      bwd_filter_src_desc,
      bwd_filter_diff_desc,
      bwd_filter_conv_desc,
      bwd_filter_grad_desc,
      static_cast<cudnnConvolutionBwdFilterAlgo_t>(*convBwdFilterAlgo),
      bwdFilterLimitBytes));
Z
zhangjinchao01 已提交
298 299 300 301 302 303 304 305

#endif
}

void hl_create_tensor_descriptor(hl_tensor_descriptor* image_desc,
                                 int batch_size,
                                 int feature_maps,
                                 int height,
306
                                 int width) {
307
  CHECK_NOTNULL(image_desc);
Z
zhangjinchao01 已提交
308

309 310 311
  cudnn_tensor_descriptor hl_desc =
      (cudnn_tensor_descriptor)malloc(sizeof(_cudnn_tensor_descriptor));
  CHECK_NOTNULL(hl_desc);
Z
zhangjinchao01 已提交
312

313
#ifndef PADDLE_TYPE_DOUBLE
314
  cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
Z
zhangjinchao01 已提交
315
#else
316
  cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
Z
zhangjinchao01 已提交
317
#endif
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
  CHECK_CUDNN(dynload::cudnnCreateTensorDescriptor(&hl_desc->desc));

  CHECK_CUDNN(dynload::cudnnSetTensor4dDescriptor(hl_desc->desc,
                                                  CUDNN_TENSOR_NCHW,
                                                  data_type,
                                                  batch_size,
                                                  feature_maps,
                                                  height,
                                                  width));

  hl_desc->format = CUDNN_TENSOR_NCHW;
  hl_desc->data_type = data_type;
  hl_desc->batch_size = batch_size;
  hl_desc->feature_maps = feature_maps;
  hl_desc->height = height;
  hl_desc->width = width;

  *image_desc = (hl_tensor_descriptor)hl_desc;
Z
zhangjinchao01 已提交
336 337 338
}

void hl_create_tensor_descriptor(hl_tensor_descriptor* image_desc) {
339
  CHECK_NOTNULL(image_desc);
Z
zhangjinchao01 已提交
340

341 342 343
  cudnn_tensor_descriptor hl_desc =
      (cudnn_tensor_descriptor)malloc(sizeof(_cudnn_tensor_descriptor));
  CHECK_NOTNULL(hl_desc);
Z
zhangjinchao01 已提交
344

345
#ifndef PADDLE_TYPE_DOUBLE
346
  cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
Z
zhangjinchao01 已提交
347
#else
348
  cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
Z
zhangjinchao01 已提交
349
#endif
350
  CHECK_CUDNN(dynload::cudnnCreateTensorDescriptor(&hl_desc->desc));
Z
zhangjinchao01 已提交
351

352
  hl_desc->data_type = data_type;
Z
zhangjinchao01 已提交
353

354
  *image_desc = (hl_tensor_descriptor)hl_desc;
Z
zhangjinchao01 已提交
355 356 357 358 359 360
}

void hl_tensor_reshape(hl_tensor_descriptor image_desc,
                       int batch_size,
                       int feature_maps,
                       int height,
361
                       int width) {
362 363 364 365 366 367 368 369 370 371 372 373 374
  const int stride_w = 1;
  const int stride_h = width * stride_w;
  const int stride_c = height * stride_h;
  const int stride_n = feature_maps * stride_c;
  return hl_tensor_reshape(image_desc,
                           batch_size,
                           feature_maps,
                           height,
                           width,
                           stride_n,
                           stride_c,
                           stride_h,
                           stride_w);
Z
zhangjinchao01 已提交
375 376 377 378 379 380 381 382 383 384
}

void hl_tensor_reshape(hl_tensor_descriptor image_desc,
                       int batch_size,
                       int feature_maps,
                       int height,
                       int width,
                       int nStride,
                       int cStride,
                       int hStride,
385
                       int wStride) {
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
  CHECK_NOTNULL(image_desc);

  cudnn_tensor_descriptor hl_desc = (cudnn_tensor_descriptor)image_desc;
  CHECK_NOTNULL(hl_desc->desc);

  CHECK_CUDNN(dynload::cudnnSetTensor4dDescriptorEx(hl_desc->desc,
                                                    hl_desc->data_type,
                                                    batch_size,
                                                    feature_maps,
                                                    height,
                                                    width,
                                                    nStride,
                                                    cStride,
                                                    hStride,
                                                    wStride));

  hl_desc->batch_size = batch_size;
  hl_desc->feature_maps = feature_maps;
  hl_desc->height = height;
  hl_desc->width = width;
Z
zhangjinchao01 已提交
406 407
}

408
void hl_destroy_tensor_descriptor(hl_tensor_descriptor image_desc) {
409
  CHECK_NOTNULL(image_desc);
Z
zhangjinchao01 已提交
410

411 412
  cudnn_tensor_descriptor hl_desc = (cudnn_tensor_descriptor)image_desc;
  CHECK_NOTNULL(hl_desc->desc);
Z
zhangjinchao01 已提交
413

414
  CHECK_CUDNN(dynload::cudnnDestroyTensorDescriptor(hl_desc->desc));
Z
zhangjinchao01 已提交
415

416
  hl_desc->desc = NULL;
Z
zhangjinchao01 已提交
417

418
  free(image_desc);
Z
zhangjinchao01 已提交
419 420 421 422 423 424 425 426 427
}

void hl_create_pooling_descriptor(hl_pooling_descriptor* pooling_desc,
                                  hl_pooling_mode_t mode,
                                  int height,
                                  int width,
                                  int height_padding,
                                  int width_padding,
                                  int stride_height,
428
                                  int stride_width) {
429 430 431 432 433 434 435 436
  cudnnPoolingMode_t cudnn_mode;
  switch (mode) {
    case HL_POOLING_MAX:
      cudnn_mode = CUDNN_POOLING_MAX;
      break;
    case HL_POOLING_AVERAGE:
      cudnn_mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
      break;
437 438 439
    case HL_POOLING_AVERAGE_INCLUDE_PADDING:
      cudnn_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
      break;
440 441 442 443 444 445 446 447 448 449 450 451 452 453
    default:
      LOG(FATAL) << "parameter mode error";
  }

  CHECK_NOTNULL(pooling_desc);

  cudnn_pooling_descriptor hl_pooling_desc =
      (cudnn_pooling_descriptor)malloc(sizeof(_cudnn_pooling_descriptor));
  CHECK_NOTNULL(hl_pooling_desc);

  CHECK_CUDNN(dynload::cudnnCreatePoolingDescriptor(&hl_pooling_desc->desc));

  CHECK_CUDNN(dynload::cudnnSetPooling2dDescriptor(hl_pooling_desc->desc,
                                                   cudnn_mode,
Z
zhangjinchao01 已提交
454
#if CUDNN_VERSION >= 5000
455
                                                   CUDNN_PROPAGATE_NAN,
Z
zhangjinchao01 已提交
456
#endif
457 458 459 460 461 462 463 464 465 466 467 468 469 470
                                                   height,
                                                   width,
                                                   height_padding,
                                                   width_padding,
                                                   stride_height,
                                                   stride_width));

  hl_pooling_desc->mode = cudnn_mode;
  hl_pooling_desc->window_height = height;
  hl_pooling_desc->window_width = width;
  hl_pooling_desc->stride_height = stride_height;
  hl_pooling_desc->stride_width = stride_width;

  *pooling_desc = (hl_pooling_descriptor)hl_pooling_desc;
Z
zhangjinchao01 已提交
471 472
}

473
void hl_destroy_pooling_descriptor(hl_pooling_descriptor pooling_desc) {
474
  CHECK_NOTNULL(pooling_desc);
Z
zhangjinchao01 已提交
475

476
  cudnn_pooling_descriptor hl_pooling = (cudnn_pooling_descriptor)pooling_desc;
Z
zhangjinchao01 已提交
477

478 479
  CHECK_NOTNULL(hl_pooling->desc);
  CHECK_CUDNN(dynload::cudnnDestroyPoolingDescriptor(hl_pooling->desc));
Z
zhangjinchao01 已提交
480

481
  hl_pooling->desc = NULL;
Z
zhangjinchao01 已提交
482

483
  free(pooling_desc);
Z
zhangjinchao01 已提交
484 485 486 487 488 489
}

void hl_pooling_forward(hl_tensor_descriptor input,
                        real* input_image,
                        hl_tensor_descriptor output,
                        real* output_image,
490
                        hl_pooling_descriptor pooling) {
491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
  cudnnPoolingDescriptor_t pooling_desc;
  cudnnTensorDescriptor_t input_desc;
  cudnnTensorDescriptor_t output_desc;

  CHECK_NOTNULL(input);
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(pooling);
  CHECK_NOTNULL(input_image);
  CHECK_NOTNULL(output_image);

  real alpha = 1.0f;
  real beta = 1.0f;
  input_desc = ((cudnn_tensor_descriptor)input)->desc;
  output_desc = ((cudnn_tensor_descriptor)output)->desc;
  pooling_desc = ((cudnn_pooling_descriptor)pooling)->desc;
  CHECK_CUDNN(dynload::cudnnPoolingForward(t_resource.cudnn_handle,
                                           pooling_desc,
                                           &alpha,
                                           input_desc,
                                           input_image,
                                           &beta,
                                           output_desc,
                                           output_image));
  CHECK_SYNC("hl_pooling_forward failed");
Z
zhangjinchao01 已提交
515 516 517 518 519 520 521 522
}

void hl_pooling_backward(hl_tensor_descriptor input,
                         real* input_image,
                         real* input_image_grad,
                         hl_tensor_descriptor output,
                         real* output_image,
                         real* output_image_grad,
523
                         hl_pooling_descriptor pooling) {
524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552
  cudnnPoolingDescriptor_t pooling_desc;
  cudnnTensorDescriptor_t input_desc;
  cudnnTensorDescriptor_t output_desc;

  CHECK_NOTNULL(input);
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(pooling);
  CHECK_NOTNULL(input_image);
  CHECK_NOTNULL(input_image_grad);
  CHECK_NOTNULL(output_image);
  CHECK_NOTNULL(output_image_grad);

  real alpha = 1.0f;
  real beta = 1.0f;
  input_desc = ((cudnn_tensor_descriptor)input)->desc;
  output_desc = ((cudnn_tensor_descriptor)output)->desc;
  pooling_desc = ((cudnn_pooling_descriptor)pooling)->desc;
  CHECK_CUDNN(dynload::cudnnPoolingBackward(t_resource.cudnn_handle,
                                            pooling_desc,
                                            &alpha,
                                            output_desc,
                                            output_image,
                                            output_desc,
                                            output_image_grad,
                                            input_desc,
                                            input_image,
                                            &beta,
                                            input_desc,
                                            input_image_grad));
Z
zhangjinchao01 已提交
553 554 555 556 557 558 559
  CHECK_SYNC("hl_pooling_backward failed");
}

void hl_create_filter_descriptor(hl_filter_descriptor* filter,
                                 int input_feature_maps,
                                 int output_feature_maps,
                                 int height,
560
                                 int width) {
561
  CHECK_NOTNULL(filter);
Z
zhangjinchao01 已提交
562

563 564 565
  cudnn_filter_descriptor hl_filter =
      (cudnn_filter_descriptor)malloc(sizeof(_cudnn_filter_descriptor));
  CHECK_NOTNULL(hl_filter);
Z
zhangjinchao01 已提交
566

567
  CHECK_CUDNN(dynload::cudnnCreateFilterDescriptor(&hl_filter->desc));
Z
zhangjinchao01 已提交
568

569
#ifndef PADDLE_TYPE_DOUBLE
570
  cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
Z
zhangjinchao01 已提交
571
#else
572
  cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
Z
zhangjinchao01 已提交
573
#endif
574 575
  CHECK_CUDNN(dynload::cudnnSetFilter4dDescriptor(hl_filter->desc,
                                                  data_type,
Z
zhangjinchao01 已提交
576
#if CUDNN_VERSION >= 5000
577
                                                  CUDNN_TENSOR_NCHW,
Z
zhangjinchao01 已提交
578
#endif
579 580 581 582 583 584 585 586 587 588 589 590
                                                  output_feature_maps,
                                                  input_feature_maps,
                                                  height,
                                                  width));

  hl_filter->data_type = data_type;
  hl_filter->output_feature_maps = output_feature_maps;
  hl_filter->input_feature_maps = input_feature_maps;
  hl_filter->filter_height = height;
  hl_filter->filter_width = width;

  *filter = (hl_filter_descriptor)hl_filter;
Z
zhangjinchao01 已提交
591 592
}

593
void hl_destroy_filter_descriptor(hl_filter_descriptor filter) {
594
  CHECK_NOTNULL(filter);
Z
zhangjinchao01 已提交
595

596 597
  cudnn_filter_descriptor hl_filter = (cudnn_filter_descriptor)filter;
  CHECK_NOTNULL(hl_filter->desc);
Z
zhangjinchao01 已提交
598

599
  CHECK_CUDNN(dynload::cudnnDestroyFilterDescriptor(hl_filter->desc));
Z
zhangjinchao01 已提交
600

601
  hl_filter->desc = NULL;
Z
zhangjinchao01 已提交
602

603
  free(filter);
Z
zhangjinchao01 已提交
604 605 606 607 608 609 610 611
}

void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
                                      hl_tensor_descriptor image,
                                      hl_filter_descriptor filter,
                                      int padding_height,
                                      int padding_width,
                                      int stride_height,
612 613 614
                                      int stride_width,
                                      int dilation_h,
                                      int dilation_w) {
615 616 617 618 619 620 621 622 623
  CHECK_NOTNULL(conv);

  cudnn_convolution_descriptor hl_conv = (cudnn_convolution_descriptor)malloc(
      sizeof(_cudnn_convolution_descriptor));

  CHECK_NOTNULL(hl_conv);
  CHECK_CUDNN(dynload::cudnnCreateConvolutionDescriptor(&hl_conv->desc));

  cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
624 625 626 627 628 629 630 631 632 633 634 635

#if CUDNN_VERSION >= 6000
#ifndef PADDLE_TYPE_DOUBLE
  cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
#else
  cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
#endif
  CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc,
                                                       padding_height,
                                                       padding_width,
                                                       stride_height,
                                                       stride_width,
636 637
                                                       dilation_h,
                                                       dilation_w,
638 639 640
                                                       mode,
                                                       data_type));
#else
641 642
  if (dilation_h > 1 || dilation_w > 1) {
    LOG(FATAL)
W
wanghaoshuang 已提交
643 644
        << "Current cuDNN version does't support for dilation convolution. "
        << "The dilation convolution requires cuDNN >= v6.0.";
645 646
  }

647 648 649 650 651
  CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc,
                                                       padding_height,
                                                       padding_width,
                                                       stride_height,
                                                       stride_width,
652 653
                                                       dilation_h,
                                                       dilation_w,
654
                                                       mode));
655
#endif
656 657 658 659 660 661 662 663 664 665 666 667

  hl_conv->input_image = image;
  hl_conv->filter = filter;
  hl_conv->padding_height = padding_height;
  hl_conv->padding_width = padding_width;
  hl_conv->stride_height = stride_height;
  hl_conv->stride_width = stride_width;
  hl_conv->upscalex = 1;
  hl_conv->upscaley = 1;
  hl_conv->mode = mode;

  *conv = (hl_convolution_descriptor)hl_conv;
Z
zhangjinchao01 已提交
668 669 670 671 672 673 674 675
}

void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
                                     hl_tensor_descriptor image,
                                     hl_filter_descriptor filter,
                                     int padding_height,
                                     int padding_width,
                                     int stride_height,
676 677 678
                                     int stride_width,
                                     int dilation_h,
                                     int dilation_w) {
679 680 681 682 683 684
  CHECK_NOTNULL(conv);
  CHECK_NOTNULL(image);
  CHECK_NOTNULL(filter);

  cudnnConvolutionDescriptor_t conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);
  cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
685 686 687 688 689 690 691 692 693 694 695 696

#if CUDNN_VERSION >= 6000
#ifndef PADDLE_TYPE_DOUBLE
  cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
#else
  cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
#endif
  CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(conv_desc,
                                                       padding_height,
                                                       padding_width,
                                                       stride_height,
                                                       stride_width,
697 698
                                                       dilation_h,
                                                       dilation_w,
699 700 701
                                                       mode,
                                                       data_type));
#else
702 703 704 705 706
  CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(conv_desc,
                                                       padding_height,
                                                       padding_width,
                                                       stride_height,
                                                       stride_width,
707 708
                                                       dilation_h,
                                                       dilation_w,
709
                                                       mode));
710
#endif
711 712 713 714 715 716 717 718 719 720 721

  cudnn_convolution_descriptor hl_conv = (cudnn_convolution_descriptor)conv;
  hl_conv->input_image = image;
  hl_conv->filter = filter;
  hl_conv->padding_height = padding_height;
  hl_conv->padding_width = padding_width;
  hl_conv->stride_height = stride_height;
  hl_conv->stride_width = stride_width;
  hl_conv->upscalex = 1;
  hl_conv->upscaley = 1;
  hl_conv->mode = mode;
Z
zhangjinchao01 已提交
722 723
}

724
void hl_destroy_convolution_descriptor(hl_convolution_descriptor conv) {
725
  CHECK_NOTNULL(conv);
Z
zhangjinchao01 已提交
726

727 728
  cudnn_convolution_descriptor hl_conv = (cudnn_convolution_descriptor)conv;
  CHECK_NOTNULL(hl_conv->desc);
Z
zhangjinchao01 已提交
729

730 731
  CHECK_CUDNN(dynload::cudnnDestroyConvolutionDescriptor(hl_conv->desc));
  hl_conv->desc = NULL;
Z
zhangjinchao01 已提交
732

733
  free(conv);
Z
zhangjinchao01 已提交
734 735 736 737 738 739 740 741 742 743 744 745
}

void hl_convolution_forward(hl_tensor_descriptor input,
                            real* input_data,
                            hl_tensor_descriptor output,
                            real* output_data,
                            hl_filter_descriptor filter,
                            real* filter_data,
                            hl_convolution_descriptor conv,
                            void* gpuWorkSpace,
                            size_t sizeInBytes,
                            int convFwdAlgo) {
746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
  CHECK_NOTNULL(input);
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(filter);
  CHECK_NOTNULL(conv);
  CHECK_NOTNULL(input_data);
  CHECK_NOTNULL(output_data);
  CHECK_NOTNULL(filter_data);
  cudnnTensorDescriptor_t src_desc = GET_TENSOR_DESCRIPTOR(input);
  cudnnTensorDescriptor_t dest_desc = GET_TENSOR_DESCRIPTOR(output);
  cudnnFilterDescriptor_t filter_desc = GET_FILTER_DESCRIPTOR(filter);
  cudnnConvolutionDescriptor_t conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);
  real alpha = 1.0f;
  real beta = 1.0f;
  CHECK_CUDNN(dynload::cudnnConvolutionForward(
      t_resource.cudnn_handle,
      &alpha,
      src_desc,
      input_data,
      filter_desc,
      filter_data,
      conv_desc,
      static_cast<cudnnConvolutionFwdAlgo_t>(convFwdAlgo),
      gpuWorkSpace,
      sizeInBytes,
      &beta,
      dest_desc,
      output_data));
Z
zhangjinchao01 已提交
773 774 775 776 777 778
  CHECK_SYNC("hl_convolution_forward failed");
}

void hl_convolution_forward_add_bias(hl_tensor_descriptor bias,
                                     real* bias_data,
                                     hl_tensor_descriptor output,
779
                                     real* output_data) {
780 781 782 783 784 785 786 787 788 789 790
  CHECK_NOTNULL(bias);
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(bias_data);
  CHECK_NOTNULL(output_data);

  cudnnTensorDescriptor_t output_desc = GET_TENSOR_DESCRIPTOR(output);
  cudnnTensorDescriptor_t bias_desc = GET_TENSOR_DESCRIPTOR(bias);
  real alpha = 1.0f;
  real beta = 1.0f;

  CHECK_CUDNN(dynload::cudnnAddTensor(t_resource.cudnn_handle,
Z
zhangjinchao01 已提交
791
#if CUDNN_VERSION < 4000
792
                                      CUDNN_ADD_SAME_C,
Z
zhangjinchao01 已提交
793
#endif
794 795 796 797 798 799
                                      &alpha,
                                      bias_desc,
                                      bias_data,
                                      &beta,
                                      output_desc,
                                      output_data));
Z
zhangjinchao01 已提交
800 801 802 803 804 805
  CHECK_SYNC("hl_convolution_forward_add_bias failed");
}

void hl_convolution_backward_bias(hl_tensor_descriptor bias,
                                  real* bias_grad_data,
                                  hl_tensor_descriptor output,
806
                                  real* output_grad_data) {
807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822
  CHECK_NOTNULL(bias);
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(bias_grad_data);
  CHECK_NOTNULL(output_grad_data);

  real alpha = 1.0f;
  real beta = 1.0f;
  cudnnTensorDescriptor_t diff_desc = GET_TENSOR_DESCRIPTOR(output);
  cudnnTensorDescriptor_t bias_desc = GET_TENSOR_DESCRIPTOR(bias);
  CHECK_CUDNN(dynload::cudnnConvolutionBackwardBias(t_resource.cudnn_handle,
                                                    &alpha,
                                                    diff_desc,
                                                    output_grad_data,
                                                    &beta,
                                                    bias_desc,
                                                    bias_grad_data));
Z
zhangjinchao01 已提交
823 824 825 826 827 828 829 830 831 832 833 834 835
  CHECK_SYNC("hl_convolution_backward_bias failed");
}

void hl_convolution_backward_filter(hl_tensor_descriptor input,
                                    real* input_data,
                                    hl_tensor_descriptor output,
                                    real* output_grad_data,
                                    hl_filter_descriptor filter,
                                    real* filter_grad_data,
                                    hl_convolution_descriptor conv,
                                    void* gpuWorkSpace,
                                    size_t sizeInBytes,
                                    int convBwdFilterAlgo) {
836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858
  CHECK_NOTNULL(input);
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(filter);
  CHECK_NOTNULL(conv);
  CHECK_NOTNULL(input_data);
  CHECK_NOTNULL(output_grad_data);
  CHECK_NOTNULL(filter_grad_data);

  real alpha = 1.0f;
  real beta = 1.0f;
  cudnnTensorDescriptor_t src_desc = GET_TENSOR_DESCRIPTOR(input);
  cudnnTensorDescriptor_t diff_desc = GET_TENSOR_DESCRIPTOR(output);
  cudnnConvolutionDescriptor_t conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);
  cudnnFilterDescriptor_t grad_desc = GET_FILTER_DESCRIPTOR(filter);

  CHECK_CUDNN(dynload::cudnnConvolutionBackwardFilter(
      t_resource.cudnn_handle,
      &alpha,
      src_desc,
      input_data,
      diff_desc,
      output_grad_data,
      conv_desc,
Z
zhangjinchao01 已提交
859
#if CUDNN_VERSION >= 4000
860 861 862
      static_cast<cudnnConvolutionBwdFilterAlgo_t>(convBwdFilterAlgo),
      gpuWorkSpace,
      sizeInBytes,
Z
zhangjinchao01 已提交
863
#endif
864 865 866
      &beta,
      grad_desc,
      filter_grad_data));
Z
zhangjinchao01 已提交
867 868 869 870 871 872 873 874 875 876 877 878 879
  CHECK_SYNC("hl_convolution_backward_filter failed");
}

void hl_convolution_backward_data(hl_tensor_descriptor input,
                                  real* input_data_grad,
                                  hl_tensor_descriptor output,
                                  real* output_grad_data,
                                  hl_filter_descriptor filter,
                                  real* filter_data,
                                  hl_convolution_descriptor conv,
                                  void* gpuWorkSpace,
                                  size_t sizeInBytes,
                                  int convBwdDataAlgo) {
880 881 882 883 884 885 886 887 888 889 890 891 892 893 894
  real alpha = 1.0f;
  real beta = 1.0f;
  cudnnFilterDescriptor_t filter_desc = GET_FILTER_DESCRIPTOR(filter);
  cudnnTensorDescriptor_t diff_desc = GET_TENSOR_DESCRIPTOR(output);
  cudnnTensorDescriptor_t grad_desc = GET_TENSOR_DESCRIPTOR(input);
  cudnnConvolutionDescriptor_t conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);

  CHECK_CUDNN(dynload::cudnnConvolutionBackwardData(
      t_resource.cudnn_handle,
      &alpha,
      filter_desc,
      filter_data,
      diff_desc,
      output_grad_data,
      conv_desc,
Z
zhangjinchao01 已提交
895
#if CUDNN_VERSION >= 4000
896 897 898
      static_cast<cudnnConvolutionBwdDataAlgo_t>(convBwdDataAlgo),
      gpuWorkSpace,
      sizeInBytes,
Z
zhangjinchao01 已提交
899
#endif
900 901 902
      &beta,
      grad_desc,
      input_data_grad));
Z
zhangjinchao01 已提交
903 904 905
  CHECK_SYNC("hl_convolution_backward_data failed");
}

906
void hl_softmax_forward(real* input, real* output, int height, int width) {
907
#ifndef PADDLE_TYPE_DOUBLE
908
  cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
Z
zhangjinchao01 已提交
909
#else
910
  cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
Z
zhangjinchao01 已提交
911
#endif
912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930
  CHECK_CUDNN(dynload::cudnnSetTensor4dDescriptor(t_resource.cudnn_desc,
                                                  CUDNN_TENSOR_NCHW,
                                                  data_type,
                                                  height,
                                                  width,
                                                  1,
                                                  1));

  real alpha = 1.0f;
  real beta = 0.0f;
  CHECK_CUDNN(dynload::cudnnSoftmaxForward(t_resource.cudnn_handle,
                                           CUDNN_SOFTMAX_ACCURATE,
                                           CUDNN_SOFTMAX_MODE_CHANNEL,
                                           &alpha,
                                           t_resource.cudnn_desc,
                                           input,
                                           &beta,
                                           t_resource.cudnn_desc,
                                           output));
Z
zhangjinchao01 已提交
931 932 933
  CHECK_SYNC("hl_softmax_forward failed");
}

934 935
void hl_softmax_backward(real* output_value,
                         real* output_grad,
Z
zhangjinchao01 已提交
936
                         int height,
937
                         int width) {
938
#ifndef PADDLE_TYPE_DOUBLE
939
  cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
Z
zhangjinchao01 已提交
940
#else
941
  cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
Z
zhangjinchao01 已提交
942
#endif
943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963
  CHECK_CUDNN(dynload::cudnnSetTensor4dDescriptor(t_resource.cudnn_desc,
                                                  CUDNN_TENSOR_NCHW,
                                                  data_type,
                                                  height,
                                                  width,
                                                  1,
                                                  1));

  real alpha = 1.0f;
  real beta = 0.0f;
  CHECK_CUDNN(dynload::cudnnSoftmaxBackward(t_resource.cudnn_handle,
                                            CUDNN_SOFTMAX_ACCURATE,
                                            CUDNN_SOFTMAX_MODE_CHANNEL,
                                            &alpha,
                                            t_resource.cudnn_desc,
                                            output_value,
                                            t_resource.cudnn_desc,
                                            output_grad,
                                            &beta,
                                            t_resource.cudnn_desc,
                                            output_grad));
Z
zhangjinchao01 已提交
964 965 966 967
  CHECK_SYNC("hl_softmax_backward failed");
}

void hl_batch_norm_forward_training(hl_tensor_descriptor inputDesc,
968
                                    real* input,
Z
zhangjinchao01 已提交
969
                                    hl_tensor_descriptor outputDesc,
970
                                    real* output,
Z
zhangjinchao01 已提交
971
                                    hl_tensor_descriptor bnParamDesc,
972 973
                                    real* scale,
                                    real* bias,
Z
zhangjinchao01 已提交
974
                                    double factor,
975 976
                                    real* runningMean,
                                    real* runningInvVar,
Z
zhangjinchao01 已提交
977
                                    double epsilon,
978 979
                                    real* savedMean,
                                    real* savedVar) {
980
#if CUDNN_VERSION >= 4007
Z
zhangjinchao01 已提交
981 982 983
  if ((NULL != runningMean && NULL == runningInvVar) ||
      (NULL == runningMean && NULL != runningInvVar)) {
    LOG(FATAL) << "runningMean and runningInvVar can be NULL "
984
               << "but only at the same time.";
Z
zhangjinchao01 已提交
985 986 987 988 989 990 991 992 993 994 995 996 997
  }
  if ((NULL != savedMean && NULL == savedVar) ||
      (NULL == savedMean && NULL != savedVar)) {
    LOG(FATAL) << "savedMean and savedVar can be NULL "
               << "but only at the same time.";
  }

  cudnnTensorDescriptor_t xDesc = GET_TENSOR_DESCRIPTOR(inputDesc);
  cudnnTensorDescriptor_t yDesc = GET_TENSOR_DESCRIPTOR(outputDesc);
  cudnnTensorDescriptor_t bnDesc = GET_TENSOR_DESCRIPTOR(bnParamDesc);
  real alpha = 1.0f;
  real beta = 1.0f;
  cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015
  CHECK_CUDNN(
      dynload::cudnnBatchNormalizationForwardTraining(t_resource.cudnn_handle,
                                                      mode,
                                                      &alpha,
                                                      &beta,
                                                      xDesc,
                                                      input,
                                                      yDesc,
                                                      output,
                                                      bnDesc,
                                                      scale,
                                                      bias,
                                                      factor,
                                                      runningMean,
                                                      runningInvVar,
                                                      epsilon,
                                                      savedMean,
                                                      savedVar));
Z
zhangjinchao01 已提交
1016 1017 1018

  CHECK_SYNC("hl_batch_norm_forward_training failed");
#else
1019
  LOG(FATAL) << "CudnnBatchNorm requires cudnn version >= 4007. "
Z
zhangjinchao01 已提交
1020 1021 1022 1023 1024
             << "But cudnn lib version is " << g_cudnn_lib_version;
#endif
}

void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc,
1025 1026 1027 1028 1029 1030 1031 1032 1033
                                     real* input,
                                     hl_tensor_descriptor outputDesc,
                                     real* output,
                                     hl_tensor_descriptor bnParamDesc,
                                     real* scale,
                                     real* bias,
                                     real* estimatedMean,
                                     real* estimatedInvVar,
                                     double epsilon) {
1034
#if CUDNN_VERSION >= 4007
Z
zhangjinchao01 已提交
1035 1036 1037 1038 1039 1040
  cudnnTensorDescriptor_t xDesc = GET_TENSOR_DESCRIPTOR(inputDesc);
  cudnnTensorDescriptor_t yDesc = GET_TENSOR_DESCRIPTOR(outputDesc);
  cudnnTensorDescriptor_t bnDesc = GET_TENSOR_DESCRIPTOR(bnParamDesc);
  real alpha = 1.0f;
  real beta = 1.0f;
  cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
1041

1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056
  CHECK_CUDNN(
      dynload::cudnnBatchNormalizationForwardInference(t_resource.cudnn_handle,
                                                       mode,
                                                       &alpha,
                                                       &beta,
                                                       xDesc,
                                                       input,
                                                       yDesc,
                                                       output,
                                                       bnDesc,
                                                       scale,
                                                       bias,
                                                       estimatedMean,
                                                       estimatedInvVar,
                                                       epsilon));
Z
zhangjinchao01 已提交
1057 1058 1059

  CHECK_SYNC("hl_batch_norm_forward_inference failed");
#else
1060
  LOG(FATAL) << "CudnnBatchNorm requires cudnn version >= 4007. "
Z
zhangjinchao01 已提交
1061 1062 1063 1064 1065
             << "But cudnn lib version is " << g_cudnn_lib_version;
#endif
}

void hl_batch_norm_backward(hl_tensor_descriptor inputDesc,
1066
                            real* input,
Z
zhangjinchao01 已提交
1067
                            hl_tensor_descriptor outGradDesc,
1068
                            real* outGrad,
Z
zhangjinchao01 已提交
1069
                            hl_tensor_descriptor inGradDesc,
1070
                            real* inGrad,
Z
zhangjinchao01 已提交
1071
                            hl_tensor_descriptor dBnParamDesc,
1072 1073 1074
                            real* scale,
                            real* scaleGrad,
                            real* biasGrad,
Z
zhangjinchao01 已提交
1075
                            double epsilon,
1076 1077
                            real* savedMean,
                            real* savedInvVar) {
1078
#if CUDNN_VERSION >= 4007
Z
zhangjinchao01 已提交
1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091
  if ((NULL != savedMean && NULL == savedInvVar) ||
      (NULL == savedMean && NULL != savedInvVar)) {
    LOG(FATAL) << "savedMean and savedVar can be NULL "
               << "but only at the same time.";
  }

  cudnnTensorDescriptor_t xDesc = GET_TENSOR_DESCRIPTOR(inputDesc);
  cudnnTensorDescriptor_t dyDesc = GET_TENSOR_DESCRIPTOR(outGradDesc);
  cudnnTensorDescriptor_t dxDesc = GET_TENSOR_DESCRIPTOR(inGradDesc);
  cudnnTensorDescriptor_t bnDesc = GET_TENSOR_DESCRIPTOR(dBnParamDesc);
  real alpha = 1.0f;
  real beta = 1.0f;
  cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110
  CHECK_CUDNN(dynload::cudnnBatchNormalizationBackward(t_resource.cudnn_handle,
                                                       mode,
                                                       &alpha,
                                                       &beta,
                                                       &alpha,
                                                       &beta,
                                                       xDesc,
                                                       input,
                                                       dyDesc,
                                                       outGrad,
                                                       dxDesc,
                                                       inGrad,
                                                       bnDesc,
                                                       scale,
                                                       scaleGrad,
                                                       biasGrad,
                                                       epsilon,
                                                       savedMean,
                                                       savedInvVar));
Z
zhangjinchao01 已提交
1111 1112 1113

  CHECK_SYNC("hl_batch_norm_backward failed");
#else
1114
  LOG(FATAL) << "CudnnBatchNorm requires cudnn version >= 4007. "
Z
zhangjinchao01 已提交
1115 1116 1117
             << "But cudnn lib version is " << g_cudnn_lib_version;
#endif
}