cudnn_conv.cc 23.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/backends/cuda/math/cudnn_conv.h"
#include "lite/backends/cuda/math/activation.h"
17
#include "lite/backends/cuda/math/conv_op_cache_cudnn.h"
W
Wilber 已提交
18
#include "lite/backends/cuda/math/cudnn_helper.h"
19 20 21 22 23 24 25 26
#include "lite/backends/cuda/math/scale.h"
#include "lite/backends/cuda/math/type_trans.h"

namespace paddle {
namespace lite {
namespace cuda {
namespace math {

27 28 29
template <typename T, PrecisionType Ptype_out>
bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param,
                                       Context<TARGET(kCUDA)>* ctx) {
30 31 32 33 34
  auto x_dims = param.x->dims();
  auto w_dims = param.filter->dims();
  auto o_dims = param.output->dims();
  int batch = x_dims[0];

H
HappyAngel 已提交
35 36 37
  auto paddings = *param.paddings;
  auto dilations = *param.dilations;

38 39 40 41 42 43 44 45 46 47
  int iw = x_dims[3];  // nchw
  int ih = x_dims[2];
  int ic = x_dims[1];
  int ow = o_dims[3];
  int oh = o_dims[2];
  int oc = o_dims[1];
  int kw = w_dims[3];
  int kh = w_dims[2];
  int sw = param.strides[1];
  int sh = param.strides[0];
H
HappyAngel 已提交
48 49 50 51
  int pw = paddings[2];
  int ph = paddings[0];
  int dw = dilations[1];
  int dh = dilations[0];
52 53 54 55 56 57

  CHECK(ic % param.groups == 0)
      << "The conv input channel shoud be divide group number.";

  CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_,
                                         CUDNN_TENSOR_NCHW,
W
Wilber 已提交
58
                                         cudnn::cudnnTypeWrapper<T>::type,
59 60 61 62 63
                                         batch,
                                         ic,
                                         ih,
                                         iw));
  CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_,
W
Wilber 已提交
64
                                         cudnn::cudnnTypeWrapper<T>::type,
65 66 67 68 69
                                         CUDNN_TENSOR_NCHW,
                                         oc,
                                         ic / param.groups,
                                         kh,
                                         kw));
W
Wilber 已提交
70 71 72 73 74 75 76 77 78 79
  CUDNN_CHECK(
      cudnnSetConvolution2dDescriptor(this->conv_desc_,
                                      ph,
                                      pw,
                                      sh,
                                      sw,
                                      dh,
                                      dw,
                                      CUDNN_CROSS_CORRELATION,
                                      cudnn::cudnnTypeWrapper<T>::type));
80 81 82
  CUDNN_CHECK(cudnnSetConvolutionGroupCount(this->conv_desc_, param.groups));
  CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_,
                                         CUDNN_TENSOR_NCHW,
W
Wilber 已提交
83
                                         cudnn::cudnnTypeWrapper<T>::type,
84 85 86 87 88
                                         batch,
                                         oc,
                                         oh,
                                         ow));

89
  if (param.activation_param.has_active && this->with_relu_act_) {
90 91 92 93
    CUDNN_CHECK(cudnnSetActivationDescriptor(
        this->act_desc_, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0));
  }

W
Wilber 已提交
94 95
#if CUDNN_VERSION_MIN(7, 0, 0)
  cudnnMathType_t math_type =
96
      this->use_tensor_core_ ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
W
Wilber 已提交
97 98 99
  CUDNN_CHECK(cudnnSetConvolutionMathType(this->conv_desc_, math_type));
#endif

100 101
  if (ic == param.groups && ic == oc && ic != 1) {
    this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
W
Wilber 已提交
102
  } else if (!param.var_length) {
103 104 105
    const auto* i_data = param.x->data<T>();
    const auto* w_data = param.filter->data<T>();
    auto* o_data = param.output->mutable_data<T>(TARGET(kCUDA));
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
    int workspace_size_limit = 256 * 1024 * 1024;

    auto search_func = [&]() {
      int returned_algo_count;
      std::array<cudnnConvolutionFwdAlgoPerf_t,
                 CUDNN_CONVOLUTION_FWD_ALGO_COUNT>
          fwd_perf_stat;
      auto cudnn_find_func = [&](void* cudnn_workspace) {
        CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx(
            this->handle_,
            this->input_desc_,
            i_data,
            this->filter_desc_,
            w_data,
            this->conv_desc_,
            this->output_desc_,
            o_data,
            CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
            &returned_algo_count,
            fwd_perf_stat.data(),
            cudnn_workspace,
            workspace_size_limit));
      };

130
      this->ResetWorkSpace();
131 132
      CUDA_CALL(cudaMalloc(&this->workspace_data_, workspace_size_limit));
      cudnn_find_func(this->workspace_data_);
133
      this->ResetWorkSpace();
134 135 136 137 138 139 140 141 142 143 144 145 146

      VLOG(2) << "Perf result: (algo: stat, time, memory)";
      for (int i = 0; i < returned_algo_count; ++i) {
        const auto& stat = fwd_perf_stat[i];
        VLOG(2) << stat.algo << ": " << stat.status << " " << stat.time << " "
                << stat.memory;
      }
      return fwd_perf_stat[0].algo;
    };
    AlgorithmsCache<cudnnConvolutionFwdAlgo_t> algo_cache;
    this->fwd_algo_ = algo_cache.GetAlgorithm(x_dims.Vectorize(),
                                              w_dims.Vectorize(),
                                              param.strides,
H
HappyAngel 已提交
147 148
                                              *param.paddings,
                                              *param.dilations,
149 150 151
                                              0,
                                              search_func);

152
  } else {
153 154 155 156 157 158 159 160 161 162 163
    int requestedAlgoCount = 1;
    int returnedAlgoCount;
    CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(this->handle_,
                                                       this->input_desc_,
                                                       this->filter_desc_,
                                                       this->conv_desc_,
                                                       this->output_desc_,
                                                       requestedAlgoCount,
                                                       &returnedAlgoCount,
                                                       &this->algo_perf_));
    this->fwd_algo_ = this->algo_perf_.algo;
164 165 166 167 168 169 170 171 172 173 174
  }
  CUDNN_CHECK(
      cudnnGetConvolutionForwardWorkspaceSize(this->handle_,
                                              this->input_desc_,
                                              this->filter_desc_,
                                              this->conv_desc_,
                                              this->output_desc_,
                                              this->fwd_algo_,
                                              &this->workspace_fwd_sizes_));
  if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) {
    this->workspace_size_inbytes_ = this->workspace_fwd_sizes_;
175
    this->ResetWorkSpace();
176 177 178 179 180 181
    cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_);
    this->workspace_ = reinterpret_cast<char*>(this->workspace_data_);
  }
  if (param.bias) {
    int dim_bias[] = {1, oc, 1, 1};
    int stride_bias[] = {oc, 1, 1, 1};
W
Wilber 已提交
182
    cudnnSetTensorNdDescriptor(this->bias_desc_,
W
Wilber 已提交
183
                               cudnn::cudnnTypeWrapper<T>::type,
W
Wilber 已提交
184 185 186
                               4,
                               dim_bias,
                               stride_bias);
187 188 189 190
  }
  return true;
}

191 192 193
template <typename T, PrecisionType Ptype_out>
bool CudnnConv2D<T, Ptype_out>::init(const operators::ConvParam& param,
                                     Context<TARGET(kCUDA)>* ctx) {
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
  this->workspace_size_inbytes_ = 0;
  this->workspace_data_ = NULL;
  this->workspace_fwd_sizes_ = 0;

  this->stream_ = ctx->exec_stream();
  CUDNN_CHECK(cudnnCreate(&this->handle_));
  CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_));

  this->workspace_ = NULL;

  cudnnCreateTensorDescriptor(&this->input_desc_);
  cudnnCreateTensorDescriptor(&this->output_desc_);
  cudnnCreateFilterDescriptor(&this->filter_desc_);
  cudnnCreateConvolutionDescriptor(&this->conv_desc_);
  cudnnCreateTensorDescriptor(&this->bias_desc_);

  if (param.activation_param.has_active) {
    if (param.activation_param.active_type == lite_api::ActivationType::kRelu) {
      cudnnCreateActivationDescriptor(&this->act_desc_);
    } else {
      this->with_relu_act_ = false;
    }
  }
  return create(param, ctx);
}

220 221 222 223 224 225
template <typename T, PrecisionType Ptype_out>
bool CudnnConv2D<T, Ptype_out>::run(const operators::ConvParam& param) {
  const auto* i_data = param.x->data<T>();
  const auto* w_data = param.filter->data<T>();
  const auto* b_data = param.bias ? param.bias->data<T>() : nullptr;
  auto* o_data = param.output->mutable_data<T>(TARGET(kCUDA));
226

227
  if (param.activation_param.has_active && this->with_relu_act_) {
228 229 230
    if (b_data) {
      float alpha = 1.0f;
      float beta = 0.0f;
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
      CUDNN_CHECK(
          cudnnConvolutionBiasActivationForward(this->handle_,
                                                &alpha,
                                                this->input_desc_,
                                                i_data,
                                                this->filter_desc_,
                                                w_data,
                                                this->conv_desc_,
                                                this->fwd_algo_,
                                                this->workspace_,
                                                this->workspace_fwd_sizes_,
                                                &beta,
                                                this->output_desc_,
                                                o_data,
                                                this->bias_desc_,
                                                b_data,
                                                this->act_desc_,
                                                this->output_desc_,
                                                o_data));
250 251 252
    } else {
      float alpha = 1.0f;
      float beta = 0.0f;
253
      CUDNN_CHECK(cudnnConvolutionForward(this->handle_,
254
                                          &alpha,
255
                                          this->input_desc_,
256
                                          i_data,
257
                                          this->filter_desc_,
258
                                          w_data,
259 260 261 262
                                          this->conv_desc_,
                                          this->fwd_algo_,
                                          this->workspace_,
                                          this->workspace_fwd_sizes_,
263
                                          &beta,
264
                                          this->output_desc_,
265 266
                                          o_data));

267 268
      CUDNN_CHECK(cudnnActivationForward(this->handle_,
                                         this->act_desc_,
269
                                         &alpha,
270
                                         this->output_desc_,
271 272
                                         o_data,
                                         &beta,
273
                                         this->output_desc_,
274 275 276 277 278
                                         o_data));
    }
  } else {
    float alpha = 1.0f;
    float beta = 0.0f;
279
    CUDNN_CHECK(cudnnConvolutionForward(this->handle_,
280
                                        &alpha,
281
                                        this->input_desc_,
282
                                        i_data,
283
                                        this->filter_desc_,
284
                                        w_data,
285 286 287 288
                                        this->conv_desc_,
                                        this->fwd_algo_,
                                        this->workspace_,
                                        this->workspace_fwd_sizes_,
289
                                        &beta,
290
                                        this->output_desc_,
291 292
                                        o_data));
    if (b_data) {
293 294 295 296 297 298 299
      CUDNN_CHECK(cudnnAddTensor(this->handle_,
                                 &alpha,
                                 this->bias_desc_,
                                 b_data,
                                 &alpha,
                                 this->output_desc_,
                                 o_data));
300 301 302
    }
  }

303
  if (!this->with_relu_act_) {
304 305 306 307 308 309 310 311 312 313 314 315 316
    CHECK(param.activation_param.active_type ==
          lite_api::ActivationType::kLeakyRelu)
        << "Only support leaky relu now.";
    auto out_dims = param.output->dims();
    int n = out_dims[0], c = out_dims[1], h = out_dims[2], w = out_dims[3];
    int num = n * h * w * c;
    float alpha = param.activation_param.Leaky_relu_alpha;

    relu(num, o_data, o_data, alpha, this->stream_);
  }
  return true;
}

317 318 319
template class CudnnConv2D<float, PRECISION(kFloat)>;
template class CudnnConv2D<half, PRECISION(kFP16)>;

320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
template <PrecisionType Ptype_out>
bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param,
                                        Context<TARGET(kCUDA)>* ctx) {
  auto x_dims = param.x->dims();
  auto w_dims = param.filter->dims();
  auto o_dims = param.output->dims();

  int batch = x_dims[0];

  int iw = x_dims[2];  // nchw
  int ih = x_dims[1];
  int ic = x_dims[3];
  int ow = o_dims[2];
  int oh = o_dims[1];
  int oc = o_dims[3];

  int kw = w_dims[2];
  int kh = w_dims[1];

H
HappyAngel 已提交
339 340 341
  auto paddings = *param.paddings;
  auto dilations = *param.dilations;

342 343
  int sw = param.strides[1];
  int sh = param.strides[0];
H
HappyAngel 已提交
344 345 346 347
  int pw = paddings[2];
  int ph = paddings[0];
  int dw = dilations[1];
  int dh = dilations[0];
348 349 350 351

  std::vector<float> weight_scale = param.weight_scale;
  float input_scale = param.input_scale;
  float output_scale = param.output_scale;
352
  CHECK(weight_scale.size() == static_cast<size_t>(oc))
353 354 355 356
      << "the num of the weight_scale should be equals to the output channel.";
  if (Ptype_out == PRECISION(kInt8)) {
    this->temp_tensor_.Resize(o_dims);
    this->temp_tensor_.template mutable_data<float>(TARGET(kCUDA));
357
    for (size_t i = 0; i < weight_scale.size(); i++) {
358 359
      weight_scale[i] = (weight_scale[i] * input_scale) / output_scale;
    }
360 361 362 363 364

    auto* b_data = param.bias ? param.bias->mutable_data<float>() : nullptr;
    if (b_data) {
      scale(param.bias->numel(), b_data, b_data, 1.f / output_scale);
    }
365
  } else {
366
    for (size_t i = 0; i < weight_scale.size(); i++) {
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
      weight_scale[i] = (weight_scale[i] * input_scale);
    }
  }
  this->scale_.Resize({oc});
  this->scale_.template Assign<float, lite::DDim, TARGET(kCUDA)>(
      weight_scale.data(), this->scale_.dims());

  CHECK(ic % param.groups == 0)
      << "The conv input channel shoud be divide group number.";
  CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_,
                                         CUDNN_TENSOR_NHWC,
                                         CUDNN_DATA_INT8,
                                         batch,
                                         ic,
                                         ih,
                                         iw));
  CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_,
                                         CUDNN_DATA_INT8,
                                         CUDNN_TENSOR_NHWC,
                                         oc,
                                         ic / param.groups,
                                         kh,
                                         kw));
  CUDNN_CHECK(cudnnSetConvolution2dDescriptor(this->conv_desc_,
                                              ph,
                                              pw,
                                              sh,
                                              sw,
                                              dh,
                                              dw,
                                              CUDNN_CROSS_CORRELATION,
                                              CUDNN_DATA_INT32));

  CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_,
                                         CUDNN_TENSOR_NHWC,
                                         CUDNN_DATA_FLOAT,
                                         batch,
                                         oc,
                                         oh,
                                         ow));
407 408 409 410 411
  if (ic % 4 == 0 && oc % 4 == 0) {
    this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
  } else {
    this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
  }
412 413 414 415 416 417 418
  CUDNN_CHECK(
      cudnnGetConvolutionForwardWorkspaceSize(this->handle_,
                                              this->input_desc_,
                                              this->filter_desc_,
                                              this->conv_desc_,
                                              this->output_desc_,
                                              this->fwd_algo_,
419
                                              &this->workspace_fwd_sizes_));
420 421 422 423

  if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) {
    this->workspace_size_inbytes_ = this->workspace_fwd_sizes_;
    if (this->workspace_data_ != NULL) {
424
      CUDA_CALL(cudaFree(this->workspace_data_));
425
    }
426 427
    CUDA_CALL(
        cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_));
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
    this->workspace_ = reinterpret_cast<char*>(this->workspace_data_);
  }

  return true;
}

template <PrecisionType Ptype_out>
bool CudnnConv2DInt8<Ptype_out>::init(const operators::ConvParam& param,
                                      Context<TARGET(kCUDA)>* ctx) {
  this->workspace_size_inbytes_ = 0;  // 64Mb
  this->workspace_data_ = NULL;
  this->workspace_fwd_sizes_ = 0;

  this->stream_ = ctx->exec_stream();
  CUDNN_CHECK(cudnnCreate(&this->handle_));
  CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_));

  this->workspace_ = NULL;

  cudnnCreateTensorDescriptor(&this->input_desc_);
  cudnnCreateTensorDescriptor(&this->output_desc_);
  cudnnCreateFilterDescriptor(&this->filter_desc_);
  cudnnCreateConvolutionDescriptor(&this->conv_desc_);
  cudnnCreateTensorDescriptor(&this->bias_desc_);

  if (param.activation_param.has_active) {
    if (!(param.activation_param.active_type ==
          lite_api::ActivationType::kRelu)) {
      this->with_relu_act_ = false;
    }
  }
  return create(param, ctx);
}

template <PrecisionType Ptype_out>
bool CudnnConv2DInt8<Ptype_out>::run(const operators::ConvParam& param) {
  const auto* i_data = param.x->data<int8_t>();
  const auto* w_data = param.filter->data<int8_t>();
  const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
  float* temp_out;
  float* scale = this->scale_.template mutable_data<float>(TARGET(kCUDA));
  if (Ptype_out == PRECISION(kInt8)) {
    temp_out = this->temp_tensor_.template mutable_data<float>(TARGET(kCUDA));
  } else {
Z
Zhaolong Xing 已提交
472
    // LOG(INFO) << param.output->dims().repr();
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
    temp_out = param.output->mutable_data<float>(TARGET(kCUDA));
  }

  float alpha = 1.0f;
  float beta = 0.0f;
  CUDNN_CHECK(cudnnConvolutionForward(this->handle_,
                                      &alpha,
                                      this->input_desc_,
                                      i_data,
                                      this->filter_desc_,
                                      w_data,
                                      this->conv_desc_,
                                      this->fwd_algo_,
                                      this->workspace_,
                                      this->workspace_fwd_sizes_,
                                      &beta,
                                      this->output_desc_,
                                      temp_out));

  auto out_dims = param.output->dims();
  int n = out_dims[0], h = out_dims[1], w = out_dims[2], c = out_dims[3];
Z
Zhaolong Xing 已提交
494
  int num = n * h * w * c;
495 496 497 498

  if (!param.activation_param.has_active && !b_data) {
    if (Ptype_out == PRECISION(kInt8)) {
      auto* out = param.output->mutable_data<int8_t>(TARGET(kCUDA));
Z
Zhaolong Xing 已提交
499 500 501 502 503 504 505 506 507
      fp32_to_int8_nhwc(num,
                        static_cast<const void*>(temp_out),
                        static_cast<void*>(out),
                        static_cast<const void*>(scale),
                        n,
                        c,
                        h,
                        w,
                        this->stream_);
508
    } else {
Z
Zhaolong Xing 已提交
509 510 511 512 513 514 515 516 517
      fp32_scale_nhwc(num,
                      static_cast<const void*>(temp_out),
                      static_cast<void*>(temp_out),
                      static_cast<const void*>(scale),
                      n,
                      c,
                      h,
                      w,
                      this->stream_);
518 519 520 521 522 523 524 525 526 527 528
    }
    return true;
  }

  if (b_data) {
    if (param.activation_param.has_active) {
      float alpha = 0.0;
      if (!this->with_relu_act_)
        alpha = param.activation_param.Leaky_relu_alpha;
      if (Ptype_out == PRECISION(kInt8)) {
        auto* out = param.output->mutable_data<int8_t>(TARGET(kCUDA));
Z
Zhaolong Xing 已提交
529
        bias_relu_int8_nhwc<int8_t>(num,
530 531
                                    static_cast<const void*>(temp_out),
                                    static_cast<const void*>(b_data),
Z
Zhaolong Xing 已提交
532
                                    static_cast<void*>(out),
533
                                    n,
Z
Zhaolong Xing 已提交
534
                                    c,
535 536 537 538 539
                                    h,
                                    w,
                                    static_cast<const void*>(scale),
                                    alpha,
                                    this->stream_);
Z
Zhaolong Xing 已提交
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567
      } else {
        bias_relu_int8_nhwc<float>(num,
                                   static_cast<const void*>(temp_out),
                                   static_cast<const void*>(b_data),
                                   static_cast<void*>(temp_out),
                                   n,
                                   c,
                                   h,
                                   w,
                                   static_cast<const void*>(scale),
                                   alpha,
                                   this->stream_);
      }
      return true;
    } else {
      if (Ptype_out == PRECISION(kInt8)) {
        auto* out = param.output->mutable_data<int8_t>(TARGET(kCUDA));
        bias_int8_nhwc<int8_t>(num,
                               static_cast<const void*>(temp_out),
                               static_cast<const void*>(b_data),
                               static_cast<void*>(out),
                               n,
                               c,
                               h,
                               w,
                               static_cast<const void*>(scale),
                               this->stream_);
      } else {
Z
Zhaolong Xing 已提交
568 569 570 571 572 573 574 575 576 577
        bias_int8_nhwc<float>(num,
                              static_cast<const void*>(temp_out),
                              static_cast<const void*>(b_data),
                              static_cast<void*>(temp_out),
                              n,
                              c,
                              h,
                              w,
                              static_cast<const void*>(scale),
                              this->stream_);
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593
      }
      return true;
    }
  }

  CHECK(false)
      << "Conv Int8 support Conv, Conv + bias + relu, Conv + bias + leaky_relu";
}

template class CudnnConv2DInt8<PRECISION(kInt8)>;
template class CudnnConv2DInt8<PRECISION(kFloat)>;

}  // namespace math
}  // namespace cuda
}  // namespace lite
}  // namespace paddle