curand.h 1.8 KB
Newer Older
Q
qijun 已提交
1 2 3 4 5
#include <curand.h>
#include "paddle/platform/dynamic_loader.h"

namespace paddle {
namespace dyload {
Q
qijun 已提交
6 7
std::once_flag curand_dso_flag;
void *curand_dso_handle = nullptr;
Q
qijun 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CURAND_WRAP(__name)                                       \
  struct DynLoad__##__name {                                                   \
    template <typename... Args>                                                \
    curandStatus_t operator()(Args... args) {                                  \
      typedef curandStatus_t (*curandFunc)(Args...);                           \
      std::call_once(curand_dso_flag, GetCurandDsoHandle, &curand_dso_handle); \
      void *p_##__name = dlsym(curand_dso_handle, #__name);                    \
      return reinterpret_cast<curandFunc>(p_##__name)(args...);                \
    }                                                                          \
  } __name; /* struct DynLoad__##__name */
#else
#define DYNAMIC_LOAD_CURAND_WRAP(__name)      \
  struct DynLoad__##__name {                  \
    template <typename... Args>               \
    curandStatus_t operator()(Args... args) { \
      return __name(args...);                 \
    }                                         \
  } __name; /* struct DynLoad__##__name */
#endif

/* include all needed curand functions in HPPL */
// clang-format off
#define CURAND_RAND_ROUTINE_EACH(__macro)    \
  __macro(curandCreateGenerator)             \
  __macro(curandSetStream)                   \
  __macro(curandSetPseudoRandomGeneratorSeed)\
  __macro(curandGenerateUniform)             \
Q
qijun 已提交
36 37
  __macro(curandGenerateUniformDouble)       \
  __macro(curandDestroyGenerator)
Q
qijun 已提交
38 39 40 41 42 43 44 45
// clang-format on

CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP)

#undef CURAND_RAND_ROUTINE_EACH
#undef DYNAMIC_LOAD_CURAND_WRAP
}
}  // namespace paddle