blas_impl.hip.h 30.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/rocblas.h"
#include "paddle/fluid/platform/gpu_info.h"

DECLARE_bool(enable_cublas_tensor_op_math);

namespace paddle {
namespace operators {
namespace math {

template <typename T>
struct CUBlas;

template <>
struct CUBlas<float> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_sgemm(args...));
  }

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_saxpy(args...));
  }

  template <typename... ARGS>
  static void SCAL(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_sscal(args...));
  }

  template <typename... ARGS>
  static void VCOPY(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_scopy(args...));
  }

  template <typename... ARGS>
  static void GEMV(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_sgemv(args...));
  }

  template <typename... ARGS>
  static void GEMM_STRIDED_BATCH(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::rocblas_sgemm_strided_batched(args...));
  }

  // HIP not supportted, refer to the doc here:
  // https://github.com/ROCm-Developer-Tools/HIP/blob/roc-3.5.x/docs/markdown/CUBLAS_API_supported_by_HIP.md
  template <typename... ARGS>
  static void GEMM_EX(ARGS... args) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "cublasSgemmEx is not supported on HIP platform."));
  }

  template <typename... ARGS>
  static void TRSM(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_strsm(args...));
  }

  template <typename... ARGS>
  static void GETRF_BATCH(ARGS... args) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "cublasSgetrfBatched is not supported on HIP platform."));
  }

  template <typename... ARGS>
  static void GETRI_BATCH(ARGS... args) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "cublasSgetriBatched is not supported on HIP platform."));
  }

  template <typename... ARGS>
  static void MATINV_BATCH(ARGS... args) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "cublasSmatinvBatched is not supported on HIP platform."));
  }
};

template <>
struct CUBlas<double> {
  template <typename... ARGS>
  static void GEMM(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_dgemm(args...));
  }

  template <typename... ARGS>
  static void AXPY(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_daxpy(args...));
  }

  template <typename... ARGS>
  static void SCAL(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_dscal(args...));
  }

  template <typename... ARGS>
  static void VCOPY(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_dcopy(args...));
  }

  template <typename... ARGS>
  static void GEMV(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_dgemv(args...));
  }

  template <typename... ARGS>
  static void GEMM_STRIDED_BATCH(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::rocblas_dgemm_strided_batched(args...));
  }

  template <typename... ARGS>
  static void GEMM_EX(ARGS... args) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Currently there are not cublasDgemmEx."));
  }

  template <typename... ARGS>
  static void TRSM(ARGS... args) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_dtrsm(args...));
  }

  template <typename... ARGS>
  static void GETRF_BATCH(ARGS... args) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "cublasDgetrfBatched is not supported on HIP platform."));
  }

  template <typename... ARGS>
  static void GETRI_BATCH(ARGS... args) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "cublasDgetriBatched is not supported on HIP platform."));
  }

  template <typename... ARGS>
  static void MATINV_BATCH(ARGS... args) {
    PADDLE_THROW(platform::errors::Unimplemented(
        "cublasDmatinvBatched is not supported on HIP platform."));
  }
};

template <>
struct CUBlas<platform::float16> {
  using float16 = platform::float16;

  static void GEMM(rocblas_handle handle, rocblas_operation transa,
                   rocblas_operation transb, int m, int n, int k,
                   const float16 *alpha, const float16 *A, int lda,
                   const float16 *B, int ldb, const float16 *beta, float16 *C,
                   int ldc) {
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_hgemm(
        handle, transa, transb, m, n, k,
        reinterpret_cast<const rocblas_half *>(alpha),
        reinterpret_cast<const rocblas_half *>(A), lda,
        reinterpret_cast<const rocblas_half *>(B), ldb,
        reinterpret_cast<const rocblas_half *>(beta),
        reinterpret_cast<rocblas_half *>(C), ldc));
  }

  static void GEMM_STRIDED_BATCH(rocblas_handle handle,
                                 rocblas_operation transa,
                                 rocblas_operation transb, int m, int n, int k,
                                 const float16 *alpha, const float16 *A,
                                 int lda, long long int strideA,  // NOLINT
                                 const float16 *B,                // NOLINT
                                 int ldb, long long int strideB,  // NOLINT
                                 const float16 *beta, float16 *C, int ldc,
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::rocblas_hgemm_strided_batched(
            handle, transa, transb, m, n, k,
            reinterpret_cast<const rocblas_half *>(alpha),
            reinterpret_cast<const rocblas_half *>(A), lda, strideA,
            reinterpret_cast<const rocblas_half *>(B), ldb, strideB,
            reinterpret_cast<const rocblas_half *>(beta),
            reinterpret_cast<rocblas_half *>(C), ldc, strideC, batchCount));
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
                      rocblas_operation transa, rocblas_operation transb, int m,
                      int n, int k, const void *alpha, const void *A,
                      rocblas_datatype Atype, int lda, const void *B,
                      rocblas_datatype Btype, int ldb, const void *beta,
                      void *C, rocblas_datatype Ctype, int ldc,
                      rocblas_datatype computeType) {
    rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
    dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_gemm_ex(
          handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
          beta, C, Ctype, ldc, C, Ctype, ldc, computeType, algo, 0, 0));
    });
  }
};

template <>
216
struct CUBlas<platform::complex<float>> {
217
  static void GEMV(rocblas_handle handle, rocblas_operation transa, int m,
218 219 220 221 222
                   int n, const platform::complex<float> *alpha,
                   const platform::complex<float> *A, int lda,
                   const platform::complex<float> *B, int ldb,
                   const platform::complex<float> *beta,
                   platform::complex<float> *C, int ldc) {
223 224 225 226 227 228 229 230 231
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_cgemv(
        handle, transa, m, n,
        reinterpret_cast<const rocblas_float_complex *>(alpha),
        reinterpret_cast<const rocblas_float_complex *>(A), lda,
        reinterpret_cast<const rocblas_float_complex *>(B), ldb,
        reinterpret_cast<const rocblas_float_complex *>(beta),
        reinterpret_cast<rocblas_float_complex *>(C), ldc));
  }

232 233 234 235
  static void AXPY(rocblas_handle handle, int n,
                   const platform::complex<float> *alpha,
                   const platform::complex<float> *X, const int incX,
                   platform::complex<float> *Y, const int incY) {
236 237 238 239 240 241 242 243 244
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_caxpy(
        handle, n, reinterpret_cast<const rocblas_float_complex *>(alpha),
        reinterpret_cast<const rocblas_float_complex *>(X), incX,
        reinterpret_cast<rocblas_float_complex *>(Y), incY));
  }

  static void GEMM_STRIDED_BATCH(rocblas_handle handle,
                                 rocblas_operation transa,
                                 rocblas_operation transb, int m, int n, int k,
245 246 247 248 249 250 251
                                 const platform::complex<float> *alpha,
                                 const platform::complex<float> *A, int lda,
                                 long long int strideA,              // NOLINT
                                 const platform::complex<float> *B,  // NOLINT
                                 int ldb, long long int strideB,     // NOLINT
                                 const platform::complex<float> *beta,
                                 platform::complex<float> *C, int ldc,
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::rocblas_cgemm_strided_batched(
            handle, transa, transb, m, n, k,
            reinterpret_cast<const rocblas_float_complex *>(alpha),
            reinterpret_cast<const rocblas_float_complex *>(A), lda, strideA,
            reinterpret_cast<const rocblas_float_complex *>(B), ldb, strideB,
            reinterpret_cast<const rocblas_float_complex *>(beta),
            reinterpret_cast<rocblas_float_complex *>(C), ldc, strideC,
            batchCount));
  }

  static void GEMM(rocblas_handle handle, rocblas_operation transa,
                   rocblas_operation transb, int m, int n, int k,
267 268 269 270 271
                   const platform::complex<float> *alpha,
                   const platform::complex<float> *A, int lda,
                   const platform::complex<float> *B, int ldb,
                   const platform::complex<float> *beta,
                   platform::complex<float> *C, int ldc) {
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_cgemm(
        handle, transa, transb, m, n, k,
        reinterpret_cast<const rocblas_float_complex *>(alpha),
        reinterpret_cast<const rocblas_float_complex *>(A), lda,
        reinterpret_cast<const rocblas_float_complex *>(B), ldb,
        reinterpret_cast<const rocblas_float_complex *>(beta),
        reinterpret_cast<rocblas_float_complex *>(C), ldc));
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
                      rocblas_operation transa, rocblas_operation transb, int m,
                      int n, int k, const void *alpha, const void *A,
                      rocblas_datatype Atype, int lda, const void *B,
                      rocblas_datatype Btype, int ldb, const void *beta,
                      void *C, rocblas_datatype Ctype, int ldc,
                      rocblas_datatype computeType) {
    rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
    dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_gemm_ex(
          handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
          beta, C, Ctype, ldc, C, Ctype, ldc, computeType, algo, 0, 0));
    });
  }
};

template <>
301
struct CUBlas<platform::complex<double>> {
302
  static void GEMV(rocblas_handle handle, rocblas_operation transa, int m,
303 304 305 306 307
                   int n, const platform::complex<double> *alpha,
                   const platform::complex<double> *A, int lda,
                   const platform::complex<double> *B, int ldb,
                   const platform::complex<double> *beta,
                   platform::complex<double> *C, int ldc) {
308 309 310 311 312 313 314 315 316
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zgemv(
        handle, transa, m, n,
        reinterpret_cast<const rocblas_double_complex *>(alpha),
        reinterpret_cast<const rocblas_double_complex *>(A), lda,
        reinterpret_cast<const rocblas_double_complex *>(B), ldb,
        reinterpret_cast<const rocblas_double_complex *>(beta),
        reinterpret_cast<rocblas_double_complex *>(C), ldc));
  }

317 318 319 320
  static void AXPY(rocblas_handle handle, int n,
                   const platform::complex<double> *alpha,
                   const platform::complex<double> *X, const int incX,
                   platform::complex<double> *Y, const int incY) {
321 322 323 324 325 326 327 328 329
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zaxpy(
        handle, n, reinterpret_cast<const rocblas_double_complex *>(alpha),
        reinterpret_cast<const rocblas_double_complex *>(X), incX,
        reinterpret_cast<rocblas_double_complex *>(Y), incY));
  }

  static void GEMM_STRIDED_BATCH(rocblas_handle handle,
                                 rocblas_operation transa,
                                 rocblas_operation transb, int m, int n, int k,
330 331 332 333 334 335 336
                                 const platform::complex<double> *alpha,
                                 const platform::complex<double> *A, int lda,
                                 long long int strideA,               // NOLINT
                                 const platform::complex<double> *B,  // NOLINT
                                 int ldb, long long int strideB,      // NOLINT
                                 const platform::complex<double> *beta,
                                 platform::complex<double> *C, int ldc,
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
                                 long long int strideC,  // NOLINT
                                 int batchCount) {
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::rocblas_zgemm_strided_batched(
            handle, transa, transb, m, n, k,
            reinterpret_cast<const rocblas_double_complex *>(alpha),
            reinterpret_cast<const rocblas_double_complex *>(A), lda, strideA,
            reinterpret_cast<const rocblas_double_complex *>(B), ldb, strideB,
            reinterpret_cast<const rocblas_double_complex *>(beta),
            reinterpret_cast<rocblas_double_complex *>(C), ldc, strideC,
            batchCount));
  }

  static void GEMM(rocblas_handle handle, rocblas_operation transa,
                   rocblas_operation transb, int m, int n, int k,
352 353 354 355 356
                   const platform::complex<double> *alpha,
                   const platform::complex<double> *A, int lda,
                   const platform::complex<double> *B, int ldb,
                   const platform::complex<double> *beta,
                   platform::complex<double> *C, int ldc) {
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zgemm(
        handle, transa, transb, m, n, k,
        reinterpret_cast<const rocblas_double_complex *>(alpha),
        reinterpret_cast<const rocblas_double_complex *>(A), lda,
        reinterpret_cast<const rocblas_double_complex *>(B), ldb,
        reinterpret_cast<const rocblas_double_complex *>(beta),
        reinterpret_cast<rocblas_double_complex *>(C), ldc));
  }

  // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
  // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
  template <typename... ARGS>
  static void GEMM_EX(platform::CUDADeviceContext *dev_ctx,
                      rocblas_operation transa, rocblas_operation transb, int m,
                      int n, int k, const void *alpha, const void *A,
                      rocblas_datatype Atype, int lda, const void *B,
                      rocblas_datatype Btype, int ldb, const void *beta,
                      void *C, rocblas_datatype Ctype, int ldc,
                      rocblas_datatype computeType) {
    rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
    dev_ctx->TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) {
      PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_gemm_ex(
          handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb,
          beta, C, Ctype, ldc, C, Ctype, ldc, computeType, algo, 0, 0));
    });
  }
};

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
                                             CBLAS_TRANSPOSE transB, int M,
                                             int N, int K, T alpha, const T *A,
                                             const T *B, T beta, T *C) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda,
                    &beta, C, N);
  });
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
    platform::float16 alpha, const platform::float16 *A,
    const platform::float16 *B, platform::float16 beta,
    platform::float16 *C) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(), 53,
      platform::errors::InvalidArgument(
          "cublas fp16 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  float h_alpha = static_cast<float>(alpha);
  float h_beta = static_cast<float>(beta);

  auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
  CUBlas<platform::float16>::GEMM_EX(
      &cuda_ctx, cuTransB, cuTransA, N, M, K, &h_alpha, B,
      rocblas_datatype_f16_r, ldb, A, rocblas_datatype_f16_r, lda, &h_beta, C,
      rocblas_datatype_f16_r, N, rocblas_datatype_f32_r);
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
447 448 449
    platform::complex<float> alpha, const platform::complex<float> *A,
    const platform::complex<float> *B, platform::complex<float> beta,
    platform::complex<float> *C) const {
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(), 53,
      platform::errors::InvalidArgument(
          "cublas complex64 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  thrust::complex<float> c_alpha =
      thrust::complex<float>(alpha.real, alpha.imag);
  thrust::complex<float> c_beta = thrust::complex<float>(beta.real, beta.imag);

  auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
474
  CUBlas<platform::complex<float>>::GEMM_EX(
475 476 477 478 479 480 481 482 483
      &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B,
      rocblas_datatype_f32_c, ldb, A, rocblas_datatype_f32_c, lda, &c_beta, C,
      rocblas_datatype_f32_c, N, rocblas_datatype_f32_c);
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
484 485 486
    platform::complex<double> alpha, const platform::complex<double> *A,
    const platform::complex<double> *B, platform::complex<double> beta,
    platform::complex<double> *C) const {
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;

  // TODO(kexinzhao): add processing code for compute capability < 53 case
  PADDLE_ENFORCE_GE(
      context_.GetComputeCapability(), 53,
      platform::errors::InvalidArgument(
          "cublas complex128 gemm requires GPU compute capability >= 53,"
          "but received %d",
          context_.GetComputeCapability()));

  thrust::complex<double> c_alpha =
      thrust::complex<double>(alpha.real, alpha.imag);
  thrust::complex<double> c_beta =
      thrust::complex<double>(beta.real, beta.imag);

  auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
512
  CUBlas<platform::complex<double>>::GEMM_EX(
513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722
      &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B,
      rocblas_datatype_f64_c, ldb, A, rocblas_datatype_f64_c, lda, &c_beta, C,
      rocblas_datatype_f64_c, N, rocblas_datatype_f64_c);
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
                                             int N, int K, T alpha, const T *A,
                                             int lda, const T *B, int ldb,
                                             T beta, T *C, int ldc) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  rocblas_operation cuTransA =
      transA ? rocblas_operation_transpose : rocblas_operation_none;
  rocblas_operation cuTransB =
      transB ? rocblas_operation_transpose : rocblas_operation_none;
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda,
                    &beta, C, ldc);
  });
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
    bool transA, bool transB, int M, int N, int K, platform::float16 alpha,
    const platform::float16 *A, int lda, const platform::float16 *B, int ldb,
    platform::float16 beta, platform::float16 *C, int ldc) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  rocblas_operation cuTransA =
      transA ? rocblas_operation_transpose : rocblas_operation_none;
  rocblas_operation cuTransB =
      transB ? rocblas_operation_transpose : rocblas_operation_none;

  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<platform::float16>::GEMM(handle, cuTransB, cuTransA, N, M, K, &alpha,
                                    B, ldb, A, lda, &beta, C, ldc);
  });
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::AXPY(int n, T alpha, const T *x,
                                             T *y) const {
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::AXPY(handle, n, &alpha, x, 1, y, 1);
  });
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::SCAL(int n, const T alpha, T *x) const {
  context_.CublasCall(
      [&](rocblas_handle handle) { CUBlas<T>::SCAL(handle, n, &alpha, x, 1); });
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::VCOPY(int n, const T *x, T *y) const {
  context_.CublasCall(
      [&](rocblas_handle handle) { CUBlas<T>::VCOPY(handle, n, x, 1, y, 1); });
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
                                             T alpha, const T *A, const T *B,
                                             T beta, T *C) const {
  rocblas_operation cuTransA =
      !trans_a ? rocblas_operation_transpose : rocblas_operation_none;

  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GEMV(handle, cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1);
  });
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMV(
    bool trans_a, int M, int N, platform::float16 alpha,
    const platform::float16 *A, const platform::float16 *B,
    platform::float16 beta, platform::float16 *C) const {
  // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it.
  if (trans_a) {
    this->template GEMM<platform::float16>(CblasNoTrans, CblasNoTrans, 1, N, M,
                                           alpha, B, A, beta, C);
  } else {
    this->template GEMM<platform::float16>(CblasNoTrans, CblasNoTrans, M, 1, N,
                                           alpha, A, B, beta, C);
  }
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
    T alpha, const T *A, const T *B, T beta, T *C, int batchCount,
    int64_t strideA, int64_t strideB) const {
  // Note that cublas follows fortran order, so the order is different from
  // the cblas convention.
  int lda = (transA == CblasNoTrans) ? K : M;
  int ldb = (transB == CblasNoTrans) ? N : K;
  int ldc = N;
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_operation cuTransB = (transB == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  const int64_t strideC = M * N;
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GEMM_STRIDED_BATCH(handle, cuTransB, cuTransA, N, M, K, &alpha,
                                  B, ldb, strideB, A, lda, strideA, &beta, C,
                                  ldc, strideC, batchCount);
  });
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
    T alpha, const T **A, const T **B, T beta, T **C, int batchCount) const {
  for (int k = 0; k < batchCount; ++k) {
    this->template GEMM<T>(transA, transB, M, N, K, alpha, A[k], B[k], beta,
                           C[k]);
  }
}

template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
    CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
    platform::float16 alpha, const platform::float16 **A,
    const platform::float16 **B, platform::float16 beta, platform::float16 **C,
    int batchCount) const {
  for (int k = 0; k < batchCount; ++k) {
    this->template GEMM<platform::float16>(transA, transB, M, N, K, alpha, A[k],
                                           B[k], beta, C[k]);
  }
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
                                             CBLAS_TRANSPOSE transA,
                                             CBLAS_DIAG diag, int M, int N,
                                             T alpha, const T *A, int lda, T *B,
                                             int ldb) const {
  // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' )  =  α B'`
  // where ' stands for transpose
  rocblas_side cuSide =
      (side == CblasLeft) ? rocblas_side_right : rocblas_side_left;
  rocblas_fill cuUplo =
      (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower;
  // use CUBLAS_OP_C (conjugate transpose) for complex
  rocblas_operation cuTransA = (transA == CblasNoTrans)
                                   ? rocblas_operation_none
                                   : rocblas_operation_transpose;
  rocblas_diagonal cuDiag =
      (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit;

  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::TRSM(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A,
                    lda, B, ldb);
  });
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRF(int n, T **a, int *ipiv,
                                                     int *info,
                                                     int batch_size) const {
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size);
  });
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRI(int n, const T **a,
                                                     const int *ipiv, T **a_inv,
                                                     int *info,
                                                     int batch_size) const {
  PADDLE_ENFORCE_NE(
      a_inv, a,
      platform::errors::InvalidArgument(
          "cuBLAS fuction 'cublas<S/D>getrfBatched' cannot be executed "
          "in-place. The memory space of output matrix (address: %p) cannot "
          "overlap memory space of input matrix (address: %p).",
          a_inv, a));
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size);
  });
}

template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedMatInv(int n, const T **a,
                                                      T **a_inv, int *info,
                                                      int batch_size) const {
  context_.CublasCall([&](rocblas_handle handle) {
    CUBlas<T>::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size);
  });
}

}  // namespace math
}  // namespace operators
}  // namespace paddle