cuda_util.cc 110.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/runtime/cuda/cuda_util.h"

#include <absl/container/flat_hash_map.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <curand.h>
#include <cusolverDn.h>
#include <glog/logging.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

#include <algorithm>
#include <string>
#ifdef CINN_WITH_CUDNN
#include <cudnn.h>
#endif

#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/backends/extern_func_jit_register.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/runtime/cuda/cublas_util.h"
#include "paddle/cinn/runtime/custom_function.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/profiler.h"
#include "paddle/cinn/utils/timer.h"

namespace cinn {
namespace runtime {
namespace cuda {

class CublasHandle {
 public:
  CublasHandle(const CublasHandle &) = delete;
  CublasHandle &operator=(const CublasHandle &) = delete;
  ~CublasHandle() {
    CUBLAS_CALL(cublasDestroy(cuhandle));
    CUDA_CALL(cudaStreamDestroy(custream));
  }
  static CublasHandle &GetInstance() {
    static CublasHandle instance;
    return instance;
  }
  cudaStream_t GetCuStream() { return custream; }
  cublasHandle_t &GetCublasHandle() { return cuhandle; }

 private:
  CublasHandle() {
    CUDA_CALL(cudaStreamCreate(&custream));
    CUBLAS_CALL(cublasCreate(&cuhandle));
    cudaMemPool_t mem_pool;
    CUDA_CALL(cudaDeviceGetMemPool(&mem_pool, 0));

    uint64_t threshold = UINT32_MAX;
68 69
    CUDA_CALL(cudaMemPoolSetAttribute(
        mem_pool, cudaMemPoolAttrReleaseThreshold, &threshold));
70 71

    int enable = 1;
72 73 74 75
    CUDA_CALL(cudaMemPoolSetAttribute(
        mem_pool, cudaMemPoolReuseFollowEventDependencies, &enable));
    CUDA_CALL(cudaMemPoolSetAttribute(
        mem_pool, cudaMemPoolReuseAllowInternalDependencies, &enable));
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
  }
  cudaStream_t custream;
  cublasHandle_t cuhandle;
};

void cinn_call_cuda_kernel(void *kernel_fn,
                           void *v_args,
                           int num_args,
                           int grid_x,
                           int grid_y,
                           int grid_z,
                           int block_x,
                           int block_y,
                           int block_z,
                           void *stream) {
91 92 93 94
  VLOG(3) << "cinn_call_cuda_kernel, grid_dim={" << grid_x << ", " << grid_y
          << ", " << grid_z << "}, block_dim={" << block_x << ", " << block_y
          << ", " << block_z << "}, num_args=" << num_args
          << ", stream=" << stream;
95 96 97

  std::vector<void *> kernel_args;
  {
98 99
    cinn::utils::RecordEvent record_run("prepare_args",
                                        cinn::utils::EventType::kInstruction);
100 101 102 103
    kernel_args.reserve(num_args);
    cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
    for (int idx = 0; idx < num_args; ++idx) {
      if (args[idx].type_code() == ::cinn_type_code<cinn_buffer_t *>()) {
104 105
        kernel_args.emplace_back(
            &((cinn_buffer_t *)(args[idx]))->memory);  // NOLINT
106 107 108 109 110 111 112
      } else {
        kernel_args.emplace_back(args[idx].data_addr());
      }
    }
  }

  {
113 114
    cinn::utils::RecordEvent record_run("cuLaunchKernel",
                                        cinn::utils::EventType::kInstruction);
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
    CUDA_DRIVER_CALL(cuLaunchKernel(static_cast<CUfunction>(kernel_fn),
                                    grid_x,
                                    grid_y,
                                    grid_z,
                                    block_x,
                                    block_y,
                                    block_z,
                                    0,  // share memory
                                    static_cast<CUstream>(stream),
                                    kernel_args.data(),
                                    nullptr))
  }
}

void cinn_call_cublas(void *v_args,
                      int num_args,
                      bool trans_a,
                      bool trans_b,
                      bool trans_o,
                      float alpha,
                      float beta,
                      int a1,
                      int a2,
                      int a3,
                      int a4,
                      int b1,
                      int b2,
                      int b3,
                      int b4,
                      void *stream) {
145 146
  cinn::utils::RecordEvent record_run("cinn_call_cublas",
                                      cinn::utils::EventType::kInstruction);
147 148
  CHECK_EQ(num_args, 3);
  cublasHandle_t &cuhandle = CublasHandle::GetInstance().GetCublasHandle();
149 150
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
  cudaStream_t custream = static_cast<cudaStream_t>(stream);
151 152 153
  CUBLAS_CALL(cublasSetStream(cuhandle, custream));
  VLOG(3) << "a1 ~ a4: " << a1 << " " << a2 << " " << a3 << " " << a4;
  VLOG(3) << "b1 ~ b4: " << b1 << " " << b2 << " " << b3 << " " << b4;
154 155
  VLOG(3) << "trans_a: " << trans_a << ", trans_b: " << trans_b
          << ", trans_o: " << trans_o;
156 157 158 159 160 161 162 163 164

  void *A = args[0].operator cinn_buffer_t *()->memory;
  void *B = args[1].operator cinn_buffer_t *()->memory;
  void *C = args[2].operator cinn_buffer_t *()->memory;

  int m = trans_o ? (trans_a ? a4 : a3) : (trans_b ? b3 : b4);
  int n = trans_o ? (trans_b ? b3 : b4) : (trans_a ? a4 : a3);
  int k = trans_a ? a3 : a4;

165 166 167 168 169 170 171 172 173 174 175 176
  cublasOperation_t trans_op_l = trans_o
                                     ? (trans_a ? CUBLAS_OP_N : CUBLAS_OP_T)
                                     : (trans_b ? CUBLAS_OP_T : CUBLAS_OP_N);
  cublasOperation_t trans_op_r = trans_o
                                     ? (trans_b ? CUBLAS_OP_N : CUBLAS_OP_T)
                                     : (trans_a ? CUBLAS_OP_T : CUBLAS_OP_N);
  int ldl = trans_op_l == CUBLAS_OP_N
                ? m
                : k;  // trans_o ? (trans_a ? k : m) : (trans_b ? k : m);
  int ldr = trans_op_r == CUBLAS_OP_N
                ? k
                : n;  // trans_o ? (trans_b ? n : k) : (trans_a ? n : k);
177 178 179 180 181 182
  int ldc = m;

  void *lhs = trans_o ? A : B;
  void *rhs = trans_o ? B : A;

  cudaDataType_t cuda_dtype;
183 184
  auto type_code = args[0].operator cinn_buffer_t *()->type.code;
  bool is_float = type_code == cinn_type_float;
185
  bool is_bfloat16 = type_code == cinn_type_bfloat;
186
  int bytes = args[0].operator cinn_buffer_t *()->type.bits / CHAR_BIT;
187 188 189 190 191 192 193 194 195
  if (is_float && bytes == sizeof(common::float16)) {
    cuda_dtype = CUDA_R_16F;
  } else if (is_float && bytes == sizeof(float)) {
    cuda_dtype = CUDA_R_32F;
  } else if (is_float && bytes == sizeof(double)) {
    cuda_dtype = CUDA_R_64F;
  } else if (is_bfloat16) {
    cuda_dtype = CUDA_R_16BF;
  } else {
196 197
    LOG(FATAL) << "unsupported cublas data type: "
               << static_cast<int>(type_code) << ", bytes = " << bytes;
198 199 200 201
  }

  if (a1 * a2 * b1 * b2 == 1) {
    VLOG(3) << "call cublasGemm for a1 * a2 * b1 * b2 == 1";
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
    cinn::utils::RecordEvent record_run("Call cublasGemm",
                                        cinn::utils::EventType::kInstruction);
    CUBLAS_CALL(cublasGemm(cuda_dtype,
                           cuhandle,
                           trans_op_l,
                           trans_op_r,
                           m,
                           n,
                           k,
                           alpha,
                           lhs,
                           ldl,
                           rhs,
                           ldr,
                           beta,
                           C,
                           ldc));
219 220 221 222
  } else if (a1 * b1 == 1) {
    CHECK(a2 == b2 || a2 == 1 || b2 == 1);
    if (b2 == 1 && trans_op_r == CUBLAS_OP_N) {
      // In case of [1, bs, M, K] * [1, 1, K, N]
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
      VLOG(3) << "call cublasGemm for a1 * b1 = 1, b2 = 1, trans_op_r:"
              << trans_op_r;
      cinn::utils::RecordEvent record_run("Call cublasGemm",
                                          cinn::utils::EventType::kInstruction);
      CUBLAS_CALL(cublasGemm(cuda_dtype,
                             cuhandle,
                             trans_op_l,
                             trans_op_r,
                             m,
                             a2 * n,
                             k,
                             alpha,
                             lhs,
                             ldl,
                             A,
                             ldr,
                             beta,
                             C,
                             ldc));
242 243 244
    } else {
      int stride_l = trans_o ? (a2 > 1 ? a3 * a4 : 0) : (b2 > 1 ? b3 * b4 : 0);
      int stride_r = trans_o ? (b2 > 1 ? b3 * b4 : 0) : (a2 > 1 ? a3 * a4 : 0);
245 246 247
      int batch = std::max(a2, b2);
      VLOG(3) << "call cublasGemmStridedBatched with a1*b1 = 1, stride_l = "
              << stride_l << ", stride_r = " << stride_r
248
              << ", batch = " << batch;
249 250
      cinn::utils::RecordEvent record_run("Call cublasGemmStridedBatched",
                                          cinn::utils::EventType::kInstruction);
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
      CUBLAS_CALL(cublasGemmStridedBatched(cuda_dtype,
                                           cuhandle,
                                           trans_op_l,
                                           trans_op_r,
                                           m,
                                           n,
                                           k,
                                           alpha,
                                           lhs,
                                           ldl,
                                           stride_l,
                                           rhs,
                                           ldr,
                                           stride_r,
                                           beta,
                                           C,
                                           ldc,
                                           m * n,
                                           batch));
    }
  } else {
272 273 274 275
    int l1 = trans_o ? a1 : b1, l2 = trans_o ? a2 : b2, l3 = trans_o ? a3 : b3,
        l4 = trans_o ? a4 : b4;
    int r1 = trans_o ? b1 : a1, r2 = trans_o ? b2 : a2, r3 = trans_o ? b3 : a3,
        r4 = trans_o ? b4 : a4;
276

277 278
    if ((l1 == r1 && l2 == r2) || (l1 == 1 && l2 == 1) ||
        (r1 == 1 && r2 == 1)) {
279 280 281 282 283 284
      int stride_l = (l1 == 1 && l2 == 1) ? 0 : l3 * l4;
      int stride_r = (r1 == 1 && r2 == 1) ? 0 : r3 * r4;

      // four types matmul:
      // (N, L) * (N, L) , (N, 1) * (N, 1)
      // (N, L) * (1, 1) , (1, 1) * (N, L)
285 286
      VLOG(3) << "call cublasGemmStridedBatched for stride_l = " << stride_l
              << ", stride_r = " << stride_r
287
              << ", batch = " << std::max(l1, r1) * std::max(l2, r2);
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
      cinn::utils::RecordEvent record_run("Call cublasGemmStridedBatched",
                                          cinn::utils::EventType::kInstruction);
      CUBLAS_CALL(
          cublasGemmStridedBatched(cuda_dtype,
                                   cuhandle,
                                   trans_op_l,
                                   trans_op_r,
                                   m,
                                   n,
                                   k,
                                   alpha,
                                   lhs,
                                   ldl,
                                   stride_l,
                                   rhs,
                                   ldr,
                                   stride_r,
                                   beta,
                                   C,
                                   ldc,
                                   m * n,
                                   std::max(l1, r1) * std::max(l2, r2)));
310
    } else {
311 312
      cinn::utils::RecordEvent record_run("Call cublasGemmBatched",
                                          cinn::utils::EventType::kInstruction);
313
      // (N, L) / (N, 1) / (1, L)
314 315
      int bstride_l =
          (l1 != 1 && l2 != 1) ? (l2 * m * k) : ((l1 != 1) ? m * k : 0);
316
      // (N, L) / (N, 1) / (1, L)
317 318
      int bstride_r =
          (r1 != 1 && r2 != 1) ? (r2 * k * n) : ((r1 != 1) ? k * n : 0);
319 320 321 322 323 324 325 326 327
      int bstride_c = std::max(l2, r2) * m * n;

      int stride_l = l2 == 1 ? 0 : l3 * l4;
      int stride_r = r2 == 1 ? 0 : r3 * r4;
      // six type matmul:
      // (N, L) * (N, 1) , (N, L) * (1, L)
      // (N, 1) * (N, L) , (1, L) * (N, L)
      // (N, 1) * (1, L) , (1, L) * (N, 1)

328
      void **ptr_arr = nullptr;
329
      cudaStream_t g_stream = CublasHandle::GetInstance().GetCuStream();
330 331 332 333
      CUDA_CALL(cudaMallocAsync(
          &ptr_arr,
          sizeof(void *) * 3 * std::max(l1, r1) * std::max(l2, r2),
          g_stream));
334 335 336 337 338 339 340 341

      std::vector<void *> ptr(3 * std::max(l1, r1) * std::max(l2, r2));
      void **ptr_a = ptr.data();
      void **ptr_b = ptr.data() + std::max(l1, r1) * std::max(l2, r2);
      void **ptr_c = ptr.data() + std::max(l1, r1) * std::max(l2, r2) * 2;

      for (int idx = 0, index = 0; idx < std::max(l1, r1); ++idx) {
        for (int idy = 0; idy < std::max(l2, r2); ++idy) {
342 343 344 345 346 347
          ptr_a[index] = reinterpret_cast<uint8_t *>(lhs) +
                         (idx * bstride_l + idy * stride_l) * bytes;
          ptr_b[index] = reinterpret_cast<uint8_t *>(rhs) +
                         (idx * bstride_r + idy * stride_r) * bytes;
          ptr_c[index] = reinterpret_cast<uint8_t *>(C) +
                         (idx * bstride_c + idy * m * n) * bytes;
348 349 350
          ++index;
        }
      }
351 352 353 354 355
      CUDA_CALL(cudaMemcpyAsync(ptr_arr,
                                ptr.data(),
                                ptr.size() * sizeof(void *),
                                cudaMemcpyHostToDevice,
                                g_stream));
356 357
      CUDA_CALL(cudaStreamSynchronize(g_stream));

358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
      CUBLAS_CALL(
          cublasGemmBatched(cuda_dtype,
                            cuhandle,
                            trans_op_l,
                            trans_op_r,
                            m,
                            n,
                            k,
                            alpha,
                            ptr_arr,
                            ldl,
                            ptr_arr + std::max(l1, r1) * std::max(l2, r2),
                            ldr,
                            beta,
                            ptr_arr + std::max(l1, r1) * std::max(l2, r2) * 2,
                            ldc,
                            std::max(l1, r1) * std::max(l2, r2)));
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
      CUDA_CALL(cudaFreeAsync(ptr_arr, custream));
    }
  }
}

void cinn_call_batched_cublas(void *v_args,
                              int num_args,
                              int opside,
                              bool trans_a,
                              bool trans_b,
                              bool trans_o,
                              float alpha,
                              float beta,
                              int a1,
                              int a2,
                              int a3,
                              int a4,
                              int b1,
                              int b2,
                              int b3,
                              int b4,
                              void *stream) {
  // A * [B, C, D, ...] or [B, C, D, ...] * A
  CHECK_EQ((num_args - 1) % 2, 0);
  cublasHandle_t &cuhandle = CublasHandle::GetInstance().GetCublasHandle();
400 401
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
  cudaStream_t custream = static_cast<cudaStream_t>(stream);
402 403 404
  CUBLAS_CALL(cublasSetStream(cuhandle, custream));

  cudaDataType_t cuda_dtype;
405 406
  auto type_code = args[0].operator cinn_buffer_t *()->type.code;
  bool is_float = type_code == cinn_type_float;
407
  bool is_bfloat16 = type_code == cinn_type_bfloat;
408
  int bytes = args[0].operator cinn_buffer_t *()->type.bits / CHAR_BIT;
409 410 411 412 413 414 415 416 417
  if (is_float && bytes == sizeof(common::float16)) {
    cuda_dtype = CUDA_R_16F;
  } else if (is_float && bytes == sizeof(float)) {
    cuda_dtype = CUDA_R_32F;
  } else if (is_float && bytes == sizeof(double)) {
    cuda_dtype = CUDA_R_64F;
  } else if (is_bfloat16) {
    cuda_dtype = CUDA_R_16BF;
  } else {
418 419
    LOG(FATAL) << "unsupported cublas data type: "
               << static_cast<int>(type_code) << ", bytes = " << bytes;
420 421 422 423 424 425
  }

  int m = trans_o ? (trans_a ? a4 : a3) : (trans_b ? b3 : b4);
  int n = trans_o ? (trans_b ? b3 : b4) : (trans_a ? a4 : a3);
  int k = trans_a ? a3 : a4;

426 427 428 429 430 431 432 433 434 435 436 437
  cublasOperation_t trans_op_l = trans_o
                                     ? (trans_a ? CUBLAS_OP_N : CUBLAS_OP_T)
                                     : (trans_b ? CUBLAS_OP_T : CUBLAS_OP_N);
  cublasOperation_t trans_op_r = trans_o
                                     ? (trans_b ? CUBLAS_OP_N : CUBLAS_OP_T)
                                     : (trans_a ? CUBLAS_OP_T : CUBLAS_OP_N);
  int ldl = trans_op_l == CUBLAS_OP_N
                ? m
                : k;  // trans_o ? (trans_a ? k : m) : (trans_b ? k : m);
  int ldr = trans_op_r == CUBLAS_OP_N
                ? k
                : n;  // trans_o ? (trans_b ? n : k) : (trans_a ? n : k);
438 439
  int ldc = m;

440 441 442 443
  int l1 = trans_o ? a1 : b1, l2 = trans_o ? a2 : b2, l3 = trans_o ? a3 : b3,
      l4 = trans_o ? a4 : b4;
  int r1 = trans_o ? b1 : a1, r2 = trans_o ? b2 : a2, r3 = trans_o ? b3 : a3,
      r4 = trans_o ? b4 : a4;
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462

  // (N, L): L * M * K
  // (N, 1): 1 * M * K
  // (1, L): 0
  // (1, 1): 0
  int bstride_l = (l1 != 1 && l2 != 1) ? (l2 * m * k) : ((l1 != 1) ? m * k : 0);
  int bstride_r = (r1 != 1 && r2 != 1) ? (r2 * k * n) : ((r1 != 1) ? k * n : 0);
  int bstride_c = std::max(l2, r2) * m * n;
  // (N, L): K * N
  // (N, 1): 0
  // (1, L): K * N
  // (1, 1): 0
  int stride_l = l2 == 1 ? 0 : l3 * l4;
  int stride_r = r2 == 1 ? 0 : r3 * r4;

  int num_gemm = ((num_args - 1) / 2);
  std::vector<void *> ptr(3 * std::max(l1, r1) * std::max(l2, r2) * num_gemm);
  void **ptr_a = ptr.data();
  void **ptr_b = ptr.data() + std::max(l1, r1) * std::max(l2, r2) * num_gemm;
463 464
  void **ptr_c =
      ptr.data() + std::max(l1, r1) * std::max(l2, r2) * num_gemm * 2;
465

466
  void **ptr_arr = nullptr;
467 468 469 470 471 472 473 474 475 476 477
  cudaStream_t g_stream = CublasHandle::GetInstance().GetCuStream();
  CUDA_CALL(cudaMallocAsync(&ptr_arr, sizeof(void *) * ptr.size(), g_stream));

  for (int g = 0, index = 0; g < num_gemm; ++g) {
    void *A = args[0].operator cinn_buffer_t *()->memory;
    void *B = args[1 + g].operator cinn_buffer_t *()->memory;
    void *C = args[1 + num_gemm + g].operator cinn_buffer_t *()->memory;

    // if opside is 1, exhange A,B.
    if (opside) {
      auto tmp = A;
478 479
      A = B;
      B = tmp;
480 481 482 483 484 485 486
    }

    void *lhs = trans_o ? A : B;
    void *rhs = trans_o ? B : A;

    for (int idx = 0; idx < std::max(l1, r1); ++idx) {
      for (int idy = 0; idy < std::max(l2, r2); ++idy) {
487 488 489 490 491 492
        ptr_a[index] = reinterpret_cast<uint8_t *>(lhs) +
                       (idx * bstride_l + idy * stride_l) * bytes;
        ptr_b[index] = reinterpret_cast<uint8_t *>(rhs) +
                       (idx * bstride_r + idy * stride_r) * bytes;
        ptr_c[index] = reinterpret_cast<uint8_t *>(C) +
                       (idx * bstride_c + idy * m * n) * bytes;
493 494 495 496 497
        ++index;
      }
    }
  }

498 499 500 501 502
  CUDA_CALL(cudaMemcpyAsync(ptr_arr,
                            ptr.data(),
                            ptr.size() * sizeof(void *),
                            cudaMemcpyHostToDevice,
                            g_stream));
503 504
  CUDA_CALL(cudaStreamSynchronize(g_stream));

505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
  CUBLAS_CALL(cublasGemmBatched(
      cuda_dtype,
      cuhandle,
      trans_op_l,
      trans_op_r,
      m,
      n,
      k,
      alpha,
      ptr_arr,
      ldl,
      ptr_arr + std::max(l1, r1) * std::max(l2, r2) * num_gemm,
      ldr,
      beta,
      ptr_arr + std::max(l1, r1) * std::max(l2, r2) * 2 * num_gemm,
      ldc,
      std::max(l1, r1) * std::max(l2, r2) * num_gemm));
522 523 524
  CUDA_CALL(cudaFreeAsync(ptr_arr, custream));
}

525 526
void cinn_call_cuda_memset(
    void *v_args, int num_args, int value, size_t count, void *stream) {
527
  CHECK_EQ(num_args, 1) << "The cinn_call_cuda_memset only accept a output";
528 529
  VLOG(4) << "call cinn_call_cuda_memset with value=" << value
          << ", count=" << count;
530 531

  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
532
  void *output = args[0].operator cinn_buffer_t *()->memory;
533 534 535 536 537 538

  cudaStream_t custream = static_cast<cudaStream_t>(stream);

  CUDA_CALL(cudaMemsetAsync(output, value, count, custream));
}

539 540 541 542 543 544
void cinn_call_cuda_memcpy(void *v_args,
                           int num_args,
                           size_t count,
                           void *stream) {
  CHECK_EQ(num_args, 2)
      << "The cinn_call_cuda_memcpy only accept a input and a output";
545 546 547
  VLOG(4) << "call cinn_call_cuda_memcpy with count=" << count;

  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
548 549
  void *input = args[0].operator cinn_buffer_t *()->memory;
  void *output = args[1].operator cinn_buffer_t *()->memory;
550 551 552

  cudaStream_t custream = static_cast<cudaStream_t>(stream);

553 554
  CUDA_CALL(cudaMemcpyAsync(
      output, input, count, cudaMemcpyDeviceToDevice, custream));
555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586
}

#ifdef CINN_WITH_CUDNN
class CudnnHandle {
 public:
  CudnnHandle(const CudnnHandle &) = delete;
  CudnnHandle &operator=(const CudnnHandle &) = delete;
  ~CudnnHandle() {
    CUDNN_CALL(cudnnDestroy(cuhandle_));
    if (workspace_) {
      CUDA_CALL(cudaFree(workspace_));
    }
  }
  static CudnnHandle &GetInstance() {
    static CudnnHandle instance;
    return instance;
  }
  cudnnHandle_t &GetCudnnHandle() { return cuhandle_; }
  void *GetWorkSpace(size_t size) {
    if (size_ >= size) {
      return workspace_;
    } else {
      if (workspace_) {
        CUDA_CALL(cudaFree(workspace_));
      }
      size_ = size;
      CUDA_CALL(cudaMalloc(&workspace_, size_));
      return workspace_;
    }
  }

 private:
587 588 589
  CudnnHandle() : workspace_(nullptr), size_(0) {
    CUDNN_CALL(cudnnCreate(&cuhandle_));
  }
590 591 592 593 594 595 596 597 598 599 600 601 602
  cudnnHandle_t cuhandle_;
  void *workspace_;
  size_t size_;
};

class ConvAlgoMap {
 public:
  ConvAlgoMap(const ConvAlgoMap &) = delete;
  ConvAlgoMap &operator=(const ConvAlgoMap &) = delete;
  static ConvAlgoMap &GetInstance() {
    static ConvAlgoMap instance;
    return instance;
  }
603 604 605 606 607 608
  void InsertAlgo(const std::string &key, const int algo) {
    algo_map_[key] = algo;
  }
  int GetAlgo(const std::string &key) {
    return algo_map_.count(key) ? algo_map_[key] : -1;
  }
609 610 611 612 613 614 615 616 617

 private:
  ConvAlgoMap() {}
  absl::flat_hash_map<std::string, int> algo_map_;
};

cudnnDataType_t convert_to_cudnn_dtype(void *v_args, int num_args) {
  CHECK_GT(num_args, 0) << "the number of arguments must larger than zero";
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
618 619
  auto type_code = args[0].operator cinn_buffer_t *()->type.code;
  int bits = args[0].operator cinn_buffer_t *()->type.bits;
620 621
  for (int i = 1; i < num_args; ++i) {
    auto t = args[i].operator cinn_buffer_t *()->type.code;
622
    int b = args[0].operator cinn_buffer_t *()->type.bits;
623 624 625 626 627
    if (t != type_code || bits != b) {
      LOG(FATAL) << "The types of all arguments need to be consistent.";
    }
  }
  cudnnDataType_t data_type;
628
  bool is_float = type_code == cinn_type_float;
629 630 631 632 633 634 635 636 637 638
  bool is_bfloat16 = type_code == cinn_type_bfloat;
  if (is_float && bits == 16) {
    data_type = CUDNN_DATA_HALF;
  } else if (is_float && bits == 32) {
    data_type = CUDNN_DATA_FLOAT;
  } else if (is_bfloat16) {
    data_type = CUDNN_DATA_BFLOAT16;
  } else if (is_float && bits == 64) {
    data_type = CUDNN_DATA_DOUBLE;
  } else {
639 640
    LOG(FATAL) << "unsupported cudnn data type: " << static_cast<int>(type_code)
               << ", bits = " << bits;
641 642 643 644 645 646 647 648 649 650 651 652 653
  }
  return data_type;
}

cudnnDataType_t get_cudnn_compute_dtype(cudnnDataType_t data_type) {
  switch (data_type) {
    case CUDNN_DATA_FLOAT:
    case CUDNN_DATA_HALF:
    case CUDNN_DATA_BFLOAT16:
      return CUDNN_DATA_FLOAT;
    case CUDNN_DATA_DOUBLE:
      return CUDNN_DATA_DOUBLE;
    default:
654 655
      LOG(FATAL) << "unsupported cudnn data type, only support "
                    "float16/bfloat16/float32/float64 now!";
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732
  }
  return CUDNN_DATA_FLOAT;
}

std::string debug_cudnn_tensor_format(cudnnTensorFormat_t tensor_format) {
  switch (tensor_format) {
    case CUDNN_TENSOR_NCHW:
      return "NCHW";
    case CUDNN_TENSOR_NHWC:
      return "NHWC";
    default:
      LOG(FATAL) << "Only support NCHW and NHWC data layout\n";
  };
  return "";
}

std::string debug_cudnn_tensor_dtype(cudnnDataType_t tensor_dtype) {
  switch (tensor_dtype) {
    case CUDNN_DATA_FLOAT:
      return "float32";
    case CUDNN_DATA_HALF:
      return "float16";
    case CUDNN_DATA_BFLOAT16:
      return "bfloat16";
    case CUDNN_DATA_DOUBLE:
      return "float64";
    default:
      LOG(FATAL) << "Only support float16/bfloat16/float32/float64 now!";
  };
  return "";
}

std::string debug_cudnn_pool_mode(cudnnPoolingMode_t pool_mode) {
  switch (pool_mode) {
    case CUDNN_POOLING_MAX:
      return "max";
    case CUDNN_POOLING_MAX_DETERMINISTIC:
      return "max_deterministic";
    case CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING:
      return "avg_include_padding";
    case CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING:
      return "avg_exclulude_padding";
    default:
      LOG(FATAL) << "Pool only support max and avg now!";
  };
  return "";
}

void cinn_call_cudnn_conv2d_forward(void *v_args,
                                    int num_args,
                                    int format,
                                    float alpha,
                                    float beta,
                                    int input_n,
                                    int input_c,
                                    int input_h,
                                    int input_w,
                                    int filter_n,
                                    int filter_c,
                                    int filter_h,
                                    int filter_w,
                                    int pad_h,
                                    int pad_w,
                                    int stride_h,
                                    int stride_w,
                                    int dilation_h,
                                    int dilation_w,
                                    int groups,
                                    int output_n,
                                    int output_c,
                                    int output_h,
                                    int output_w,
                                    void *stream) {
  CHECK_EQ(num_args, 3);
  cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle();
  CUDNN_CALL(cudnnSetStream(handle, static_cast<cudaStream_t>(stream)));
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
733 734 735
  void *_x = args[0].operator cinn_buffer_t *()->memory;
  void *_w = args[1].operator cinn_buffer_t *()->memory;
  void *_y = args[2].operator cinn_buffer_t *()->memory;
736 737

  cudnnTensorFormat_t tensor_format = static_cast<cudnnTensorFormat_t>(format);
738
  cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args);
739 740 741

  cudnnTensorDescriptor_t x_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc));
742 743
  CUDNN_CALL(cudnnSetTensor4dDescriptor(
      x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w));
744 745 746

  cudnnFilterDescriptor_t w_desc;
  CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc));
747 748 749 750 751 752 753
  CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc,
                                        data_type,
                                        tensor_format,
                                        filter_n,
                                        filter_c,
                                        filter_h,
                                        filter_w));
754 755 756

  cudnnConvolutionDescriptor_t conv_desc;
  CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc));
757 758 759 760 761 762 763 764 765 766
  CUDNN_CALL(
      cudnnSetConvolution2dDescriptor(conv_desc,
                                      pad_h,
                                      pad_w,
                                      stride_h,
                                      stride_w,
                                      dilation_h,
                                      dilation_w,
                                      CUDNN_CROSS_CORRELATION,
                                      get_cudnn_compute_dtype(data_type)));
767 768 769 770 771
  CUDNN_CALL(cudnnSetConvolutionGroupCount(conv_desc, groups));
  CUDNN_CALL(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH));

  cudnnTensorDescriptor_t y_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc));
772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790
  CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc,
                                        tensor_format,
                                        data_type,
                                        output_n,
                                        output_c,
                                        output_h,
                                        output_w));

  auto &conv_algo_map = ConvAlgoMap::GetInstance();
  std::string hash_key =
      "conv2d forward, layout=" + debug_cudnn_tensor_format(tensor_format) +
      ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" +
      std::to_string(input_n) + "," + std::to_string(input_c) + "," +
      std::to_string(input_h) + "," + std::to_string(input_w) +
      "}, filter_nchw={" + std::to_string(filter_n) + "," +
      std::to_string(filter_c) + "," + std::to_string(filter_h) + "," +
      std::to_string(filter_w) + "}, output_nchw={" + std::to_string(output_n) +
      "," + std::to_string(output_c) + "," + std::to_string(output_h) + "," +
      std::to_string(output_w) + "}";
791 792 793 794 795 796 797 798
  VLOG(4) << hash_key;
  cudnnConvolutionFwdAlgo_t algo;
  int algo_int = conv_algo_map.GetAlgo(hash_key);
  if (algo_int >= 0) {
    algo = cudnnConvolutionFwdAlgo_t(algo_int);
  } else {
    int count = 0;
    cudnnConvolutionFwdAlgoPerf_t algo_perf;
799 800
    CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(
        handle, x_desc, w_desc, conv_desc, y_desc, 1, &count, &algo_perf));
801 802 803 804 805 806 807 808 809 810

    algo = algo_perf.algo;
    conv_algo_map.InsertAlgo(hash_key, static_cast<int>(algo_perf.algo));
  }

  if (GetCinnCudnnDeterministic()) {
    algo = static_cast<cudnnConvolutionFwdAlgo_t>(1);
  }

  size_t workspace_size = 0;
811 812
  CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
      handle, x_desc, w_desc, conv_desc, y_desc, algo, &workspace_size));
813

814 815
  void *workspace_data =
      CudnnHandle::GetInstance().GetWorkSpace(workspace_size);
816 817
  if (data_type == CUDNN_DATA_DOUBLE) {
    const double alpha_fp64 = static_cast<double>(alpha);
818
    const double beta_fp64 = static_cast<double>(beta);
819 820 821 822 823 824 825 826 827 828 829 830 831 832
    CUDNN_CALL(cudnnConvolutionForward(handle,
                                       &alpha_fp64,
                                       x_desc,
                                       _x,
                                       w_desc,
                                       _w,
                                       conv_desc,
                                       algo,
                                       workspace_data,
                                       workspace_size,
                                       &beta_fp64,
                                       y_desc,
                                       _y));
  } else {
833 834 835 836 837 838 839 840 841 842 843 844 845
    CUDNN_CALL(cudnnConvolutionForward(handle,
                                       &alpha,
                                       x_desc,
                                       _x,
                                       w_desc,
                                       _w,
                                       conv_desc,
                                       algo,
                                       workspace_data,
                                       workspace_size,
                                       &beta,
                                       y_desc,
                                       _y));
846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882
  }

  CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc));
  CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc));
  CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc));
}

void cinn_call_cudnn_conv2d_backward_data(void *v_args,
                                          int num_args,
                                          int format,
                                          float alpha,
                                          float beta,
                                          int input_n,
                                          int input_c,
                                          int input_h,
                                          int input_w,
                                          int filter_n,
                                          int filter_c,
                                          int filter_h,
                                          int filter_w,
                                          int pad_h,
                                          int pad_w,
                                          int stride_h,
                                          int stride_w,
                                          int dilation_h,
                                          int dilation_w,
                                          int groups,
                                          int output_n,
                                          int output_c,
                                          int output_h,
                                          int output_w,
                                          void *stream) {
  CHECK_EQ(num_args, 3);
  cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle();
  CUDNN_CALL(cudnnSetStream(handle, static_cast<cudaStream_t>(stream)));
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
883 884 885
  void *_w = args[0].operator cinn_buffer_t *()->memory;
  void *_dy = args[1].operator cinn_buffer_t *()->memory;
  void *_dx = args[2].operator cinn_buffer_t *()->memory;
886 887

  cudnnTensorFormat_t tensor_format = static_cast<cudnnTensorFormat_t>(format);
888
  cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args);
889 890 891

  cudnnTensorDescriptor_t x_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc));
892 893
  CUDNN_CALL(cudnnSetTensor4dDescriptor(
      x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w));
894 895 896

  cudnnFilterDescriptor_t w_desc;
  CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc));
897 898 899 900 901 902 903
  CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc,
                                        data_type,
                                        tensor_format,
                                        filter_n,
                                        filter_c,
                                        filter_h,
                                        filter_w));
904 905 906

  cudnnConvolutionDescriptor_t conv_desc;
  CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc));
907 908 909 910 911 912 913 914 915 916
  CUDNN_CALL(
      cudnnSetConvolution2dDescriptor(conv_desc,
                                      pad_h,
                                      pad_w,
                                      stride_h,
                                      stride_w,
                                      dilation_h,
                                      dilation_w,
                                      CUDNN_CROSS_CORRELATION,
                                      get_cudnn_compute_dtype(data_type)));
917 918 919 920 921
  CUDNN_CALL(cudnnSetConvolutionGroupCount(conv_desc, groups));
  CUDNN_CALL(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH));

  cudnnTensorDescriptor_t y_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc));
922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941
  CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc,
                                        tensor_format,
                                        data_type,
                                        output_n,
                                        output_c,
                                        output_h,
                                        output_w));

  auto &conv_algo_map = ConvAlgoMap::GetInstance();
  std::string hash_key =
      "conv2d backward data, layout=" +
      debug_cudnn_tensor_format(tensor_format) +
      ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" +
      std::to_string(input_n) + "," + std::to_string(input_c) + "," +
      std::to_string(input_h) + "," + std::to_string(input_w) +
      "}, filter_nchw={" + std::to_string(filter_n) + "," +
      std::to_string(filter_c) + "," + std::to_string(filter_h) + "," +
      std::to_string(filter_w) + "}, output_nchw={" + std::to_string(output_n) +
      "," + std::to_string(output_c) + "," + std::to_string(output_h) + "," +
      std::to_string(output_w) + "}";
942 943 944 945 946 947 948 949 950 951

  VLOG(4) << hash_key;

  int algo_int = conv_algo_map.GetAlgo(hash_key);
  cudnnConvolutionBwdDataAlgo_t algo;
  if (algo_int >= 0) {
    algo = cudnnConvolutionBwdDataAlgo_t(algo_int);
  } else {
    int count = 0;
    cudnnConvolutionBwdDataAlgoPerf_t algo_perf;
952 953
    CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(
        handle, w_desc, y_desc, conv_desc, x_desc, 1, &count, &algo_perf));
954 955 956 957 958 959 960 961 962 963

    algo = algo_perf.algo;
    conv_algo_map.InsertAlgo(hash_key, static_cast<int>(algo_perf.algo));
  }

  if (GetCinnCudnnDeterministic()) {
    algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
  }

  size_t workspace_size = 0;
964 965
  CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize(
      handle, w_desc, y_desc, conv_desc, x_desc, algo, &workspace_size));
966

967 968
  void *workspace_data =
      CudnnHandle::GetInstance().GetWorkSpace(workspace_size);
969 970
  if (data_type == CUDNN_DATA_DOUBLE) {
    const double alpha_fp64 = static_cast<double>(alpha);
971
    const double beta_fp64 = static_cast<double>(beta);
972 973 974 975 976 977 978 979 980 981 982 983 984 985
    CUDNN_CALL(cudnnConvolutionBackwardData(handle,
                                            &alpha_fp64,
                                            w_desc,
                                            _w,
                                            y_desc,
                                            _dy,
                                            conv_desc,
                                            algo,
                                            workspace_data,
                                            workspace_size,
                                            &beta_fp64,
                                            x_desc,
                                            _dx));
  } else {
986 987 988 989 990 991 992 993 994 995 996 997 998
    CUDNN_CALL(cudnnConvolutionBackwardData(handle,
                                            &alpha,
                                            w_desc,
                                            _w,
                                            y_desc,
                                            _dy,
                                            conv_desc,
                                            algo,
                                            workspace_data,
                                            workspace_size,
                                            &beta,
                                            x_desc,
                                            _dx));
999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036
  }

  CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc));
  CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc));
  CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc));
}

void cinn_call_cudnn_conv2d_backward_filter(void *v_args,
                                            int num_args,
                                            int format,
                                            float alpha,
                                            float beta,
                                            int input_n,
                                            int input_c,
                                            int input_h,
                                            int input_w,
                                            int filter_n,
                                            int filter_c,
                                            int filter_h,
                                            int filter_w,
                                            int pad_h,
                                            int pad_w,
                                            int stride_h,
                                            int stride_w,
                                            int dilation_h,
                                            int dilation_w,
                                            int groups,
                                            int output_n,
                                            int output_c,
                                            int output_h,
                                            int output_w,
                                            void *stream) {
  CHECK_EQ(num_args, 3);
  cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle();
  CUDNN_CALL(cudnnSetStream(handle, static_cast<cudaStream_t>(stream)));
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);

1037
  void *_x = args[0].operator cinn_buffer_t *()->memory;
1038 1039 1040 1041
  void *_dy = args[1].operator cinn_buffer_t *()->memory;
  void *_dw = args[2].operator cinn_buffer_t *()->memory;

  cudnnTensorFormat_t tensor_format = static_cast<cudnnTensorFormat_t>(format);
1042
  cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args);
1043 1044 1045

  cudnnTensorDescriptor_t x_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc));
1046 1047
  CUDNN_CALL(cudnnSetTensor4dDescriptor(
      x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w));
1048 1049 1050

  cudnnFilterDescriptor_t w_desc;
  CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc));
1051 1052 1053 1054 1055 1056 1057
  CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc,
                                        data_type,
                                        tensor_format,
                                        filter_n,
                                        filter_c,
                                        filter_h,
                                        filter_w));
1058 1059 1060

  cudnnConvolutionDescriptor_t conv_desc;
  CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc));
1061 1062 1063 1064 1065 1066 1067 1068 1069 1070
  CUDNN_CALL(
      cudnnSetConvolution2dDescriptor(conv_desc,
                                      pad_h,
                                      pad_w,
                                      stride_h,
                                      stride_w,
                                      dilation_h,
                                      dilation_w,
                                      CUDNN_CROSS_CORRELATION,
                                      get_cudnn_compute_dtype(data_type)));
1071 1072 1073 1074 1075
  CUDNN_CALL(cudnnSetConvolutionGroupCount(conv_desc, groups));
  CUDNN_CALL(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH));

  cudnnTensorDescriptor_t y_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc));
1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095
  CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc,
                                        tensor_format,
                                        data_type,
                                        output_n,
                                        output_c,
                                        output_h,
                                        output_w));

  auto &algo_map = ConvAlgoMap::GetInstance();
  std::string hash_key =
      "conv2d backward filter, layout=" +
      debug_cudnn_tensor_format(tensor_format) +
      ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" +
      std::to_string(input_n) + "," + std::to_string(input_c) + "," +
      std::to_string(input_h) + "," + std::to_string(input_w) +
      "}, filter_nchw={" + std::to_string(filter_n) + "," +
      std::to_string(filter_c) + "," + std::to_string(filter_h) + "," +
      std::to_string(filter_w) + "}, output_nchw={" + std::to_string(output_n) +
      "," + std::to_string(output_c) + "," + std::to_string(output_h) + "," +
      std::to_string(output_w) + "}";
1096 1097 1098 1099 1100 1101 1102 1103 1104 1105

  VLOG(4) << hash_key;

  int algo_int = algo_map.GetAlgo(hash_key);
  cudnnConvolutionBwdFilterAlgo_t algo;
  if (algo_int >= 0) {
    algo = cudnnConvolutionBwdFilterAlgo_t(algo_int);
  } else {
    int count = 0;
    cudnnConvolutionBwdFilterAlgoPerf_t algo_perf;
1106 1107
    CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(
        handle, x_desc, y_desc, conv_desc, w_desc, 1, &count, &algo_perf));
1108 1109 1110 1111 1112 1113 1114 1115 1116 1117

    algo = algo_perf.algo;
    algo_map.InsertAlgo(hash_key, static_cast<int>(algo_perf.algo));
  }

  if (GetCinnCudnnDeterministic()) {
    algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
  }

  size_t workspace_size = 0;
1118 1119
  CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize(
      handle, x_desc, y_desc, conv_desc, w_desc, algo, &workspace_size));
1120

1121 1122
  void *workspace_data =
      CudnnHandle::GetInstance().GetWorkSpace(workspace_size);
1123 1124
  if (data_type == CUDNN_DATA_DOUBLE) {
    const double alpha_fp64 = static_cast<double>(alpha);
1125
    const double beta_fp64 = static_cast<double>(beta);
1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139
    CUDNN_CALL(cudnnConvolutionBackwardFilter(handle,
                                              &alpha_fp64,
                                              x_desc,
                                              _x,
                                              y_desc,
                                              _dy,
                                              conv_desc,
                                              algo,
                                              workspace_data,
                                              workspace_size,
                                              &beta_fp64,
                                              w_desc,
                                              _dw));
  } else {
1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152
    CUDNN_CALL(cudnnConvolutionBackwardFilter(handle,
                                              &alpha,
                                              x_desc,
                                              _x,
                                              y_desc,
                                              _dy,
                                              conv_desc,
                                              algo,
                                              workspace_data,
                                              workspace_size,
                                              &beta,
                                              w_desc,
                                              _dw));
1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189
  }

  CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc));
  CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc));
  CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc));
}

void cinn_call_cudnn_pool2d_forward(void *v_args,
                                    int num_args,
                                    int mode,
                                    int format,
                                    float alpha,
                                    float beta,
                                    int input_n,
                                    int input_c,
                                    int input_h,
                                    int input_w,
                                    int kernel_h,
                                    int kernel_w,
                                    int pad_h,
                                    int pad_w,
                                    int stride_h,
                                    int stride_w,
                                    int output_n,
                                    int output_c,
                                    int output_h,
                                    int output_w,
                                    void *stream) {
  CHECK_EQ(num_args, 2);
  cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle();
  CUDNN_CALL(cudnnSetStream(handle, static_cast<cudaStream_t>(stream)));
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);

  void *_x = args[0].operator cinn_buffer_t *()->memory;
  void *_y = args[1].operator cinn_buffer_t *()->memory;

1190
  cudnnPoolingMode_t pool_mode = static_cast<cudnnPoolingMode_t>(mode);
1191
  cudnnTensorFormat_t tensor_format = static_cast<cudnnTensorFormat_t>(format);
1192
  cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args);
1193 1194 1195 1196 1197 1198 1199

  if (GetCinnCudnnDeterministic() && pool_mode == CUDNN_POOLING_MAX) {
    pool_mode = CUDNN_POOLING_MAX_DETERMINISTIC;
  }

  std::string hash_key =
      "pool2d forward, layout=" + debug_cudnn_tensor_format(tensor_format) +
1200 1201 1202 1203 1204 1205 1206 1207 1208 1209
      ", pool_type=" + debug_cudnn_pool_mode(pool_mode) +
      ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" +
      std::to_string(input_n) + "," + std::to_string(input_c) + "," +
      std::to_string(input_h) + "," + std::to_string(input_w) +
      "}, kernel_hw={" + std::to_string(kernel_h) + "," +
      std::to_string(kernel_w) + "}, pad_hw={" + std::to_string(pad_h) + "," +
      std::to_string(pad_w) + "}, stride_hw={" + std::to_string(stride_h) +
      "," + std::to_string(stride_w) + "}, output_nchw={" +
      std::to_string(output_n) + "," + std::to_string(output_c) + "," +
      std::to_string(output_h) + "," + std::to_string(output_w) + "}";
1210 1211 1212 1213 1214

  VLOG(4) << hash_key;

  cudnnPoolingDescriptor_t pool_desc;
  CUDNN_CALL(cudnnCreatePoolingDescriptor(&pool_desc));
1215 1216 1217 1218 1219 1220 1221 1222 1223
  CUDNN_CALL(cudnnSetPooling2dDescriptor(pool_desc,
                                         pool_mode,
                                         CUDNN_NOT_PROPAGATE_NAN,
                                         kernel_h,
                                         kernel_w,
                                         pad_h,
                                         pad_w,
                                         stride_h,
                                         stride_w));
1224 1225 1226

  cudnnTensorDescriptor_t x_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc));
1227 1228
  CUDNN_CALL(cudnnSetTensor4dDescriptor(
      x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w));
1229 1230 1231

  cudnnTensorDescriptor_t y_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc));
1232 1233 1234 1235 1236 1237 1238
  CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc,
                                        tensor_format,
                                        data_type,
                                        output_n,
                                        output_c,
                                        output_h,
                                        output_w));
1239 1240 1241

  if (data_type == CUDNN_DATA_DOUBLE) {
    const double alpha_fp64 = static_cast<double>(alpha);
1242 1243 1244
    const double beta_fp64 = static_cast<double>(beta);
    CUDNN_CALL(cudnnPoolingForward(
        handle, pool_desc, &alpha_fp64, x_desc, _x, &beta_fp64, y_desc, _y));
1245
  } else {
1246 1247
    CUDNN_CALL(cudnnPoolingForward(
        handle, pool_desc, &alpha, x_desc, _x, &beta, y_desc, _y));
1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280
  }

  CUDNN_CALL(cudnnDestroyPoolingDescriptor(pool_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc));
}

void cinn_call_cudnn_pool2d_backward(void *v_args,
                                     int num_args,
                                     int mode,
                                     int format,
                                     float alpha,
                                     float beta,
                                     int input_n,
                                     int input_c,
                                     int input_h,
                                     int input_w,
                                     int kernel_h,
                                     int kernel_w,
                                     int pad_h,
                                     int pad_w,
                                     int stride_h,
                                     int stride_w,
                                     int output_n,
                                     int output_c,
                                     int output_h,
                                     int output_w,
                                     void *stream) {
  CHECK_EQ(num_args, 4);
  cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle();
  CUDNN_CALL(cudnnSetStream(handle, static_cast<cudaStream_t>(stream)));
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);

1281 1282
  void *_x = args[0].operator cinn_buffer_t *()->memory;
  void *_y = args[1].operator cinn_buffer_t *()->memory;
1283 1284 1285
  void *_dy = args[2].operator cinn_buffer_t *()->memory;
  void *_dx = args[3].operator cinn_buffer_t *()->memory;

1286
  cudnnPoolingMode_t pool_mode = static_cast<cudnnPoolingMode_t>(mode);
1287
  cudnnTensorFormat_t tensor_format = static_cast<cudnnTensorFormat_t>(format);
1288
  cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args);
1289 1290 1291 1292 1293 1294 1295

  if (GetCinnCudnnDeterministic() && pool_mode == CUDNN_POOLING_MAX) {
    pool_mode = CUDNN_POOLING_MAX_DETERMINISTIC;
  }

  std::string hash_key =
      "pool2d backward, layout=" + debug_cudnn_tensor_format(tensor_format) +
1296 1297 1298 1299 1300 1301 1302 1303 1304 1305
      ", pool_type=" + debug_cudnn_pool_mode(pool_mode) +
      ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" +
      std::to_string(input_n) + "," + std::to_string(input_c) + "," +
      std::to_string(input_h) + "," + std::to_string(input_w) +
      "}, kernel_hw={" + std::to_string(kernel_h) + "," +
      std::to_string(kernel_w) + "}, pad_hw={" + std::to_string(pad_h) + "," +
      std::to_string(pad_w) + "}, stride_hw={" + std::to_string(stride_h) +
      "," + std::to_string(stride_w) + ", output_nchw={" +
      std::to_string(output_n) + "," + std::to_string(output_c) + "," +
      std::to_string(output_h) + "," + std::to_string(output_w) + "}";
1306 1307 1308 1309 1310

  VLOG(4) << hash_key;

  cudnnPoolingDescriptor_t pool_desc;
  CUDNN_CALL(cudnnCreatePoolingDescriptor(&pool_desc));
1311 1312 1313 1314 1315 1316 1317 1318 1319
  CUDNN_CALL(cudnnSetPooling2dDescriptor(pool_desc,
                                         pool_mode,
                                         CUDNN_NOT_PROPAGATE_NAN,
                                         kernel_h,
                                         kernel_w,
                                         pad_h,
                                         pad_w,
                                         stride_h,
                                         stride_w));
1320 1321 1322

  cudnnTensorDescriptor_t x_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc));
1323 1324
  CUDNN_CALL(cudnnSetTensor4dDescriptor(
      x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w));
1325 1326 1327

  cudnnTensorDescriptor_t y_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc));
1328 1329 1330 1331 1332 1333 1334
  CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc,
                                        tensor_format,
                                        data_type,
                                        output_n,
                                        output_c,
                                        output_h,
                                        output_w));
1335 1336 1337

  if (data_type == CUDNN_DATA_DOUBLE) {
    const double alpha_fp64 = static_cast<double>(alpha);
1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350
    const double beta_fp64 = static_cast<double>(beta);
    CUDNN_CALL(cudnnPoolingBackward(handle,
                                    pool_desc,
                                    &alpha_fp64,
                                    y_desc,
                                    _y,
                                    y_desc,
                                    _dy,
                                    x_desc,
                                    _x,
                                    &beta_fp64,
                                    x_desc,
                                    _dx));
1351
  } else {
1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363
    CUDNN_CALL(cudnnPoolingBackward(handle,
                                    pool_desc,
                                    &alpha,
                                    y_desc,
                                    _y,
                                    y_desc,
                                    _dy,
                                    x_desc,
                                    _x,
                                    &beta,
                                    x_desc,
                                    _dx));
1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393
  }

  CUDNN_CALL(cudnnDestroyPoolingDescriptor(pool_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc));
}

void cinn_call_cudnn_softmax_forward(void *v_args,
                                     int num_args,
                                     int mode,
                                     int format,
                                     float alpha,
                                     float beta,
                                     int input_n,
                                     int input_c,
                                     int input_h,
                                     int input_w,
                                     int output_n,
                                     int output_c,
                                     int output_h,
                                     int output_w,
                                     void *stream) {
  CHECK_EQ(num_args, 2);
  cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle();
  CUDNN_CALL(cudnnSetStream(handle, static_cast<cudaStream_t>(stream)));
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);

  void *_x = args[0].operator cinn_buffer_t *()->memory;
  void *_y = args[1].operator cinn_buffer_t *()->memory;

1394
  cudnnSoftmaxMode_t softmax_mode = static_cast<cudnnSoftmaxMode_t>(mode);
1395
  cudnnTensorFormat_t tensor_format = static_cast<cudnnTensorFormat_t>(format);
1396
  cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args);
1397 1398 1399

  cudnnTensorDescriptor_t x_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc));
1400 1401
  CUDNN_CALL(cudnnSetTensor4dDescriptor(
      x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w));
1402 1403 1404

  cudnnTensorDescriptor_t y_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc));
1405 1406 1407 1408 1409 1410 1411
  CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc,
                                        tensor_format,
                                        data_type,
                                        output_n,
                                        output_c,
                                        output_h,
                                        output_w));
1412 1413 1414

  if (data_type == CUDNN_DATA_DOUBLE) {
    const double alpha_fp64 = static_cast<double>(alpha);
1415 1416 1417 1418 1419 1420 1421 1422 1423 1424
    const double beta_fp64 = static_cast<double>(beta);
    CUDNN_CALL(cudnnSoftmaxForward(handle,
                                   CUDNN_SOFTMAX_LOG,
                                   softmax_mode,
                                   &alpha_fp64,
                                   x_desc,
                                   _x,
                                   &beta_fp64,
                                   y_desc,
                                   _y));
1425
  } else {
1426 1427 1428 1429 1430 1431 1432 1433 1434
    CUDNN_CALL(cudnnSoftmaxForward(handle,
                                   CUDNN_SOFTMAX_LOG,
                                   softmax_mode,
                                   &alpha,
                                   x_desc,
                                   _x,
                                   &beta,
                                   y_desc,
                                   _y));
1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460
  }

  CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc));
}

void cinn_call_cudnn_softmax_backward(void *v_args,
                                      int num_args,
                                      int mode,
                                      int format,
                                      float alpha,
                                      float beta,
                                      int input_n,
                                      int input_c,
                                      int input_h,
                                      int input_w,
                                      int output_n,
                                      int output_c,
                                      int output_h,
                                      int output_w,
                                      void *stream) {
  CHECK_EQ(num_args, 3);
  cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle();
  CUDNN_CALL(cudnnSetStream(handle, static_cast<cudaStream_t>(stream)));
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);

1461
  void *_y = args[0].operator cinn_buffer_t *()->memory;
1462 1463 1464
  void *_dy = args[1].operator cinn_buffer_t *()->memory;
  void *_dx = args[2].operator cinn_buffer_t *()->memory;

1465
  cudnnSoftmaxMode_t softmax_mode = static_cast<cudnnSoftmaxMode_t>(mode);
1466
  cudnnTensorFormat_t tensor_format = static_cast<cudnnTensorFormat_t>(format);
1467
  cudnnDataType_t data_type = convert_to_cudnn_dtype(v_args, num_args);
1468 1469 1470

  cudnnTensorDescriptor_t x_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc));
1471 1472
  CUDNN_CALL(cudnnSetTensor4dDescriptor(
      x_desc, tensor_format, data_type, input_n, input_c, input_h, input_w));
1473 1474 1475

  cudnnTensorDescriptor_t y_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc));
1476 1477 1478 1479 1480 1481 1482
  CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc,
                                        tensor_format,
                                        data_type,
                                        output_n,
                                        output_c,
                                        output_h,
                                        output_w));
1483 1484 1485

  if (data_type == CUDNN_DATA_DOUBLE) {
    const double alpha_fp64 = static_cast<double>(alpha);
1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497
    const double beta_fp64 = static_cast<double>(beta);
    CUDNN_CALL(cudnnSoftmaxBackward(handle,
                                    CUDNN_SOFTMAX_LOG,
                                    softmax_mode,
                                    &alpha_fp64,
                                    y_desc,
                                    _y,
                                    y_desc,
                                    _dy,
                                    &beta_fp64,
                                    x_desc,
                                    _dx));
1498
  } else {
1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509
    CUDNN_CALL(cudnnSoftmaxBackward(handle,
                                    CUDNN_SOFTMAX_LOG,
                                    softmax_mode,
                                    &alpha,
                                    y_desc,
                                    _y,
                                    y_desc,
                                    _dy,
                                    &beta,
                                    x_desc,
                                    _dx));
1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534
  }

  CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc));
}

#endif  // CINN_WITH_CUDNN

/********************to be removed in future***********************/

namespace details {

void Gemm(const cublasHandle_t &cublas,
          bool lhs_trans,
          bool rhs_trans,
          const float alpha,
          const float *lhs_data,
          const std::vector<int> &lhs_shape,
          const float *rhs_data,
          const std::vector<int> &rhs_shape,
          const float *bias_data,
          const float beta,
          float *output_data,
          const std::vector<int> &output_shape,
          cudaStream_t stream) {
1535 1536 1537 1538
  int lhs_row = lhs_shape[0];
  int lhs_col = lhs_shape[1];
  int rhs_row = rhs_shape[0];
  int rhs_col = rhs_shape[1];
1539 1540 1541 1542 1543
  int output_row = output_shape[0];
  int output_col = output_shape[1];

  // copy values of bias_data to the output_data
  if (bias_data != nullptr) {
1544 1545 1546 1547 1548
    cudaMemcpyAsync(output_data,
                    bias_data,
                    output_row * output_col * sizeof(float),
                    cudaMemcpyDeviceToDevice,
                    stream);
1549 1550 1551 1552
  }

  int contracting_size = lhs_trans ? lhs_row : lhs_col;
  CHECK_EQ(contracting_size, (rhs_trans ? rhs_col : rhs_row))
1553 1554
      << "The contracting dimension value of lhs matrix should be equal to the "
         "one of rhs matrix.";
1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585
  auto trans_a = rhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N;
  auto trans_b = lhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N;
  cublasSgemm(cublas,
              trans_a,
              trans_b,
              output_col,
              output_row,
              contracting_size,
              &alpha,
              rhs_data,
              rhs_col,
              lhs_data,
              lhs_col,
              &beta,
              output_data,
              output_col);
}

void GemmStridedBatched(const cublasHandle_t &cublas,
                        bool lhs_trans,
                        bool rhs_trans,
                        const float alpha,
                        const float *lhs_data,
                        const std::vector<int> &lhs_shape,
                        const float *rhs_data,
                        const std::vector<int> &rhs_shape,
                        const float *bias_data,
                        const float beta,
                        float *output_data,
                        const std::vector<int> &output_shape,
                        cudaStream_t stream) {
1586 1587 1588 1589 1590 1591 1592
  int lhs_bs = lhs_shape[0];
  int lhs_row = lhs_shape[1];
  int lhs_col = lhs_shape[2];
  int rhs_bs = rhs_shape[0];
  int rhs_row = rhs_shape[1];
  int rhs_col = rhs_shape[2];
  int output_bs = output_shape[0];
1593 1594 1595 1596 1597 1598 1599
  int output_row = output_shape[1];
  int output_col = output_shape[2];
  CHECK_EQ(lhs_bs, rhs_bs);
  CHECK_EQ(lhs_bs, output_bs);

  // copy values of bias_data to the output_data
  if (bias_data != nullptr) {
1600 1601 1602 1603 1604
    cudaMemcpyAsync(output_data,
                    bias_data,
                    output_bs * output_row * output_col * sizeof(float),
                    cudaMemcpyDeviceToDevice,
                    stream);
1605 1606 1607 1608
  }

  int contracting_size = lhs_trans ? lhs_row : lhs_col;
  CHECK_EQ(contracting_size, (rhs_trans ? rhs_col : rhs_row))
1609 1610 1611 1612 1613 1614
      << "The contracting dimension value of lhs matrix should be equal to the "
         "one of rhs matrix.";
  auto trans_a = rhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N;
  auto trans_b = lhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N;
  int64_t lhs_stride = lhs_row * lhs_col;
  int64_t rhs_stride = rhs_row * rhs_col;
1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653
  int64_t output_stride = output_row * output_col;
  cublasSgemmStridedBatched(cublas,
                            trans_a,
                            trans_b,
                            output_col,
                            output_row,
                            contracting_size,
                            &alpha,
                            rhs_data,
                            rhs_col,
                            rhs_stride,
                            lhs_data,
                            lhs_col,
                            lhs_stride,
                            &beta,
                            output_data,
                            output_col,
                            output_stride,
                            output_bs);
}

}  // namespace details

class CusolverHandle {
 public:
  CusolverHandle(const CusolverHandle &) = delete;
  CusolverHandle &operator=(const CusolverHandle &) = delete;
  ~CusolverHandle() { CUSOLVER_CALL(cusolverDnDestroy(handle_)); }
  static CusolverHandle &GetInstance() {
    static CusolverHandle instance;
    return instance;
  }
  cusolverDnHandle_t &GetHandle() { return handle_; }

 private:
  CusolverHandle() { CUSOLVER_CALL(cusolverDnCreate(&handle_)); }
  cusolverDnHandle_t handle_;
};

1654 1655 1656 1657 1658 1659
void cinn_call_cholesky_nvgpu(void *v_args,
                              int num_args,
                              int batch_size,
                              int m,
                              bool upper,
                              void *stream) {
1660
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
1661 1662 1663 1664 1665 1666 1667 1668 1669 1670
  cinn_buffer_t *x = args[0].operator cinn_buffer_t *();
  cinn_buffer_t *out = args[1].operator cinn_buffer_t *();
  // In cuSOLVER, dense matrix stores in COL_MAJOR, thus FILL_MODE needs to be
  // filpped. See also:
  // https://docs.nvidia.com/cuda/cusolver/index.html#matrix-dense-format
  cublasFillMode_t uplo =
      upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
  size_t numel = x->num_elements();
  uint8_t bits = x->type.bits;
  uint8_t bytes = bits / 8;
1671
  CHECK_EQ(x->type.code, cinn_type_code_t::cinn_type_float);
1672 1673
  CHECK(bits == 32 || bits == 64)
      << "Unsupported bits = " << bits << " float data type for cholesky";
1674 1675 1676 1677

  auto cuda_stream = static_cast<cudaStream_t>(stream);

  // Copy data from x to out
1678
  void *x_ptr = reinterpret_cast<void *>(x->memory);
1679
  void *out_ptr = reinterpret_cast<void *>(out->memory);
1680 1681
  CUDA_CALL(cudaMemcpyAsync(
      out_ptr, x_ptr, numel * bytes, cudaMemcpyDeviceToDevice, cuda_stream));
1682 1683 1684 1685 1686
  // Generate pointer array
  thrust::host_vector<void *> host_out_ptr(batch_size, nullptr);
  for (int i = 0; i < batch_size; ++i) {
    host_out_ptr[i] = reinterpret_cast<char *>(out_ptr) + i * m * m * bytes;
  }
1687 1688
  thrust::device_vector<void *> dev_out_ptr(host_out_ptr.begin(),
                                            host_out_ptr.end());
1689 1690 1691 1692 1693 1694 1695
  // Store the return value of each matrix
  thrust::host_vector<int> host_info(batch_size, 0);
  thrust::device_vector<int> dev_info(host_info.begin(), host_info.end());

  cusolverDnHandle_t handler = CusolverHandle::GetInstance().GetHandle();
  CUSOLVER_CALL(cusolverDnSetStream(handler, cuda_stream));
  if (bits == 32) {
1696 1697 1698 1699 1700 1701 1702 1703
    CUSOLVER_CALL(cusolverDnSpotrfBatched(
        handler,
        uplo,
        m,
        reinterpret_cast<float **>(dev_out_ptr.data().get()),
        m,
        thrust::raw_pointer_cast(dev_info.data()),
        batch_size));
1704
  } else if (bits == 64) {
1705 1706 1707 1708 1709 1710 1711 1712
    CUSOLVER_CALL(cusolverDnDpotrfBatched(
        handler,
        uplo,
        m,
        reinterpret_cast<double **>(dev_out_ptr.data().get()),
        m,
        thrust::raw_pointer_cast(dev_info.data()),
        batch_size));
1713 1714 1715 1716 1717
  }

  // Check result
  thrust::copy(dev_info.begin(), dev_info.end(), host_info.begin());
  for (int i = 0; i < host_info.size(); i++) {
1718 1719 1720
    CHECK_EQ(host_info[i], 0)
        << "Cholesky decomposition fail, please check the " << i + 1
        << "th input matrix.";
1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734
  }
}

void cinn_call_triangular_solve_nvgpu(void *v_args,
                                      int num_args,
                                      int batch_size,
                                      int m,
                                      int k,
                                      bool left_side,
                                      bool upper,
                                      bool transpose_a,
                                      bool unit_diagonal,
                                      void *stream) {
  cublasHandle_t &handle = CublasHandle::GetInstance().GetCublasHandle();
1735
  cudaStream_t custream = static_cast<cudaStream_t>(stream);
1736 1737
  CUBLAS_CALL(cublasSetStream(handle, custream));

1738 1739 1740 1741 1742 1743 1744
  int b_rows = left_side ? k : m;
  int b_cols = left_side ? m : k;
  int lda = m;
  int ldb = b_rows;
  cublasSideMode_t side = left_side ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
  cublasFillMode_t uplo =
      upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
1745
  cublasOperation_t transa = transpose_a ? CUBLAS_OP_T : CUBLAS_OP_N;
1746 1747
  cublasDiagType_t diag =
      unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
1748 1749

  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
1750 1751 1752
  cinn_buffer_t *input1 = args[0].operator cinn_buffer_t *();
  cinn_buffer_t *input2 = args[1].operator cinn_buffer_t *();
  cinn_buffer_t *output = args[2].operator cinn_buffer_t *();
1753 1754 1755 1756

  CHECK_EQ(input1->type.code, cinn_type_code_t::cinn_type_float);
  CHECK_EQ(input2->type.code, cinn_type_code_t::cinn_type_float);
  CHECK_EQ(input1->type.bits, input2->type.bits);
1757
  uint8_t bits = input1->type.bits;
1758
  uint8_t bytes = bits / 8;
1759 1760
  CHECK(bits == 32 || bits == 64) << "unsupported bits = " << bits
                                  << " float data type for triangular solve";
1761 1762

  std::string debug_info =
1763 1764 1765 1766 1767 1768 1769 1770 1771
      "triangular solve op: left_side=" + std::to_string(left_side) +
      ", upper=" + std::to_string(uplo) +
      ", transpose_a=" + std::to_string(transa) +
      ", unit_diagonal=" + std::to_string(unit_diagonal) +
      ", batch_size=" + std::to_string(batch_size) +
      ", m=" + std::to_string(m) + ", k=" + std::to_string(k) +
      ", input1_dtype={code: " + std::to_string(input1->type.code) +
      ", bits: " + std::to_string(input1->type.bits) + "}" +
      ", input2_dtype={code: " + std::to_string(input2->type.code) +
1772 1773 1774 1775 1776 1777 1778
      ", bits: " + std::to_string(input2->type.bits) + "}";
  VLOG(4) << debug_info;

  void *a_ptr = reinterpret_cast<void *>(input1->memory);
  void *b_ptr = reinterpret_cast<void *>(input2->memory);
  void *x_ptr = reinterpret_cast<void *>(output->memory);

1779 1780 1781
  // The API cublasStrsmBatched overwrites the right-hand sides, so the
  // right-hand sides should be copied to the output. The output can then be
  // used directly for the calculation.
1782
  size_t numel = input2->num_elements();
1783 1784
  CUDA_CALL(cudaMemcpyAsync(
      x_ptr, b_ptr, numel * bytes, cudaMemcpyDeviceToDevice, custream));
1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796

  std::vector<void *> a_array(batch_size, nullptr);
  std::vector<void *> x_array(batch_size, nullptr);
  for (int i = 0; i < batch_size; ++i) {
    a_array[i] = reinterpret_cast<char *>(a_ptr) + i * m * m * bytes;
    x_array[i] = reinterpret_cast<char *>(x_ptr) + i * m * k * bytes;
  }
  thrust::device_vector<void *> dev_a_array(a_array.begin(), a_array.end());
  thrust::device_vector<void *> dev_x_array(x_array.begin(), x_array.end());

  if (bits == 32) {
    std::vector<float> alpha(batch_size, 1.0f);
1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810
    CUBLAS_CALL(
        cublasStrsmBatched(handle,
                           side,
                           uplo,
                           transa,
                           diag,
                           b_rows,
                           b_cols,
                           alpha.data(),
                           reinterpret_cast<float **>(dev_a_array.data().get()),
                           lda,
                           reinterpret_cast<float **>(dev_x_array.data().get()),
                           ldb,
                           batch_size));
1811 1812
  } else if (bits == 64) {
    std::vector<double> alpha(batch_size, 1.0);
1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826
    CUBLAS_CALL(cublasDtrsmBatched(
        handle,
        side,
        uplo,
        transa,
        diag,
        b_rows,
        b_cols,
        alpha.data(),
        reinterpret_cast<double **>(dev_a_array.data().get()),
        lda,
        reinterpret_cast<double **>(dev_x_array.data().get()),
        ldb,
        batch_size));
1827 1828 1829
  }
}

1830 1831 1832 1833 1834 1835 1836 1837
void cinn_assert_true_nvgpu(
    void *v_args, int num_args, int msg, bool only_warning, void *stream) {
  cinn_assert_true(v_args,
                   num_args,
                   msg,
                   only_warning,
                   stream,
                   common::DefaultNVGPUTarget());
1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848
}

void cinn_gpu_cublas_mul(const std::vector<int> &attrs,
                         cinn_buffer_t *input1,
                         cinn_buffer_t *input2,
                         cinn_buffer_t *output,
                         cudaStream_t stream) {
  cublasHandle_t &handle = CublasHandle::GetInstance().GetCublasHandle();
  CHECK_EQ(input1->type.code, cinn_type_code_t::cinn_type_float);
  cudaStream_t custream = static_cast<cudaStream_t>(stream);
  CUBLAS_CALL(cublasSetStream(handle, custream));
1849 1850
  float *x_data = reinterpret_cast<float *>(input1->memory);
  float *y_data = reinterpret_cast<float *>(input2->memory);
1851
  float *out_data = reinterpret_cast<float *>(output->memory);
1852
  int M = 1;
1853 1854 1855 1856
  CHECK_GE(attrs.size(), 6);
  for (int i = 0; i < attrs[attrs.size() - 2]; i++) {
    M *= attrs[i];
  }
1857 1858
  int N = attrs[attrs.size() - 3];
  int K = attrs[attrs.size() - 4];
1859
  float alpha = 1.f;
1860
  float beta = 0.f;
1861
  // M,N * N,K
1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875
  cublasSgemm(handle,
              CUBLAS_OP_N,
              CUBLAS_OP_N,
              K,
              M,
              N,
              &alpha,
              y_data,
              K,
              x_data,
              N,
              &beta,
              out_data,
              K);
1876 1877 1878 1879 1880 1881 1882 1883 1884
}

void cinn_gpu_cublas_gemm(const std::vector<int> &attrs,
                          cinn_buffer_t *lhs,
                          cinn_buffer_t *rhs,
                          cinn_buffer_t *bias,
                          cinn_buffer_t *output,
                          cudaStream_t stream) {
  cublasHandle_t &handle = CublasHandle::GetInstance().GetCublasHandle();
1885
  cudaStream_t custream = static_cast<cudaStream_t>(stream);
1886 1887 1888
  CUBLAS_CALL(cublasSetStream(handle, custream));

  CHECK_EQ(lhs->type.code, cinn_type_code_t::cinn_type_float);
1889 1890 1891 1892 1893
  const float *lhs_data = reinterpret_cast<const float *>(lhs->memory);
  const float *rhs_data = reinterpret_cast<const float *>(rhs->memory);
  const float *bias_data =
      bias ? reinterpret_cast<const float *>(bias->memory) : nullptr;
  float *output_data = reinterpret_cast<float *>(output->memory);
1894 1895 1896 1897 1898

  CHECK_GE(attrs.size(), 13);
  int lhs_dim_size = attrs[attrs.size() - 7];
  int rhs_dim_size = attrs[attrs.size() - 6];
  int out_dim_size = attrs[attrs.size() - 5];
1899 1900 1901
  bool lhs_trans = static_cast<bool>(attrs[attrs.size() - 4]);
  bool rhs_trans = static_cast<bool>(attrs[attrs.size() - 3]);
  bool out_trans = static_cast<bool>(attrs[attrs.size() - 2]);
1902 1903 1904 1905 1906 1907 1908 1909
  // 1)C = A^T * B    -->  C^T = B^T * A
  // 2)C = A * B^T    -->  C^T = B * A^T
  // 3)C = A^T * B^T  -->  C^T = B * A
  // 4)C = A * B      -->  C^T = B^T * A^T
  if (out_trans) {
    lhs_trans = static_cast<bool>(attrs[attrs.size() - 3]) ^ out_trans;
    rhs_trans = static_cast<bool>(attrs[attrs.size() - 4]) ^ out_trans;
  }
1910 1911 1912
  const float alpha =
      *reinterpret_cast<const float *>(&attrs[attrs.size() - 1]);
  const float beta = bias ? 1.f : 0.f;
1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970
  VLOG(4) << "The lhs_trans value used by cinn_gpu_cublas_gemm: " << lhs_trans;
  VLOG(4) << "The rhs_trans value used by cinn_gpu_cublas_gemm: " << rhs_trans;
  VLOG(4) << "The out_trans value used by cinn_gpu_cublas_gemm: " << out_trans;
  VLOG(4) << "The alpha value used by cinn_gpu_cublas_gemm: " << alpha;
  VLOG(4) << "The beta value used by cinn_gpu_cublas_gemm: " << beta;
  CHECK_EQ(lhs_dim_size, rhs_dim_size);
  CHECK_EQ(lhs_dim_size, out_dim_size);
  CHECK((lhs_dim_size == 2 || lhs_dim_size == 3));

  if (lhs_dim_size == 2) {
    // [row, col]
    std::vector<int> lhs_shape{attrs[0], attrs[1]};
    std::vector<int> rhs_shape{attrs[2], attrs[3]};
    std::vector<int> output_shape{attrs[4], attrs[5]};
    if (out_trans) {
      std::swap(lhs_shape, rhs_shape);
      std::swap(lhs_data, rhs_data);
    }
    details::Gemm(handle,
                  lhs_trans,
                  rhs_trans,
                  alpha,
                  lhs_data,
                  lhs_shape,
                  rhs_data,
                  rhs_shape,
                  bias_data,
                  beta,
                  output_data,
                  output_shape,
                  stream);
  } else {
    // [batch, row, col]
    std::vector<int> lhs_shape{attrs[0], attrs[1], attrs[2]};
    std::vector<int> rhs_shape{attrs[3], attrs[4], attrs[5]};
    std::vector<int> output_shape{attrs[6], attrs[7], attrs[8]};
    if (out_trans) {
      std::swap(lhs_shape, rhs_shape);
      std::swap(lhs_data, rhs_data);
    }
    details::GemmStridedBatched(handle,
                                lhs_trans,
                                rhs_trans,
                                alpha,
                                lhs_data,
                                lhs_shape,
                                rhs_data,
                                rhs_shape,
                                bias_data,
                                beta,
                                output_data,
                                output_shape,
                                stream);
  }
}

class CurandGenerator {
 public:
1971 1972 1973
  CurandGenerator() {
    CURAND_CALL(curandCreateGenerator(&generator_, CURAND_RNG_PSEUDO_DEFAULT));
  }
1974

1975 1976 1977
  CurandGenerator(curandRngType rng_type) {
    CURAND_CALL(curandCreateGenerator(&generator_, rng_type));
  }
1978 1979 1980 1981 1982

  ~CurandGenerator() { CURAND_CALL(curandDestroyGenerator(generator_)); }

  curandGenerator_t &GetGenerator() { return generator_; }

1983
  CurandGenerator &SetOffset(uint64_t offset = 0ULL) {
1984 1985 1986 1987 1988
    CURAND_CALL(curandSetGeneratorOffset(generator_, offset));
    VLOG(4) << "Set curand generator offset to: " << offset;
    return *this;
  }

1989
  CurandGenerator &SetSeed(uint64_t seed = 0ULL) {
1990 1991 1992 1993
    // set global seed if seed is zero
    auto rand_seed = (seed == 0ULL) ? RandomSeed::GetOrSet() : seed;
    if (rand_seed != 0ULL && rand_seed != seed_) {
      CURAND_CALL(curandSetPseudoRandomGeneratorSeed(generator_, rand_seed));
1994 1995
      VLOG(4) << "Change curand random seed from: " << seed_
              << " to: " << rand_seed;
1996 1997 1998 1999 2000 2001 2002 2003
      seed_ = rand_seed;
    }
    return *this;
  }

  CurandGenerator &SetStream(cudaStream_t stream) {
    if (stream != nullptr && stream != stream_) {
      CURAND_CALL(curandSetStream(generator_, stream));
2004 2005
      VLOG(4) << "Change curand generator stream from: " << stream_
              << " to: " << stream;
2006 2007 2008 2009 2010 2011 2012
      stream_ = stream;
    }
    return *this;
  }

 private:
  curandGenerator_t generator_;
2013
  uint64_t seed_ = 0ULL;
2014
  cudaStream_t stream_ = nullptr;
2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028
};

class CurandGeneratorFactory {
 public:
  enum class CurandGeneratorType {
    GENERATOR_DEFAULT,
    GENERATOR_GAUSSIAN,
    GENERATOR_UNIFORM,
    GENERATOR_RANDINT,
  };

  static CurandGenerator &Get(CurandGeneratorType type) {
    switch (type) {
      case CurandGeneratorType::GENERATOR_GAUSSIAN:
2029 2030
        static CurandGenerator gaussian_generator(
            CURAND_RNG_PSEUDO_PHILOX4_32_10);
2031 2032
        return gaussian_generator;
      case CurandGeneratorType::GENERATOR_UNIFORM:
2033 2034
        static CurandGenerator uniform_generator(
            CURAND_RNG_PSEUDO_PHILOX4_32_10);
2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045
        return uniform_generator;
      case CurandGeneratorType::GENERATOR_RANDINT:
        static CurandGenerator randint_generator(CURAND_RNG_PSEUDO_MT19937);
        return randint_generator;
      default:
        static CurandGenerator default_generator;
        return default_generator;
    }
  }
};

2046 2047
void cinn_call_gaussian_random(
    void *v_args, int num_args, float mean, float std, int seed, void *stream) {
2048
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
2049 2050 2051
  cinn_buffer_t *output = args[0].operator cinn_buffer_t *();
  cinn_type_t dtype = output->type;
  size_t numel = output->num_elements();
2052 2053

  curandGenerator_t generator =
2054 2055
      CurandGeneratorFactory::Get(
          CurandGeneratorFactory::CurandGeneratorType::GENERATOR_GAUSSIAN)
2056 2057 2058 2059
          .SetStream(static_cast<cudaStream_t>(stream))
          .SetSeed(seed)
          .GetGenerator();

2060 2061
  VLOG(4) << "cinn_call_gaussian_random: output_size=" << numel
          << ", mean=" << mean << ", std=" << std << ", seed=" << seed;
2062 2063 2064 2065 2066 2067 2068 2069

  if (dtype == cinn_float32_t()) {
    float *ptr = reinterpret_cast<float *>(output->memory);
    CURAND_CALL(curandGenerateNormal(generator, ptr, numel, mean, std));
  } else if (dtype == cinn_float64_t()) {
    double *ptr = reinterpret_cast<double *>(output->memory);
    CURAND_CALL(curandGenerateNormalDouble(generator, ptr, numel, mean, std));
  } else {
2070 2071
    LOG(FATAL)
        << "gaussian_random only support float32 and float64! Please check.";
2072 2073 2074
  }
}

2075 2076
void cinn_call_uniform_random(
    void *v_args, int num_args, float min, float max, int seed, void *stream) {
2077
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
2078 2079 2080
  cinn_buffer_t *output = args[0].operator cinn_buffer_t *();
  cinn_type_t dtype = output->type;
  size_t numel = output->num_elements();
2081 2082

  curandGenerator_t generator =
2083 2084
      CurandGeneratorFactory::Get(
          CurandGeneratorFactory::CurandGeneratorType::GENERATOR_UNIFORM)
2085 2086 2087 2088
          .SetStream(static_cast<cudaStream_t>(stream))
          .SetSeed(seed)
          .GetGenerator();

2089 2090
  VLOG(4) << "cinn_call_uniform_random: output_size=" << numel
          << ", min=" << min << ", max=" << max << ", seed=" << seed;
2091 2092 2093 2094 2095 2096 2097 2098

  if (dtype == cinn_float32_t()) {
    float *ptr = reinterpret_cast<float *>(output->memory);
    CURAND_CALL(curandGenerateUniform(generator, ptr, numel));
  } else if (dtype == cinn_float64_t()) {
    double *ptr = reinterpret_cast<double *>(output->memory);
    CURAND_CALL(curandGenerateUniformDouble(generator, ptr, numel));
  } else {
2099 2100
    LOG(FATAL)
        << "uniform_random only support float32 and float64! Please check.";
2101 2102 2103 2104 2105
  }
}

void cinn_call_randint(void *v_args, int num_args, int seed, void *stream) {
  cinn_pod_value_t *args = static_cast<cinn_pod_value_t *>(v_args);
2106 2107 2108
  cinn_buffer_t *output = args[0].operator cinn_buffer_t *();
  cinn_type_t dtype = output->type;
  size_t numel = output->num_elements();
2109 2110 2111 2112

  VLOG(4) << "cinn_call_randint: output_size=" << numel << ", seed=" << seed;

  curandGenerator_t generator =
2113 2114
      CurandGeneratorFactory::Get(
          CurandGeneratorFactory::CurandGeneratorType::GENERATOR_RANDINT)
2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132
          .SetStream(static_cast<cudaStream_t>(stream))
          .SetSeed(seed)
          .GetGenerator();

  if (dtype == cinn_int32_t()) {
    unsigned int *ptr = reinterpret_cast<unsigned int *>(output->memory);
    CURAND_CALL(curandGenerate(generator, ptr, numel));
  } else {
    LOG(FATAL) << "randint only support int32! Please check.";
  }
}

#ifdef CINN_WITH_CUDNN

namespace {
cudnnDataType_t convert_to_cudnn_dtype(cinn_buffer_t *input) {
  CHECK(input) << "the pointer of input is null";
  auto type_code = input->type.code;
2133
  int bits = input->type.bits;
2134
  cudnnDataType_t data_type;
2135
  bool is_float = type_code == cinn_type_float;
2136 2137 2138 2139 2140 2141 2142 2143 2144 2145
  bool is_bfloat16 = type_code == cinn_type_bfloat;
  if (is_float && bits == 16) {
    data_type = CUDNN_DATA_HALF;
  } else if (is_float && bits == 32) {
    data_type = CUDNN_DATA_FLOAT;
  } else if (is_bfloat16) {
    data_type = CUDNN_DATA_BFLOAT16;
  } else if (is_float && bits == 64) {
    data_type = CUDNN_DATA_DOUBLE;
  } else {
2146 2147
    LOG(FATAL) << "unsupported cudnn data type: " << static_cast<int>(type_code)
               << ", bits = " << bits;
2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207
  }
  return data_type;
}
}  // namespace

#define GetAttrValue(attr_map, key_name, default_value)      \
  int key_name = 0;                                          \
  if (attr_map.count(#key_name) != 0) {                      \
    key_name = attr_map.find(#key_name)->second;             \
  } else if (default_value >= 0) {                           \
    key_name = default_value;                                \
  } else {                                                   \
    LOG(FATAL) << #key_name << " is not exist in attr_map!"; \
  }

void cinn_gpu_cudnn_conv2d(const absl::flat_hash_map<std::string, int> &attr,
                           cinn_buffer_t *x,
                           cinn_buffer_t *w,
                           cinn_buffer_t *y,
                           cudaStream_t stream,
                           common::Layout target) {
  cudnnTensorFormat_t cudnn_tensor_format;
  if (target == common::Layout::kNCHW) {
    cudnn_tensor_format = CUDNN_TENSOR_NCHW;
  } else if (target == common::Layout::kNHWC) {
    cudnn_tensor_format = CUDNN_TENSOR_NHWC;
  } else {
    CINN_NOT_IMPLEMENTED
  }

  GetAttrValue(attr, input_n, -1);
  GetAttrValue(attr, input_c, -1);
  GetAttrValue(attr, input_h, -1);
  GetAttrValue(attr, input_w, -1);
  GetAttrValue(attr, weights_n, -1);
  GetAttrValue(attr, weights_c, -1);
  GetAttrValue(attr, weights_h, -1);
  GetAttrValue(attr, weights_w, -1);
  GetAttrValue(attr, pad_h, 0);
  GetAttrValue(attr, pad_w, 0);
  GetAttrValue(attr, stride_h, 1);
  GetAttrValue(attr, stride_w, 1);
  GetAttrValue(attr, dilation_h, 1);
  GetAttrValue(attr, dilation_w, 1);
  GetAttrValue(attr, groups, 1);
  GetAttrValue(attr, output_n, -1);
  GetAttrValue(attr, output_c, -1);
  GetAttrValue(attr, output_h, -1);
  GetAttrValue(attr, output_w, -1);

  cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle();
  CUDNN_CALL(cudnnSetStream(handle, static_cast<cudaStream_t>(stream)));
  void *_x = x->memory;
  void *_w = w->memory;
  void *_y = y->memory;

  auto data_type = convert_to_cudnn_dtype(x);

  cudnnTensorDescriptor_t x_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc));
2208 2209 2210 2211 2212 2213 2214
  CUDNN_CALL(cudnnSetTensor4dDescriptor(x_desc,
                                        cudnn_tensor_format,
                                        data_type,
                                        input_n,
                                        input_c,
                                        input_h,
                                        input_w));
2215 2216 2217

  cudnnFilterDescriptor_t w_desc;
  CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc));
2218 2219 2220 2221 2222 2223 2224
  CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc,
                                        data_type,
                                        cudnn_tensor_format,
                                        weights_n,
                                        weights_c,
                                        weights_h,
                                        weights_w));
2225 2226 2227

  cudnnConvolutionDescriptor_t conv_desc;
  CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc));
2228 2229 2230 2231 2232 2233 2234 2235 2236 2237
  CUDNN_CALL(
      cudnnSetConvolution2dDescriptor(conv_desc,
                                      pad_h,
                                      pad_w,
                                      stride_h,
                                      stride_w,
                                      dilation_h,
                                      dilation_w,
                                      CUDNN_CROSS_CORRELATION,
                                      get_cudnn_compute_dtype(data_type)));
2238 2239 2240 2241 2242
  CUDNN_CALL(cudnnSetConvolutionGroupCount(conv_desc, groups));
  CUDNN_CALL(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH));

  cudnnTensorDescriptor_t y_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc));
2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261
  CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc,
                                        cudnn_tensor_format,
                                        data_type,
                                        output_n,
                                        output_c,
                                        output_h,
                                        output_w));

  auto &conv_algo_map = ConvAlgoMap::GetInstance();
  std::string hash_key =
      "conv2d forward, layout=" + debug_cudnn_tensor_format(CUDNN_TENSOR_NCHW) +
      ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" +
      std::to_string(input_n) + "," + std::to_string(input_c) + "," +
      std::to_string(input_h) + "," + std::to_string(input_w) +
      "}, filter_nchw={" + std::to_string(weights_n) + "," +
      std::to_string(weights_c) + "," + std::to_string(weights_h) + "," +
      std::to_string(weights_w) + "}, output_nchw={" +
      std::to_string(output_n) + "," + std::to_string(output_c) + "," +
      std::to_string(output_h) + "," + std::to_string(output_w) + "}";
2262 2263 2264 2265 2266 2267 2268 2269

  cudnnConvolutionFwdAlgo_t algo;
  int algo_int = conv_algo_map.GetAlgo(hash_key);
  if (algo_int >= 0) {
    algo = cudnnConvolutionFwdAlgo_t(algo_int);
  } else {
    int count = 0;
    cudnnConvolutionFwdAlgoPerf_t algo_perf;
2270 2271
    CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(
        handle, x_desc, w_desc, conv_desc, y_desc, 1, &count, &algo_perf));
2272 2273 2274 2275 2276 2277 2278 2279 2280 2281

    algo = algo_perf.algo;
    conv_algo_map.InsertAlgo(hash_key, static_cast<int>(algo_perf.algo));
  }

  if (GetCinnCudnnDeterministic()) {
    algo = static_cast<cudnnConvolutionFwdAlgo_t>(1);
  }

  size_t ws_size = 0;
2282 2283
  CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
      handle, x_desc, w_desc, conv_desc, y_desc, algo, &ws_size));
2284 2285 2286 2287

  void *ws_data = CudnnHandle::GetInstance().GetWorkSpace(ws_size);
  if (data_type == CUDNN_DATA_DOUBLE) {
    double alpha[] = {1.f}, beta[] = {0.f};
2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300
    CUDNN_CALL(cudnnConvolutionForward(handle,
                                       alpha,
                                       x_desc,
                                       _x,
                                       w_desc,
                                       _w,
                                       conv_desc,
                                       algo,
                                       ws_data,
                                       ws_size,
                                       beta,
                                       y_desc,
                                       _y));
2301 2302
  } else {
    float alpha[] = {1.f}, beta[] = {0.f};
2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315
    CUDNN_CALL(cudnnConvolutionForward(handle,
                                       alpha,
                                       x_desc,
                                       _x,
                                       w_desc,
                                       _w,
                                       conv_desc,
                                       algo,
                                       ws_data,
                                       ws_size,
                                       beta,
                                       y_desc,
                                       _y));
2316 2317 2318 2319 2320 2321 2322 2323
  }

  CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc));
  CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc));
  CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc));
}

2324 2325 2326 2327 2328 2329
void cinn_gpu_cudnn_conv2d_backward_data(
    const absl::flat_hash_map<std::string, int> &attr,
    cinn_buffer_t *w,
    cinn_buffer_t *dy,
    cinn_buffer_t *dx,
    cudaStream_t stream) {
2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351
  GetAttrValue(attr, input_n, -1);
  GetAttrValue(attr, input_c, -1);
  GetAttrValue(attr, input_h, -1);
  GetAttrValue(attr, input_w, -1);
  GetAttrValue(attr, weights_n, -1);
  GetAttrValue(attr, weights_c, -1);
  GetAttrValue(attr, weights_h, -1);
  GetAttrValue(attr, weights_w, -1);
  GetAttrValue(attr, pad_h, 0);
  GetAttrValue(attr, pad_w, 0);
  GetAttrValue(attr, stride_h, 1);
  GetAttrValue(attr, stride_w, 1);
  GetAttrValue(attr, dilation_h, 1);
  GetAttrValue(attr, dilation_w, 1);
  GetAttrValue(attr, groups, 1);
  GetAttrValue(attr, output_n, -1);
  GetAttrValue(attr, output_c, -1);
  GetAttrValue(attr, output_h, -1);
  GetAttrValue(attr, output_w, -1);

  cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle();
  CUDNN_CALL(cudnnSetStream(handle, static_cast<cudaStream_t>(stream)));
2352
  void *_w = w->memory;
2353 2354 2355 2356 2357 2358 2359
  void *_dy = dy->memory;
  void *_dx = dx->memory;

  auto data_type = convert_to_cudnn_dtype(w);

  cudnnTensorDescriptor_t x_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc));
2360 2361 2362 2363 2364 2365 2366
  CUDNN_CALL(cudnnSetTensor4dDescriptor(x_desc,
                                        CUDNN_TENSOR_NCHW,
                                        data_type,
                                        input_n,
                                        input_c,
                                        input_h,
                                        input_w));
2367 2368 2369

  cudnnFilterDescriptor_t w_desc;
  CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc));
2370 2371 2372 2373 2374 2375 2376
  CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc,
                                        data_type,
                                        CUDNN_TENSOR_NCHW,
                                        weights_n,
                                        weights_c,
                                        weights_h,
                                        weights_w));
2377 2378 2379

  cudnnConvolutionDescriptor_t conv_desc;
  CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc));
2380 2381 2382 2383 2384 2385 2386 2387 2388 2389
  CUDNN_CALL(
      cudnnSetConvolution2dDescriptor(conv_desc,
                                      pad_h,
                                      pad_w,
                                      stride_h,
                                      stride_w,
                                      dilation_h,
                                      dilation_w,
                                      CUDNN_CROSS_CORRELATION,
                                      get_cudnn_compute_dtype(data_type)));
2390 2391 2392 2393 2394
  CUDNN_CALL(cudnnSetConvolutionGroupCount(conv_desc, groups));
  CUDNN_CALL(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH));

  cudnnTensorDescriptor_t y_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc));
2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414
  CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc,
                                        CUDNN_TENSOR_NCHW,
                                        data_type,
                                        output_n,
                                        output_c,
                                        output_h,
                                        output_w));

  auto &conv_algo_map = ConvAlgoMap::GetInstance();
  std::string hash_key =
      "conv2d backward data, layout=" +
      debug_cudnn_tensor_format(CUDNN_TENSOR_NCHW) +
      ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" +
      std::to_string(input_n) + "," + std::to_string(input_c) + "," +
      std::to_string(input_h) + "," + std::to_string(input_w) +
      "}, filter_nchw={" + std::to_string(weights_n) + "," +
      std::to_string(weights_c) + "," + std::to_string(weights_h) + "," +
      std::to_string(weights_w) + "}, output_nchw={" +
      std::to_string(output_n) + "," + std::to_string(output_c) + "," +
      std::to_string(output_h) + "," + std::to_string(output_w) + "}";
2415 2416 2417 2418 2419 2420 2421 2422 2423

  int algo_int = conv_algo_map.GetAlgo(hash_key);
  cudnnConvolutionBwdDataAlgo_t algo;
  if (algo_int >= 0) {
    algo = cudnnConvolutionBwdDataAlgo_t(algo_int);
  } else {
    int count = 0;
    cudnnConvolutionBwdDataAlgoPerf_t algo_perf;

2424 2425
    CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(
        handle, w_desc, y_desc, conv_desc, x_desc, 1, &count, &algo_perf));
2426 2427 2428 2429 2430 2431 2432 2433 2434 2435

    algo = algo_perf.algo;
    conv_algo_map.InsertAlgo(hash_key, static_cast<int>(algo_perf.algo));
  }

  if (GetCinnCudnnDeterministic()) {
    algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
  }

  size_t ws_size = 0;
2436 2437
  CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize(
      handle, w_desc, y_desc, conv_desc, x_desc, algo, &ws_size));
2438 2439 2440 2441

  void *ws_data = CudnnHandle::GetInstance().GetWorkSpace(ws_size);
  if (data_type == CUDNN_DATA_DOUBLE) {
    double alpha[] = {1.0f}, beta[] = {0.0f};
2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454
    CUDNN_CALL(cudnnConvolutionBackwardData(handle,
                                            alpha,
                                            w_desc,
                                            _w,
                                            y_desc,
                                            _dy,
                                            conv_desc,
                                            algo,
                                            ws_data,
                                            ws_size,
                                            beta,
                                            x_desc,
                                            _dx));
2455 2456
  } else {
    float alpha[] = {1.0f}, beta[] = {0.0f};
2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469
    CUDNN_CALL(cudnnConvolutionBackwardData(handle,
                                            alpha,
                                            w_desc,
                                            _w,
                                            y_desc,
                                            _dy,
                                            conv_desc,
                                            algo,
                                            ws_data,
                                            ws_size,
                                            beta,
                                            x_desc,
                                            _dx));
2470 2471 2472 2473 2474 2475 2476 2477
  }

  CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc));
  CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc));
  CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc));
}

2478 2479 2480 2481 2482 2483
void cinn_gpu_cudnn_conv2d_backward_filter(
    const absl::flat_hash_map<std::string, int> &attr,
    cinn_buffer_t *x,
    cinn_buffer_t *dy,
    cinn_buffer_t *dw,
    cudaStream_t stream) {
2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506
  GetAttrValue(attr, input_n, -1);
  GetAttrValue(attr, input_c, -1);
  GetAttrValue(attr, input_h, -1);
  GetAttrValue(attr, input_w, -1);
  GetAttrValue(attr, weights_n, -1);
  GetAttrValue(attr, weights_c, -1);
  GetAttrValue(attr, weights_h, -1);
  GetAttrValue(attr, weights_w, -1);
  GetAttrValue(attr, pad_h, 0);
  GetAttrValue(attr, pad_w, 0);
  GetAttrValue(attr, stride_h, 1);
  GetAttrValue(attr, stride_w, 1);
  GetAttrValue(attr, dilation_h, 1);
  GetAttrValue(attr, dilation_w, 1);
  GetAttrValue(attr, groups, 1);
  GetAttrValue(attr, output_n, -1);
  GetAttrValue(attr, output_c, -1);
  GetAttrValue(attr, output_h, -1);
  GetAttrValue(attr, output_w, -1);

  cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle();
  CUDNN_CALL(cudnnSetStream(handle, static_cast<cudaStream_t>(stream)));

2507
  void *_x = x->memory;
2508 2509 2510 2511 2512 2513 2514
  void *_dy = dy->memory;
  void *_dw = dw->memory;

  auto data_type = convert_to_cudnn_dtype(x);

  cudnnTensorDescriptor_t x_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc));
2515 2516 2517 2518 2519 2520 2521
  CUDNN_CALL(cudnnSetTensor4dDescriptor(x_desc,
                                        CUDNN_TENSOR_NCHW,
                                        data_type,
                                        input_n,
                                        input_c,
                                        input_h,
                                        input_w));
2522 2523 2524

  cudnnFilterDescriptor_t w_desc;
  CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc));
2525 2526 2527 2528 2529 2530 2531
  CUDNN_CALL(cudnnSetFilter4dDescriptor(w_desc,
                                        data_type,
                                        CUDNN_TENSOR_NCHW,
                                        weights_n,
                                        weights_c,
                                        weights_h,
                                        weights_w));
2532 2533 2534

  cudnnConvolutionDescriptor_t conv_desc;
  CUDNN_CALL(cudnnCreateConvolutionDescriptor(&conv_desc));
2535 2536 2537 2538 2539 2540 2541 2542 2543 2544
  CUDNN_CALL(
      cudnnSetConvolution2dDescriptor(conv_desc,
                                      pad_h,
                                      pad_w,
                                      stride_h,
                                      stride_w,
                                      dilation_h,
                                      dilation_w,
                                      CUDNN_CROSS_CORRELATION,
                                      get_cudnn_compute_dtype(data_type)));
2545 2546 2547 2548 2549
  CUDNN_CALL(cudnnSetConvolutionGroupCount(conv_desc, groups));
  CUDNN_CALL(cudnnSetConvolutionMathType(conv_desc, CUDNN_DEFAULT_MATH));

  cudnnTensorDescriptor_t y_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc));
2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569
  CUDNN_CALL(cudnnSetTensor4dDescriptor(y_desc,
                                        CUDNN_TENSOR_NCHW,
                                        data_type,
                                        output_n,
                                        output_c,
                                        output_h,
                                        output_w));

  auto &algo_map = ConvAlgoMap::GetInstance();
  std::string hash_key =
      "conv2d backward filter, layout=" +
      debug_cudnn_tensor_format(CUDNN_TENSOR_NCHW) +
      ", dtype=" + debug_cudnn_tensor_dtype(data_type) + ", input_nchw={" +
      std::to_string(input_n) + "," + std::to_string(input_c) + "," +
      std::to_string(input_h) + "," + std::to_string(input_w) +
      "}, filter_nchw={" + std::to_string(weights_n) + "," +
      std::to_string(weights_c) + "," + std::to_string(weights_h) + "," +
      std::to_string(weights_w) + "}, output_nchw={" +
      std::to_string(output_n) + "," + std::to_string(output_c) + "," +
      std::to_string(output_h) + "," + std::to_string(output_w) + "}";
2570 2571 2572 2573 2574 2575 2576 2577

  int algo_int = algo_map.GetAlgo(hash_key);
  cudnnConvolutionBwdFilterAlgo_t algo;
  if (algo_int >= 0) {
    algo = cudnnConvolutionBwdFilterAlgo_t(algo_int);
  } else {
    int count = 0;
    cudnnConvolutionBwdFilterAlgoPerf_t algo_perf;
2578 2579
    CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(
        handle, x_desc, y_desc, conv_desc, w_desc, 1, &count, &algo_perf));
2580 2581 2582 2583 2584 2585 2586 2587 2588 2589

    algo = algo_perf.algo;
    algo_map.InsertAlgo(hash_key, static_cast<int>(algo_perf.algo));
  }

  if (GetCinnCudnnDeterministic()) {
    algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
  }

  size_t ws_size = 0;
2590 2591
  CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize(
      handle, x_desc, y_desc, conv_desc, w_desc, algo, &ws_size));
2592 2593 2594 2595

  void *ws_data = CudnnHandle::GetInstance().GetWorkSpace(ws_size);
  if (data_type == CUDNN_DATA_DOUBLE) {
    double alpha[] = {1.0}, beta[] = {0.0};
2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608
    CUDNN_CALL(cudnnConvolutionBackwardFilter(handle,
                                              alpha,
                                              x_desc,
                                              _x,
                                              y_desc,
                                              _dy,
                                              conv_desc,
                                              algo,
                                              ws_data,
                                              ws_size,
                                              beta,
                                              w_desc,
                                              _dw));
2609 2610
  } else {
    float alpha[] = {1.0}, beta[] = {0.0};
2611 2612 2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623
    CUDNN_CALL(cudnnConvolutionBackwardFilter(handle,
                                              alpha,
                                              x_desc,
                                              _x,
                                              y_desc,
                                              _dy,
                                              conv_desc,
                                              algo,
                                              ws_data,
                                              ws_size,
                                              beta,
                                              w_desc,
                                              _dw));
2624 2625 2626 2627 2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639 2640
  }

  CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc));
  CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc));
  CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
  CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc));
}

void cinn_gpu_cudnn_pool2d(const std::vector<int> &attrs,
                           const std::vector<std::string> &str_attrs,
                           cinn_buffer_t *input,
                           cinn_buffer_t *output,
                           cudaStream_t stream) {
  cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle();
  CUDNN_CALL(cudnnSetStream(handle, static_cast<cudaStream_t>(stream)));
  CHECK_EQ(attrs.size(), 17);
  // Here the input paddings are pad_top, pad_bottom, pad_left, pad_right.
2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657
  // Since pad_top==pad_bottom and pad_left==pad_rifht, we only take pad_top and
  // pad_left.
  int input_n = attrs[0];
  int input_c = attrs[1];
  int input_h = attrs[2];
  int input_w = attrs[3];
  int kernel_h = attrs[4];
  int kernel_w = attrs[5];
  int pad_h = attrs[6];
  int pad_w = attrs[8];
  int stride_h = attrs[10];
  int stride_w = attrs[11];
  int output_n = attrs[12];
  int output_c = attrs[13];
  int output_h = attrs[14];
  int output_w = attrs[15];
  int adaptive = attrs[16];
2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677
  std::string pool_type = str_attrs[0];
  cudnnPoolingDescriptor_t pooling_desc;
  CUDNN_CALL(cudnnCreatePoolingDescriptor(&pooling_desc));
  cudnnPoolingMode_t pool_mode;
  if (pool_type == "max") {
    pool_mode = CUDNN_POOLING_MAX;
  } else if (pool_type == "avg") {
    pool_mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
  } else {
    LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
  }
  if (adaptive == 1) {
    stride_h = input_h / output_h;
    stride_w = input_w / output_w;
    kernel_h = input_h - (output_h - 1) * stride_h;
    kernel_w = input_w - (output_w - 1) * stride_w;
  }

  auto data_type = convert_to_cudnn_dtype(input);

2678 2679 2680 2681 2682 2683 2684 2685 2686
  CUDNN_CALL(cudnnSetPooling2dDescriptor(pooling_desc,
                                         pool_mode,
                                         CUDNN_NOT_PROPAGATE_NAN,
                                         kernel_h,
                                         kernel_w,
                                         pad_h,
                                         pad_w,
                                         stride_h,
                                         stride_w));
2687 2688 2689 2690 2691

  cudnnTensorDescriptor_t in_desc;

  CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc));

2692 2693 2694 2695 2696 2697 2698
  CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc,
                                        CUDNN_TENSOR_NCHW,
                                        data_type,
                                        input_n,
                                        input_c,
                                        input_h,
                                        input_w));
2699 2700 2701 2702 2703

  cudnnTensorDescriptor_t out_desc;

  CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc));

2704 2705 2706 2707 2708 2709 2710
  CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc,
                                        CUDNN_TENSOR_NCHW,
                                        data_type,
                                        output_n,
                                        output_c,
                                        output_h,
                                        output_w));
2711

2712
  void *in_data = input->memory;
2713 2714 2715 2716
  void *out_data = output->memory;

  if (data_type == CUDNN_DATA_DOUBLE) {
    double alpha = 1.0f;
2717 2718 2719 2720 2721 2722 2723 2724 2725
    double beta = 0.0f;
    CUDNN_CALL(cudnnPoolingForward(handle,
                                   pooling_desc,
                                   &alpha,
                                   in_desc,
                                   in_data,
                                   &beta,
                                   out_desc,
                                   out_data));
2726 2727
  } else {
    float alpha = 1.0f;
2728 2729 2730 2731 2732 2733 2734 2735 2736
    float beta = 0.0f;
    CUDNN_CALL(cudnnPoolingForward(handle,
                                   pooling_desc,
                                   &alpha,
                                   in_desc,
                                   in_data,
                                   &beta,
                                   out_desc,
                                   out_data));
2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752
  }

  cudnnDestroyTensorDescriptor(in_desc);
  cudnnDestroyTensorDescriptor(out_desc);
  cudnnDestroyPoolingDescriptor(pooling_desc);
}

void cinn_gpu_cudnn_softmax(const std::vector<int> &attrs,
                            cinn_buffer_t *input,
                            cinn_buffer_t *output,
                            cudaStream_t stream) {
  std::vector<int> shape;
  int rank = attrs.size() - 1;
  for (int i = 0; i < rank; i++) {
    shape.push_back(attrs[i]);
  }
2753 2754
  int axis = attrs.back();
  axis = axis < 0 ? rank + axis : axis;
2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768
  int inner_num = 1;
  int outer_num = 1;
  for (int i = 0; i < shape.size(); i++) {
    if (i < axis)
      outer_num *= shape[i];
    else if (i > axis)
      inner_num *= shape[i];
  }
  rank = shape.size();

  auto data_type = convert_to_cudnn_dtype(input);

  cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle();
  CUDNN_CALL(cudnnSetStream(handle, static_cast<cudaStream_t>(stream)));
2769
  void *in_data = input->memory;
2770 2771 2772 2773
  void *out_data = output->memory;

  cudnnTensorDescriptor_t in_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc));
2774 2775 2776 2777 2778 2779 2780
  CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc,
                                        CUDNN_TENSOR_NCHW,
                                        data_type,
                                        outer_num,
                                        shape[axis],
                                        inner_num,
                                        1));
2781 2782 2783

  cudnnTensorDescriptor_t out_desc;
  CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc));
2784 2785 2786 2787 2788 2789 2790
  CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc,
                                        CUDNN_TENSOR_NCHW,
                                        data_type,
                                        outer_num,
                                        shape[axis],
                                        inner_num,
                                        1));
2791 2792 2793

  if (data_type == CUDNN_DATA_DOUBLE) {
    double alpha = 1.f;
2794
    double beta = 0.f;
2795 2796 2797 2798 2799 2800 2801 2802 2803 2804 2805
    CUDNN_CALL(cudnnSoftmaxForward(handle,
                                   CUDNN_SOFTMAX_ACCURATE,
                                   CUDNN_SOFTMAX_MODE_CHANNEL,
                                   &alpha,
                                   in_desc,
                                   in_data,
                                   &beta,
                                   out_desc,
                                   out_data));
  } else {
    float alpha = 1.f;
2806
    float beta = 0.f;
2807 2808 2809 2810 2811 2812 2813 2814 2815 2816 2817 2818 2819 2820 2821 2822 2823 2824 2825 2826
    CUDNN_CALL(cudnnSoftmaxForward(handle,
                                   CUDNN_SOFTMAX_ACCURATE,
                                   CUDNN_SOFTMAX_MODE_CHANNEL,
                                   &alpha,
                                   in_desc,
                                   in_data,
                                   &beta,
                                   out_desc,
                                   out_data));
  }

  cudnnDestroyTensorDescriptor(in_desc);
  cudnnDestroyTensorDescriptor(out_desc);
}

#endif  // CINN_WITH_CUDNN

}  // namespace cuda
}  // namespace runtime
}  // namespace cinn