enforce.h 43.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#ifdef __GNUC__
#include <cxxabi.h>  // for __cxa_demangle
#endif               // __GNUC__

#if !defined(_WIN32)
#include <dlfcn.h>   // dladdr
#include <unistd.h>  // sleep, usleep
#else                // _WIN32
#ifndef NOMINMAX
#define NOMINMAX  // msvc max/min macro conflict with std::min/max
#endif
#include <windows.h>  // GetModuleFileName, Sleep
#endif

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
#ifdef PADDLE_WITH_CUDA
#include <cublas_v2.h>
#include <cudnn.h>
#include <cufft.h>
#include <curand.h>
#include <cusparse.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>

#include "paddle/phi/core/external_error.pb.h"
#endif  // PADDLE_WITH_CUDA

#ifdef PADDLE_WITH_HIP
#include <hiprand.h>
#include <miopen/miopen.h>
#include <rocblas.h>
#include <thrust/system/hip/error.h>
#include <thrust/system_error.h>  // NOLINT
#endif

#include <fstream>
#include <iomanip>
#include <memory>
51 52 53 54
#include <sstream>
#include <stdexcept>
#include <string>
#include <type_traits>
55
#include <utility>
56 57 58 59 60

#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
#include <execinfo.h>
#endif

61
#define GLOG_NO_ABBREVIATED_SEVERITIES  // msvc conflict logging with windows.h
62 63
#include "gflags/gflags.h"
#include "glog/logging.h"
64
#include "paddle/phi/core/errors.h"
65 66

#include "paddle/phi/backends/dynload/port.h"
67 68
#include "paddle/utils/string/printf.h"
#include "paddle/utils/string/to_string.h"
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101

#ifdef PADDLE_WITH_CUDA
#include "paddle/phi/backends/dynload/cublas.h"
#include "paddle/phi/backends/dynload/cudnn.h"
#include "paddle/phi/backends/dynload/curand.h"
#include "paddle/phi/backends/dynload/cusolver.h"
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
#include <error.h>

#include "paddle/phi/backends/dynload/nccl.h"
#endif  // __APPLE__
#endif  // PADDLE_WITH_CUDA

#ifdef PADDLE_WITH_HIP
#include "paddle/phi/backends/dynload/hipfft.h"
#include "paddle/phi/backends/dynload/hiprand.h"
#include "paddle/phi/backends/dynload/miopen.h"
#include "paddle/phi/backends/dynload/rocblas.h"
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
#include <error.h>  // NOLINT

#include "paddle/phi/backends/dynload/rccl.h"
#endif  // __APPLE__
#endif  // PADDLE_WITH_HIP

// Note: these headers for simplify demangle type string
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/type_defs.h"
// Note: this header for simplify HIP and CUDA type string
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/phi/backends/gpu/gpu_types.h"
#endif

102
#include "paddle/utils/variant.h"
103

104
namespace phi {
105
class ErrorSummary;
106
}  // namespace phi
107

108 109 110 111
namespace phi {
namespace proto {}  // namespace proto
}  // namespace phi

112
namespace phi {
113 114
namespace enforce {

115
/** HELPER MACROS AND FUNCTIONS **/
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
#ifndef PADDLE_MAY_THROW
#define PADDLE_MAY_THROW noexcept(false)
#endif

// Because most enforce conditions would evaluate to true, we can use
// __builtin_expect to instruct the C++ compiler to generate code that
// always forces branch prediction of true.
// This generates faster binary code. __builtin_expect is since C++11.
// For more details, please check https://stackoverflow.com/a/43870188/724872.
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
// there is no equivalent intrinsics in msvc.
#define UNLIKELY(condition) (condition)
#endif

#if !defined(_WIN32)
#define LIKELY(condition) __builtin_expect(static_cast<bool>(condition), 1)
#else
// there is no equivalent intrinsics in msvc.
#define LIKELY(condition) (condition)
#endif

#if defined _WIN32 && defined PADDLE_ON_INFERENCE && defined PADDLE_NO_PYTHON
#define HANDLE_THE_ERROR try {
#define END_HANDLE_THE_ERROR            \
  }                                     \
  catch (const std::exception& e) {     \
    std::cout << e.what() << std::endl; \
    throw;                              \
  }
#else
#define HANDLE_THE_ERROR
#define END_HANDLE_THE_ERROR
#endif

#ifdef __GNUC__
inline std::string demangle(std::string name) {
  int status = -4;  // some arbitrary value to eliminate the compiler warning
  std::unique_ptr<char, void (*)(void*)> res{
      abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free};
  return (status == 0) ? res.get() : name;
}
#else
inline std::string demangle(std::string name) { return name; }
#endif

namespace details {
template <typename T>
inline constexpr bool IsArithmetic() {
  return std::is_arithmetic<T>::value;
}

template <typename T1, typename T2, bool kIsArithmetic /* = true */>
struct TypeConverterImpl {
  using Type1 = typename std::common_type<T1, T2>::type;
  using Type2 = Type1;
};

template <typename T1, typename T2>
struct TypeConverterImpl<T1, T2, false> {
  using Type1 = T1;
  using Type2 = T2;
};

template <typename T1, typename T2>
struct TypeConverter {
  static constexpr bool kIsArithmetic =
      IsArithmetic<T1>() && IsArithmetic<T2>();
  using Type1 = typename TypeConverterImpl<T1, T2, kIsArithmetic>::Type1;
  using Type2 = typename TypeConverterImpl<T1, T2, kIsArithmetic>::Type2;
};

template <typename T1, typename T2>
using CommonType1 = typename std::add_lvalue_reference<
    typename std::add_const<typename TypeConverter<T1, T2>::Type1>::type>::type;

template <typename T1, typename T2>
using CommonType2 = typename std::add_lvalue_reference<
    typename std::add_const<typename TypeConverter<T1, T2>::Type2>::type>::type;

// Here, we use SFINAE to check whether T can be converted to std::string
template <typename T>
struct CanToString {
 private:
  using YesType = uint8_t;
  using NoType = uint16_t;

  template <typename U>
  static YesType Check(decltype(std::cout << std::declval<U>())) {
    return 0;
  }

  template <typename U>
  static NoType Check(...) {
    return 0;
  }

 public:
  static constexpr bool kValue =
      std::is_same<YesType, decltype(Check<T>(std::cout))>::value;
};

template <bool kCanToString /* = true */>
struct BinaryCompareMessageConverter {
  template <typename T>
  static std::string Convert(const char* expression, const T& value) {
    return expression + std::string(":") + paddle::string::to_string(value);
  }
};

template <>
struct BinaryCompareMessageConverter<false> {
  template <typename T>
  static const char* Convert(const char* expression, const T& value) {
    return expression;
  }
};
}  // namespace details

236
int GetCallStackLevel();
237 238
std::string GetCurrentTraceBackString(bool for_signal = false);
std::string SimplifyErrorTypeFormat(const std::string& str);
239 240

template <typename StrType>
241
static std::string GetErrorSumaryString(StrType&& what,
242 243 244
                                        const char* file,
                                        int line) {
  std::ostringstream sout;
245
  if (GetCallStackLevel() > 1) {
246 247 248 249 250 251 252 253 254 255
    sout << "\n----------------------\nError Message "
            "Summary:\n----------------------\n";
  }
  sout << paddle::string::Sprintf(
              "%s (at %s:%d)", std::forward<StrType>(what), file, line)
       << std::endl;
  return sout.str();
}

template <typename StrType>
256 257 258 259 260 261 262 263 264 265 266 267 268 269
std::string GetCompleteTraceBackString(StrType&& what,
                                       const char* file,
                                       int line) {
  std::ostringstream sout;
  sout << "\n----------------------\nError Message "
          "Summary:\n----------------------\n";
  sout << paddle::string::Sprintf(
              "%s (at %s:%d)", std::forward<StrType>(what), file, line)
       << std::endl;
  return GetCurrentTraceBackString() + sout.str();
}

template <typename StrType>
static std::string GetTraceBackString(StrType&& what,
270 271
                                      const char* file,
                                      int line) {
272
  if (GetCallStackLevel() > 1) {
273 274 275 276 277 278 279 280 281 282
    // FLAGS_call_stack_level>1 means showing c++ call stack
    return GetCurrentTraceBackString() + GetErrorSumaryString(what, file, line);
  } else {
    return GetErrorSumaryString(what, file, line);
  }
}

inline bool is_error(bool stat) { return !stat; }

// Note: This Macro can only be used within enforce.h
283 284 285 286 287
#define __THROW_ERROR_INTERNAL__(__ERROR_SUMMARY)                             \
  do {                                                                        \
    HANDLE_THE_ERROR                                                          \
    throw ::phi::enforce::EnforceNotMet(__ERROR_SUMMARY, __FILE__, __LINE__); \
    END_HANDLE_THE_ERROR                                                      \
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
  } while (0)

/** ENFORCE EXCEPTION AND MACROS **/

struct EnforceNotMet : public std::exception {
 public:
  EnforceNotMet(std::exception_ptr e, const char* file, int line) {
    try {
      std::rethrow_exception(e);
    } catch (EnforceNotMet& e) {
      code_ = e.code();
      err_str_ = GetTraceBackString(e.what(), file, line);
      simple_err_str_ = SimplifyErrorTypeFormat(err_str_);
    } catch (std::exception& e) {
      err_str_ = GetTraceBackString(e.what(), file, line);
      simple_err_str_ = SimplifyErrorTypeFormat(err_str_);
    }
  }

  EnforceNotMet(const std::string& str, const char* file, int line)
      : err_str_(GetTraceBackString(str, file, line)) {
    simple_err_str_ = SimplifyErrorTypeFormat(err_str_);
  }

312
  EnforceNotMet(const phi::ErrorSummary& error, const char* file, int line)
313 314 315 316 317 318
      : code_(error.code()),
        err_str_(GetTraceBackString(error.to_string(), file, line)) {
    simple_err_str_ = SimplifyErrorTypeFormat(err_str_);
  }

  const char* what() const noexcept override {
319
    if (GetCallStackLevel() > 1) {
320 321 322 323 324 325
      return err_str_.c_str();
    } else {
      return simple_err_str_.c_str();
    }
  }

326
  phi::ErrorCode code() const { return code_; }
327 328 329 330 331 332

  const std::string& error_str() const { return err_str_; }

  const std::string& simple_error_str() const { return simple_err_str_; }

  void set_error_str(std::string str) {
333
    if (GetCallStackLevel() > 1) {
334 335 336 337 338 339
      err_str_ = str;
    } else {
      simple_err_str_ = str;
    }
  }

340 341
  ~EnforceNotMet() override = default;

342 343
 private:
  // Used to determine the final type of exception thrown
344
  phi::ErrorCode code_ = phi::ErrorCode::LEGACY;
345 346 347 348 349 350 351 352
  // Complete error message
  // e.g. InvalidArgumentError: ***
  std::string err_str_;
  // Simple errror message used when no C++ stack and python compile stack
  // e.g. (InvalidArgument) ***
  std::string simple_err_str_;
};

353 354 355 356 357 358
#define PADDLE_THROW(...)                                      \
  do {                                                         \
    HANDLE_THE_ERROR                                           \
    throw ::phi::enforce::EnforceNotMet(                       \
        ::phi::ErrorSummary(__VA_ARGS__), __FILE__, __LINE__); \
    END_HANDLE_THE_ERROR                                       \
359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
  } while (0)

#if defined(__CUDA_ARCH__)
// For cuda, the assertions can affect performance and it is therefore
// recommended to disable them in production code
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
#define PADDLE_ENFORCE(_IS_NOT_ERROR, __FORMAT, ...)               \
  do {                                                             \
    if (!(_IS_NOT_ERROR)) {                                        \
      printf("Error: %s:%d Assertion `%s` failed. " __FORMAT "\n", \
             __FILE__,                                             \
             __LINE__,                                             \
             #_IS_NOT_ERROR,                                       \
             ##__VA_ARGS__);                                       \
      asm("trap;");                                                \
    }                                                              \
  } while (0)
#elif defined(__HIPCC__)
#define PADDLE_ENFORCE(_IS_NOT_ERROR, __FORMAT, ...)               \
  do {                                                             \
    if (!(_IS_NOT_ERROR)) {                                        \
      printf("Error: %s:%d Assertion `%s` failed. " __FORMAT "\n", \
             __FILE__,                                             \
             __LINE__,                                             \
             #_IS_NOT_ERROR,                                       \
             ##__VA_ARGS__);                                       \
      abort();                                                     \
    }                                                              \
  } while (0)
#else
389 390 391 392 393 394
#define PADDLE_ENFORCE(COND, ...)                               \
  do {                                                          \
    auto __cond__ = (COND);                                     \
    if (UNLIKELY(::phi::is_error(__cond__))) {                  \
      __THROW_ERROR_INTERNAL__(phi::ErrorSummary(__VA_ARGS__)); \
    }                                                           \
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
  } while (0)
#endif

/*
 * Some enforce helpers here, usage:
 *    int a = 1;
 *    int b = 2;
 *    PADDLE_ENFORCE_EQ(a, b);
 *
 *    will raise an expression described as follows:
 *    "Expected input a == b, but received a(1) != b(2)."
 *      with detailed stack information.
 *
 *    extra messages is also supported, for example:
 *    PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2)
 */

412 413 414 415 416 417 418 419 420 421
#define PADDLE_ENFORCE_NOT_NULL(__VAL, ...)                    \
  do {                                                         \
    if (UNLIKELY(nullptr == (__VAL))) {                        \
      auto __summary__ = phi::ErrorSummary(__VA_ARGS__);       \
      auto __message__ = ::paddle::string::Sprintf(            \
          "%s\n  [Hint: " #__VAL " should not be null.]",      \
          __summary__.error_message());                        \
      __THROW_ERROR_INTERNAL__(                                \
          phi::ErrorSummary(__summary__.code(), __message__)); \
    }                                                          \
422 423 424 425 426 427 428 429 430
  } while (0)

#define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...)  \
  do {                                                                  \
    auto __val1 = (__VAL1);                                             \
    auto __val2 = (__VAL2);                                             \
    using __TYPE1__ = decltype(__val1);                                 \
    using __TYPE2__ = decltype(__val2);                                 \
    using __COMMON_TYPE1__ =                                            \
431
        ::phi::details::CommonType1<__TYPE1__, __TYPE2__>;              \
432
    using __COMMON_TYPE2__ =                                            \
433
        ::phi::details::CommonType2<__TYPE1__, __TYPE2__>;              \
434 435 436
    bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
        static_cast<__COMMON_TYPE2__>(__val2));                         \
    if (UNLIKELY(!__is_not_error)) {                                    \
437
      auto __summary__ = phi::ErrorSummary(__VA_ARGS__);                \
438
      constexpr bool __kCanToString__ =                                 \
439 440
          ::phi::details::CanToString<__TYPE1__>::kValue &&             \
          ::phi::details::CanToString<__TYPE2__>::kValue;               \
441 442 443 444 445 446
      auto __message__ = ::paddle::string::Sprintf(                     \
          "%s\n  [Hint: Expected %s " #__CMP                            \
          " %s, but received %s " #__INV_CMP " %s.]",                   \
          __summary__.error_message(),                                  \
          #__VAL1,                                                      \
          #__VAL2,                                                      \
447
          ::phi::details::BinaryCompareMessageConverter<                \
448
              __kCanToString__>::Convert(#__VAL1, __val1),              \
449
          ::phi::details::BinaryCompareMessageConverter<                \
450 451
              __kCanToString__>::Convert(#__VAL2, __val2));             \
      __THROW_ERROR_INTERNAL__(                                         \
452
          phi::ErrorSummary(__summary__.code(), __message__));          \
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
    }                                                                   \
  } while (0)

#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \
  __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, ==, !=, __VA_ARGS__)
#define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) \
  __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, !=, ==, __VA_ARGS__)
#define PADDLE_ENFORCE_GT(__VAL0, __VAL1, ...) \
  __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >, <=, __VA_ARGS__)
#define PADDLE_ENFORCE_GE(__VAL0, __VAL1, ...) \
  __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, >=, <, __VA_ARGS__)
#define PADDLE_ENFORCE_LT(__VAL0, __VAL1, ...) \
  __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
  __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)

/** EXTENDED TOOL FUNCTIONS WITH CHECKING **/

471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523
/*
 * Summary: This macro is used to get Variable or internal type
 *   data (such as LoDTensor or SelectedRows) of the Input and
 *   Output in op, generally used when call scope.FindVar(Input/
 *   Output("Name")) or ctx.Input<LoDTensor>().
 *   Firstly this macro check whether the obtained pointer is null,
 *   and then return data if it is not null.
 *
 * Note: This macro is only suitable for specific scenarios and
 *   does not intended to be widely used. If it cannot meet the
 *   requirements, please use other PADDLE_ENFORCE** check macro.
 *
 * Parameters:
 *     __PTR: pointer
 *     __ROLE: (string), Input or Output
 *     __NAME: (string), Input or Output name
 *     __OP_TYPE: (string), the op type
 *
 * Return: The data pointed to by the pointer.
 *
 * Examples:
 *    GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X", "Mul");
 */
#define GET_DATA_SAFELY(__PTR, __ROLE, __NAME, __OP_TYPE)               \
  (([&]() -> std::add_lvalue_reference<decltype(*(__PTR))>::type {      \
    auto* __ptr = (__PTR);                                              \
    if (UNLIKELY(nullptr == __ptr)) {                                   \
      auto __summary__ = phi::errors::NotFound(                         \
          "Unable to get %s data of %s %s in operator %s. "             \
          "Possible reasons are:\n"                                     \
          "  1. The %s is not the %s of operator %s;\n"                 \
          "  2. The %s has no corresponding variable passed in;\n"      \
          "  3. The %s corresponding variable is not initialized.",     \
          phi::demangle(                                                \
              typeid(std::add_lvalue_reference<decltype(*__ptr)>::type) \
                  .name()),                                             \
          __ROLE,                                                       \
          __NAME,                                                       \
          __OP_TYPE,                                                    \
          __NAME,                                                       \
          __ROLE,                                                       \
          __OP_TYPE,                                                    \
          __NAME,                                                       \
          __NAME);                                                      \
      auto __message__ = ::paddle::string::Sprintf(                     \
          "%s\n  [Hint: pointer " #__PTR " should not be null.]",       \
          __summary__.error_message());                                 \
      __THROW_ERROR_INTERNAL__(                                         \
          phi::ErrorSummary(__summary__.code(), __message__));          \
    }                                                                   \
    return *__ptr;                                                      \
  })())

524
/*
525 526 527 528 529 530 531 532
 * Summary: This PADDLE_GET(_**) series macros are used to call paddle::get
 *   safely. paddle::get is not a completely safe api, although it will not
 *   go wrong in most cases, but in extreme cases, it may fail and directly
 *   throw a paddle::bad_variant_access const exception, without any stack
 *information.
 *   This kind of problems is difficult to debug, so add these macros to
 *   enrich paddle::get error information. At the same time, we restrict
 *   the direct use of paddle::get by CI rule.
533 534
 *
 * Parameters:
535
 *     __TYPE: the target variable type
536
 *     __VALUE: the target variable to get
537 538
 *
 * Examples:
539 540
 *     - unsafe writing: int x = paddle::get<int>(y);
 *     - safe writing: int x = PADDLE_GET(int, y);
541
 *
542 543 544
 * Note: GCC 4.8 cannot select right overloaded function here, so need
 *    to define different functions and macros here, after we upgrade
 *    CI gcc version, we can only define one PADDLE_GET macro.
545
 */
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593
namespace details {

#define DEFINE_SAFE_PADDLE_GET(                                              \
    __InputType, __OutputType, __OutputTypePtr, __FuncName)                  \
  template <typename OutputType, typename InputType>                         \
  auto __FuncName(                                                           \
      __InputType input, const char* expression, const char* file, int line) \
      ->typename std::conditional<std::is_pointer<InputType>::value,         \
                                  __OutputTypePtr,                           \
                                  __OutputType>::type {                      \
    try {                                                                    \
      return paddle::get<OutputType>(input);                                 \
    } catch (paddle::bad_variant_access const&) {                            \
      HANDLE_THE_ERROR                                                       \
      throw ::phi::enforce::EnforceNotMet(                                   \
          phi::errors::InvalidArgument(                                      \
              "paddle::get failed, cannot get value "                        \
              "(%s) by type %s, its type is %s.",                            \
              expression,                                                    \
              phi::enforce::demangle(typeid(OutputType).name()),             \
              phi::enforce::demangle(input.type().name())),                  \
          file,                                                              \
          line);                                                             \
      END_HANDLE_THE_ERROR                                                   \
    }                                                                        \
  }

DEFINE_SAFE_PADDLE_GET(InputType&, OutputType&, OutputType*, SafeBoostGet);
DEFINE_SAFE_PADDLE_GET(const InputType&,
                       const OutputType&,
                       const OutputType*,
                       SafeBoostGetConst);
DEFINE_SAFE_PADDLE_GET(InputType&&,
                       OutputType,
                       OutputType*,
                       SafeBoostGetMutable);

}  // namespace details

#define PADDLE_GET(__TYPE, __VALUE)            \
  phi::enforce::details::SafeBoostGet<__TYPE>( \
      __VALUE, #__VALUE, __FILE__, __LINE__)
#define PADDLE_GET_CONST(__TYPE, __VALUE)           \
  phi::enforce::details::SafeBoostGetConst<__TYPE>( \
      __VALUE, #__VALUE, __FILE__, __LINE__)
#define PADDLE_GET_MUTABLE(__TYPE, __VALUE)           \
  phi::enforce::details::SafeBoostGetMutable<__TYPE>( \
      __VALUE, #__VALUE, __FILE__, __LINE__)
594

595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122
/**************************************************************************/
/**************************** NVIDIA ERROR ********************************/
#ifdef PADDLE_WITH_CUDA

namespace details {

template <typename T>
struct ExternalApiType {};

#define DEFINE_EXTERNAL_API_TYPE(type, success_value, proto_type) \
  template <>                                                     \
  struct ExternalApiType<type> {                                  \
    using Type = type;                                            \
    static constexpr Type kSuccess = success_value;               \
    static constexpr const char* kTypeString = #proto_type;       \
    static constexpr phi::proto::ApiType kProtoType =             \
        phi::proto::ApiType::proto_type;                          \
  }

DEFINE_EXTERNAL_API_TYPE(cudaError_t, cudaSuccess, CUDA);
DEFINE_EXTERNAL_API_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS, CURAND);
DEFINE_EXTERNAL_API_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS, CUDNN);
DEFINE_EXTERNAL_API_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS, CUBLAS);
DEFINE_EXTERNAL_API_TYPE(cusparseStatus_t, CUSPARSE_STATUS_SUCCESS, CUSPARSE);
DEFINE_EXTERNAL_API_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS, CUSOLVER);
DEFINE_EXTERNAL_API_TYPE(cufftResult_t, CUFFT_SUCCESS, CUFFT);
DEFINE_EXTERNAL_API_TYPE(CUresult, CUDA_SUCCESS, CU);

#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess, NCCL);
#endif

}  // namespace details

template <typename T>
inline const char* GetErrorMsgUrl(T status) {
  using __CUDA_STATUS_TYPE__ = decltype(status);
  phi::proto::ApiType proto_type =
      details::ExternalApiType<__CUDA_STATUS_TYPE__>::kProtoType;
  switch (proto_type) {
    case phi::proto::ApiType::CUDA:
    case phi::proto::ApiType::CU:
      return "https://docs.nvidia.com/cuda/cuda-runtime-api/"
             "group__CUDART__TYPES.html#group__CUDART__TYPES_"
             "1g3f51e3575c2178246db0a94a430e0038";
      break;
    case phi::proto::ApiType::CURAND:
      return "https://docs.nvidia.com/cuda/curand/"
             "group__HOST.html#group__HOST_1gb94a31d5c165858c96b6c18b70644437";
      break;
    case phi::proto::ApiType::CUDNN:
      return "https://docs.nvidia.com/deeplearning/cudnn/api/"
             "index.html#cudnnStatus_t";
      break;
    case phi::proto::ApiType::CUBLAS:
      return "https://docs.nvidia.com/cuda/cublas/index.html#cublasstatus_t";
      break;
    case phi::proto::ApiType::CUSOLVER:
      return "https://docs.nvidia.com/cuda/cusolver/"
             "index.html#cuSolverSPstatus";
      break;
    case phi::proto::ApiType::NCCL:
      return "https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/"
             "types.html#ncclresult-t";
      break;
    case phi::proto::ApiType::CUFFT:
      return "https://docs.nvidia.com/cuda/cufft/index.html#cufftresult";
    case phi::proto::ApiType::CUSPARSE:
      return "https://docs.nvidia.com/cuda/cusparse/"
             "index.html#cusparseStatus_t";
      break;
    default:
      return "Unknown type of External API, can't get error message URL!";
      break;
  }
}

template <typename T>
inline std::string GetExternalErrorMsg(T status) {
  std::ostringstream sout;
  bool _initSucceed = false;
  phi::proto::ExternalErrorDesc externalError;
  if (externalError.ByteSizeLong() == 0) {
    std::string filePath;
#if !defined(_WIN32)
    Dl_info info;
    if (dladdr(reinterpret_cast<void*>(GetCurrentTraceBackString), &info)) {
      std::string strModule(info.dli_fname);
      const size_t last_slash_idx = strModule.find_last_of("/");
      std::string compare_path = strModule.substr(strModule.length() - 6);
      if (std::string::npos != last_slash_idx) {
        strModule.erase(last_slash_idx, std::string::npos);
      }
      if (compare_path.compare("avx.so") == 0) {
        filePath =
            strModule +
            "/../include/third_party/externalError/data/externalErrorMsg.pb";
      } else {
        filePath = strModule +
                   "/../../third_party/externalError/data/externalErrorMsg.pb";
      }
    }
#else
    char buf[512];
    MEMORY_BASIC_INFORMATION mbi;
    HMODULE h_module =
        (::VirtualQuery(GetCurrentTraceBackString, &mbi, sizeof(mbi)) != 0)
            ? (HMODULE)mbi.AllocationBase
            : NULL;
    GetModuleFileName(h_module, buf, 512);
    std::string strModule(buf);
    const size_t last_slash_idx = strModule.find_last_of("\\");
    std::string compare_path = strModule.substr(strModule.length() - 7);
    if (std::string::npos != last_slash_idx) {
      strModule.erase(last_slash_idx, std::string::npos);
    }
    if (compare_path.compare("avx.pyd") == 0) {
      filePath = strModule +
                 "\\..\\include\\third_"
                 "party\\externalerror\\data\\externalErrorMsg.pb";
    } else {
      filePath =
          strModule +
          "\\..\\..\\third_party\\externalerror\\data\\externalErrorMsg.pb";
    }
#endif
    std::ifstream fin(filePath, std::ios::in | std::ios::binary);
    _initSucceed = externalError.ParseFromIstream(&fin);
  }
  using __CUDA_STATUS_TYPE__ = decltype(status);
  phi::proto::ApiType proto_type =
      details::ExternalApiType<__CUDA_STATUS_TYPE__>::kProtoType;
  if (_initSucceed) {
    for (int i = 0; i < externalError.errors_size(); ++i) {
      if (proto_type == externalError.errors(i).type()) {
        for (int j = 0; j < externalError.errors(i).messages_size(); ++j) {
          if (status == externalError.errors(i).messages(j).code()) {
            sout << "\n  [Hint: "
                 << externalError.errors(i).messages(j).message() << "]";
            return sout.str();
          }
        }
      }
    }
  }

  sout << "\n  [Hint: Please search for the error code(" << status
       << ") on website (" << GetErrorMsgUrl(status)
       << ") to get Nvidia's official solution and advice about "
       << details::ExternalApiType<__CUDA_STATUS_TYPE__>::kTypeString
       << " Error.]";
  return sout.str();
}

template std::string GetExternalErrorMsg<cudaError_t>(cudaError_t);
template std::string GetExternalErrorMsg<curandStatus_t>(curandStatus_t);
template std::string GetExternalErrorMsg<cudnnStatus_t>(cudnnStatus_t);
template std::string GetExternalErrorMsg<cublasStatus_t>(cublasStatus_t);
template std::string GetExternalErrorMsg<cusparseStatus_t>(cusparseStatus_t);
template std::string GetExternalErrorMsg<cusolverStatus_t>(cusolverStatus_t);
template std::string GetExternalErrorMsg<cufftResult_t>(cufftResult_t);
template std::string GetExternalErrorMsg<CUresult>(CUresult);
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
template std::string GetExternalErrorMsg<ncclResult_t>(ncclResult_t);
#endif

/*************** CUDA ERROR ***************/
inline bool is_error(cudaError_t e) { return e != cudaSuccess; }

inline std::string build_nvidia_error_msg(cudaError_t e) {
  std::ostringstream sout;
  sout << "CUDA error(" << e << "), " << cudaGetErrorString(e) << ". "
       << GetExternalErrorMsg(e);
  return sout.str();
}

/*************** CURAND ERROR ***************/
inline bool is_error(curandStatus_t stat) {
  return stat != CURAND_STATUS_SUCCESS;
}

inline std::string build_nvidia_error_msg(curandStatus_t stat) {
  std::ostringstream sout;
  sout << "CURAND error(" << stat << "). " << GetExternalErrorMsg(stat);
  return sout.str();
}

/*************** CUDNN ERROR ***************/
inline bool is_error(cudnnStatus_t stat) {
  return stat != CUDNN_STATUS_SUCCESS;
}

inline std::string build_nvidia_error_msg(cudnnStatus_t stat) {
  std::ostringstream sout;
  sout << "CUDNN error(" << stat << "), "
       << phi::dynload::cudnnGetErrorString(stat) << ". "
       << GetExternalErrorMsg(stat);
  return sout.str();
}

/*************** CUBLAS ERROR ***************/
inline bool is_error(cublasStatus_t stat) {
  return stat != CUBLAS_STATUS_SUCCESS;
}

inline std::string build_nvidia_error_msg(cublasStatus_t stat) {
  std::ostringstream sout;
  sout << "CUBLAS error(" << stat << "). " << GetExternalErrorMsg(stat);
  return sout.str();
}

/*************** CUSPARSE ERROR ***************/
inline bool is_error(cusparseStatus_t stat) {
  return stat != CUSPARSE_STATUS_SUCCESS;
}

inline std::string build_nvidia_error_msg(cusparseStatus_t stat) {
  std::ostringstream sout;
  sout << "CUSparse error(" << stat << "). " << GetExternalErrorMsg(stat);
  return sout.str();
}

/*************** CUSOLVER ERROR ***************/
inline bool is_error(cusolverStatus_t stat) {
  return stat != CUSOLVER_STATUS_SUCCESS;
}

inline std::string build_nvidia_error_msg(cusolverStatus_t stat) {
  std::ostringstream sout;
  sout << "CUSOLVER error(" << stat << "). " << GetExternalErrorMsg(stat);
  return sout.str();
}

/*************** CUFFT ERROR ***************/
inline bool is_error(cufftResult_t stat) { return stat != CUFFT_SUCCESS; }

inline std::string build_nvidia_error_msg(cufftResult_t stat) {
  std::ostringstream sout;
  sout << "CUFFT error(" << stat << "). " << GetExternalErrorMsg(stat);
  return sout.str();
}

/*************** CUresult ERROR ***************/
inline bool is_error(CUresult stat) { return stat != CUDA_SUCCESS; }

inline std::string build_nvidia_error_msg(CUresult stat) {
  std::ostringstream sout;
  sout << "CU error(" << stat << "). " << GetExternalErrorMsg(stat);
  return sout.str();
}

/**************** NCCL ERROR ****************/
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
inline bool is_error(ncclResult_t nccl_result) {
  return nccl_result != ncclSuccess;
}

inline std::string build_nvidia_error_msg(ncclResult_t nccl_result) {
  std::ostringstream sout;
  sout << "NCCL error(" << nccl_result << "), "
       << phi::dynload::ncclGetErrorString(nccl_result) << ". ";
  if (errno == ENOSPC || errno == EAGAIN) {
    std::string detail(strerror(errno));
    detail += "\nPlease try one of the following solutions:";
    detail += "\n1. export NCCL_SHM_DISABLE=1;";
    detail += "\n2. export NCCL_P2P_LEVEL=SYS;";
    detail +=
        "\n3. Increase shared memory by setting the -shm-size "
        "option when starting docker container, e.g., setting "
        " -shm-size=2g.\n";
    sout << " Detail: " + detail;
  }
  sout << GetExternalErrorMsg(nccl_result);
  return sout.str();
}
#endif  // not(__APPLE__) and PADDLE_WITH_NCCL

#define PADDLE_ENFORCE_GPU_SUCCESS(COND)                     \
  do {                                                       \
    auto __cond__ = (COND);                                  \
    using __CUDA_STATUS_TYPE__ = decltype(__cond__);         \
    constexpr auto __success_type__ =                        \
        ::phi::enforce::details::ExternalApiType<            \
            __CUDA_STATUS_TYPE__>::kSuccess;                 \
    if (UNLIKELY(__cond__ != __success_type__)) {            \
      auto __summary__ = phi::errors::External(              \
          ::phi::enforce::build_nvidia_error_msg(__cond__)); \
      __THROW_ERROR_INTERNAL__(__summary__);                 \
    }                                                        \
  } while (0)

#define PADDLE_ENFORCE_CUDA_LAUNCH_SUCCESS(OP)                              \
  do {                                                                      \
    auto res = cudaGetLastError();                                          \
    if (UNLIKELY(res != cudaSuccess)) {                                     \
      auto msg = ::phi::enforce::build_nvidia_error_msg(res);               \
      PADDLE_THROW(                                                         \
          phi::errors::Fatal("CUDA error after kernel (%s): %s", OP, msg)); \
    }                                                                       \
  } while (0)

inline void retry_sleep(unsigned milliseconds) {
#ifdef _WIN32
  Sleep(milliseconds);
#else
  if (milliseconds < 1000) {
    // usleep argument must be less than 1,000,000. Reference:
    // https://pubs.opengroup.org/onlinepubs/7908799/xsh/usleep.html
    usleep(milliseconds * 1000);
  } else {
    // clip to sleep in seconds because we can not and don't have to
    // sleep for exact milliseconds
    sleep(milliseconds / 1000);
  }
#endif
}

#define PADDLE_RETRY_CUDA_SUCCESS(COND)                                 \
  do {                                                                  \
    auto __cond__ = (COND);                                             \
    int retry_count = 1;                                                \
    using __CUDA_STATUS_TYPE__ = decltype(__cond__);                    \
    constexpr auto __success_type__ =                                   \
        ::phi::enforce::details::ExternalApiType<                       \
            __CUDA_STATUS_TYPE__>::kSuccess;                            \
    while (UNLIKELY(__cond__ != __success_type__) && retry_count < 5) { \
      phi::enforce::retry_sleep(10000);                                 \
      __cond__ = (COND);                                                \
      ++retry_count;                                                    \
    }                                                                   \
    if (UNLIKELY(__cond__ != __success_type__)) {                       \
      auto __summary__ = phi::errors::External(                         \
          ::phi::enforce::build_nvidia_error_msg(__cond__));            \
      __THROW_ERROR_INTERNAL__(__summary__);                            \
    }                                                                   \
  } while (0)

#undef DEFINE_EXTERNAL_API_TYPE
#endif  // PADDLE_WITH_CUDA

/**************************************************************************/
/***************************** HIP ERROR **********************************/
#ifdef PADDLE_WITH_HIP

/***** HIP ERROR *****/
inline bool is_error(hipError_t e) { return e != hipSuccess; }

inline std::string build_rocm_error_msg(hipError_t e) {
  std::ostringstream sout;
  sout << " Hip error(" << e << "), " << hipGetErrorString(e) << ".";
  return sout.str();
}

/***** HIPRAND ERROR *****/
inline bool is_error(hiprandStatus_t stat) {
  return stat != HIPRAND_STATUS_SUCCESS;
}

inline const char* hiprandGetErrorString(hiprandStatus_t stat) {
  switch (stat) {
    case HIPRAND_STATUS_SUCCESS:
      return "HIPRAND_STATUS_SUCCESS";
    case HIPRAND_STATUS_VERSION_MISMATCH:
      return "HIPRAND_STATUS_VERSION_MISMATCH";
    case HIPRAND_STATUS_NOT_INITIALIZED:
      return "HIPRAND_STATUS_NOT_INITIALIZED";
    case HIPRAND_STATUS_ALLOCATION_FAILED:
      return "HIPRAND_STATUS_ALLOCATION_FAILED";
    case HIPRAND_STATUS_TYPE_ERROR:
      return "HIPRAND_STATUS_TYPE_ERROR";
    case HIPRAND_STATUS_OUT_OF_RANGE:
      return "HIPRAND_STATUS_OUT_OF_RANGE";
    case HIPRAND_STATUS_LENGTH_NOT_MULTIPLE:
      return "HIPRAND_STATUS_LENGTH_NOT_MULTIPLE";
    case HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED:
      return "HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED";
    case HIPRAND_STATUS_LAUNCH_FAILURE:
      return "HIPRAND_STATUS_LAUNCH_FAILURE";
    case HIPRAND_STATUS_PREEXISTING_FAILURE:
      return "HIPRAND_STATUS_PREEXISTING_FAILURE";
    case HIPRAND_STATUS_INITIALIZATION_FAILED:
      return "HIPRAND_STATUS_INITIALIZATION_FAILED";
    case HIPRAND_STATUS_ARCH_MISMATCH:
      return "HIPRAND_STATUS_ARCH_MISMATCH";
    case HIPRAND_STATUS_INTERNAL_ERROR:
      return "HIPRAND_STATUS_INTERNAL_ERROR";
    case HIPRAND_STATUS_NOT_IMPLEMENTED:
      return "HIPRAND_STATUS_NOT_IMPLEMENTED";
    default:
      return "Unknown hiprand status";
  }
}

inline std::string build_rocm_error_msg(hiprandStatus_t stat) {
  std::string msg(" Hiprand error, ");
  return msg + hiprandGetErrorString(stat) + " ";
}

/***** MIOPEN ERROR *****/
inline bool is_error(miopenStatus_t stat) {
  return stat != miopenStatusSuccess;
}

inline std::string build_rocm_error_msg(miopenStatus_t stat) {
  std::string msg(" Miopen error, ");
  return msg + phi::dynload::miopenGetErrorString(stat) + " ";
}

/***** ROCBLAS ERROR *****/
inline bool is_error(rocblas_status stat) {
  return stat != rocblas_status_success;
}

inline const char* rocblasGetErrorString(rocblas_status stat) {
  switch (stat) {
    case rocblas_status_invalid_handle:
      return "rocblas_status_invalid_handle";
    case rocblas_status_memory_error:
      return "rocblas_status_memory_error";
    case rocblas_status_invalid_value:
      return "rocblas_status_invalid_value";
    case rocblas_status_not_implemented:
      return "rocblas_status_not_implemented";
    case rocblas_status_invalid_pointer:
      return "rocblas_status_invalid_pointer";
    case rocblas_status_invalid_size:
      return "rocblas_status_invalid_size";
    case rocblas_status_internal_error:
      return "rocblas_status_internal_error";
    default:
      return "Unknown cublas status";
  }
}

inline std::string build_rocm_error_msg(rocblas_status stat) {
  std::string msg(" Rocblas error, ");
  return msg + rocblasGetErrorString(stat) + " ";
}

/****** RCCL ERROR ******/
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
inline bool is_error(ncclResult_t nccl_result) {
  return nccl_result != ncclSuccess;
}

inline std::string build_rocm_error_msg(ncclResult_t nccl_result) {
  std::string msg(" Rccl error, ");
  return msg + phi::dynload::ncclGetErrorString(nccl_result) + " ";
}
#endif  // not(__APPLE__) and PADDLE_WITH_NCCL

/***** HIPFFT ERROR *****/
inline bool is_error(hipfftResult_t stat) { return stat != HIPFFT_SUCCESS; }

inline std::string build_rocm_error_msg(hipfftResult_t stat) {
  std::string msg(" HIPFFT error, ");
  return msg + phi::dynload::hipfftGetErrorString(stat) + " ";
}

namespace details {

template <typename T>
struct ExternalApiType {};

#define DEFINE_EXTERNAL_API_TYPE(type, success_value) \
  template <>                                         \
  struct ExternalApiType<type> {                      \
    using Type = type;                                \
    static constexpr Type kSuccess = success_value;   \
  }

DEFINE_EXTERNAL_API_TYPE(hipError_t, hipSuccess);
DEFINE_EXTERNAL_API_TYPE(hiprandStatus_t, HIPRAND_STATUS_SUCCESS);
DEFINE_EXTERNAL_API_TYPE(miopenStatus_t, miopenStatusSuccess);
DEFINE_EXTERNAL_API_TYPE(rocblas_status, rocblas_status_success);
DEFINE_EXTERNAL_API_TYPE(hipfftResult_t, HIPFFT_SUCCESS);

#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess);
#endif

}  // namespace details

#define PADDLE_ENFORCE_GPU_SUCCESS(COND)                   \
  do {                                                     \
    auto __cond__ = (COND);                                \
    using __CUDA_STATUS_TYPE__ = decltype(__cond__);       \
    constexpr auto __success_type__ =                      \
        ::phi::enforce::details::ExternalApiType<          \
            __CUDA_STATUS_TYPE__>::kSuccess;               \
    if (UNLIKELY(__cond__ != __success_type__)) {          \
      auto __summary__ = phi::errors::External(            \
          ::phi::enforce::build_rocm_error_msg(__cond__)); \
      __THROW_ERROR_INTERNAL__(__summary__);               \
    }                                                      \
  } while (0)

inline void retry_sleep(unsigned millisecond) {
#ifdef _WIN32
  Sleep(millisecond);
#else
  sleep(millisecond);
#endif
}

#define PADDLE_RETRY_CUDA_SUCCESS(COND)                                 \
  do {                                                                  \
    auto __cond__ = (COND);                                             \
    int retry_count = 1;                                                \
    using __CUDA_STATUS_TYPE__ = decltype(__cond__);                    \
    constexpr auto __success_type__ =                                   \
        ::phi::enforce::details::ExternalApiType<                       \
            __CUDA_STATUS_TYPE__>::kSuccess;                            \
    while (UNLIKELY(__cond__ != __success_type__) && retry_count < 5) { \
      ::phi::enforce::retry_sleep(10000);                               \
      __cond__ = (COND);                                                \
      ++retry_count;                                                    \
    }                                                                   \
    if (UNLIKELY(__cond__ != __success_type__)) {                       \
      auto __summary__ = phi::errors::External(                         \
          ::phi::enforce::build_rocm_error_msg(__cond__));              \
      __THROW_ERROR_INTERNAL__(__summary__);                            \
    }                                                                   \
  } while (0)

#undef DEFINE_EXTERNAL_API_TYPE
#endif  // PADDLE_WITH_HIP

1123 1124
}  // namespace enforce
using namespace enforce;  // NOLINT
1125
}  // namespace phi