hl_cuda_cudnn.h 21.3 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef HL_CUDA_CUDNN_H_
#define HL_CUDA_CUDNN_H_

#include "hl_base.h"

/*
 *  hppl pooling mode
 */
typedef enum {
24
  HL_POOLING_MAX = 0,
Z
zhangjinchao01 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
  // average includes padded values
  HL_POOLING_AVERAGE = 1,
  // average does not include padded values
  HL_POOLING_AVERAGE_EXCLUDE_PADDING = 2,
  HL_POOLING_END
} hl_pooling_mode_t;

/**
 * @brief return cudnn lib version
 */

extern int hl_get_cudnn_lib_version();

/**
 * @brief   hppl image descriptor.
 */
typedef struct _hl_tensor_descriptor* hl_tensor_descriptor;

/**
 * @brief   hppl pooling descriptor.
 */
typedef struct _hl_pooling_descriptor* hl_pooling_descriptor;

/**
 * @brief   hppl filter descriptor.
 */
typedef struct _hl_filter_descriptor* hl_filter_descriptor;

/**
 * @brief   hppl filter descriptor.
 */
typedef struct _hl_convolution_descriptor* hl_convolution_descriptor;

/**
 * @brief   create image descriptor.
 *
 * @param[out]   image_desc     image descriptor.
 *
 */
extern void hl_create_tensor_descriptor(hl_tensor_descriptor* image_desc);

/**
 * @brief   reshape image descriptor.
 *
 * @param[in,out]   image_desc    image descriptor.
 * @param[in]       batch_size    input batch size.
 * @param[in]       feature_maps  image feature maps.
 * @param[in]       height        image height.
 * @param[in]       width         image width.
 */
extern void hl_tensor_reshape(hl_tensor_descriptor image_desc,
                              int batch_size,
                              int feature_maps,
                              int height,
                              int width);

/**
 * @brief   reshape image descriptor.
 *
 * @param[in,out]   image_desc    image descriptor.
 * @param[in]       batch_size    input batch size.
 * @param[in]       feature_maps  image feature maps.
 * @param[in]       height        image height.
 * @param[in]       width         image width.
 * @param[in]       nStride       stride between two consecutive images.
 * @param[in]       cStride       stride between two consecutive feature maps.
 * @param[in]       hStride       stride between two consecutive rows.
 * @param[in]       wStride       stride between two consecutive columns.
 *
 */
extern void hl_tensor_reshape(hl_tensor_descriptor image_desc,
                              int batch_size,
                              int feature_maps,
                              int height,
                              int width,
                              int nStride,
                              int cStride,
                              int hStride,
                              int wStride);

/**
 * @brief   destroy image descriptor.
 *
 * @param[in]   image_desc  hppl image descriptor.
 */
extern void hl_destroy_tensor_descriptor(hl_tensor_descriptor image_desc);

/**
 * @brief   create pooling descriptor.
 *
 * @param[out]  pooling_desc    pooling descriptor.
 * @param[in]   mode            pooling mode.
 * @param[in]   height          height of the pooling window.
 * @param[in]   width           width of the pooling window.
 * @param[in]   height_padding  padding height.
 * @param[in]   width_padding   padding width.
 * @param[in]   stride_height   pooling vertical stride.
 * @param[in]   stride_width    pooling horizontal stride.
 */
extern void hl_create_pooling_descriptor(hl_pooling_descriptor* pooling_desc,
                                         hl_pooling_mode_t mode,
                                         int height,
                                         int width,
                                         int height_padding,
                                         int width_padding,
                                         int stride_height,
                                         int stride_width);

/**
 * @brief   destroy pooling descriptor.
 *
 * @param[in]   pooling_desc  hppl pooling descriptor.
 *
 */
extern void hl_destroy_pooling_descriptor(hl_pooling_descriptor pooling_desc);

/**
 * @brief   pooling forward(calculate output image).
 *
 * @param[in]   input           input image descriptor.
 * @param[in]   input_image     input image data.
 * @param[in]   output          output image descriptor.
 * @param[out]  output_image    output image data.
 * @param[in]   pooling         pooling descriptor.
 *
 */
extern void hl_pooling_forward(hl_tensor_descriptor input,
                               real* input_image,
                               hl_tensor_descriptor output,
                               real* output_image,
                               hl_pooling_descriptor pooling);

/**
 * @brief   pooling backward(calculate input image gradient).
 *
 * @param[in]   input               input image descriptor.
 * @param[in]   input_image         input image data.
 * @param[in]   input_image_grad    input image gradient data.
 * @param[in]   output              output image descriptor.
 * @param[in]   output_image        output image data.
 * @param[out]  output_image_grad   output image gradient data.
 * @param[in]   pooling             pooling descriptor.
 *
 */
extern void hl_pooling_backward(hl_tensor_descriptor input,
                                real* input_image,
                                real* input_image_grad,
                                hl_tensor_descriptor output,
                                real* output_image,
                                real* output_image_grad,
                                hl_pooling_descriptor pooling);

/**
 * @brief   create filter descriptor.
 *
 * @param[out]  filter                  filter descriptor.
 * @param[in]   input_feature_maps      input image feature maps.
 * @param[in]   output_feature_maps     output image feature maps.
 * @param[in]   height                  filter height.
 * @param[in]   width                   filter width.
 *
 */
extern void hl_create_filter_descriptor(hl_filter_descriptor* filter,
                                        int input_feature_maps,
                                        int output_feature_maps,
                                        int height,
                                        int width);

/**
 * @brief    convolution workspace configuration
 *
 * @param[in]    input                image descriptor
 * @param[in]    output               image descriptor
 * @param[in]    filter               filter descriptor
 * @param[in]    conv                 convolution descriptor
 * @param[out]   convFwdAlgo          forward algorithm
 * @param[out]   fwdLimitBytes        forward workspace size
 * @param[out]   convBwdDataAlgo      backward data algorithm
 * @param[out]   bwdDataLimitBytes    backward data workspace size
 * @param[out]   convBwdFilterAlgo    backward filter algorithm
 * @param[out]   bwdFilterLimitBytes  backward filter workspace size
 *
 */
extern void hl_conv_workspace(hl_tensor_descriptor input,
                              hl_tensor_descriptor output,
                              hl_filter_descriptor filter,
                              hl_convolution_descriptor conv,
                              int* convFwdAlgo,
                              size_t* fwdLimitBytes,
                              int* convBwdDataAlgo,
                              size_t* bwdDataLimitBytes,
                              int* convBwdFilterAlgo,
                              size_t* bwdFilterLimitBytes);

/**
 * @brief   destroy filter descriptor.
 *
 * @param[in]   filter  hppl filter descriptor.
 *
 */
extern void hl_destroy_filter_descriptor(hl_filter_descriptor filter);

/**
 * @brief   create convolution descriptor.
 *
 * @param[out]  conv                    conv descriptor.
 * @param[in]   image                   input image descriptor.
 * @param[in]   filter                  filter descriptor.
 * @param[in]   padding_height          padding height.
 * @param[in]   padding_width           padding width.
 * @param[in]   stride_height           stride height.
 * @param[in]   stride_width            stride width.
 *
 */
extern void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
                                             hl_tensor_descriptor image,
                                             hl_filter_descriptor filter,
                                             int padding_height,
                                             int padding_width,
                                             int stride_height,
                                             int stride_width);

/**
 * @brief   reset convolution descriptor.
 *
 * @param[in,out]   conv                conv descriptor.
 * @param[in]       image               input image descriptor.
 * @param[in]       filter              filter descriptor.
 * @param[in]       padding_height      padding height.
 * @param[in]       padding_width       padding width.
 * @param[in]       stride_height       stride height.
 * @param[in]       stride_width        stride width.
 *
 */
extern void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
                                            hl_tensor_descriptor image,
                                            hl_filter_descriptor filter,
                                            int padding_height,
                                            int padding_width,
                                            int stride_height,
                                            int stride_width);

/**
 * @brief   destroy convolution descriptor.
 *
 * @param[in]   conv  hppl convolution descriptor.
 */
extern void hl_destroy_convolution_descriptor(hl_convolution_descriptor conv);

/**
 * @brief   convolution forward(calculate output image).
 *
 * @param[in]   input           input image descriptor.
 * @param[in]   input_data      input image data.
 * @param[in]   output          output image descriptor.
 * @param[out]  output_data     output image data.
 * @param[in]   filter          filter descriptor.
 * @param[in]   filter_data     filter data.
 * @param[in]   conv            convolution descriptor.
 * @param[in]   gpuWorkSpace    limited gpu workspace.
 * @param[in]   sizeInBytes     gpu workspace size (bytes).
 * @param[in]   convFwdAlgo     forward algorithm.
 */
extern void hl_convolution_forward(hl_tensor_descriptor input,
                                   real* input_data,
                                   hl_tensor_descriptor output,
                                   real* output_data,
                                   hl_filter_descriptor filter,
                                   real* filter_data,
                                   hl_convolution_descriptor conv,
                                   void* gpuWorkSpace,
                                   size_t sizeInBytes,
                                   int convFwdAlgo);

/**
 * @brief   convolution forward add bias(calculate output add bias).
 *
 * @param[in]   bias                bias descriptor.
 * @param[in]   bias_data           bias data.
 * @param[in]   output              output image descriptor.
 * @param[out]  output_data         output image data.
 */
extern void hl_convolution_forward_add_bias(hl_tensor_descriptor bias,
                                            real* bias_data,
                                            hl_tensor_descriptor output,
                                            real* output_data);

/**
 * @brief   convolution backward filter(calculate filter grad data).
 *
 * @param[in]   input               input image descriptor.
 * @param[in]   input_data          input image data.
 * @param[in]   output              output image descriptor.
 * @param[in]   output_grad_data    output image grad data.
 * @param[in]   filter              filter descriptor.
 * @param[out]  filter_grad_data    filter grad data.
 * @param[in]   conv                convolution descriptor.
 * @param[in]   gpuWorkSpace        limited gpu workspace.
 * @param[in]   sizeInBytes         gpu workspace size (bytes).
 * @param[in]   convBwdFilterAlgo   backward filter algorithm.
 */
326 327 328 329 330 331 332 333 334 335
extern void hl_convolution_backward_filter(hl_tensor_descriptor input,
                                           real* input_data,
                                           hl_tensor_descriptor output,
                                           real* output_grad_data,
                                           hl_filter_descriptor filter,
                                           real* filter_grad_data,
                                           hl_convolution_descriptor conv,
                                           void* gpuWorkSpace,
                                           size_t sizeInBytes,
                                           int convBwdFilterAlgo);
Z
zhangjinchao01 已提交
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350

/**
 * @brief   convolution backward data(calculate input image grad data).
 *
 * @param[in]   input               input image descriptor.
 * @param[out]  input_data_grad     input image grad data.
 * @param[in]   output              output image descriptor.
 * @param[in]   output_grad_data    output image grad data.
 * @param[in]   filter              filter descriptor.
 * @param[in]   filter_data         filter data.
 * @param[in]   conv                convolution descriptor.
 * @param[in]   gpuWorkSpace        limited gpu workspace.
 * @param[in]   sizeInBytes         gpu workspace size (bytes).
 * @param[in]   convBwdDataAlgo     backward data algorithm.
 */
351 352 353 354 355 356 357 358 359 360
extern void hl_convolution_backward_data(hl_tensor_descriptor input,
                                         real* input_data_grad,
                                         hl_tensor_descriptor output,
                                         real* output_grad_data,
                                         hl_filter_descriptor filter,
                                         real* filter_data,
                                         hl_convolution_descriptor conv,
                                         void* gpuWorkSpace,
                                         size_t sizeInBytes,
                                         int convBwdDataAlgo);
Z
zhangjinchao01 已提交
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382

/**
 * @brief   convolution backward bias(calculate bias grad data).
 *
 * @param[in]   bias                bias descriptor.
 * @param[out]  bias_grad_data      bias grad data.
 * @param[in]   output              output image descriptor.
 * @param[in]   output_grad_data    output image grad data.
 */
extern void hl_convolution_backward_bias(hl_tensor_descriptor bias,
                                         real* bias_grad_data,
                                         hl_tensor_descriptor output,
                                         real* output_grad_data);

/**
 * @brief   softmax forward.
 *
 * @param[in]   input               input value.
 * @param[out]  output              output value.
 * @param[in]   height              matrix height.
 * @param[in]   width               matrix width.
 */
383 384
extern void hl_softmax_forward(real* input,
                               real* output,
Z
zhangjinchao01 已提交
385 386 387 388 389 390 391 392 393 394 395
                               int height,
                               int width);

/**
 * @brief   softmax backward.
 *
 * @param[in]   output_value        output value data.
 * @param[out]  output_grad         output grad data.
 * @param[in]   height              matrix height.
 * @param[in]   width               matrix width.
 */
396 397
extern void hl_softmax_backward(real* output_value,
                                real* output_grad,
Z
zhangjinchao01 已提交
398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425
                                int height,
                                int width);

/**
 * @brief   cudnn batch norm forward.
 *
 * @param[in]   inputDesc     input tensor descriptor desc.
 * @param[in]   input         input data.
 * @param[in]   outputDesc    output tensor descriptor desc.
 * @param[out]  output        output data.
 * @param[in]   bnParamDesc   tensor descriptor desc.
 *                            bnScale, bnBias, running mean/var, save_mean/var.
 * @param[in]   scale         batch normalization scale parameter (in original
 *                            paper scale is referred to as gamma).
 * @param[in]   bias          batch normalization bias parameter (in original
 *                            paper scale is referred to as beta).
 * @param[in]   factor        Factor used in the moving average computation.
 *                            runningMean = newMean * factor
 *                                         + runningMean * (1 - factor)
 * @param[in]   runningMean   running mean.
 * @param[in]   runningInvVar running variance.
 * @param[in]   epsilon       Epsilon value used in the batch normalization
 *                            formula.
 * @param[out]  savedMean     optional cache to save intermediate results.
 * @param[out]  savedVar      optional cache to save intermediate results.
 *
 */
extern void hl_batch_norm_forward_training(hl_tensor_descriptor inputDesc,
426
                                           real* input,
Z
zhangjinchao01 已提交
427
                                           hl_tensor_descriptor outputDesc,
428
                                           real* output,
Z
zhangjinchao01 已提交
429
                                           hl_tensor_descriptor bnParamDesc,
430 431
                                           real* scale,
                                           real* bias,
Z
zhangjinchao01 已提交
432
                                           double factor,
433 434
                                           real* runningMean,
                                           real* runningInvVar,
Z
zhangjinchao01 已提交
435
                                           double epsilon,
436 437
                                           real* savedMean,
                                           real* savedVar);
Z
zhangjinchao01 已提交
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462

/**
 * @brief   cudnn batch norm forward.
 *
 * @param[in]   inputDesc    input tensor descriptor desc.
 * @param[in]   input        input data.
 * @param[in]   outputDesc   output tensor descriptor desc.
 * @param[out]  output       output data.
 * @param[in]   bnParamDesc  tensor descriptor desc.
 *                           bnScale, bnBias, running mean/var, save_mean/var.
 * @param[in]   scale        batch normalization scale parameter (in original
 *                           paper scale is referred to as gamma).
 * @param[in]   bias         batch normalization bias parameter (in original
 *                           paper scale is referred to as beta).
 * @param[in]   estimatedMean
 * @param[in]   estimatedVar It is suggested that resultRunningMean,
 *                           resultRunningVariance from the
 *                           cudnnBatchNormalizationForwardTraining call
 *                           accumulated during the training phase are passed
 *                           as inputs here.
 * @param[in]   epsilon      Epsilon value used in the batch
 *                           normalization formula.
 *
 */
extern void hl_batch_norm_forward_inference(hl_tensor_descriptor inputDesc,
463
                                            real* input,
Z
zhangjinchao01 已提交
464
                                            hl_tensor_descriptor outputDesc,
465
                                            real* output,
Z
zhangjinchao01 已提交
466
                                            hl_tensor_descriptor bnParamDesc,
467 468 469 470
                                            real* scale,
                                            real* bias,
                                            real* estimatedMean,
                                            real* estimatedVar,
Z
zhangjinchao01 已提交
471 472 473 474 475 476 477 478 479 480 481 482
                                            double epsilon);

/**
 * @brief   cudnn batch norm forward.
 *
 * @param[in]   inputDesc       input tensor descriptor desc.
 * @param[in]   input           input data.
 * @param[in]   outGradDesc     output tensor descriptor desc.
 * @param[out]  outGrad         output data.
 * @param[in]   inGradDesc      input tensor descriptor desc.
 * @param[in]   inGrad          input data.
 * @param[in]   dBnParamDesc    tensor descriptor desc.
483 484
 *                              bnScale, bnBias, running mean/var,
 * save_mean/var.
Z
zhangjinchao01 已提交
485 486 487 488 489 490 491 492 493 494 495 496 497
 * @param[in]   scale           batch normalization scale parameter (in original
 *                              paper scale is referred to as gamma).
 * @param[in]   scaleGrad       batch normalization scale parameter (in original
 *                              paper scale is referred to as gamma) gradient.
 * @param[in]   biasGrad        batch normalization bias parameter (in original
 *                              paper scale is referred to as beta) gradient.
 * @param[in]   epsilon         Epsilon value used in the batch
 *                              normalization formula.
 * @param[out]  savedMean       optional cache to save intermediate results.
 * @param[out]  savedInvVar     optional cache to save intermediate results.
 *
 */
extern void hl_batch_norm_backward(hl_tensor_descriptor inputDesc,
498
                                   real* input,
Z
zhangjinchao01 已提交
499
                                   hl_tensor_descriptor outGradDesc,
500
                                   real* outGrad,
Z
zhangjinchao01 已提交
501
                                   hl_tensor_descriptor inGradDesc,
502
                                   real* inGrad,
Z
zhangjinchao01 已提交
503
                                   hl_tensor_descriptor dBnParamDesc,
504 505 506
                                   real* scale,
                                   real* scaleGrad,
                                   real* biasGrad,
Z
zhangjinchao01 已提交
507
                                   double epsilon,
508 509
                                   real* savedMean,
                                   real* savedInvVar);
Z
zhangjinchao01 已提交
510 511

#endif  // HL_CUDA_CUDNN_H_