hl_cuda_cudnn.cc 38.2 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */


#include <cudnn.h>
#include <mutex>
#include "hl_cuda_cudnn.h"
#include "hl_cuda_cudnn.ph"
#include "hl_thread.ph"
#include "hl_dso_loader.h"
#include "paddle/utils/Logging.h"
23 24 25 26 27
#include "paddle/utils/CommandLineParser.h"

P_DEFINE_int32(cudnn_conv_workspace_limit_in_mb, 4096,
                "Specify cuDNN max workspace limit, in units MB, "
                "4096MB=4GB by default.");
Z
zhangjinchao01 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

namespace dynload {

std::once_flag cudnn_dso_flag;
void* cudnn_dso_handle = nullptr;

/**
 * The following macro definition can generate structs
 * (for each function) to dynamic load cudbnn routine
 * via operator overloading: operator ()
 *
 * note: default dynamic linked libs
 **/

#ifdef PADDLE_USE_DSO

44 45 46 47 48 49 50 51 52 53
#define DYNAMIC_LOAD_CUDNN_WRAP(__name)                          \
  struct DynLoad__##__name {                                     \
    template <typename... Args>                                  \
    auto operator()(Args... args) -> decltype(__name(args...)) { \
      using cudnn_func = decltype(__name(args...))(*)(Args...);  \
      std::call_once(cudnn_dso_flag, GetCudnnDsoHandle,          \
                     &cudnn_dso_handle);                         \
      void* p_##__name = dlsym(cudnn_dso_handle, #__name);       \
      return reinterpret_cast<cudnn_func>(p_##__name)(args...);  \
    }                                                            \
Z
zhangjinchao01 已提交
54 55 56 57
  } __name; /* struct DynLoad__##__name */

#else

58 59 60 61 62 63
#define DYNAMIC_LOAD_CUDNN_WRAP(__name)                          \
  struct DynLoad__##__name {                                     \
    template <typename... Args>                                  \
    auto operator()(Args... args) -> decltype(__name(args...)) { \
      return __name(args...);                                    \
    }                                                            \
Z
zhangjinchao01 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
  } __name; /* struct DynLoad__##__name */

#endif

/**
 * include all needed cudnn functions in HPPL
 * different cudnn version has different interfaces
 **/
#define CUDNN_DNN_ROUTINE_EACH(__macro)                   \
  __macro(cudnnSetTensor4dDescriptor)                     \
  __macro(cudnnSetTensor4dDescriptorEx)                   \
  __macro(cudnnGetConvolutionNdForwardOutputDim)          \
  __macro(cudnnGetConvolutionForwardAlgorithm)            \
  __macro(cudnnCreateTensorDescriptor)                    \
  __macro(cudnnDestroyTensorDescriptor)                   \
  __macro(cudnnCreateFilterDescriptor)                    \
  __macro(cudnnSetFilter4dDescriptor)                     \
  __macro(cudnnSetPooling2dDescriptor)                    \
  __macro(cudnnDestroyFilterDescriptor)                   \
  __macro(cudnnCreateConvolutionDescriptor)               \
  __macro(cudnnCreatePoolingDescriptor)                   \
  __macro(cudnnDestroyPoolingDescriptor)                  \
  __macro(cudnnSetConvolution2dDescriptor)                \
  __macro(cudnnDestroyConvolutionDescriptor)              \
  __macro(cudnnCreate)                                    \
  __macro(cudnnDestroy)                                   \
  __macro(cudnnSetStream)                                 \
  __macro(cudnnActivationForward)                         \
  __macro(cudnnConvolutionForward)                        \
  __macro(cudnnConvolutionBackwardBias)                   \
  __macro(cudnnGetConvolutionForwardWorkspaceSize)        \
  __macro(cudnnTransformTensor)                           \
  __macro(cudnnPoolingForward)                            \
  __macro(cudnnPoolingBackward)                           \
  __macro(cudnnSoftmaxBackward)                           \
99 100 101
  __macro(cudnnSoftmaxForward)                            \
  __macro(cudnnGetVersion)                                \
  __macro(cudnnGetErrorString)
Z
zhangjinchao01 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP)

#define CUDNN_DNN_ROUTINE_EACH_R2(__macro)                \
  __macro(cudnnAddTensor)                                 \
  __macro(cudnnConvolutionBackwardData)                   \
  __macro(cudnnConvolutionBackwardFilter)
CUDNN_DNN_ROUTINE_EACH_R2(DYNAMIC_LOAD_CUDNN_WRAP)

// APIs available after R3:
#if CUDNN_VERSION >= 3000
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro)              \
  __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize)     \
  __macro(cudnnGetConvolutionBackwardDataAlgorithm)           \
  __macro(cudnnGetConvolutionBackwardFilterAlgorithm)         \
  __macro(cudnnGetConvolutionBackwardDataWorkspaceSize)
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
#endif


// APIs available after R4:
123
#if CUDNN_VERSION >= 4007
Z
zhangjinchao01 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro)             \
  __macro(cudnnBatchNormalizationForwardTraining)            \
  __macro(cudnnBatchNormalizationForwardInference)           \
  __macro(cudnnBatchNormalizationBackward)
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R4
#endif

// APIs in R5
#if CUDNN_VERSION >= 5000
#define CUDNN_DNN_ROUTINE_EACH_R5(__macro)                    \
  __macro(cudnnCreateActivationDescriptor)                    \
  __macro(cudnnSetActivationDescriptor)                       \
  __macro(cudnnGetActivationDescriptor)                       \
  __macro(cudnnDestroyActivationDescriptor)
CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_R5
#endif

#undef CUDNN_DNN_ROUTINE_EACH

} /* namespace dynload */

/**
148
 * Check build-in cudnn function using glog and it **does not**
Z
zhangjinchao01 已提交
149 150
 * support << operator for more details error info.
 */
151 152 153 154 155 156 157
#define CHECK_CUDNN(cudnnFunc)                               \
  do {                                                       \
    cudnnStatus_t cudnnStat = cudnnFunc;                     \
    CHECK_EQ(CUDNN_STATUS_SUCCESS, cudnnStat)                \
        << "Cudnn Error: "                                   \
        << dynload::cudnnGetErrorString(cudnnStat);          \
  } while (0)
Z
zhangjinchao01 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214

bool g_is_libcudnn_init = false;
int g_cudnn_lib_version = 0;

void hl_cudnn_desc_init(cudnnTensorDescriptor_t*  cudnn_desc)
{
    CHECK_CUDNN(dynload::cudnnCreateTensorDescriptor(cudnn_desc));
}

void hl_cudnn_init(cudnnHandle_t *cudnn_handle, cudaStream_t stream)
{
    size_t cudnn_dso_ver = dynload::cudnnGetVersion();
    size_t cudnn_dso_major = cudnn_dso_ver / 1000;
    size_t cudnn_cuh_major = CUDNN_VERSION / 1000;

    // Compare cudnn header version with that of cudnn.so.
    CHECK((cudnn_cuh_major < 4 && cudnn_dso_major < 4) ||
          (cudnn_cuh_major == cudnn_dso_major))
        << "[cudnn init] libcudnn v" << cudnn_dso_major <<
        " with header v" << cudnn_cuh_major << " unmatched!\n"
        << "PaddlePaddle Requirement: "
        << "(header v[2-3] with libcudnn v[2-3]) Or "
        << "(header v4 with libcudnn v4) Or "
        << "(header v5 with libcudnn v5).";

    CHECK(!(CUDNN_VERSION >= 5000 && CUDA_VERSION < 7050))
        << "cudnn v5 requires cuda version >= 7.5";

    CHECK_CUDNN(dynload::cudnnCreate(cudnn_handle));
    CHECK_CUDNN(dynload::cudnnSetStream(*cudnn_handle, stream));

    g_is_libcudnn_init = true;
    g_cudnn_lib_version = cudnn_dso_ver;
}

int hl_get_cudnn_lib_version() {
  return g_cudnn_lib_version;
}

void hl_conv_workspace(hl_tensor_descriptor input,
                       hl_tensor_descriptor output,
                       hl_filter_descriptor filter,
                       hl_convolution_descriptor conv,
                       int* convFwdAlgo,
                       size_t* fwdLimitBytes,
                       int* convBwdDataAlgo,
                       size_t* bwdDataLimitBytes,
                       int* convBwdFilterAlgo,
                       size_t* bwdFilterLimitBytes) {
#if CUDNN_VERSION >= 4000

    CHECK_NOTNULL(input);
    CHECK_NOTNULL(output);
    CHECK_NOTNULL(filter);
    CHECK_NOTNULL(conv);

    // Specify workspace limit directly
215
    size_t memoryLimitBytes = (1LL << 20) * FLAGS_cudnn_conv_workspace_limit_in_mb;
Z
zhangjinchao01 已提交
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312

    // cudnn convolution forward configuration
    cudnnTensorDescriptor_t       fwd_src_desc = GET_TENSOR_DESCRIPTOR(input);
    cudnnTensorDescriptor_t       fwd_dest_desc = GET_TENSOR_DESCRIPTOR(output);
    cudnnFilterDescriptor_t       fwd_filter_desc = GET_FILTER_DESCRIPTOR(filter);
    cudnnConvolutionDescriptor_t  fwd_conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);

    CHECK_CUDNN(dynload::cudnnGetConvolutionForwardAlgorithm(
             t_resource.cudnn_handle,
             fwd_src_desc,
             fwd_filter_desc,
             fwd_conv_desc,
             fwd_dest_desc,
             CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
             memoryLimitBytes,
             reinterpret_cast<cudnnConvolutionFwdAlgo_t*>(convFwdAlgo)));

    CHECK_CUDNN(dynload::cudnnGetConvolutionForwardWorkspaceSize(
             t_resource.cudnn_handle,
             fwd_src_desc,
             fwd_filter_desc,
             fwd_conv_desc,
             fwd_dest_desc,
             static_cast<cudnnConvolutionFwdAlgo_t>(*convFwdAlgo),
             fwdLimitBytes));

    // cudnn convolution backward data configuration
    cudnnFilterDescriptor_t       bwd_data_filter_desc =
                                          GET_FILTER_DESCRIPTOR(filter);
    cudnnTensorDescriptor_t       bwd_data_diff_desc =
                                          GET_TENSOR_DESCRIPTOR(output);
    cudnnTensorDescriptor_t       bwd_data_grad_desc =
                                          GET_TENSOR_DESCRIPTOR(input);
    cudnnConvolutionDescriptor_t  bwd_data_conv_desc =
                                          GET_CONVOLUTION_DESCRIPTOR(conv);

    CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardDataAlgorithm(
             t_resource.cudnn_handle,
             bwd_data_filter_desc,
             bwd_data_diff_desc,
             bwd_data_conv_desc,
             bwd_data_grad_desc,
             CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
             memoryLimitBytes,
             reinterpret_cast<cudnnConvolutionBwdDataAlgo_t*>(convBwdDataAlgo)));

    CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
             t_resource.cudnn_handle,
             bwd_data_filter_desc,
             bwd_data_diff_desc,
             bwd_data_conv_desc,
             bwd_data_grad_desc,
             static_cast<cudnnConvolutionBwdDataAlgo_t>(*convBwdDataAlgo),
             bwdDataLimitBytes));

    // cudnn convolution backward filter configuration
    cudnnTensorDescriptor_t       bwd_filter_src_desc =
                                      GET_TENSOR_DESCRIPTOR(input);
    cudnnTensorDescriptor_t       bwd_filter_diff_desc =
                                      GET_TENSOR_DESCRIPTOR(output);
    cudnnConvolutionDescriptor_t  bwd_filter_conv_desc =
                                      GET_CONVOLUTION_DESCRIPTOR(conv);
    cudnnFilterDescriptor_t       bwd_filter_grad_desc =
                                      GET_FILTER_DESCRIPTOR(filter);

    CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
             t_resource.cudnn_handle,
             bwd_filter_src_desc,
             bwd_filter_diff_desc,
             bwd_filter_conv_desc,
             bwd_filter_grad_desc,
             CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
             memoryLimitBytes,
             reinterpret_cast<cudnnConvolutionBwdFilterAlgo_t*>(convBwdFilterAlgo)));

    CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
             t_resource.cudnn_handle, bwd_filter_src_desc,
             bwd_filter_diff_desc, bwd_filter_conv_desc,
             bwd_filter_grad_desc,
             static_cast<cudnnConvolutionBwdFilterAlgo_t>(*convBwdFilterAlgo),
             bwdFilterLimitBytes));

#endif
}

void hl_create_tensor_descriptor(hl_tensor_descriptor* image_desc,
                                 int batch_size,
                                 int feature_maps,
                                 int height,
                                 int width)
{
    CHECK_NOTNULL(image_desc);

    cudnn_tensor_descriptor hl_desc =
        (cudnn_tensor_descriptor)malloc(sizeof(_cudnn_tensor_descriptor));
    CHECK_NOTNULL(hl_desc);

313
#ifndef PADDLE_TYPE_DOUBLE
Z
zhangjinchao01 已提交
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
    cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
#else
    cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
#endif
    CHECK_CUDNN(dynload::cudnnCreateTensorDescriptor(&hl_desc->desc));

    CHECK_CUDNN(dynload::cudnnSetTensor4dDescriptor(
                hl_desc->desc,
                CUDNN_TENSOR_NCHW,
                data_type,
                batch_size,
                feature_maps,
                height,
                width));

    hl_desc->format = CUDNN_TENSOR_NCHW;
    hl_desc->data_type = data_type;
    hl_desc->batch_size = batch_size;
    hl_desc->feature_maps = feature_maps;
    hl_desc->height = height;
    hl_desc->width = width;

    *image_desc = (hl_tensor_descriptor)hl_desc;
}

void hl_create_tensor_descriptor(hl_tensor_descriptor* image_desc) {
    CHECK_NOTNULL(image_desc);

    cudnn_tensor_descriptor hl_desc =
        (cudnn_tensor_descriptor)malloc(sizeof(_cudnn_tensor_descriptor));
    CHECK_NOTNULL(hl_desc);

346
#ifndef PADDLE_TYPE_DOUBLE
Z
zhangjinchao01 已提交
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
    cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
#else
    cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
#endif
    CHECK_CUDNN(dynload::cudnnCreateTensorDescriptor(&hl_desc->desc));

    hl_desc->data_type = data_type;

    *image_desc = (hl_tensor_descriptor)hl_desc;
}

void hl_tensor_reshape(hl_tensor_descriptor image_desc,
                       int batch_size,
                       int feature_maps,
                       int height,
                       int width)
{
    const int stride_w = 1;
    const int stride_h = width * stride_w;
    const int stride_c = height * stride_h;
    const int stride_n = feature_maps * stride_c;
    return hl_tensor_reshape(image_desc,
                             batch_size,
                             feature_maps,
                             height,
                             width,
                             stride_n,
                             stride_c,
                             stride_h,
                             stride_w);
}

void hl_tensor_reshape(hl_tensor_descriptor image_desc,
                       int batch_size,
                       int feature_maps,
                       int height,
                       int width,
                       int nStride,
                       int cStride,
                       int hStride,
                       int wStride)
{
    CHECK_NOTNULL(image_desc);

    cudnn_tensor_descriptor hl_desc = (cudnn_tensor_descriptor)image_desc;
    CHECK_NOTNULL(hl_desc->desc);

    CHECK_CUDNN(dynload::cudnnSetTensor4dDescriptorEx(hl_desc->desc,
                hl_desc->data_type,
                batch_size,
                feature_maps,
                height,
                width,
                nStride,
                cStride,
                hStride,
                wStride));

    hl_desc->batch_size = batch_size;
    hl_desc->feature_maps = feature_maps;
    hl_desc->height = height;
    hl_desc->width = width;
}

void hl_destroy_tensor_descriptor(hl_tensor_descriptor image_desc)
{
    CHECK_NOTNULL(image_desc);

    cudnn_tensor_descriptor hl_desc = (cudnn_tensor_descriptor)image_desc;
    CHECK_NOTNULL(hl_desc->desc);

    CHECK_CUDNN(dynload::cudnnDestroyTensorDescriptor(hl_desc->desc));

    hl_desc->desc = NULL;

    free(image_desc);
}


void hl_create_pooling_descriptor(hl_pooling_descriptor* pooling_desc,
                                  hl_pooling_mode_t mode,
                                  int height,
                                  int width,
                                  int height_padding,
                                  int width_padding,
                                  int stride_height,
                                  int stride_width)
{
    cudnnPoolingMode_t cudnn_mode;
    switch (mode)
    {
        case HL_POOLING_MAX:
            cudnn_mode = CUDNN_POOLING_MAX;
            break;
        case HL_POOLING_AVERAGE:
            cudnn_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
            break;
        case HL_POOLING_AVERAGE_EXCLUDE_PADDING:
            cudnn_mode = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
            break;
        default:
            LOG(FATAL) << "parameter mode error";
    }

    CHECK_NOTNULL(pooling_desc);

    cudnn_pooling_descriptor hl_pooling_desc =
        (cudnn_pooling_descriptor)malloc(sizeof(_cudnn_pooling_descriptor));
    CHECK_NOTNULL(hl_pooling_desc);

    CHECK_CUDNN(dynload::cudnnCreatePoolingDescriptor(&hl_pooling_desc->desc));

    CHECK_CUDNN(dynload::cudnnSetPooling2dDescriptor(
                hl_pooling_desc->desc,
                cudnn_mode,
#if CUDNN_VERSION >= 5000
                CUDNN_PROPAGATE_NAN,
#endif
                height,
                width,
                height_padding,
                width_padding,
                stride_height,
                stride_width));

    hl_pooling_desc->mode = cudnn_mode;
    hl_pooling_desc->window_height = height;
    hl_pooling_desc->window_width = width;
    hl_pooling_desc->stride_height = stride_height;
    hl_pooling_desc->stride_width = stride_width;

    *pooling_desc = (hl_pooling_descriptor)hl_pooling_desc;
}

void hl_destroy_pooling_descriptor(hl_pooling_descriptor pooling_desc)
{
    CHECK_NOTNULL(pooling_desc);

    cudnn_pooling_descriptor hl_pooling = (cudnn_pooling_descriptor)pooling_desc;
    CHECK_NOTNULL(hl_pooling->desc);

    CHECK_CUDNN(dynload::cudnnDestroyPoolingDescriptor(hl_pooling->desc));

    hl_pooling->desc = NULL;

    free(pooling_desc);
}

void hl_pooling_forward(hl_tensor_descriptor input,
                        real* input_image,
                        hl_tensor_descriptor output,
                        real* output_image,
                        hl_pooling_descriptor pooling)
{
    cudnnPoolingDescriptor_t    pooling_desc;
    cudnnTensorDescriptor_t     input_desc;
    cudnnTensorDescriptor_t     output_desc;

    CHECK_NOTNULL(input);
    CHECK_NOTNULL(output);
    CHECK_NOTNULL(pooling);
    CHECK_NOTNULL(input_image);
    CHECK_NOTNULL(output_image);

    real alpha = 1.0f;
    real beta = 1.0f;
    input_desc = ((cudnn_tensor_descriptor)input)->desc;
    output_desc = ((cudnn_tensor_descriptor)output)->desc;
    pooling_desc = ((cudnn_pooling_descriptor)pooling)->desc;
    CHECK_CUDNN(dynload::cudnnPoolingForward(
                t_resource.cudnn_handle,
                pooling_desc,
                &alpha,
                input_desc,
                input_image,
                &beta,
                output_desc,
                output_image));
    CHECK_SYNC("hl_pooling_forward failed");
}

void hl_pooling_backward(hl_tensor_descriptor input,
                         real* input_image,
                         real* input_image_grad,
                         hl_tensor_descriptor output,
                         real* output_image,
                         real* output_image_grad,
                         hl_pooling_descriptor pooling)
{
    cudnnPoolingDescriptor_t    pooling_desc;
    cudnnTensorDescriptor_t     input_desc;
    cudnnTensorDescriptor_t     output_desc;

    CHECK_NOTNULL(input);
    CHECK_NOTNULL(output);
    CHECK_NOTNULL(pooling);
    CHECK_NOTNULL(input_image);
    CHECK_NOTNULL(input_image_grad);
    CHECK_NOTNULL(output_image);
    CHECK_NOTNULL(output_image_grad);

    real alpha = 1.0f;
    real beta = 1.0f;
    input_desc = ((cudnn_tensor_descriptor)input)->desc;
    output_desc = ((cudnn_tensor_descriptor)output)->desc;
    pooling_desc = ((cudnn_pooling_descriptor)pooling)->desc;
    CHECK_CUDNN(dynload::cudnnPoolingBackward(
                t_resource.cudnn_handle,
                pooling_desc,
                &alpha,
                output_desc,
                output_image,
                output_desc,
                output_image_grad,
                input_desc,
                input_image,
                &beta,
                input_desc,
                input_image_grad));
  CHECK_SYNC("hl_pooling_backward failed");
}


void hl_create_filter_descriptor(hl_filter_descriptor* filter,
                                 int input_feature_maps,
                                 int output_feature_maps,
                                 int height,
                                 int width)
{
    CHECK_NOTNULL(filter);

    cudnn_filter_descriptor hl_filter =
        (cudnn_filter_descriptor)malloc(sizeof(_cudnn_filter_descriptor));
    CHECK_NOTNULL(hl_filter);

    CHECK_CUDNN(dynload::cudnnCreateFilterDescriptor(&hl_filter->desc));

584
#ifndef PADDLE_TYPE_DOUBLE
Z
zhangjinchao01 已提交
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893
    cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
#else
    cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
#endif
    CHECK_CUDNN(dynload::cudnnSetFilter4dDescriptor(
             hl_filter->desc,
             data_type,
#if CUDNN_VERSION >= 5000
             CUDNN_TENSOR_NCHW,
#endif
             output_feature_maps,
             input_feature_maps,
             height,
             width));

    hl_filter->data_type = data_type;
    hl_filter->output_feature_maps = output_feature_maps;
    hl_filter->input_feature_maps = input_feature_maps;
    hl_filter->filter_height = height;
    hl_filter->filter_width = width;

    *filter = (hl_filter_descriptor)hl_filter;
}


void hl_destroy_filter_descriptor(hl_filter_descriptor filter)
{
    CHECK_NOTNULL(filter);

    cudnn_filter_descriptor hl_filter = (cudnn_filter_descriptor)filter;
    CHECK_NOTNULL(hl_filter->desc);

    CHECK_CUDNN(dynload::cudnnDestroyFilterDescriptor(hl_filter->desc));

    hl_filter->desc = NULL;

    free(filter);
}

void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
                                      hl_tensor_descriptor image,
                                      hl_filter_descriptor filter,
                                      int padding_height,
                                      int padding_width,
                                      int stride_height,
                                      int stride_width)
{
    CHECK_NOTNULL(conv);

    cudnn_convolution_descriptor hl_conv =
        (cudnn_convolution_descriptor)malloc(sizeof(_cudnn_convolution_descriptor));
    CHECK_NOTNULL(hl_conv);

    CHECK_CUDNN(dynload::cudnnCreateConvolutionDescriptor(&hl_conv->desc));

    cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
    CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(
                hl_conv->desc,
                padding_height,
                padding_width,
                stride_height,
                stride_width,
                1,
                1,
                mode));

    hl_conv->input_image = image;
    hl_conv->filter = filter;
    hl_conv->padding_height = padding_height;
    hl_conv->padding_width = padding_width;
    hl_conv->stride_height = stride_height;
    hl_conv->stride_width = stride_width;
    hl_conv->upscalex = 1;
    hl_conv->upscaley = 1;
    hl_conv->mode = mode;

    *conv = (hl_convolution_descriptor)hl_conv;
}

void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
                                     hl_tensor_descriptor image,
                                     hl_filter_descriptor filter,
                                     int padding_height,
                                     int padding_width,
                                     int stride_height,
                                     int stride_width)
{
    CHECK_NOTNULL(conv);
    CHECK_NOTNULL(image);
    CHECK_NOTNULL(filter);

    cudnnConvolutionDescriptor_t  conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);
    cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION;
    CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(
                conv_desc,
                padding_height,
                padding_width,
                stride_height,
                stride_width,
                1,
                1,
                mode));

    cudnn_convolution_descriptor hl_conv = (cudnn_convolution_descriptor)conv;
    hl_conv->input_image = image;
    hl_conv->filter = filter;
    hl_conv->padding_height = padding_height;
    hl_conv->padding_width = padding_width;
    hl_conv->stride_height = stride_height;
    hl_conv->stride_width = stride_width;
    hl_conv->upscalex = 1;
    hl_conv->upscaley = 1;
    hl_conv->mode = mode;
}

void hl_destroy_convolution_descriptor(hl_convolution_descriptor conv)
{
    CHECK_NOTNULL(conv);

    cudnn_convolution_descriptor hl_conv = (cudnn_convolution_descriptor)conv;
    CHECK_NOTNULL(hl_conv->desc);

    CHECK_CUDNN(dynload::cudnnDestroyConvolutionDescriptor(hl_conv->desc));
    hl_conv->desc = NULL;

    free(conv);
}

void hl_convolution_forward(hl_tensor_descriptor input,
                            real* input_data,
                            hl_tensor_descriptor output,
                            real* output_data,
                            hl_filter_descriptor filter,
                            real* filter_data,
                            hl_convolution_descriptor conv,
                            void* gpuWorkSpace,
                            size_t sizeInBytes,
                            int convFwdAlgo) {
    CHECK_NOTNULL(input);
    CHECK_NOTNULL(output);
    CHECK_NOTNULL(filter);
    CHECK_NOTNULL(conv);
    CHECK_NOTNULL(input_data);
    CHECK_NOTNULL(output_data);
    CHECK_NOTNULL(filter_data);
    cudnnTensorDescriptor_t       src_desc = GET_TENSOR_DESCRIPTOR(input);
    cudnnTensorDescriptor_t       dest_desc = GET_TENSOR_DESCRIPTOR(output);
    cudnnFilterDescriptor_t       filter_desc = GET_FILTER_DESCRIPTOR(filter);
    cudnnConvolutionDescriptor_t  conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);
    real alpha = 1.0f;
    real beta = 1.0f;
    CHECK_CUDNN(dynload::cudnnConvolutionForward(
                t_resource.cudnn_handle,
                &alpha,
                src_desc,
                input_data,
                filter_desc,
                filter_data,
                conv_desc,
                static_cast<cudnnConvolutionFwdAlgo_t>(convFwdAlgo),
                gpuWorkSpace,
                sizeInBytes,
                &beta,
                dest_desc,
                output_data));
  CHECK_SYNC("hl_convolution_forward failed");
}

void hl_convolution_forward_add_bias(hl_tensor_descriptor bias,
                                     real* bias_data,
                                     hl_tensor_descriptor output,
                                     real* output_data)
{
    CHECK_NOTNULL(bias);
    CHECK_NOTNULL(output);
    CHECK_NOTNULL(bias_data);
    CHECK_NOTNULL(output_data);

    cudnnTensorDescriptor_t output_desc = GET_TENSOR_DESCRIPTOR(output);
    cudnnTensorDescriptor_t bias_desc = GET_TENSOR_DESCRIPTOR(bias);
    real alpha = 1.0f;
    real beta = 1.0f;

    CHECK_CUDNN(dynload::cudnnAddTensor(
                t_resource.cudnn_handle,
#if CUDNN_VERSION < 4000
                CUDNN_ADD_SAME_C,
#endif
                &alpha,
                bias_desc,
                bias_data,
                &beta,
                output_desc,
                output_data));
  CHECK_SYNC("hl_convolution_forward_add_bias failed");
}

void hl_convolution_backward_bias(hl_tensor_descriptor bias,
                                  real* bias_grad_data,
                                  hl_tensor_descriptor output,
                                  real* output_grad_data)
{
    CHECK_NOTNULL(bias);
    CHECK_NOTNULL(output);
    CHECK_NOTNULL(bias_grad_data);
    CHECK_NOTNULL(output_grad_data);

    real alpha = 1.0f;
    real beta = 1.0f;
    cudnnTensorDescriptor_t diff_desc = GET_TENSOR_DESCRIPTOR(output);
    cudnnTensorDescriptor_t bias_desc = GET_TENSOR_DESCRIPTOR(bias);
    CHECK_CUDNN(dynload::cudnnConvolutionBackwardBias(
                t_resource.cudnn_handle,
                &alpha,
                diff_desc,
                output_grad_data,
                &beta,
                bias_desc,
                bias_grad_data));
  CHECK_SYNC("hl_convolution_backward_bias failed");
}

void hl_convolution_backward_filter(hl_tensor_descriptor input,
                                    real* input_data,
                                    hl_tensor_descriptor output,
                                    real* output_grad_data,
                                    hl_filter_descriptor filter,
                                    real* filter_grad_data,
                                    hl_convolution_descriptor conv,
                                    void* gpuWorkSpace,
                                    size_t sizeInBytes,
                                    int convBwdFilterAlgo) {

    CHECK_NOTNULL(input);
    CHECK_NOTNULL(output);
    CHECK_NOTNULL(filter);
    CHECK_NOTNULL(conv);
    CHECK_NOTNULL(input_data);
    CHECK_NOTNULL(output_grad_data);
    CHECK_NOTNULL(filter_grad_data);

    real alpha = 1.0f;
    real beta = 1.0f;
    cudnnTensorDescriptor_t       src_desc = GET_TENSOR_DESCRIPTOR(input);
    cudnnTensorDescriptor_t       diff_desc = GET_TENSOR_DESCRIPTOR(output);
    cudnnConvolutionDescriptor_t  conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);
    cudnnFilterDescriptor_t       grad_desc = GET_FILTER_DESCRIPTOR(filter);

    CHECK_CUDNN(dynload::cudnnConvolutionBackwardFilter(
                t_resource.cudnn_handle,
                &alpha,
                src_desc,
                input_data,
                diff_desc,
                output_grad_data,
                conv_desc,
#if CUDNN_VERSION >= 4000
                static_cast<cudnnConvolutionBwdFilterAlgo_t>(convBwdFilterAlgo),
                gpuWorkSpace,
                sizeInBytes,
#endif
                &beta,
                grad_desc,
                filter_grad_data));
  CHECK_SYNC("hl_convolution_backward_filter failed");
}

void hl_convolution_backward_data(hl_tensor_descriptor input,
                                  real* input_data_grad,
                                  hl_tensor_descriptor output,
                                  real* output_grad_data,
                                  hl_filter_descriptor filter,
                                  real* filter_data,
                                  hl_convolution_descriptor conv,
                                  void* gpuWorkSpace,
                                  size_t sizeInBytes,
                                  int convBwdDataAlgo) {
    real alpha = 1.0f;
    real beta = 1.0f;
    cudnnFilterDescriptor_t       filter_desc = GET_FILTER_DESCRIPTOR(filter);
    cudnnTensorDescriptor_t       diff_desc = GET_TENSOR_DESCRIPTOR(output);
    cudnnTensorDescriptor_t       grad_desc = GET_TENSOR_DESCRIPTOR(input);
    cudnnConvolutionDescriptor_t  conv_desc = GET_CONVOLUTION_DESCRIPTOR(conv);

    CHECK_CUDNN(dynload::cudnnConvolutionBackwardData(
                t_resource.cudnn_handle,
                &alpha,
                filter_desc,
                filter_data,
                diff_desc,
                output_grad_data,
                conv_desc,
#if CUDNN_VERSION >= 4000
                static_cast<cudnnConvolutionBwdDataAlgo_t>(convBwdDataAlgo),
                gpuWorkSpace,
                sizeInBytes,
#endif
                &beta,
                grad_desc,
                input_data_grad));
  CHECK_SYNC("hl_convolution_backward_data failed");
}


void hl_softmax_forward(real *input,
                        real *output,
                        int height,
                        int width)
{
894
#ifndef PADDLE_TYPE_DOUBLE
Z
zhangjinchao01 已提交
895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927
    cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
#else
    cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
#endif
    CHECK_CUDNN(dynload::cudnnSetTensor4dDescriptor(
                t_resource.cudnn_desc,
                CUDNN_TENSOR_NCHW,
                data_type,
                height,
                width,
                1,
                1));

    real alpha = 1.0f;
    real beta = 0.0f;
    CHECK_CUDNN(dynload::cudnnSoftmaxForward(
                t_resource.cudnn_handle,
                CUDNN_SOFTMAX_ACCURATE,
                CUDNN_SOFTMAX_MODE_CHANNEL,
                &alpha,
                t_resource.cudnn_desc,
                input,
                &beta,
                t_resource.cudnn_desc,
                output));
  CHECK_SYNC("hl_softmax_forward failed");
}

void hl_softmax_backward(real *output_value,
                         real *output_grad,
                         int height,
                         int width)
{
928
#ifndef PADDLE_TYPE_DOUBLE
Z
zhangjinchao01 已提交
929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971
    cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
#else
    cudnnDataType_t data_type = CUDNN_DATA_DOUBLE;
#endif
    CHECK_CUDNN(dynload::cudnnSetTensor4dDescriptor(
                t_resource.cudnn_desc,
                CUDNN_TENSOR_NCHW,
                data_type,
                height,
                width,
                1,
                1));

    real alpha = 1.0f;
    real beta = 0.0f;
    CHECK_CUDNN(dynload::cudnnSoftmaxBackward(
                t_resource.cudnn_handle,
                CUDNN_SOFTMAX_ACCURATE,
                CUDNN_SOFTMAX_MODE_CHANNEL,
                &alpha,
                t_resource.cudnn_desc,
                output_value,
                t_resource.cudnn_desc,
                output_grad,
                &beta,
                t_resource.cudnn_desc,
                output_grad));
  CHECK_SYNC("hl_softmax_backward failed");
}

void hl_batch_norm_forward_training(hl_tensor_descriptor inputDesc,
                                    real *input,
                                    hl_tensor_descriptor outputDesc,
                                    real *output,
                                    hl_tensor_descriptor bnParamDesc,
                                    real *scale,
                                    real *bias,
                                    double factor,
                                    real *runningMean,
                                    real *runningInvVar,
                                    double epsilon,
                                    real *savedMean,
                                    real *savedVar) {
972
#if CUDNN_VERSION >= 4007
Z
zhangjinchao01 已提交
973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996
  if ((NULL != runningMean && NULL == runningInvVar) ||
      (NULL == runningMean && NULL != runningInvVar)) {
    LOG(FATAL) << "runningMean and runningInvVar can be NULL "
              << "but only at the same time.";
  }
  if ((NULL != savedMean && NULL == savedVar) ||
      (NULL == savedMean && NULL != savedVar)) {
    LOG(FATAL) << "savedMean and savedVar can be NULL "
               << "but only at the same time.";
  }

  cudnnTensorDescriptor_t xDesc = GET_TENSOR_DESCRIPTOR(inputDesc);
  cudnnTensorDescriptor_t yDesc = GET_TENSOR_DESCRIPTOR(outputDesc);
  cudnnTensorDescriptor_t bnDesc = GET_TENSOR_DESCRIPTOR(bnParamDesc);
  real alpha = 1.0f;
  real beta = 1.0f;
  cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
  CHECK_CUDNN(dynload::cudnnBatchNormalizationForwardTraining(
              t_resource.cudnn_handle, mode, &alpha, &beta, xDesc,
              input, yDesc, output, bnDesc, scale, bias, factor,
              runningMean, runningInvVar, epsilon, savedMean, savedVar));

  CHECK_SYNC("hl_batch_norm_forward_training failed");
#else
997
  LOG(FATAL) << "CudnnBatchNorm requires cudnn version >= 4007. "
Z
zhangjinchao01 已提交
998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011
             << "But cudnn lib version is " << g_cudnn_lib_version;
#endif
}

void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc,
                                    real *input,
                                    hl_tensor_descriptor outputDesc,
                                    real *output,
                                    hl_tensor_descriptor bnParamDesc,
                                    real *scale,
                                    real *bias,
                                    real *estimatedMean,
                                    real *estimatedInvVar,
                                    double epsilon) {
1012
#if CUDNN_VERSION >= 4007
Z
zhangjinchao01 已提交
1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025
  cudnnTensorDescriptor_t xDesc = GET_TENSOR_DESCRIPTOR(inputDesc);
  cudnnTensorDescriptor_t yDesc = GET_TENSOR_DESCRIPTOR(outputDesc);
  cudnnTensorDescriptor_t bnDesc = GET_TENSOR_DESCRIPTOR(bnParamDesc);
  real alpha = 1.0f;
  real beta = 1.0f;
  cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
  CHECK_CUDNN(dynload::cudnnBatchNormalizationForwardInference(
              t_resource.cudnn_handle, mode, &alpha, &beta, xDesc,
              input, yDesc, output, bnDesc, scale, bias,
              estimatedMean, estimatedInvVar, epsilon));

  CHECK_SYNC("hl_batch_norm_forward_inference failed");
#else
1026
  LOG(FATAL) << "CudnnBatchNorm requires cudnn version >= 4007. "
Z
zhangjinchao01 已提交
1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043
             << "But cudnn lib version is " << g_cudnn_lib_version;
#endif
}

void hl_batch_norm_backward(hl_tensor_descriptor inputDesc,
                            real *input,
                            hl_tensor_descriptor outGradDesc,
                            real *outGrad,
                            hl_tensor_descriptor inGradDesc,
                            real *inGrad,
                            hl_tensor_descriptor dBnParamDesc,
                            real *scale,
                            real *scaleGrad,
                            real *biasGrad,
                            double epsilon,
                            real *savedMean,
                            real *savedInvVar) {
1044
#if CUDNN_VERSION >= 4007
Z
zhangjinchao01 已提交
1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066
  if ((NULL != savedMean && NULL == savedInvVar) ||
      (NULL == savedMean && NULL != savedInvVar)) {
    LOG(FATAL) << "savedMean and savedVar can be NULL "
               << "but only at the same time.";
  }

  cudnnTensorDescriptor_t xDesc = GET_TENSOR_DESCRIPTOR(inputDesc);
  cudnnTensorDescriptor_t dyDesc = GET_TENSOR_DESCRIPTOR(outGradDesc);
  cudnnTensorDescriptor_t dxDesc = GET_TENSOR_DESCRIPTOR(inGradDesc);
  cudnnTensorDescriptor_t bnDesc = GET_TENSOR_DESCRIPTOR(dBnParamDesc);
  real alpha = 1.0f;
  real beta = 1.0f;
  cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
  CHECK_CUDNN(dynload::cudnnBatchNormalizationBackward(
              t_resource.cudnn_handle, mode, &alpha, &beta,
              &alpha, &beta,
              xDesc, input, dyDesc, outGrad, dxDesc, inGrad,
              bnDesc, scale, scaleGrad, biasGrad, epsilon,
              savedMean, savedInvVar));

  CHECK_SYNC("hl_batch_norm_backward failed");
#else
1067
  LOG(FATAL) << "CudnnBatchNorm requires cudnn version >= 4007. "
Z
zhangjinchao01 已提交
1068 1069 1070
             << "But cudnn lib version is " << g_cudnn_lib_version;
#endif
}