hl_cuda_device.cc 23.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yu Yang 已提交
15
#include "hl_cuda.h"
L
liaogang 已提交
16
#include <cuda_profiler_api.h>
Z
zhangjinchao01 已提交
17 18
#include <string.h>
#include <sys/syscall.h>
L
liaogang 已提交
19 20
#include <sys/time.h>
#include <unistd.h>
Z
zhangjinchao01 已提交
21 22 23
#include <mutex>
#include "hl_cuda.ph"
#include "hl_dso_loader.h"
L
liaogang 已提交
24
#include "hl_thread.ph"
Z
zhangjinchao01 已提交
25 26 27 28 29
#include "paddle/utils/Logging.h"

namespace dynload {

std::once_flag curand_dso_flag;
30
void *curand_dso_handle = nullptr;
Z
zhangjinchao01 已提交
31 32 33 34 35 36 37 38 39

/**
 * The following macro definition can generate structs
 * (for each function) to dynamic load curand routine
 * via operator overloading.
 *
 * note: default dynamic linked libs
 */
#ifdef PADDLE_USE_DSO
40 41 42 43 44 45 46 47 48 49
#define DYNAMIC_LOAD_CURAND_WRAP(__name)                                       \
  struct DynLoad__##__name {                                                   \
    template <typename... Args>                                                \
    curandStatus_t operator()(Args... args) {                                  \
      typedef curandStatus_t (*curandFunc)(Args...);                           \
      std::call_once(curand_dso_flag, GetCurandDsoHandle, &curand_dso_handle); \
      void *p_##__name = dlsym(curand_dso_handle, #__name);                    \
      return reinterpret_cast<curandFunc>(p_##__name)(args...);                \
    }                                                                          \
  } __name; /* struct DynLoad__##__name */
Z
zhangjinchao01 已提交
50
#else
51 52 53 54 55 56 57
#define DYNAMIC_LOAD_CURAND_WRAP(__name)      \
  struct DynLoad__##__name {                  \
    template <typename... Args>               \
    curandStatus_t operator()(Args... args) { \
      return __name(args...);                 \
    }                                         \
  } __name; /* struct DynLoad__##__name */
Z
zhangjinchao01 已提交
58 59 60
#endif

/* include all needed curand functions in HPPL */
L
Luo Tao 已提交
61
// clang-format off
Z
zhangjinchao01 已提交
62 63 64 65 66 67
#define CURAND_RAND_ROUTINE_EACH(__macro)    \
  __macro(curandCreateGenerator)             \
  __macro(curandSetStream)                   \
  __macro(curandSetPseudoRandomGeneratorSeed)\
  __macro(curandGenerateUniform)             \
  __macro(curandGenerateUniformDouble)
L
Luo Tao 已提交
68
// clang-format on
Z
zhangjinchao01 已提交
69 70 71 72 73 74 75

CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP)

#undef CURAND_RAND_ROUTINE_EACH
#undef DYNAMIC_LOAD_CURAND_WRAP

std::once_flag cudart_dso_flag;
76
void *cudart_dso_handle = nullptr;
Z
zhangjinchao01 已提交
77 78 79 80 81 82 83 84 85

/**
 * The following macro definition can generate structs
 * (for each function) to dynamic load cuda routine
 * via operator overloading.
 *
 * note: default dynamic linked libs
 */
#ifdef PADDLE_USE_DSO
86 87 88 89 90 91 92 93 94 95
#define DYNAMIC_LOAD_CUDART_WRAP(__name)                                       \
  struct DynLoad__##__name {                                                   \
    template <typename... Args>                                                \
    auto operator()(Args... args) -> decltype(__name(args...)) {               \
      using cudart_func = decltype(__name(args...)) (*)(Args...);              \
      std::call_once(cudart_dso_flag, GetCudartDsoHandle, &cudart_dso_handle); \
      void *p_##__name = dlsym(cudart_dso_handle, #__name);                    \
      return reinterpret_cast<cudart_func>(p_##__name)(args...);               \
    }                                                                          \
  } __name; /* struct DynLoad__##__name */
Z
zhangjinchao01 已提交
96
#else
97 98 99 100 101 102 103
#define DYNAMIC_LOAD_CUDART_WRAP(__name)                         \
  struct DynLoad__##__name {                                     \
    template <typename... Args>                                  \
    auto operator()(Args... args) -> decltype(__name(args...)) { \
      return __name(args...);                                    \
    }                                                            \
  } __name; /* struct DynLoad__##__name */
Z
zhangjinchao01 已提交
104 105 106
#endif

/* include all needed cuda functions in HPPL */
L
Luo Tao 已提交
107
// clang-format off
Z
zhangjinchao01 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
#define CUDA_ROUTINE_EACH(__macro)        \
  __macro(cudaMalloc)                     \
  __macro(cudaHostAlloc)                  \
  __macro(cudaFree)                       \
  __macro(cudaFreeHost)                   \
  __macro(cudaMemcpy)                     \
  __macro(cudaMemset)                     \
  __macro(cudaMemcpyAsync)                \
  __macro(cudaSetDevice)                  \
  __macro(cudaGetDevice)                  \
  __macro(cudaGetDeviceCount)             \
  __macro(cudaGetDeviceProperties)        \
  __macro(cudaDeviceSynchronize)          \
  __macro(cudaDeviceCanAccessPeer)        \
  __macro(cudaDeviceEnablePeerAccess)     \
  __macro(cudaStreamCreate)               \
  __macro(cudaStreamDestroy)              \
  __macro(cudaStreamSynchronize)          \
  __macro(cudaStreamWaitEvent)            \
  __macro(cudaEventCreate)                \
  __macro(cudaEventRecord)                \
  __macro(cudaEventQuery)                 \
  __macro(cudaEventDestroy)               \
  __macro(cudaEventSynchronize)           \
  __macro(cudaEventElapsedTime)           \
  __macro(cudaSetDeviceFlags)             \
  __macro(cudaGetLastError)               \
  __macro(cudaFuncSetCacheConfig)         \
136
  __macro(cudaRuntimeGetVersion)          \
L
liaogang 已提交
137 138 139
  __macro(cudaGetErrorString)             \
  __macro(cudaProfilerStart)              \
  __macro(cudaProfilerStop)
L
Luo Tao 已提交
140
// clang-format on
Z
zhangjinchao01 已提交
141 142 143 144 145 146

CUDA_ROUTINE_EACH(DYNAMIC_LOAD_CUDART_WRAP)

#undef CUDA_ROUNTINE_EACH
#undef DYNAMIC_LOAD_CUDART_WRAP

147
} /* namespace dynload */
Z
zhangjinchao01 已提交
148 149 150 151

/**
 * @brief   global resource.
 */
152 153 154 155
int g_system_device_num = 0;                /* system device number */
int device_num = 0;                         /* use    device number */
hl_device_prop *g_device;                   /* device info table */
__thread thread_device_resources *t_device; /* device resources table */
Z
zhangjinchao01 已提交
156 157 158
int g_cuda_lib_version = 0;

/* number of global stream */
159
#define NUMBER_OF_GLOBAL_STREAM (HPPL_THREAD_STREAM_1)
Z
zhangjinchao01 已提交
160
/* number of thread stream */
161
#define NUMBER_OF_THREAD_STREAM (HPPL_STREAM_END - HPPL_THREAD_STREAM_1)
Z
zhangjinchao01 已提交
162
/* sizeof of device memory */
163
#define HPPL_GPU_MEMORY_SIZE (256 * 4)
Z
zhangjinchao01 已提交
164 165

/**
166
 * Check build-in cuda function using glog and it **does not**
Z
zhangjinchao01 已提交
167 168
 * support << operator for more details error info.
 */
169 170 171 172 173
#define CHECK_CUDA(cudaFunc)                                                  \
  do {                                                                        \
    cudaError_t cudaStat = cudaFunc;                                          \
    CHECK_EQ(cudaSuccess, cudaStat) << "Cuda Error: "                         \
                                    << dynload::cudaGetErrorString(cudaStat); \
174
  } while (0)
Z
zhangjinchao01 已提交
175 176 177 178

/**
 * @brief   thread resource.
 */
179 180 181 182 183 184 185 186 187 188 189 190
__thread _hl_thread_resource t_resource = {{0},    /* stream */
                                           0,      /* handle */
                                           0,      /* gen */
                                           0,      /* cudnn_handle */
                                           0,      /* cudnn_desc */
                                           NULL,   /* gen_mutex */
                                           NULL,   /* gpu_mem */
                                           NULL,   /* cpu_mem */
                                           0,      /* event */
                                           -1,     /* device */
                                           0,      /* major */
                                           false}; /* is_init */
Z
zhangjinchao01 已提交
191 192 193 194 195

__thread cudaStream_t default_stream = 0;
__thread bool g_sync_flag = true;
bool hl_start_flag = false;

L
liaogang 已提交
196 197
inline pid_t gettid() {
#if defined(__APPLE__) || defined(__OSX__)
G
gangliao 已提交
198 199 200 201 202
  // syscall is deprecated: first deprecated in macOS 10.12.
  // syscall is unsupported;
  // syscall pid_t tid = syscall(SYS_thread_selfid);
  uint64_t tid;
  pthread_threadid_np(NULL, &tid);
L
liaogang 已提交
203
#else
204 205 206
#ifndef __NR_gettid
#define __NR_gettid 224
#endif
L
liaogang 已提交
207 208
  pid_t tid = syscall(__NR_gettid);
#endif
209 210
  CHECK_NE((int)tid, -1);
  return tid;
L
liaogang 已提交
211
}
Z
zhangjinchao01 已提交
212 213

void hl_init(int device) {
214
  CHECK(hl_start_flag) << "[Init failed] hl_start() did not succeed.";
Z
zhangjinchao01 已提交
215 216 217 218 219 220 221 222 223 224

  /* thread has been initialized */
  if (true == t_resource.is_init) {
    hl_set_device(device);
    return;
  }

  /* create thread devcie resources */
  char *tmp;
  thread_device_resources device_res;
225 226
  tmp = (char *)malloc(g_system_device_num * sizeof(thread_device_resources *) +
                       device_num * sizeof(_thread_device_resources));
Z
zhangjinchao01 已提交
227
  CHECK_NOTNULL(tmp);
228 229 230 231
  t_device = (thread_device_resources *)tmp;
  device_res = (thread_device_resources)(
      (char *)tmp + g_system_device_num * sizeof(thread_device_resources *));
  memset(t_device, 0, g_system_device_num * sizeof(thread_device_resources *));
Z
zhangjinchao01 已提交
232

233 234
  char *tmp_stream = (char *)malloc(device_num * NUMBER_OF_THREAD_STREAM *
                                    sizeof(cudaStream_t));
Z
zhangjinchao01 已提交
235 236 237 238 239 240 241 242 243
  CHECK_NOTNULL(tmp_stream);

  int num = 0;
  for (int dev = 0; dev < g_system_device_num; dev++) {
    if (!g_device[dev]) {
      continue;
    }

    t_device[dev] = &device_res[num];
244 245 246
    t_device[dev]->stream =
        (cudaStream_t *)(tmp_stream +
                         num * NUMBER_OF_THREAD_STREAM * sizeof(cudaStream_t));
Z
zhangjinchao01 已提交
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271

    hl_create_thread_resources(dev, t_device[dev]);
    num++;
  }

  hl_cudnn_desc_init(&t_resource.cudnn_desc);

  /* thread initialization is complete */
  t_resource.is_init = true;
  /* set device */
  t_resource.device = -1;
  hl_set_device(device);
}

void hl_fini() {
  if (false == t_resource.is_init) {
    return;
  }

  /* hppl stream fini */
  t_resource.device = -1;
  for (int i = NUMBER_OF_GLOBAL_STREAM; i < HPPL_STREAM_END; i++) {
    t_resource.stream[i] = 0;
  }

272 273
  char *tmp = (char *)t_device;
  char *tmp_stream = NULL;
Z
zhangjinchao01 已提交
274 275 276 277 278
  for (int dev = 0; dev < g_system_device_num; dev++) {
    if (!t_device[dev]) {
      continue;
    }
    if (!tmp_stream) {
279
      tmp_stream = (char *)t_device[dev]->stream;
Z
zhangjinchao01 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
    }
    for (int j = 0; j < NUMBER_OF_THREAD_STREAM; j++) {
      CHECK_CUDA(dynload::cudaStreamDestroy(t_device[dev]->stream[j]));
    }

    /* free device memory */
    hl_free_mem_device(t_device[dev]->gpu_mem);
    hl_free_mem_host(t_device[dev]->cpu_mem);
    CHECK_CUDA(dynload::cudaEventDestroy(t_device[dev]->mem_event));
  }

  free(tmp);
  free(tmp_stream);
  t_resource.is_init = false;
}

296
int hl_get_device_count() { return device_num; }
Z
zhangjinchao01 已提交
297 298 299 300 301 302 303

void hl_set_device(int device) {
  if (device == t_resource.device) {
    return;
  }

  CHECK(device >= 0 && device < g_system_device_num && g_device[device])
304
      << "Device: " << device << " is not specified in startup.";
Z
zhangjinchao01 已提交
305 306 307 308 309 310 311 312 313 314 315

  CHECK_CUDA(dynload::cudaSetDevice(device));

  /* switch thread stream */
  for (int i = 0; i < NUMBER_OF_GLOBAL_STREAM; i++) {
    t_resource.stream[i] = g_device[device]->device_resources->stream[i];
  }

  if (true == t_resource.is_init) {
    for (int i = NUMBER_OF_GLOBAL_STREAM; i < HPPL_STREAM_END; i++) {
      t_resource.stream[i] =
316
          t_device[device]->stream[i - NUMBER_OF_GLOBAL_STREAM];
Z
zhangjinchao01 已提交
317 318 319
    }
    t_resource.gpu_mem = t_device[device]->gpu_mem;
    t_resource.cpu_mem = t_device[device]->cpu_mem;
320
    t_resource.event = t_device[device]->mem_event;
Z
zhangjinchao01 已提交
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
  }

  t_resource.handle = g_device[device]->device_resources->handle;
  t_resource.gen = g_device[device]->device_resources->gen;
  t_resource.cudnn_handle = g_device[device]->device_resources->cudnn_handle;
  t_resource.gen_mutex = g_device[device]->device_resources->gen_mutex;
  t_resource.device = device;
  t_resource.major = g_device[device]->major;
  default_stream = t_resource.stream[0];
}

int hl_get_device() {
  int device;
  CHECK_CUDA(dynload::cudaGetDevice(&device));
  return device;
}

338
void *hl_malloc_device(size_t size) {
Z
zhangjinchao01 已提交
339 340 341
  void *dest_d;

  CHECK(size) << __func__ << ": the size for device memory is 0, please check.";
342
  CHECK_CUDA(dynload::cudaMalloc((void **)&dest_d, size));
Z
zhangjinchao01 已提交
343 344 345 346 347 348 349 350 351

  return dest_d;
}

void hl_free_mem_device(void *dest_d) {
  CHECK_NOTNULL(dest_d);

  cudaError_t err = dynload::cudaFree(dest_d);
  CHECK(cudaSuccess == err || cudaErrorCudartUnloading == err)
352
      << hl_get_device_error_string();
Z
zhangjinchao01 已提交
353 354
}

355
void *hl_malloc_host(size_t size) {
Z
zhangjinchao01 已提交
356 357 358
  void *dest_h;

  CHECK(size) << __func__ << ": the size for device memory is 0, please check.";
359 360
  CHECK_CUDA(
      dynload::cudaHostAlloc((void **)&dest_h, size, cudaHostAllocDefault));
Z
zhangjinchao01 已提交
361 362 363 364 365 366 367 368

  return dest_h;
}

void hl_free_mem_host(void *dest_h) {
  CHECK_NOTNULL(dest_h);

  cudaError_t err = dynload::cudaFreeHost(dest_h);
369
  CHECK(cudaSuccess == err || cudaErrorCudartUnloading == err)
370
      << hl_get_device_error_string();
Z
zhangjinchao01 已提交
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
}

void hl_memcpy(void *dst, void *src, size_t size) {
  if (0 == size) {
    return;
  }
  CHECK_NOTNULL(dst);
  CHECK_NOTNULL(src);
  CHECK_CUDA(dynload::cudaMemcpy(dst, src, size, cudaMemcpyDefault));
}

void hl_memset_device(void *dest_d, int value, size_t size) {
  CHECK_CUDA(dynload::cudaMemset(dest_d, value, size));
}

void hl_memcpy_host2device(void *dest_d, void *src_h, size_t size) {
  if (0 == size) {
    return;
  }
  CHECK_NOTNULL(src_h);
  CHECK_NOTNULL(dest_d);
392
  CHECK_CUDA(dynload::cudaMemcpy(dest_d, src_h, size, cudaMemcpyHostToDevice));
Z
zhangjinchao01 已提交
393 394 395 396 397 398 399 400
}

void hl_memcpy_device2host(void *dest_h, void *src_d, size_t size) {
  if (0 == size) {
    return;
  }
  CHECK_NOTNULL(dest_h);
  CHECK_NOTNULL(src_d);
401
  CHECK_CUDA(dynload::cudaMemcpy(dest_h, src_d, size, cudaMemcpyDeviceToHost));
Z
zhangjinchao01 已提交
402 403 404 405 406 407 408 409
}

void hl_memcpy_device2device(void *dest_d, void *src_d, size_t size) {
  if (0 == size) {
    return;
  }
  CHECK_NOTNULL(dest_d);
  CHECK_NOTNULL(src_d);
410 411
  CHECK_CUDA(
      dynload::cudaMemcpy(dest_d, src_d, size, cudaMemcpyDeviceToDevice));
Z
zhangjinchao01 已提交
412 413 414 415 416 417 418 419 420 421 422 423 424
}

void hl_memcpy_async(void *dst, void *src, size_t size, hl_stream_t stream) {
  cudaStream_t cu_stream;

  if (0 == size) {
    return;
  }
  CHECK_NOTNULL(dst);
  CHECK_NOTNULL(src);
  CHECK_LT(stream, HPPL_STREAM_END);
  cu_stream = t_resource.stream[stream];

425 426
  CHECK_CUDA(
      dynload::cudaMemcpyAsync(dst, src, size, cudaMemcpyDefault, cu_stream));
Z
zhangjinchao01 已提交
427 428 429 430 431 432 433 434 435 436
}

void hl_start() {
  hl_specify_devices_start(NULL, 0);
  /* set default device */
  hl_set_device(0);
}

bool hl_device_can_access_peer(int device, int peerDevice) {
  int canAccessPeer;
437 438
  CHECK_CUDA(
      dynload::cudaDeviceCanAccessPeer(&canAccessPeer, device, peerDevice));
Z
zhangjinchao01 已提交
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479

  if (canAccessPeer == 1) {
    return true;
  } else {
    return false;
  }
}

void hl_device_enable_peer_access(int peerDevice) {
  cudaError_t err = dynload::cudaDeviceEnablePeerAccess(peerDevice, 0);
  if (cudaErrorPeerAccessAlreadyEnabled == err) {
    dynload::cudaGetLastError();
  } else {
    CHECK_CUDA(err);
  }
}

void hl_create_global_resources(hl_device_prop device_prop) {
  struct cudaDeviceProp cu_prop;
  int device = device_prop->device;
  global_device_resources device_res = device_prop->device_resources;

  CHECK_CUDA(dynload::cudaSetDevice(device));
  /* device properties */
  CHECK_CUDA(dynload::cudaGetDeviceProperties(&cu_prop, device));

  device_prop->major = cu_prop.major;
  device_prop->minor = cu_prop.minor;
  strncpy(device_prop->device_name, cu_prop.name, 256);
  device_prop->device_mem = cu_prop.totalGlobalMem;

  /* create device stream */
  for (int j = 0; j < NUMBER_OF_GLOBAL_STREAM; j++) {
    CHECK_CUDA(dynload::cudaStreamCreate(&device_res->stream[j]));
  }

  /* cublas init */
  hl_cublas_init(&device_res->handle, device_res->stream[0]);

  /* create curand gen */
  CHECK_EQ(dynload::curandCreateGenerator(&device_res->gen,
480 481 482
                                          CURAND_RNG_PSEUDO_DEFAULT),
           CURAND_STATUS_SUCCESS)
      << "[Start failed] Curand init failed.";
Z
zhangjinchao01 已提交
483

484 485 486
  CHECK_EQ(dynload::curandSetStream(device_res->gen, device_res->stream[0]),
           CURAND_STATUS_SUCCESS)
      << "[Start failed] Curand set stream failed!";
Z
zhangjinchao01 已提交
487 488 489 490 491

  /* create cudnn handle */
  hl_cudnn_init(&device_res->cudnn_handle, device_res->stream[0]);

  int seed = gettid();
492 493 494
  CHECK_EQ(dynload::curandSetPseudoRandomGeneratorSeed(device_res->gen,
                                                       seed + device),
           CURAND_STATUS_SUCCESS);
Z
zhangjinchao01 已提交
495

496
  device_res->gen_mutex = (pthread_mutex_t *)(malloc(sizeof(pthread_mutex_t)));
Z
zhangjinchao01 已提交
497 498 499 500 501
  pthread_mutex_init(device_res->gen_mutex, NULL);

  CHECK_CUDA(dynload::cudaRuntimeGetVersion(&g_cuda_lib_version));
}

502
int hl_get_cuda_version() { return g_cuda_lib_version; }
Z
zhangjinchao01 已提交
503

504
void hl_create_thread_resources(int device,
505
                                thread_device_resources device_res) {
Z
zhangjinchao01 已提交
506 507 508 509 510 511 512 513
  CHECK_CUDA(dynload::cudaSetDevice(device));

  /* create thread stream */
  for (int j = 0; j < NUMBER_OF_THREAD_STREAM; j++) {
    CHECK_CUDA(dynload::cudaStreamCreate(&device_res->stream[j]));
  }

  /* allocation device memory */
514
  device_res->gpu_mem = (real *)hl_malloc_device(HPPL_GPU_MEMORY_SIZE);
Z
zhangjinchao01 已提交
515 516

  /* allocation host memory */
517
  device_res->cpu_mem = (real *)hl_malloc_host(HPPL_GPU_MEMORY_SIZE);
Z
zhangjinchao01 已提交
518 519 520 521

  CHECK_CUDA(dynload::cudaEventCreate(&device_res->mem_event));
}

522
void hl_specify_devices_start(int *device, int number) {
Z
zhangjinchao01 已提交
523 524 525 526 527 528 529 530 531 532 533
  if (hl_start_flag) return;

  /* 1. get the number of devices */
  CHECK_CUDA(dynload::cudaGetDeviceCount(&g_system_device_num));
  CHECK_NE(g_system_device_num, 0) << "[Start failed] there is no GPU device";
  if (device == NULL) {
    number = g_system_device_num;
  }

  /* 2. check device & create device property table */
  CHECK_LE(number, g_system_device_num)
534 535
      << "[Start failed] System does not have enough device. "
      << "Device number: " << g_system_device_num << "Input number: " << number;
Z
zhangjinchao01 已提交
536 537 538

  char *tmp;
  hl_device_prop device_prop;
539 540
  tmp = (char *)malloc(g_system_device_num * sizeof(hl_device_prop *) +
                       number * sizeof(_hl_device_prop));
Z
zhangjinchao01 已提交
541 542
  CHECK(tmp) << "[Start failed] System memory is not enough.";

543 544 545 546
  g_device = (hl_device_prop *)tmp;
  device_prop = (hl_device_prop)(
      (char *)tmp + g_system_device_num * sizeof(hl_device_prop *));
  memset(g_device, 0, g_system_device_num * sizeof(hl_device_prop *));
Z
zhangjinchao01 已提交
547 548 549 550 551 552 553 554 555 556
  int num = 0;
  for (int i = 0; i < number; i++) {
    int dev;
    if (device == NULL) {
      dev = i;
    } else {
      dev = device[i];
    }

    CHECK_LT(dev, g_system_device_num)
557 558 559
        << "[Start failed] The specified device number is "
        << "out of range. Max device number: " << g_system_device_num - 1
        << " Specified devcie number: " << dev;
Z
zhangjinchao01 已提交
560 561 562

    if (g_device[dev]) {
      /* Warning */
563
      LOG(WARNING) << "[Warning] Repeat specify device: " << dev;
Z
zhangjinchao01 已提交
564 565 566 567 568 569 570 571 572 573
      continue;
    }

    g_device[dev] = &device_prop[num];
    g_device[dev]->device = dev;
    num++;
  }
  device_num = num;

  /* 3.  create global device resources */
574
  char *tmp_res = (char *)malloc(device_num * sizeof(_global_device_resources));
Z
zhangjinchao01 已提交
575 576
  CHECK_NOTNULL(tmp_res);

577 578
  char *tmp_stream = (char *)malloc(device_num * NUMBER_OF_GLOBAL_STREAM *
                                    sizeof(cudaStream_t));
Z
zhangjinchao01 已提交
579 580 581 582 583 584 585 586
  CHECK_NOTNULL(tmp_stream);

  num = 0;
  for (int i = 0; i < g_system_device_num; i++) {
    if (!g_device[i]) {
      continue;
    }

587 588 589 590 591
    g_device[i]->device_resources = (global_device_resources)(
        tmp_res + num * sizeof(_global_device_resources));
    g_device[i]->device_resources->stream =
        (cudaStream_t *)(tmp_stream +
                         num * NUMBER_OF_GLOBAL_STREAM * sizeof(cudaStream_t));
Z
zhangjinchao01 已提交
592 593 594 595 596 597 598 599 600

    hl_create_global_resources(g_device[i]);
    num++;
  }

  /* hl_start() is ok */
  hl_start_flag = true;
  /* set default device */
  if (device == NULL) {
601
    hl_set_device(0);
Z
zhangjinchao01 已提交
602
  } else {
603
    hl_set_device(device[0]);
Z
zhangjinchao01 已提交
604 605 606 607 608 609
  }
}

void hl_rand(real *dest_d, size_t num) {
  pthread_mutex_lock(t_resource.gen_mutex);
  CHECK_EQ(
610
#ifndef PADDLE_TYPE_DOUBLE
611
      dynload::curandGenerateUniform(t_resource.gen, dest_d, num),
Z
zhangjinchao01 已提交
612
#else
613
      dynload::curandGenerateUniformDouble(t_resource.gen, dest_d, num),
Z
zhangjinchao01 已提交
614
#endif
615
      CURAND_STATUS_SUCCESS);
Z
zhangjinchao01 已提交
616 617 618 619 620 621
  pthread_mutex_unlock(t_resource.gen_mutex);
  CHECK_SYNC("hl_rand failed");
}

void hl_srand(unsigned int seed) {
  pthread_mutex_lock(t_resource.gen_mutex);
622 623
  CHECK_EQ(dynload::curandSetPseudoRandomGeneratorSeed(t_resource.gen, seed),
           CURAND_STATUS_SUCCESS);
Z
zhangjinchao01 已提交
624 625 626
  pthread_mutex_unlock(t_resource.gen_mutex);
}

627
void hl_set_sync_flag(bool flag) { g_sync_flag = flag; }
Z
zhangjinchao01 已提交
628

629
bool hl_get_sync_flag() { return g_sync_flag; }
Z
zhangjinchao01 已提交
630 631 632 633

void hl_stream_synchronize(hl_stream_t stream) {
  cudaStream_t cu_stream;

634 635
  CHECK_LT(stream, HPPL_STREAM_END) << __func__
                                    << ": the parameter stream is error.";
Z
zhangjinchao01 已提交
636 637 638 639 640 641 642 643

  cu_stream = t_resource.stream[stream];
  CHECK_CUDA(dynload::cudaStreamSynchronize(cu_stream));
}

void hl_create_event(hl_event_t *event) {
  CHECK_NOTNULL(event);

644 645
  struct _hl_event_st *st_event =
      (struct _hl_event_st *)malloc(sizeof(struct _hl_event_st));
Z
zhangjinchao01 已提交
646 647 648 649 650 651 652 653 654 655 656

  CHECK_CUDA(dynload::cudaEventCreate(&st_event->cu_event));

  *event = st_event;
}

float hl_event_elapsed_time(hl_event_t start, hl_event_t end) {
  float time;
  CHECK_NOTNULL(start);
  CHECK_NOTNULL(end);

657 658
  CHECK_CUDA(
      dynload::cudaEventElapsedTime(&time, start->cu_event, end->cu_event));
Z
zhangjinchao01 已提交
659 660 661 662 663 664 665
  return time;
}

void hl_stream_record_event(hl_stream_t stream, hl_event_t event) {
  cudaStream_t cu_stream;

  CHECK_NOTNULL(event);
666 667
  CHECK_LT(stream, HPPL_STREAM_END) << __func__
                                    << ": the parameter stream is error.";
Z
zhangjinchao01 已提交
668 669

  cu_stream = t_resource.stream[stream];
670
  CHECK_CUDA(dynload::cudaEventRecord(event->cu_event, cu_stream));
Z
zhangjinchao01 已提交
671 672 673 674 675 676
}

void hl_stream_wait_event(hl_stream_t stream, hl_event_t event) {
  cudaStream_t cu_stream;

  CHECK_NOTNULL(event);
677 678
  CHECK_LT(stream, HPPL_STREAM_END) << __func__
                                    << ": the parameter stream is error.";
Z
zhangjinchao01 已提交
679 680

  cu_stream = t_resource.stream[stream];
681
  CHECK_CUDA(dynload::cudaStreamWaitEvent(cu_stream, event->cu_event, 0));
Z
zhangjinchao01 已提交
682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699
}

void hl_destroy_event(hl_event_t event) {
  CHECK_NOTNULL(event);
  CHECK_CUDA(dynload::cudaEventDestroy(event->cu_event));

  free(event);
  event = NULL;
}

void hl_event_synchronize(hl_event_t event) {
  CHECK_NOTNULL(event);
  CHECK_CUDA(dynload::cudaEventSynchronize(event->cu_event));
}

void hl_get_device_name(char *name, int len, int device) {
  CHECK_NOTNULL(name);
  CHECK(device >= 0 && device < g_system_device_num && g_device[device])
700
      << "Device(" << device << ") is not specified in startup.";
Z
zhangjinchao01 已提交
701

702
  strncpy(name, g_device[device]->device_name, len);
Z
zhangjinchao01 已提交
703 704 705 706 707
}

void hl_get_device_memory(size_t *mem_size, int device) {
  CHECK_NOTNULL(mem_size);
  CHECK(device >= 0 && device < g_system_device_num && g_device[device])
708
      << "Device(" << device << ") is not specified in startup.";
Z
zhangjinchao01 已提交
709 710 711 712 713 714 715 716

  *mem_size = g_device[device]->device_mem;
}

void hl_get_device_compute_capability(int *major, int *minor, int device) {
  CHECK_NOTNULL(major);
  CHECK_NOTNULL(minor);
  CHECK(device >= 0 && device < g_system_device_num && g_device[device])
717
      << "Device(" << device << ") is not specified in startup.";
Z
zhangjinchao01 已提交
718 719 720 721 722

  *major = g_device[device]->major;
  *minor = g_device[device]->minor;
}

723
int hl_get_device_last_error() { return (int)dynload::cudaGetLastError(); }
Z
zhangjinchao01 已提交
724

725
const char *hl_get_device_error_string() {
Z
zhangjinchao01 已提交
726 727 728 729
  cudaError_t err = dynload::cudaGetLastError();
  return dynload::cudaGetErrorString(err);
}

730
const char *hl_get_device_error_string(size_t err) {
Z
zhangjinchao01 已提交
731 732 733
  return dynload::cudaGetErrorString((cudaError_t)err);
}

734
void hl_device_synchronize() { CHECK_CUDA(dynload::cudaDeviceSynchronize()); }
Z
zhangjinchao01 已提交
735
void hl_set_device_flags_block() {
736
  CHECK_CUDA(dynload::cudaSetDeviceFlags(cudaDeviceScheduleBlockingSync));
Z
zhangjinchao01 已提交
737 738
}

L
liaogang 已提交
739
bool hl_cuda_event_is_ready(hl_event_t event) {
Z
zhangjinchao01 已提交
740 741 742 743
  cudaError_t err = dynload::cudaEventQuery(event->cu_event);
  CHECK(cudaSuccess == err || cudaErrorNotReady == err);

  if (cudaErrorNotReady == err) {
L
liaogang 已提交
744
    return false;
Z
zhangjinchao01 已提交
745
  }
L
liaogang 已提交
746
  return true;
Z
zhangjinchao01 已提交
747
}
L
liaogang 已提交
748

L
liaogang 已提交
749
void hl_profiler_start() { CHECK_CUDA(dynload::cudaProfilerStart()); }
L
liaogang 已提交
750

L
liaogang 已提交
751
void hl_profiler_end() { CHECK_CUDA(dynload::cudaProfilerStop()); }