hl_cuda_cublas.cc 9.7 KB
Newer Older
Z
zhangjinchao01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */


#include <sys/time.h>
#include <mutex>
#include "hl_cuda_cublas.h"
#include "hl_thread.ph"
#include "hl_dso_loader.h"
#include "paddle/utils/Logging.h"

namespace dynload {

std::once_flag cublas_dso_flag;
void* cublas_dso_handle = nullptr;

/**
 * The following macro definition can generate structs
 * (for each function) to dynamic load cublas routine
 * via operator overloading.
 *
 * note: default dynamic linked libs
 */
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name)                          \
   struct DynLoad__##__name {                                     \
    template <typename... Args>                                   \
    cublasStatus_t operator()(Args... args) {                     \
        typedef cublasStatus_t (*cublasFunc)(Args...);            \
        std::call_once(cublas_dso_flag, GetCublasDsoHandle,       \
                      &cublas_dso_handle);                        \
        void* p_##__name = dlsym(cublas_dso_handle, #__name);     \
        return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
    }                                                             \
  } __name;  // struct DynLoad__##__name
#else
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name)                          \
   struct DynLoad__##__name {                                     \
    template <typename... Args>                                   \
    cublasStatus_t operator()(Args... args) {                     \
      return __name(args...);                                     \
    }                                                             \
  } __name;  // struct DynLoad__##__name
#endif

#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) \
  DYNAMIC_LOAD_CUBLAS_WRAP(__name)

// include all needed cublas functions in HPPL
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
  __macro(cublasSgemv)                    \
  __macro(cublasDgemv)                    \
  __macro(cublasSgemm)                    \
  __macro(cublasDgemm)                    \
  __macro(cublasSgeam)                    \
  __macro(cublasDgeam)                    \

DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched)
CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)

#undef DYNAMIC_LOAD_CUBLAS_WRAP
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
#undef CUBLAS_BLAS_ROUTINE_EACH

} /* namespace dynload */


#ifndef HPPL_TYPE_DOUBLE
#define     CUBLAS_GEAM     dynload::cublasSgeam
#define     CUBLAS_GEMV     dynload::cublasSgemv
#define     CUBLAS_GEMM     dynload::cublasSgemm
#else
#define     CUBLAS_GEAM     dynload::cublasDgeam
#define     CUBLAS_GEMV     dynload::cublasDgemv
#define     CUBLAS_GEMM     dynload::cublasDgemm
#endif

const char* hl_cublas_get_error_string(cublasStatus_t status) {
  switch(status) {
     case CUBLAS_STATUS_NOT_INITIALIZED:
        return "[cublas status]: not initialized";
     case CUBLAS_STATUS_ALLOC_FAILED:
        return "[cublas status]: allocate failed";
     case CUBLAS_STATUS_INVALID_VALUE:
        return "[cublas status]: invalid value";
     case CUBLAS_STATUS_ARCH_MISMATCH:
        return "[cublas status]: arch mismatch";
     case CUBLAS_STATUS_MAPPING_ERROR:
        return "[cublas status]: mapping error";
     case CUBLAS_STATUS_EXECUTION_FAILED:
        return "[cublas status]: execution failed";
     case CUBLAS_STATUS_INTERNAL_ERROR:
        return "[cublas status]: internal error";
     case CUBLAS_STATUS_SUCCESS:
        return "[cublas status]: success";
     default:
        return "[cublas status]: unknown error";
  }
}

/**
 * Check build-in cublas function using glog and it also
 * support << operator for more details error info.
 */
cublasStatus_t g_cublasStat;
#define CHECK_CUBLAS(cublas_func)                 \
  g_cublasStat = cublas_func;                     \
  CHECK_EQ(CUBLAS_STATUS_SUCCESS, g_cublasStat)   \
      << "Cublas Error: "                         \
      << hl_cublas_get_error_string(g_cublasStat) \
      << " "

void hl_cublas_init(cublasHandle_t *cublas_handle, cudaStream_t stream) {
  CHECK_CUBLAS(dynload::cublasCreate(cublas_handle))
    << "[cublas init] Cublas create handle faild!";

  CHECK_CUBLAS(dynload::cublasSetStream(*cublas_handle, stream))
    << "[cublas init] Cublas set stream faild!";
}

void hl_matrix_transpose(real *A_d,
                         real *C_d,
                         int dimM,
                         int dimN,
                         int lda,
                         int ldc) {
  real alpha = 1.0;
  real beta = 0.0;

  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(C_d);

  CHECK_CUBLAS(CUBLAS_GEAM(t_resource.handle,
               CUBLAS_OP_T, CUBLAS_OP_N,
               dimM, dimN,
               &alpha, A_d, lda,
               &beta, nullptr, dimM,
               C_d, ldc));
  CHECK_SYNC("hl_matrix_transpose failed");
}

void hl_matrix_transpose(real *A_d, real *C_d, int dimM, int dimN) {
  hl_matrix_transpose(A_d, C_d, dimM, dimN, dimN, dimM);
}

void hl_matrix_mul(real *A_d, hl_trans_op_t transa,
                   real *B_d, hl_trans_op_t transb,
                   real *C_d,
                   int dimM, int dimN, int dimK,
                   real alpha, real beta,
                   int lda, int ldb, int ldc) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);
  CHECK_NOTNULL(C_d);

  if (dimN == 1 && dimM != 1 && dimK != 1 && transb == HPPL_OP_N) {
    int m = (transa == HPPL_OP_N) ? dimM : dimK;
    int n = (transa == HPPL_OP_N) ? dimK : dimM;
    hl_matrix_mul_vector(A_d, transa, B_d, C_d, m, n,
                         alpha, beta, lda, ldb, ldc);
    return;
  }

  if (dimM == 1 && dimN != 1 && dimK != 1 && transa == HPPL_OP_N) {
    int m = (transb == HPPL_OP_N) ? dimK : dimN;
    int n = (transb == HPPL_OP_N) ? dimN : dimK;
    hl_trans_op_t trans = (transb == HPPL_OP_N) ? HPPL_OP_T : HPPL_OP_N;
    hl_matrix_mul_vector(B_d, trans, A_d, C_d, m, n,
                         alpha, beta, ldb, 1, 1);
    return;
  }

  cublasStatus_t stat;
  if ((HPPL_OP_N == transa) && (HPPL_OP_N == transb)) {
    stat = CUBLAS_GEMM(t_resource.handle,
                       CUBLAS_OP_N,
                       CUBLAS_OP_N,
                       dimN, dimM, dimK,
                       &alpha, B_d, ldb,
                       A_d, lda,
                       &beta, C_d, ldc);
  } else if ((HPPL_OP_T == transa) && (HPPL_OP_N == transb)) {
    stat = CUBLAS_GEMM(t_resource.handle,
                       CUBLAS_OP_N,
                       CUBLAS_OP_T,
                       dimN, dimM, dimK,
                       &alpha, B_d, ldb,
                       A_d, lda,
                       &beta, C_d, ldc);
  } else if ((HPPL_OP_N == transa) && (HPPL_OP_T == transb)) {
    stat = CUBLAS_GEMM(t_resource.handle,
                       CUBLAS_OP_T,
                       CUBLAS_OP_N,
                       dimN, dimM, dimK,
                       &alpha, B_d, ldb,
                       A_d, lda,
                       &beta, C_d, ldc);
  } else {
    LOG(FATAL) << "parameter transa error!";
  }
220
  CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS) << hl_cublas_get_error_string(stat);
Z
zhangjinchao01 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
  CHECK_SYNC("hl_matrix_mul failed");
}

void hl_matrix_mul(real *A_d, hl_trans_op_t transa,
                   real *B_d, hl_trans_op_t transb,
                   real *C_d,
                   int dimM, int dimN, int dimK,
                   real alpha, real beta) {
  int lda = (HPPL_OP_N == transa) ? dimK : dimM;
  int ldb = (HPPL_OP_N == transb) ? dimN : dimK;
  int ldc = dimN;

  hl_matrix_mul(A_d, transa, B_d, transb, C_d, dimM, dimN,
                dimK, alpha, beta, lda, ldb, ldc);
}

void hl_matrix_mul_vector(real *A_d, hl_trans_op_t trans,
                          real *B_d, real *C_d,
                          int dimM, int dimN,
                          real alpha, real beta,
                          int lda, int incb, int incc) {
  CHECK_NOTNULL(A_d);
  CHECK_NOTNULL(B_d);
  CHECK_NOTNULL(C_d);

  cublasStatus_t stat;
  if (HPPL_OP_N == trans) {
    stat = CUBLAS_GEMV(t_resource.handle,
                       CUBLAS_OP_T,
                       dimN, dimM,
                       &alpha,
                       A_d, lda,
                       B_d, incb,
                       &beta,
                       C_d, incc);
  } else if (HPPL_OP_T == trans) {
    stat = CUBLAS_GEMV(t_resource.handle,
                       CUBLAS_OP_N,
                       dimN, dimM,
                       &alpha,
                       A_d, lda,
                       B_d, incb,
                       &beta,
                       C_d, incc);
  } else {
    LOG(FATAL) << "parameter transa error!";
  }

269
  CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS) << hl_cublas_get_error_string(stat);
Z
zhangjinchao01 已提交
270 271 272 273 274 275 276 277 278 279
  CHECK_SYNC("hl_matrix_mul_vector");
}

void hl_matrix_mul_vector(real *A_d, hl_trans_op_t trans,
                          real *B_d, real *C_d,
                          int dimM, int dimN,
                          real alpha, real beta) {
  hl_matrix_mul_vector(A_d, trans, B_d, C_d, dimM, dimN,
                       alpha, beta, dimN, 1, 1);
}