hl_cuda_cublas.cc 13.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yu Yang 已提交
15
#include "hl_cuda_cublas.h"
Z
zhangjinchao01 已提交
16
#include <sys/time.h>
L
lzhao4ever 已提交
17
#include "hl_cuda.h"
Y
Yu Yang 已提交
18
#include "hl_thread.ph"
L
liaogang 已提交
19
#include "paddle/utils/DynamicLoader.h"
Z
zhangjinchao01 已提交
20 21 22 23 24
#include "paddle/utils/Logging.h"

namespace dynload {

std::once_flag cublas_dso_flag;
25
void *cublas_dso_handle = nullptr;
Z
zhangjinchao01 已提交
26 27 28 29 30 31 32 33 34

/**
 * The following macro definition can generate structs
 * (for each function) to dynamic load cublas routine
 * via operator overloading.
 *
 * note: default dynamic linked libs
 */
#ifdef PADDLE_USE_DSO
35 36 37 38 39 40 41 42 43
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name)                                       \
  struct DynLoad__##__name {                                                   \
    template <typename... Args>                                                \
    cublasStatus_t operator()(Args... args) {                                  \
      typedef cublasStatus_t (*cublasFunc)(Args...);                           \
      std::call_once(cublas_dso_flag, GetCublasDsoHandle, &cublas_dso_handle); \
      void *p_##__name = dlsym(cublas_dso_handle, #__name);                    \
      return reinterpret_cast<cublasFunc>(p_##__name)(args...);                \
    }                                                                          \
Z
zhangjinchao01 已提交
44 45
  } __name;  // struct DynLoad__##__name
#else
46 47 48 49 50 51
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name)      \
  struct DynLoad__##__name {                  \
    template <typename... Args>               \
    cublasStatus_t operator()(Args... args) { \
      return __name(args...);                 \
    }                                         \
Z
zhangjinchao01 已提交
52 53 54
  } __name;  // struct DynLoad__##__name
#endif

55
#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) DYNAMIC_LOAD_CUBLAS_WRAP(__name)
Z
zhangjinchao01 已提交
56 57

// include all needed cublas functions in HPPL
L
Luo Tao 已提交
58 59 60 61 62 63 64 65
// clang-format off
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
  __macro(cublasSgemv)                    \
  __macro(cublasDgemv)                    \
  __macro(cublasSgemm)                    \
  __macro(cublasDgemm)                    \
  __macro(cublasSgeam)                    \
  __macro(cublasDgeam)                    \
Z
zhangjinchao01 已提交
66 67 68 69 70 71 72 73 74 75

DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched)
L
lzhao4ever 已提交
76 77
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched)
78 79
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetriBatched)
Z
zhangjinchao01 已提交
80 81 82 83 84 85 86 87
CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)

#undef DYNAMIC_LOAD_CUBLAS_WRAP
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
#undef CUBLAS_BLAS_ROUTINE_EACH

} /* namespace dynload */

L
Luo Tao 已提交
88
// clang-format on
89
#ifndef PADDLE_TYPE_DOUBLE
90 91 92 93 94
#define CUBLAS_GEAM dynload::cublasSgeam
#define CUBLAS_GEMV dynload::cublasSgemv
#define CUBLAS_GEMM dynload::cublasSgemm
#define CUBLAS_GETRF dynload::cublasSgetrfBatched
#define CUBLAS_GETRI dynload::cublasSgetriBatched
Z
zhangjinchao01 已提交
95
#else
96 97 98 99 100
#define CUBLAS_GEAM dynload::cublasDgeam
#define CUBLAS_GEMV dynload::cublasDgemv
#define CUBLAS_GEMM dynload::cublasDgemm
#define CUBLAS_GETRF dynload::cublasDgetrfBatched
#define CUBLAS_GETRI dynload::cublasDgetriBatched
Z
zhangjinchao01 已提交
101 102
#endif

103
const char *hl_cublas_get_error_string(cublasStatus_t status) {
104
  switch (status) {
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    case CUBLAS_STATUS_NOT_INITIALIZED:
      return "[cublas status]: not initialized";
    case CUBLAS_STATUS_ALLOC_FAILED:
      return "[cublas status]: allocate failed";
    case CUBLAS_STATUS_INVALID_VALUE:
      return "[cublas status]: invalid value";
    case CUBLAS_STATUS_ARCH_MISMATCH:
      return "[cublas status]: arch mismatch";
    case CUBLAS_STATUS_MAPPING_ERROR:
      return "[cublas status]: mapping error";
    case CUBLAS_STATUS_EXECUTION_FAILED:
      return "[cublas status]: execution failed";
    case CUBLAS_STATUS_INTERNAL_ERROR:
      return "[cublas status]: internal error";
    case CUBLAS_STATUS_SUCCESS:
      return "[cublas status]: success";
    default:
      return "[cublas status]: unknown error";
Z
zhangjinchao01 已提交
123 124 125 126 127 128 129 130
  }
}

/**
 * Check build-in cublas function using glog and it also
 * support << operator for more details error info.
 */
cublasStatus_t g_cublasStat;
131 132 133 134
#define CHECK_CUBLAS(cublas_func)               \
  g_cublasStat = cublas_func;                   \
  CHECK_EQ(CUBLAS_STATUS_SUCCESS, g_cublasStat) \
      << "Cublas Error: " << hl_cublas_get_error_string(g_cublasStat) << " "
Z
zhangjinchao01 已提交
135 136 137

void hl_cublas_init(cublasHandle_t *cublas_handle, cudaStream_t stream) {
  CHECK_CUBLAS(dynload::cublasCreate(cublas_handle))
138
      << "[cublas init] Cublas create handle faild!";
Z
zhangjinchao01 已提交
139 140

  CHECK_CUBLAS(dynload::cublasSetStream(*cublas_handle, stream))
141
      << "[cublas init] Cublas set stream faild!";
Z
zhangjinchao01 已提交
142 143
}

144 145
void hl_matrix_transpose(
    real *A_d, real *C_d, int dimM, int dimN, int lda, int ldc) {
Z
zhangjinchao01 已提交
146 147 148 149 150 151 152
  real alpha = 1.0;
  real beta = 0.0;

  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);

  CHECK_CUBLAS(CUBLAS_GEAM(t_resource.handle,
153 154 155 156 157 158 159 160 161 162 163 164
                           CUBLAS_OP_T,
                           CUBLAS_OP_N,
                           dimM,
                           dimN,
                           &alpha,
                           A_d,
                           lda,
                           &beta,
                           nullptr,
                           dimM,
                           C_d,
                           ldc));
Z
zhangjinchao01 已提交
165 166 167 168 169 170
  CHECK_SYNC("hl_matrix_transpose failed");
}

void hl_matrix_transpose(real *A_d, real *C_d, int dimM, int dimN) {
  hl_matrix_transpose(A_d, C_d, dimM, dimN, dimN, dimM);
}
L
lzhao4ever 已提交
171 172 173 174 175 176 177 178 179 180 181

void hl_matrix_inverse(real *A_d, real *C_d, int dimN, int lda, int ldc) {
  /* Solve Ax = I */
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);

  /* Step 1: Compute the LU decomposition of matrix A */
  real **inout_h = &A_d;
  real **inout_d = (real **)hl_malloc_device(sizeof(real *));
  hl_memcpy(inout_d, inout_h, sizeof(real *));

182
  int *pivot_d = (int *)hl_malloc_device(dimN * sizeof(int));
L
lzhao4ever 已提交
183 184 185 186 187 188
  int *info_d = (int *)t_resource.gpu_mem;

  /* Note: cublasSgetrfBatched is used to calculate a number of
     small-sized matrices. There may be a better way to reconstruct
     the API for better performance.
   */
189 190
  CHECK_CUBLAS(
      CUBLAS_GETRF(t_resource.handle, dimN, inout_d, lda, pivot_d, info_d, 1));
L
lzhao4ever 已提交
191

192
  int info_h;
L
lzhao4ever 已提交
193 194
  hl_memcpy(&info_h, info_d, sizeof(int));
  if (info_h != 0) {
195
    LOG(FATAL) << "Factorization of matrix failed: matrix may be singular.\n";
L
lzhao4ever 已提交
196 197 198 199 200 201 202 203
  }

  /* Step 2: Compute the inverse of the matrix given its LU decomposition */
  real **out_h = &C_d;
  real **out_d = (real **)hl_malloc_device(sizeof(real *));
  hl_memcpy(out_d, out_h, sizeof(real *));

  CHECK_CUBLAS(CUBLAS_GETRI(t_resource.handle,
204 205 206 207 208 209 210 211
                            dimN,
                            (const real **)inout_d,
                            lda,
                            pivot_d,
                            out_d,
                            ldc,
                            info_d,
                            1));
L
lzhao4ever 已提交
212 213 214

  hl_memcpy(&info_h, info_d, sizeof(int));
  if (info_h != 0) {
215
    LOG(FATAL) << "Inversion of matrix failed: matrix may be singular.\n";
L
lzhao4ever 已提交
216 217 218 219 220
  }

  hl_free_mem_device(inout_d);
  hl_free_mem_device(pivot_d);
  hl_free_mem_device(out_d);
221

L
lzhao4ever 已提交
222 223
  CHECK_SYNC("hl_matrix_inverse failed");
}
Z
zhangjinchao01 已提交
224

225 226 227 228
void hl_matrix_mul(real *A_d,
                   hl_trans_op_t transa,
                   real *B_d,
                   hl_trans_op_t transb,
Z
zhangjinchao01 已提交
229
                   real *C_d,
230 231 232 233 234 235 236 237
                   int dimM,
                   int dimN,
                   int dimK,
                   real alpha,
                   real beta,
                   int lda,
                   int ldb,
                   int ldc) {
Z
zhangjinchao01 已提交
238 239 240 241 242 243 244
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);
  CHECK_NOTNULL(C_d);

  if (dimN == 1 && dimM != 1 && dimK != 1 && transb == HPPL_OP_N) {
    int m = (transa == HPPL_OP_N) ? dimM : dimK;
    int n = (transa == HPPL_OP_N) ? dimK : dimM;
245 246
    hl_matrix_mul_vector(
        A_d, transa, B_d, C_d, m, n, alpha, beta, lda, ldb, ldc);
Z
zhangjinchao01 已提交
247 248 249 250 251 252 253
    return;
  }

  if (dimM == 1 && dimN != 1 && dimK != 1 && transa == HPPL_OP_N) {
    int m = (transb == HPPL_OP_N) ? dimK : dimN;
    int n = (transb == HPPL_OP_N) ? dimN : dimK;
    hl_trans_op_t trans = (transb == HPPL_OP_N) ? HPPL_OP_T : HPPL_OP_N;
254
    hl_matrix_mul_vector(B_d, trans, A_d, C_d, m, n, alpha, beta, ldb, 1, 1);
Z
zhangjinchao01 已提交
255 256 257 258 259 260 261 262
    return;
  }

  cublasStatus_t stat;
  if ((HPPL_OP_N == transa) && (HPPL_OP_N == transb)) {
    stat = CUBLAS_GEMM(t_resource.handle,
                       CUBLAS_OP_N,
                       CUBLAS_OP_N,
263 264 265 266 267 268 269 270 271 272 273
                       dimN,
                       dimM,
                       dimK,
                       &alpha,
                       B_d,
                       ldb,
                       A_d,
                       lda,
                       &beta,
                       C_d,
                       ldc);
Z
zhangjinchao01 已提交
274 275 276 277
  } else if ((HPPL_OP_T == transa) && (HPPL_OP_N == transb)) {
    stat = CUBLAS_GEMM(t_resource.handle,
                       CUBLAS_OP_N,
                       CUBLAS_OP_T,
278 279 280 281 282 283 284 285 286 287 288
                       dimN,
                       dimM,
                       dimK,
                       &alpha,
                       B_d,
                       ldb,
                       A_d,
                       lda,
                       &beta,
                       C_d,
                       ldc);
Z
zhangjinchao01 已提交
289 290 291 292
  } else if ((HPPL_OP_N == transa) && (HPPL_OP_T == transb)) {
    stat = CUBLAS_GEMM(t_resource.handle,
                       CUBLAS_OP_T,
                       CUBLAS_OP_N,
293 294 295 296 297 298 299 300 301 302 303
                       dimN,
                       dimM,
                       dimK,
                       &alpha,
                       B_d,
                       ldb,
                       A_d,
                       lda,
                       &beta,
                       C_d,
                       ldc);
Z
zhangjinchao01 已提交
304 305 306
  } else {
    LOG(FATAL) << "parameter transa error!";
  }
307
  CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS) << hl_cublas_get_error_string(stat);
Z
zhangjinchao01 已提交
308 309 310
  CHECK_SYNC("hl_matrix_mul failed");
}

311 312 313 314
void hl_matrix_mul(real *A_d,
                   hl_trans_op_t transa,
                   real *B_d,
                   hl_trans_op_t transb,
Z
zhangjinchao01 已提交
315
                   real *C_d,
316 317 318 319 320
                   int dimM,
                   int dimN,
                   int dimK,
                   real alpha,
                   real beta) {
Z
zhangjinchao01 已提交
321 322 323 324
  int lda = (HPPL_OP_N == transa) ? dimK : dimM;
  int ldb = (HPPL_OP_N == transb) ? dimN : dimK;
  int ldc = dimN;

325 326 327 328 329 330 331 332 333 334 335 336 337
  hl_matrix_mul(A_d,
                transa,
                B_d,
                transb,
                C_d,
                dimM,
                dimN,
                dimK,
                alpha,
                beta,
                lda,
                ldb,
                ldc);
Z
zhangjinchao01 已提交
338 339
}

340 341 342 343 344 345 346 347 348 349 350
void hl_matrix_mul_vector(real *A_d,
                          hl_trans_op_t trans,
                          real *B_d,
                          real *C_d,
                          int dimM,
                          int dimN,
                          real alpha,
                          real beta,
                          int lda,
                          int incb,
                          int incc) {
Z
zhangjinchao01 已提交
351 352 353 354 355 356 357 358
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);
  CHECK_NOTNULL(C_d);

  cublasStatus_t stat;
  if (HPPL_OP_N == trans) {
    stat = CUBLAS_GEMV(t_resource.handle,
                       CUBLAS_OP_T,
359 360
                       dimN,
                       dimM,
Z
zhangjinchao01 已提交
361
                       &alpha,
362 363 364 365
                       A_d,
                       lda,
                       B_d,
                       incb,
Z
zhangjinchao01 已提交
366
                       &beta,
367 368
                       C_d,
                       incc);
Z
zhangjinchao01 已提交
369 370 371
  } else if (HPPL_OP_T == trans) {
    stat = CUBLAS_GEMV(t_resource.handle,
                       CUBLAS_OP_N,
372 373
                       dimN,
                       dimM,
Z
zhangjinchao01 已提交
374
                       &alpha,
375 376 377 378
                       A_d,
                       lda,
                       B_d,
                       incb,
Z
zhangjinchao01 已提交
379
                       &beta,
380 381
                       C_d,
                       incc);
Z
zhangjinchao01 已提交
382 383 384 385
  } else {
    LOG(FATAL) << "parameter transa error!";
  }

386
  CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS) << hl_cublas_get_error_string(stat);
Z
zhangjinchao01 已提交
387 388 389
  CHECK_SYNC("hl_matrix_mul_vector");
}

390 391 392 393 394 395 396 397 398 399
void hl_matrix_mul_vector(real *A_d,
                          hl_trans_op_t trans,
                          real *B_d,
                          real *C_d,
                          int dimM,
                          int dimN,
                          real alpha,
                          real beta) {
  hl_matrix_mul_vector(
      A_d, trans, B_d, C_d, dimM, dimN, alpha, beta, dimN, 1, 1);
Z
zhangjinchao01 已提交
400
}