cudnn_conv.cc 23.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/backends/cuda/math/cudnn_conv.h"
#include "lite/backends/cuda/math/activation.h"
17
#include "lite/backends/cuda/math/conv_op_cache_cudnn.h"
18 19 20 21 22 23 24 25
#include "lite/backends/cuda/math/scale.h"
#include "lite/backends/cuda/math/type_trans.h"

namespace paddle {
namespace lite {
namespace cuda {
namespace math {

26 27 28
template <PrecisionType PType>
cudnnDataType_t GetDataType();

29
template <>
30 31 32 33 34 35 36 37 38 39 40 41
cudnnDataType_t GetDataType<PRECISION(kFloat)>() {
  return CUDNN_DATA_FLOAT;
}

template <>
cudnnDataType_t GetDataType<PRECISION(kFP16)>() {
  return CUDNN_DATA_HALF;
}

template <typename T, PrecisionType Ptype_out>
bool CudnnConv2D<T, Ptype_out>::create(const operators::ConvParam& param,
                                       Context<TARGET(kCUDA)>* ctx) {
42 43 44 45 46
  auto x_dims = param.x->dims();
  auto w_dims = param.filter->dims();
  auto o_dims = param.output->dims();
  int batch = x_dims[0];

H
HappyAngel 已提交
47 48 49
  auto paddings = *param.paddings;
  auto dilations = *param.dilations;

50 51 52 53 54 55 56 57 58 59
  int iw = x_dims[3];  // nchw
  int ih = x_dims[2];
  int ic = x_dims[1];
  int ow = o_dims[3];
  int oh = o_dims[2];
  int oc = o_dims[1];
  int kw = w_dims[3];
  int kh = w_dims[2];
  int sw = param.strides[1];
  int sh = param.strides[0];
H
HappyAngel 已提交
60 61 62 63
  int pw = paddings[2];
  int ph = paddings[0];
  int dw = dilations[1];
  int dh = dilations[0];
64 65 66 67 68 69

  CHECK(ic % param.groups == 0)
      << "The conv input channel shoud be divide group number.";

  CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_,
                                         CUDNN_TENSOR_NCHW,
70
                                         GetDataType<Ptype_out>(),
71 72 73 74 75
                                         batch,
                                         ic,
                                         ih,
                                         iw));
  CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_,
76
                                         GetDataType<Ptype_out>(),
77 78 79 80 81 82 83 84 85 86 87 88 89
                                         CUDNN_TENSOR_NCHW,
                                         oc,
                                         ic / param.groups,
                                         kh,
                                         kw));
  CUDNN_CHECK(cudnnSetConvolution2dDescriptor(this->conv_desc_,
                                              ph,
                                              pw,
                                              sh,
                                              sw,
                                              dh,
                                              dw,
                                              CUDNN_CROSS_CORRELATION,
90
                                              GetDataType<Ptype_out>()));
91 92 93
  CUDNN_CHECK(cudnnSetConvolutionGroupCount(this->conv_desc_, param.groups));
  CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_,
                                         CUDNN_TENSOR_NCHW,
94
                                         GetDataType<Ptype_out>(),
95 96 97 98 99
                                         batch,
                                         oc,
                                         oh,
                                         ow));

100
  if (param.activation_param.has_active && this->with_relu_act_) {
101 102 103 104
    CUDNN_CHECK(cudnnSetActivationDescriptor(
        this->act_desc_, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0));
  }

W
Wilber 已提交
105 106
#if CUDNN_VERSION_MIN(7, 0, 0)
  cudnnMathType_t math_type =
107
      this->use_tensor_core_ ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH;
W
Wilber 已提交
108 109 110
  CUDNN_CHECK(cudnnSetConvolutionMathType(this->conv_desc_, math_type));
#endif

111 112
  if (ic == param.groups && ic == oc && ic != 1) {
    this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
W
Wilber 已提交
113
  } else if (!param.var_length) {
114 115 116
    const auto* i_data = param.x->data<T>();
    const auto* w_data = param.filter->data<T>();
    auto* o_data = param.output->mutable_data<T>(TARGET(kCUDA));
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
    int workspace_size_limit = 256 * 1024 * 1024;

    auto search_func = [&]() {
      int returned_algo_count;
      std::array<cudnnConvolutionFwdAlgoPerf_t,
                 CUDNN_CONVOLUTION_FWD_ALGO_COUNT>
          fwd_perf_stat;
      auto cudnn_find_func = [&](void* cudnn_workspace) {
        CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx(
            this->handle_,
            this->input_desc_,
            i_data,
            this->filter_desc_,
            w_data,
            this->conv_desc_,
            this->output_desc_,
            o_data,
            CUDNN_CONVOLUTION_FWD_ALGO_COUNT,
            &returned_algo_count,
            fwd_perf_stat.data(),
            cudnn_workspace,
            workspace_size_limit));
      };

141
      this->ResetWorkSpace();
142 143
      CUDA_CALL(cudaMalloc(&this->workspace_data_, workspace_size_limit));
      cudnn_find_func(this->workspace_data_);
144
      this->ResetWorkSpace();
145 146 147 148 149 150 151 152 153 154 155 156 157

      VLOG(2) << "Perf result: (algo: stat, time, memory)";
      for (int i = 0; i < returned_algo_count; ++i) {
        const auto& stat = fwd_perf_stat[i];
        VLOG(2) << stat.algo << ": " << stat.status << " " << stat.time << " "
                << stat.memory;
      }
      return fwd_perf_stat[0].algo;
    };
    AlgorithmsCache<cudnnConvolutionFwdAlgo_t> algo_cache;
    this->fwd_algo_ = algo_cache.GetAlgorithm(x_dims.Vectorize(),
                                              w_dims.Vectorize(),
                                              param.strides,
H
HappyAngel 已提交
158 159
                                              *param.paddings,
                                              *param.dilations,
160 161 162
                                              0,
                                              search_func);

163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
  } else {
    CUDNN_CHECK(
        cudnnGetConvolutionForwardAlgorithm(this->handle_,
                                            this->input_desc_,
                                            this->filter_desc_,
                                            this->conv_desc_,
                                            this->output_desc_,
                                            this->preference_,
                                            this->workspace_limit_bytes_,
                                            &this->fwd_algo_));
  }
  CUDNN_CHECK(
      cudnnGetConvolutionForwardWorkspaceSize(this->handle_,
                                              this->input_desc_,
                                              this->filter_desc_,
                                              this->conv_desc_,
                                              this->output_desc_,
                                              this->fwd_algo_,
                                              &this->workspace_fwd_sizes_));
  if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) {
    this->workspace_size_inbytes_ = this->workspace_fwd_sizes_;
184
    this->ResetWorkSpace();
185 186 187 188 189 190 191
    cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_);
    this->workspace_ = reinterpret_cast<char*>(this->workspace_data_);
  }
  if (param.bias) {
    int dim_bias[] = {1, oc, 1, 1};
    int stride_bias[] = {oc, 1, 1, 1};
    cudnnSetTensorNdDescriptor(
192
        this->bias_desc_, GetDataType<Ptype_out>(), 4, dim_bias, stride_bias);
193 194 195 196
  }
  return true;
}

197 198 199
template <typename T, PrecisionType Ptype_out>
bool CudnnConv2D<T, Ptype_out>::init(const operators::ConvParam& param,
                                     Context<TARGET(kCUDA)>* ctx) {
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
  this->workspace_size_inbytes_ = 0;
  this->workspace_data_ = NULL;
  this->workspace_fwd_sizes_ = 0;

  this->stream_ = ctx->exec_stream();
  CUDNN_CHECK(cudnnCreate(&this->handle_));
  CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_));

  this->workspace_ = NULL;

  cudnnCreateTensorDescriptor(&this->input_desc_);
  cudnnCreateTensorDescriptor(&this->output_desc_);
  cudnnCreateFilterDescriptor(&this->filter_desc_);
  cudnnCreateConvolutionDescriptor(&this->conv_desc_);
  cudnnCreateTensorDescriptor(&this->bias_desc_);

  if (param.activation_param.has_active) {
    if (param.activation_param.active_type == lite_api::ActivationType::kRelu) {
      cudnnCreateActivationDescriptor(&this->act_desc_);
    } else {
      this->with_relu_act_ = false;
    }
  }
  return create(param, ctx);
}

226 227 228 229 230 231
template <typename T, PrecisionType Ptype_out>
bool CudnnConv2D<T, Ptype_out>::run(const operators::ConvParam& param) {
  const auto* i_data = param.x->data<T>();
  const auto* w_data = param.filter->data<T>();
  const auto* b_data = param.bias ? param.bias->data<T>() : nullptr;
  auto* o_data = param.output->mutable_data<T>(TARGET(kCUDA));
232

233
  if (param.activation_param.has_active && this->with_relu_act_) {
234 235 236
    if (b_data) {
      float alpha = 1.0f;
      float beta = 0.0f;
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
      CUDNN_CHECK(
          cudnnConvolutionBiasActivationForward(this->handle_,
                                                &alpha,
                                                this->input_desc_,
                                                i_data,
                                                this->filter_desc_,
                                                w_data,
                                                this->conv_desc_,
                                                this->fwd_algo_,
                                                this->workspace_,
                                                this->workspace_fwd_sizes_,
                                                &beta,
                                                this->output_desc_,
                                                o_data,
                                                this->bias_desc_,
                                                b_data,
                                                this->act_desc_,
                                                this->output_desc_,
                                                o_data));
256 257 258
    } else {
      float alpha = 1.0f;
      float beta = 0.0f;
259
      CUDNN_CHECK(cudnnConvolutionForward(this->handle_,
260
                                          &alpha,
261
                                          this->input_desc_,
262
                                          i_data,
263
                                          this->filter_desc_,
264
                                          w_data,
265 266 267 268
                                          this->conv_desc_,
                                          this->fwd_algo_,
                                          this->workspace_,
                                          this->workspace_fwd_sizes_,
269
                                          &beta,
270
                                          this->output_desc_,
271 272
                                          o_data));

273 274
      CUDNN_CHECK(cudnnActivationForward(this->handle_,
                                         this->act_desc_,
275
                                         &alpha,
276
                                         this->output_desc_,
277 278
                                         o_data,
                                         &beta,
279
                                         this->output_desc_,
280 281 282 283 284
                                         o_data));
    }
  } else {
    float alpha = 1.0f;
    float beta = 0.0f;
285
    CUDNN_CHECK(cudnnConvolutionForward(this->handle_,
286
                                        &alpha,
287
                                        this->input_desc_,
288
                                        i_data,
289
                                        this->filter_desc_,
290
                                        w_data,
291 292 293 294
                                        this->conv_desc_,
                                        this->fwd_algo_,
                                        this->workspace_,
                                        this->workspace_fwd_sizes_,
295
                                        &beta,
296
                                        this->output_desc_,
297 298
                                        o_data));
    if (b_data) {
299 300 301 302 303 304 305
      CUDNN_CHECK(cudnnAddTensor(this->handle_,
                                 &alpha,
                                 this->bias_desc_,
                                 b_data,
                                 &alpha,
                                 this->output_desc_,
                                 o_data));
306 307 308
    }
  }

309
  if (!this->with_relu_act_) {
310 311 312 313 314 315 316 317 318 319 320 321 322
    CHECK(param.activation_param.active_type ==
          lite_api::ActivationType::kLeakyRelu)
        << "Only support leaky relu now.";
    auto out_dims = param.output->dims();
    int n = out_dims[0], c = out_dims[1], h = out_dims[2], w = out_dims[3];
    int num = n * h * w * c;
    float alpha = param.activation_param.Leaky_relu_alpha;

    relu(num, o_data, o_data, alpha, this->stream_);
  }
  return true;
}

323 324 325
template class CudnnConv2D<float, PRECISION(kFloat)>;
template class CudnnConv2D<half, PRECISION(kFP16)>;

326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
template <PrecisionType Ptype_out>
bool CudnnConv2DInt8<Ptype_out>::create(const operators::ConvParam& param,
                                        Context<TARGET(kCUDA)>* ctx) {
  auto x_dims = param.x->dims();
  auto w_dims = param.filter->dims();
  auto o_dims = param.output->dims();

  int batch = x_dims[0];

  int iw = x_dims[2];  // nchw
  int ih = x_dims[1];
  int ic = x_dims[3];
  int ow = o_dims[2];
  int oh = o_dims[1];
  int oc = o_dims[3];

  int kw = w_dims[2];
  int kh = w_dims[1];

H
HappyAngel 已提交
345 346 347
  auto paddings = *param.paddings;
  auto dilations = *param.dilations;

348 349
  int sw = param.strides[1];
  int sh = param.strides[0];
H
HappyAngel 已提交
350 351 352 353
  int pw = paddings[2];
  int ph = paddings[0];
  int dw = dilations[1];
  int dh = dilations[0];
354 355 356 357

  std::vector<float> weight_scale = param.weight_scale;
  float input_scale = param.input_scale;
  float output_scale = param.output_scale;
358
  CHECK(weight_scale.size() == static_cast<size_t>(oc))
359 360 361 362
      << "the num of the weight_scale should be equals to the output channel.";
  if (Ptype_out == PRECISION(kInt8)) {
    this->temp_tensor_.Resize(o_dims);
    this->temp_tensor_.template mutable_data<float>(TARGET(kCUDA));
363
    for (size_t i = 0; i < weight_scale.size(); i++) {
364 365
      weight_scale[i] = (weight_scale[i] * input_scale) / output_scale;
    }
366 367 368 369 370

    auto* b_data = param.bias ? param.bias->mutable_data<float>() : nullptr;
    if (b_data) {
      scale(param.bias->numel(), b_data, b_data, 1.f / output_scale);
    }
371
  } else {
372
    for (size_t i = 0; i < weight_scale.size(); i++) {
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
      weight_scale[i] = (weight_scale[i] * input_scale);
    }
  }
  this->scale_.Resize({oc});
  this->scale_.template Assign<float, lite::DDim, TARGET(kCUDA)>(
      weight_scale.data(), this->scale_.dims());

  CHECK(ic % param.groups == 0)
      << "The conv input channel shoud be divide group number.";
  CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_,
                                         CUDNN_TENSOR_NHWC,
                                         CUDNN_DATA_INT8,
                                         batch,
                                         ic,
                                         ih,
                                         iw));
  CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_,
                                         CUDNN_DATA_INT8,
                                         CUDNN_TENSOR_NHWC,
                                         oc,
                                         ic / param.groups,
                                         kh,
                                         kw));
  CUDNN_CHECK(cudnnSetConvolution2dDescriptor(this->conv_desc_,
                                              ph,
                                              pw,
                                              sh,
                                              sw,
                                              dh,
                                              dw,
                                              CUDNN_CROSS_CORRELATION,
                                              CUDNN_DATA_INT32));

  CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_,
                                         CUDNN_TENSOR_NHWC,
                                         CUDNN_DATA_FLOAT,
                                         batch,
                                         oc,
                                         oh,
                                         ow));
413 414 415 416 417
  if (ic % 4 == 0 && oc % 4 == 0) {
    this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
  } else {
    this->fwd_algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
  }
418 419 420 421 422 423 424
  CUDNN_CHECK(
      cudnnGetConvolutionForwardWorkspaceSize(this->handle_,
                                              this->input_desc_,
                                              this->filter_desc_,
                                              this->conv_desc_,
                                              this->output_desc_,
                                              this->fwd_algo_,
425
                                              &this->workspace_fwd_sizes_));
426 427 428 429

  if (this->workspace_fwd_sizes_ > this->workspace_size_inbytes_) {
    this->workspace_size_inbytes_ = this->workspace_fwd_sizes_;
    if (this->workspace_data_ != NULL) {
430
      CUDA_CALL(cudaFree(this->workspace_data_));
431
    }
432 433
    CUDA_CALL(
        cudaMalloc(&this->workspace_data_, this->workspace_size_inbytes_));
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
    this->workspace_ = reinterpret_cast<char*>(this->workspace_data_);
  }

  return true;
}

template <PrecisionType Ptype_out>
bool CudnnConv2DInt8<Ptype_out>::init(const operators::ConvParam& param,
                                      Context<TARGET(kCUDA)>* ctx) {
  this->workspace_size_inbytes_ = 0;  // 64Mb
  this->workspace_data_ = NULL;
  this->workspace_fwd_sizes_ = 0;

  this->stream_ = ctx->exec_stream();
  CUDNN_CHECK(cudnnCreate(&this->handle_));
  CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_));

  this->workspace_ = NULL;

  cudnnCreateTensorDescriptor(&this->input_desc_);
  cudnnCreateTensorDescriptor(&this->output_desc_);
  cudnnCreateFilterDescriptor(&this->filter_desc_);
  cudnnCreateConvolutionDescriptor(&this->conv_desc_);
  cudnnCreateTensorDescriptor(&this->bias_desc_);

  if (param.activation_param.has_active) {
    if (!(param.activation_param.active_type ==
          lite_api::ActivationType::kRelu)) {
      this->with_relu_act_ = false;
    }
  }
  return create(param, ctx);
}

template <PrecisionType Ptype_out>
bool CudnnConv2DInt8<Ptype_out>::run(const operators::ConvParam& param) {
  const auto* i_data = param.x->data<int8_t>();
  const auto* w_data = param.filter->data<int8_t>();
  const auto* b_data = param.bias ? param.bias->data<float>() : nullptr;
  float* temp_out;
  float* scale = this->scale_.template mutable_data<float>(TARGET(kCUDA));
  if (Ptype_out == PRECISION(kInt8)) {
    temp_out = this->temp_tensor_.template mutable_data<float>(TARGET(kCUDA));
  } else {
Z
Zhaolong Xing 已提交
478
    // LOG(INFO) << param.output->dims().repr();
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
    temp_out = param.output->mutable_data<float>(TARGET(kCUDA));
  }

  float alpha = 1.0f;
  float beta = 0.0f;
  CUDNN_CHECK(cudnnConvolutionForward(this->handle_,
                                      &alpha,
                                      this->input_desc_,
                                      i_data,
                                      this->filter_desc_,
                                      w_data,
                                      this->conv_desc_,
                                      this->fwd_algo_,
                                      this->workspace_,
                                      this->workspace_fwd_sizes_,
                                      &beta,
                                      this->output_desc_,
                                      temp_out));

  auto out_dims = param.output->dims();
  int n = out_dims[0], h = out_dims[1], w = out_dims[2], c = out_dims[3];
Z
Zhaolong Xing 已提交
500
  int num = n * h * w * c;
501 502 503 504

  if (!param.activation_param.has_active && !b_data) {
    if (Ptype_out == PRECISION(kInt8)) {
      auto* out = param.output->mutable_data<int8_t>(TARGET(kCUDA));
Z
Zhaolong Xing 已提交
505 506 507 508 509 510 511 512 513
      fp32_to_int8_nhwc(num,
                        static_cast<const void*>(temp_out),
                        static_cast<void*>(out),
                        static_cast<const void*>(scale),
                        n,
                        c,
                        h,
                        w,
                        this->stream_);
514
    } else {
Z
Zhaolong Xing 已提交
515 516 517 518 519 520 521 522 523
      fp32_scale_nhwc(num,
                      static_cast<const void*>(temp_out),
                      static_cast<void*>(temp_out),
                      static_cast<const void*>(scale),
                      n,
                      c,
                      h,
                      w,
                      this->stream_);
524 525 526 527 528 529 530 531 532 533 534
    }
    return true;
  }

  if (b_data) {
    if (param.activation_param.has_active) {
      float alpha = 0.0;
      if (!this->with_relu_act_)
        alpha = param.activation_param.Leaky_relu_alpha;
      if (Ptype_out == PRECISION(kInt8)) {
        auto* out = param.output->mutable_data<int8_t>(TARGET(kCUDA));
Z
Zhaolong Xing 已提交
535
        bias_relu_int8_nhwc<int8_t>(num,
536 537
                                    static_cast<const void*>(temp_out),
                                    static_cast<const void*>(b_data),
Z
Zhaolong Xing 已提交
538
                                    static_cast<void*>(out),
539
                                    n,
Z
Zhaolong Xing 已提交
540
                                    c,
541 542 543 544 545
                                    h,
                                    w,
                                    static_cast<const void*>(scale),
                                    alpha,
                                    this->stream_);
Z
Zhaolong Xing 已提交
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
      } else {
        bias_relu_int8_nhwc<float>(num,
                                   static_cast<const void*>(temp_out),
                                   static_cast<const void*>(b_data),
                                   static_cast<void*>(temp_out),
                                   n,
                                   c,
                                   h,
                                   w,
                                   static_cast<const void*>(scale),
                                   alpha,
                                   this->stream_);
      }
      return true;
    } else {
      if (Ptype_out == PRECISION(kInt8)) {
        auto* out = param.output->mutable_data<int8_t>(TARGET(kCUDA));
        bias_int8_nhwc<int8_t>(num,
                               static_cast<const void*>(temp_out),
                               static_cast<const void*>(b_data),
                               static_cast<void*>(out),
                               n,
                               c,
                               h,
                               w,
                               static_cast<const void*>(scale),
                               this->stream_);
      } else {
Z
Zhaolong Xing 已提交
574 575 576 577 578 579 580 581 582 583
        bias_int8_nhwc<float>(num,
                              static_cast<const void*>(temp_out),
                              static_cast<const void*>(b_data),
                              static_cast<void*>(temp_out),
                              n,
                              c,
                              h,
                              w,
                              static_cast<const void*>(scale),
                              this->stream_);
584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599
      }
      return true;
    }
  }

  CHECK(false)
      << "Conv Int8 support Conv, Conv + bias + relu, Conv + bias + leaky_relu";
}

template class CudnnConv2DInt8<PRECISION(kInt8)>;
template class CudnnConv2DInt8<PRECISION(kFloat)>;

}  // namespace math
}  // namespace cuda
}  // namespace lite
}  // namespace paddle