enforce.h 30.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18 19 20
#ifdef __GNUC__
#include <cxxabi.h>  // for __cxa_demangle
#endif               // __GNUC__

21 22 23 24 25 26 27
#if !defined(_WIN32)
#include <dlfcn.h>    // dladdr
#else                 // _WIN32
#define NOMINMAX      // msvc max/min macro conflict with std::min/max
#include <windows.h>  // GetModuleFileName
#endif

28 29 30 31 32 33 34 35
#ifdef PADDLE_WITH_CUDA
#include <cublas_v2.h>
#include <cudnn.h>
#include <curand.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#endif  // PADDLE_WITH_CUDA

36
#include <fstream>
Y
Yu Yang 已提交
37
#include <iomanip>
38
#include <iostream>
L
liaogang 已提交
39
#include <memory>
40 41 42
#include <sstream>
#include <stdexcept>
#include <string>
S
sneaxiy 已提交
43 44
#include <type_traits>
#include <utility>
45

46
#define GLOG_NO_ABBREVIATED_SEVERITIES  // msvc conflict logging with windows.h
47
#include "glog/logging.h"
48
#include "paddle/fluid/platform/cuda_error.pb.h"
49
#include "paddle/fluid/platform/errors.h"
Y
Yi Wang 已提交
50
#include "paddle/fluid/platform/macros.h"
D
dzhwinter 已提交
51
#include "paddle/fluid/platform/port.h"
52 53
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/to_string.h"
54

55
#ifdef PADDLE_WITH_CUDA
Y
Yi Wang 已提交
56 57 58
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/curand.h"
G
Guo Sheng 已提交
59
#include "paddle/fluid/platform/dynload/cusolver.h"
60
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
Y
Yi Wang 已提交
61
#include "paddle/fluid/platform/dynload/nccl.h"
Y
Yi Wang 已提交
62 63
#endif  // __APPLE__
#endif  // PADDLE_WITH_CUDA
64 65 66 67

namespace paddle {
namespace platform {

68 69
/** HELPER MACROS AND FUNCTIONS **/

Z
Zeng Jinle 已提交
70 71 72 73
#ifndef PADDLE_MAY_THROW
#define PADDLE_MAY_THROW noexcept(false)
#endif

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
// Because most enforce conditions would evaluate to true, we can use
// __builtin_expect to instruct the C++ compiler to generate code that
// always forces branch prediction of true.
// This generates faster binary code. __builtin_expect is since C++11.
// For more details, please check https://stackoverflow.com/a/43870188/724872.
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
// there is no equivalent intrinsics in msvc.
#define UNLIKELY(condition) (condition)
#endif

#if !defined(_WIN32)
#define LIKELY(condition) __builtin_expect(static_cast<bool>(condition), 1)
#else
// there is no equivalent intrinsics in msvc.
#define LIKELY(condition) (condition)
#endif

93 94 95 96 97 98 99 100 101 102 103 104 105
#if defined _WIN32 && defined PADDLE_ON_INFERENCE && defined PADDLE_NO_PYTHON
#define HANDLE_THE_ERROR try {
#define END_HANDLE_THE_ERROR            \
  }                                     \
  catch (const std::exception& e) {     \
    std::cout << e.what() << std::endl; \
    throw;                              \
  }
#else
#define HANDLE_THE_ERROR
#define END_HANDLE_THE_ERROR
#endif

L
liaogang 已提交
106 107 108 109 110 111 112 113 114 115 116
#ifdef __GNUC__
inline std::string demangle(std::string name) {
  int status = -4;  // some arbitrary value to eliminate the compiler warning
  std::unique_ptr<char, void (*)(void*)> res{
      abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free};
  return (status == 0) ? res.get() : name;
}
#else
inline std::string demangle(std::string name) { return name; }
#endif

117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
namespace details {
template <typename T>
inline constexpr bool IsArithmetic() {
  return std::is_arithmetic<T>::value;
}

template <typename T1, typename T2, bool kIsArithmetic /* = true */>
struct TypeConverterImpl {
  using Type1 = typename std::common_type<T1, T2>::type;
  using Type2 = Type1;
};

template <typename T1, typename T2>
struct TypeConverterImpl<T1, T2, false> {
  using Type1 = T1;
  using Type2 = T2;
};

template <typename T1, typename T2>
struct TypeConverter {
 private:
  static constexpr bool kIsArithmetic =
      IsArithmetic<T1>() && IsArithmetic<T2>();

 public:
  using Type1 = typename TypeConverterImpl<T1, T2, kIsArithmetic>::Type1;
  using Type2 = typename TypeConverterImpl<T1, T2, kIsArithmetic>::Type2;
};

template <typename T1, typename T2>
using CommonType1 = typename std::add_lvalue_reference<
    typename std::add_const<typename TypeConverter<T1, T2>::Type1>::type>::type;

template <typename T1, typename T2>
using CommonType2 = typename std::add_lvalue_reference<
    typename std::add_const<typename TypeConverter<T1, T2>::Type2>::type>::type;

// Here, we use SFINAE to check whether T can be converted to std::string
template <typename T>
struct CanToString {
 private:
  using YesType = uint8_t;
  using NoType = uint16_t;

  template <typename U>
  static YesType Check(decltype(std::cout << std::declval<U>())) {
    return 0;
  }

  template <typename U>
  static NoType Check(...) {
    return 0;
  }

 public:
  static constexpr bool kValue =
      std::is_same<YesType, decltype(Check<T>(std::cout))>::value;
};

template <bool kCanToString /* = true */>
struct BinaryCompareMessageConverter {
  template <typename T>
  static std::string Convert(const char* expression, const T& value) {
    return expression + std::string(":") + string::to_string(value);
  }
};

template <>
struct BinaryCompareMessageConverter<false> {
  template <typename T>
  static const char* Convert(const char* expression, const T& value) {
    return expression;
  }
};
}  // namespace details

193 194 195 196 197 198
template <typename StrType>
inline std::string GetTraceBackString(StrType&& what, const char* file,
                                      int line) {
  static constexpr int TRACE_STACK_LIMIT = 100;
  std::ostringstream sout;

199 200 201
  sout << "\n\n--------------------------------------------\n";
  sout << "C++ Call Stacks (More useful to developers):";
  sout << "\n--------------------------------------------\n";
202 203 204 205 206
#if !defined(_WIN32)
  void* call_stack[TRACE_STACK_LIMIT];
  auto size = backtrace(call_stack, TRACE_STACK_LIMIT);
  auto symbols = backtrace_symbols(call_stack, size);
  Dl_info info;
207
  int idx = 0;
208 209 210
  for (int i = 0; i < size; ++i) {
    if (dladdr(call_stack[i], &info) && info.dli_sname) {
      auto demangled = demangle(info.dli_sname);
211 212 213 214 215
      std::string path(info.dli_fname);
      // C++ traceback info are from core.so
      if (path.substr(path.length() - 3).compare(".so") == 0) {
        sout << string::Sprintf("%-3d %s\n", idx++, demangled);
      }
216 217 218 219
    }
  }
  free(symbols);
#else
220
  sout << "Windows not support stack backtrace yet.\n";
221
#endif
222 223
  sout << "\n----------------------\nError Message "
          "Summary:\n----------------------\n";
224 225
  sout << string::Sprintf("%s at (%s:%d)", std::forward<StrType>(what), file,
                          line)
226
       << std::endl;
227 228 229
  return sout.str();
}

230 231 232 233 234 235 236 237 238 239
inline bool is_error(bool stat) { return !stat; }

inline void throw_on_error(bool stat, const std::string& msg) {
#ifndef REPLACE_ENFORCE_GLOG
  throw std::runtime_error(msg);
#else
  LOG(FATAL) << msg;
#endif
}

240 241 242
// Note: This Macro can only be used within enforce.h
#define __THROW_ERROR_INTERNAL__(...)                                \
  do {                                                               \
243
    HANDLE_THE_ERROR                                                 \
244 245
    throw ::paddle::platform::EnforceNotMet(                         \
        ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
246
    END_HANDLE_THE_ERROR                                             \
247 248
  } while (0)

249 250
/** ENFORCE EXCEPTION AND MACROS **/

251
struct EnforceNotMet : public std::exception {
252
  EnforceNotMet(std::exception_ptr e, const char* file, int line) {
253
    try {
Y
Yu Yang 已提交
254 255
      std::rethrow_exception(e);
    } catch (std::exception& e) {
256
      err_str_ = GetTraceBackString(e.what(), file, line);
Y
Yu Yang 已提交
257 258
    }
  }
259

260
  EnforceNotMet(const std::string& str, const char* file, int line)
261
      : err_str_(GetTraceBackString(str, file, line)) {}
Y
Yu Yang 已提交
262

263 264 265
  EnforceNotMet(const platform::ErrorSummary& error, const char* file, int line)
      : err_str_(GetTraceBackString(error.ToString(), file, line)) {}

Y
Yu Yang 已提交
266
  const char* what() const noexcept override { return err_str_.c_str(); }
267 268

  std::string err_str_;
269 270
};

271 272
#define PADDLE_THROW(...)                                                   \
  do {                                                                      \
273
    HANDLE_THE_ERROR                                                        \
274 275
    throw ::paddle::platform::EnforceNotMet(                                \
        ::paddle::platform::ErrorSummary(__VA_ARGS__), __FILE__, __LINE__); \
276
    END_HANDLE_THE_ERROR                                                    \
277 278
  } while (0)

279 280 281 282
#if defined(__CUDA_ARCH__)
// For cuda, the assertions can affect performance and it is therefore
// recommended to disable them in production code
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
283 284 285 286 287 288 289
#define PADDLE_ENFORCE(_IS_NOT_ERROR, __FORMAT, ...)                         \
  do {                                                                       \
    if (!(_IS_NOT_ERROR)) {                                                  \
      printf("Error: %s:%d Assertion `%s` failed. " __FORMAT "\n", __FILE__, \
             __LINE__, #_IS_NOT_ERROR, ##__VA_ARGS__);                       \
      asm("trap;");                                                          \
    }                                                                        \
290 291
  } while (0)
#else
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
#define PADDLE_ENFORCE(COND, ...)                                         \
  do {                                                                    \
    auto __cond__ = (COND);                                               \
    if (UNLIKELY(::paddle::platform::is_error(__cond__))) {               \
      try {                                                               \
        ::paddle::platform::throw_on_error(                               \
            __cond__,                                                     \
            ::paddle::platform::ErrorSummary(__VA_ARGS__).ToString());    \
      } catch (...) {                                                     \
        HANDLE_THE_ERROR                                                  \
        throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
                                                __FILE__, __LINE__);      \
        END_HANDLE_THE_ERROR                                              \
      }                                                                   \
    }                                                                     \
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
  } while (0)
#endif

/*
 * Some enforce helpers here, usage:
 *    int a = 1;
 *    int b = 2;
 *    PADDLE_ENFORCE_EQ(a, b);
 *
 *    will raise an expression described as follows:
 *    "Expected input a == b, but received a(1) != b(2)."
 *      with detailed stack information.
 *
 *    extra messages is also supported, for example:
 *    PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2)
 */
323

324 325 326
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...)                          \
  do {                                                               \
    if (UNLIKELY(nullptr == (__VAL))) {                              \
327
      __THROW_ERROR_INTERNAL__(                                      \
328 329 330
          "%s\n  [Hint: " #__VAL " should not be null.]",            \
          ::paddle::platform::ErrorSummary(__VA_ARGS__).ToString()); \
    }                                                                \
331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
  } while (0)

#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...)         \
  do {                                                                         \
    auto __val1 = (__VAL1);                                                    \
    auto __val2 = (__VAL2);                                                    \
    using __TYPE1__ = decltype(__val1);                                        \
    using __TYPE2__ = decltype(__val2);                                        \
    using __COMMON_TYPE1__ =                                                   \
        ::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>;        \
    using __COMMON_TYPE2__ =                                                   \
        ::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>;        \
    bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP(        \
        static_cast<__COMMON_TYPE2__>(__val2));                                \
    if (UNLIKELY(!__is_not_error)) {                                           \
      constexpr bool __kCanToString__ =                                        \
          ::paddle::platform::details::CanToString<__TYPE1__>::kValue &&       \
          ::paddle::platform::details::CanToString<__TYPE2__>::kValue;         \
349
      __THROW_ERROR_INTERNAL__(                                                \
350 351 352 353
          "%s\n  [Hint: Expected %s " #__CMP                                   \
          " %s, but received %s " #__INV_CMP " %s.]",                          \
          ::paddle::platform::ErrorSummary(__VA_ARGS__).ToString(), #__VAL1,   \
          #__VAL2, ::paddle::platform::details::BinaryCompareMessageConverter< \
354
                       __kCanToString__>::Convert(#__VAL1, __val1),            \
355 356
          ::paddle::platform::details::BinaryCompareMessageConverter<          \
              __kCanToString__>::Convert(#__VAL2, __val2));                    \
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
    }                                                                          \
  } while (0)

#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \
  __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, ==, !=, __VA_ARGS__)
#define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) \
  __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, !=, ==, __VA_ARGS__)
#define PADDLE_ENFORCE_GT(__VAL0, __VAL1, ...) \
  __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >, <=, __VA_ARGS__)
#define PADDLE_ENFORCE_GE(__VAL0, __VAL1, ...) \
  __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >=, <, __VA_ARGS__)
#define PADDLE_ENFORCE_LT(__VAL0, __VAL1, ...) \
  __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
  __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)

373 374
/** EXTENDED TOOL FUNCTIONS WITH CHECKING **/

375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
/*
 * Summary: This macro is used to get Variable or internal type
 *   data (such as LoDTensor or SelectedRows) of the Input and
 *   Output in op, generally used when call scope.FindVar(Input/
 *   Output("Name")) or ctx.Input<LoDTensor>().
 *   Firstly this macro check whether the obtained pointer is null,
 *   and then return data if it is not null.
 *
 * Note: This macro is only suitable for specific scenarios and
 *   does not intended to be widely used. If it cannot meet the
 *   requirements, please use other PADDLE_ENFORCE** check macro.
 *
 * Parameters:
 *     __PTR: pointer
 *     __ROLE: (string), Input or Output
 *     __NAME: (string), Input or Output name
 *     __OP_TYPE: (string), the op type
 *  
 * Return: The data pointed to by the pointer.
 *
 * Examples:
 *    GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X", "Mul");
*/
#define GET_DATA_SAFELY(__PTR, __ROLE, __NAME, __OP_TYPE)                   \
  (([&]() -> std::add_lvalue_reference<decltype(*(__PTR))>::type {          \
400 401
    auto* __ptr = (__PTR);                                                  \
    if (UNLIKELY(nullptr == __ptr)) {                                       \
402 403 404 405 406 407 408 409 410
      __THROW_ERROR_INTERNAL__(                                             \
          "%s\n  [Hint: pointer " #__PTR " should not be null.]",           \
          paddle::platform::errors::NotFound(                               \
              "Unable to get %s data of %s %s in operator %s. "             \
              "Possible reasons are:\n"                                     \
              "  1. The %s is not the %s of operator %s;\n"                 \
              "  2. The %s has no corresponding variable passed in;\n"      \
              "  3. The %s corresponding variable is not initialized.",     \
              paddle::platform::demangle(                                   \
411
                  typeid(std::add_lvalue_reference<decltype(*__ptr)>::type) \
412 413 414 415 416
                      .name()),                                             \
              __ROLE, __NAME, __OP_TYPE, __NAME, __ROLE, __OP_TYPE, __NAME, \
              __NAME)                                                       \
              .ToString());                                                 \
    }                                                                       \
417
    return *__ptr;                                                          \
418 419
  })())

420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
/*
 * Summary: This macro is used to check whether op has specified
 * Input or Output Variables. Because op's Input and Output
 * checking are written similarly, so abstract this macro.
 *
 * Parameters:
 *     __EXPR: (bool), the bool expression
 *     __ROLE: (string), Input or Output
 *     __NAME: (string), Input or Output name
 *     __OP_TYPE: (string), the op type
 *
 * Examples:
 *    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Mul");
*/
#define OP_INOUT_CHECK(__EXPR, __ROLE, __NAME, __OP_TYPE)                   \
  do {                                                                      \
    PADDLE_ENFORCE_EQ(__EXPR, true, paddle::platform::errors::NotFound(     \
                                        "No %s(%s) found for %s operator.", \
                                        __ROLE, __NAME, __OP_TYPE));        \
  } while (0)

441 442
/** OTHER EXCEPTION AND ENFORCE **/

443 444
struct EOFException : public std::exception {
  std::string err_str_;
445 446
  EOFException(const char* err_msg, const char* file, int line) {
    err_str_ = string::Sprintf("%s at [%s:%d]", err_msg, file, line);
447 448
  }

449
  const char* what() const noexcept override { return err_str_.c_str(); }
450 451
};

452 453
#define PADDLE_THROW_EOF()                                                     \
  do {                                                                         \
454
    HANDLE_THE_ERROR                                                           \
455 456
    throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \
                                           __LINE__);                          \
457
    END_HANDLE_THE_ERROR                                                       \
458
  } while (0)
459

460 461
#define PADDLE_THROW_BAD_ALLOC(...)                                         \
  do {                                                                      \
462
    HANDLE_THE_ERROR                                                        \
463 464 465
    throw ::paddle::memory::allocation::BadAlloc(                           \
        ::paddle::platform::ErrorSummary(__VA_ARGS__).ToString(), __FILE__, \
        __LINE__);                                                          \
466
    END_HANDLE_THE_ERROR                                                    \
467
  } while (0)
M
minqiyang 已提交
468

469
/** CUDA PADDLE ENFORCE FUNCTIONS AND MACROS **/
470
#ifdef PADDLE_WITH_CUDA
471

472
/***** CUDA ERROR *****/
S
sneaxiy 已提交
473
inline bool is_error(cudaError_t e) { return e != cudaSuccess; }
M
minqiyang 已提交
474

475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560
inline std::string GetCudaErrorWebsite(int32_t cuda_version) {
  std::ostringstream webstr;
  webstr << "https://docs.nvidia.com/cuda/";
  if (cuda_version != -1) {
    double version = cuda_version / 10;
    webstr << "archive/" << std::fixed << std::setprecision(1) << version;
  }
  webstr << "/cuda-runtime-api/group__CUDART__TYPES.html"
            "#group__CUDART__TYPES_1g3f51e3575c2178246db0a94a430e0038";
  return webstr.str();
}

inline std::string build_nvidia_error_msg(cudaError_t e) {
#if CUDA_VERSION >= 10000 && CUDA_VERSION < 11000
  int32_t cuda_version = 100;
#elif CUDA_VERSION >= 9000
  int32_t cuda_version = 90;
#else
  int32_t cuda_version = -1;
#endif
  std::ostringstream sout;
  sout << " Cuda error(" << e << "), " << cudaGetErrorString(e) << ".";
  static platform::proto::cudaerrorDesc cudaerror;
  static bool _initSucceed = false;
  if (cudaerror.ByteSizeLong() == 0) {
    std::string filePath;
#if !defined(_WIN32)
    Dl_info info;
    if (dladdr(reinterpret_cast<void*>(GetCudaErrorWebsite), &info)) {
      std::string strModule(info.dli_fname);
      const size_t last_slash_idx = strModule.find_last_of("/");
      std::string compare_path = strModule.substr(strModule.length() - 6);
      if (std::string::npos != last_slash_idx) {
        strModule.erase(last_slash_idx, std::string::npos);
      }
      if (compare_path.compare("avx.so") == 0) {
        filePath = strModule +
                   "/../include/third_party/cudaerror/data/cudaErrorMessage.pb";
      } else {
        filePath =
            strModule + "/../../thirl_party/cudaerror/data/cudaErrorMessage.pb";
      }
    }
#else
    char buf[100];
    MEMORY_BASIC_INFORMATION mbi;
    HMODULE h_module =
        (::VirtualQuery(GetCudaErrorWebsite, &mbi, sizeof(mbi)) != 0)
            ? (HMODULE)mbi.AllocationBase
            : NULL;
    GetModuleFileName(h_module, buf, 100);
    std::string strModule(buf);
    const size_t last_slash_idx = strModule.find_last_of("\\");
    std::string compare_path = strModule.substr(strModule.length() - 7);
    if (std::string::npos != last_slash_idx) {
      strModule.erase(last_slash_idx, std::string::npos);
    }
    if (compare_path.compare("avx.pyd") == 0) {
      filePath =
          strModule +
          "\\..\\include\\third_party\\cudaerror\\data\\cudaErrorMessage.pb";
    } else {
      filePath =
          strModule + "\\..\\third_party\\cudaerror\\data\\cudaErrorMessage.pb";
    }
#endif
    std::ifstream fin(filePath, std::ios::in | std::ios::binary);
    _initSucceed = cudaerror.ParseFromIstream(&fin);
  }
  if (_initSucceed) {
    for (int i = 0; i < cudaerror.allmessages_size(); ++i) {
      if (cuda_version == cudaerror.allmessages(i).version()) {
        for (int j = 0; j < cudaerror.allmessages(i).messages_size(); ++j) {
          if (e == cudaerror.allmessages(i).messages(j).errorcode()) {
            sout << "\n  [Advise: "
                 << cudaerror.allmessages(i).messages(j).errormessage() << "]";
            return sout.str();
          }
        }
      }
    }
  }
  sout << "\n  [Advise: Please search for the error code(" << e
       << ") on website( " << GetCudaErrorWebsite(cuda_version)
       << " ) to get Nvidia's official solution about CUDA Error.]";
  return sout.str();
561 562
}

S
sneaxiy 已提交
563
inline void throw_on_error(cudaError_t e, const std::string& msg) {
564
#ifndef REPLACE_ENFORCE_GLOG
565
  throw std::runtime_error(msg);
566
#else
S
sneaxiy 已提交
567
  LOG(FATAL) << msg;
568
#endif
M
minqiyang 已提交
569 570
}

571
/** curand ERROR **/
M
minqiyang 已提交
572 573
inline bool is_error(curandStatus_t stat) {
  return stat != CURAND_STATUS_SUCCESS;
574 575
}

576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
inline const char* curandGetErrorString(curandStatus_t stat) {
  switch (stat) {
    case CURAND_STATUS_SUCCESS:
      return "CURAND_STATUS_SUCCESS";
    case CURAND_STATUS_VERSION_MISMATCH:
      return "CURAND_STATUS_VERSION_MISMATCH";
    case CURAND_STATUS_NOT_INITIALIZED:
      return "CURAND_STATUS_NOT_INITIALIZED";
    case CURAND_STATUS_ALLOCATION_FAILED:
      return "CURAND_STATUS_ALLOCATION_FAILED";
    case CURAND_STATUS_TYPE_ERROR:
      return "CURAND_STATUS_TYPE_ERROR";
    case CURAND_STATUS_OUT_OF_RANGE:
      return "CURAND_STATUS_OUT_OF_RANGE";
    case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
      return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
    case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
      return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
    case CURAND_STATUS_LAUNCH_FAILURE:
      return "CURAND_STATUS_LAUNCH_FAILURE";
    case CURAND_STATUS_PREEXISTING_FAILURE:
      return "CURAND_STATUS_PREEXISTING_FAILURE";
    case CURAND_STATUS_INITIALIZATION_FAILED:
      return "CURAND_STATUS_INITIALIZATION_FAILED";
    case CURAND_STATUS_ARCH_MISMATCH:
      return "CURAND_STATUS_ARCH_MISMATCH";
    case CURAND_STATUS_INTERNAL_ERROR:
      return "CURAND_STATUS_INTERNAL_ERROR";
    default:
      return "Unknown curand status";
  }
}

inline std::string build_nvidia_error_msg(curandStatus_t stat) {
  std::string msg(" Curand error, ");
  return msg + curandGetErrorString(stat) + " ";
612 613
}

S
sneaxiy 已提交
614
inline void throw_on_error(curandStatus_t stat, const std::string& msg) {
615
#ifndef REPLACE_ENFORCE_GLOG
M
minqiyang 已提交
616
  throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
S
sneaxiy 已提交
617
                             msg);
618
#else
S
sneaxiy 已提交
619
  LOG(FATAL) << msg;
620
#endif
M
minqiyang 已提交
621 622
}

623
/***** CUDNN ERROR *****/
M
minqiyang 已提交
624 625
inline bool is_error(cudnnStatus_t stat) {
  return stat != CUDNN_STATUS_SUCCESS;
626 627
}

628 629 630
inline std::string build_nvidia_error_msg(cudnnStatus_t stat) {
  std::string msg(" Cudnn error, ");
  return msg + platform::dynload::cudnnGetErrorString(stat) + " ";
631 632
}

S
sneaxiy 已提交
633
inline void throw_on_error(cudnnStatus_t stat, const std::string& msg) {
634
#ifndef REPLACE_ENFORCE_GLOG
635
  throw std::runtime_error(msg);
636
#else
637
  LOG(FATAL) << msg;
638
#endif
M
minqiyang 已提交
639 640
}

641
/***** CUBLAS ERROR *****/
M
minqiyang 已提交
642 643
inline bool is_error(cublasStatus_t stat) {
  return stat != CUBLAS_STATUS_SUCCESS;
644 645
}

646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667
inline const char* cublasGetErrorString(cublasStatus_t stat) {
  switch (stat) {
    case CUBLAS_STATUS_NOT_INITIALIZED:
      return "CUBLAS_STATUS_NOT_INITIALIZED";
    case CUBLAS_STATUS_ALLOC_FAILED:
      return "CUBLAS_STATUS_ALLOC_FAILED";
    case CUBLAS_STATUS_INVALID_VALUE:
      return "CUBLAS_STATUS_INVALID_VALUE";
    case CUBLAS_STATUS_ARCH_MISMATCH:
      return "CUBLAS_STATUS_ARCH_MISMATCH";
    case CUBLAS_STATUS_MAPPING_ERROR:
      return "CUBLAS_STATUS_MAPPING_ERROR";
    case CUBLAS_STATUS_EXECUTION_FAILED:
      return "CUBLAS_STATUS_EXECUTION_FAILED";
    case CUBLAS_STATUS_INTERNAL_ERROR:
      return "CUBLAS_STATUS_INTERNAL_ERROR";
    case CUBLAS_STATUS_NOT_SUPPORTED:
      return "CUBLAS_STATUS_NOT_SUPPORTED";
    case CUBLAS_STATUS_LICENSE_ERROR:
      return "CUBLAS_STATUS_LICENSE_ERROR";
    default:
      return "Unknown cublas status";
668
  }
669 670 671 672 673
}

inline std::string build_nvidia_error_msg(cublasStatus_t stat) {
  std::string msg(" Cublas error, ");
  return msg + cublasGetErrorString(stat) + " ";
674 675 676
}

inline void throw_on_error(cublasStatus_t stat, const std::string& msg) {
677
#ifndef REPLACE_ENFORCE_GLOG
678
  throw std::runtime_error(msg);
679
#else
680
  LOG(FATAL) << msg;
681
#endif
682 683
}

G
Guo Sheng 已提交
684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721
/***** CUSOLVER ERROR *****/
inline bool is_error(cusolverStatus_t stat) {
  return stat != CUSOLVER_STATUS_SUCCESS;
}

inline const char* cusolverGetErrorString(cusolverStatus_t stat) {
  switch (stat) {
    case CUSOLVER_STATUS_NOT_INITIALIZED:
      return "CUSOLVER_STATUS_NOT_INITIALIZED";
    case CUSOLVER_STATUS_ALLOC_FAILED:
      return "CUSOLVER_STATUS_ALLOC_FAILED";
    case CUSOLVER_STATUS_INVALID_VALUE:
      return "CUSOLVER_STATUS_INVALID_VALUE";
    case CUSOLVER_STATUS_ARCH_MISMATCH:
      return "CUSOLVER_STATUS_ARCH_MISMATCH";
    case CUSOLVER_STATUS_EXECUTION_FAILED:
      return "CUSOLVER_STATUS_EXECUTION_FAILED";
    case CUSOLVER_STATUS_INTERNAL_ERROR:
      return "CUSOLVER_STATUS_INTERNAL_ERROR";
    case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED:
      return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED";
    default:
      return "Unknown cusolver status";
  }
}
inline std::string build_nvidia_error_msg(cusolverStatus_t stat) {
  std::string msg(" Cublas error, ");
  return msg + cusolverGetErrorString(stat) + " ";
}

inline void throw_on_error(cusolverStatus_t stat, const std::string& msg) {
#ifndef REPLACE_ENFORCE_GLOG
  throw std::runtime_error(msg);
#else
  LOG(FATAL) << msg;
#endif
}

722
/****** NCCL ERROR ******/
723
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
S
sneaxiy 已提交
724 725 726 727
inline bool is_error(ncclResult_t nccl_result) {
  return nccl_result != ncclSuccess;
}

728 729 730
inline std::string build_nvidia_error_msg(ncclResult_t nccl_result) {
  std::string msg(" Nccl error, ");
  return msg + platform::dynload::ncclGetErrorString(nccl_result) + " ";
731 732 733
}

inline void throw_on_error(ncclResult_t nccl_result, const std::string& msg) {
734
#ifndef REPLACE_ENFORCE_GLOG
735
  throw std::runtime_error(msg);
736
#else
737
  LOG(FATAL) << msg;
738
#endif
Y
Yu Yang 已提交
739
}
740
#endif  // not(__APPLE__) and PADDLE_WITH_NCCL
741

742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757
namespace details {

template <typename T>
struct CudaStatusType {};

#define DEFINE_CUDA_STATUS_TYPE(type, success_value) \
  template <>                                        \
  struct CudaStatusType<type> {                      \
    using Type = type;                               \
    static constexpr Type kSuccess = success_value;  \
  }

DEFINE_CUDA_STATUS_TYPE(cudaError_t, cudaSuccess);
DEFINE_CUDA_STATUS_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS);
DEFINE_CUDA_STATUS_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS);
DEFINE_CUDA_STATUS_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS);
G
Guo Sheng 已提交
758
DEFINE_CUDA_STATUS_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS);
759

760
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
761 762 763 764
DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess);
#endif

}  // namespace details
M
minqiyang 已提交
765

766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786
#define PADDLE_ENFORCE_CUDA_SUCCESS(COND)                                 \
  do {                                                                    \
    auto __cond__ = (COND);                                               \
    using __CUDA_STATUS_TYPE__ = decltype(__cond__);                      \
    constexpr auto __success_type__ =                                     \
        ::paddle::platform::details::CudaStatusType<                      \
            __CUDA_STATUS_TYPE__>::kSuccess;                              \
    if (UNLIKELY(__cond__ != __success_type__)) {                         \
      try {                                                               \
        ::paddle::platform::throw_on_error(                               \
            __cond__,                                                     \
            ::paddle::platform::errors::External(                         \
                ::paddle::platform::build_nvidia_error_msg(__cond__))     \
                .ToString());                                             \
      } catch (...) {                                                     \
        HANDLE_THE_ERROR                                                  \
        throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
                                                __FILE__, __LINE__);      \
        END_HANDLE_THE_ERROR                                              \
      }                                                                   \
    }                                                                     \
787 788 789
  } while (0)

#undef DEFINE_CUDA_STATUS_TYPE
790
#endif  // PADDLE_WITH_CUDA
S
add EQ  
Superjom 已提交
791

792 793
}  // namespace platform
}  // namespace paddle