enforce.h 28.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18 19 20
#ifdef __GNUC__
#include <cxxabi.h>  // for __cxa_demangle
#endif               // __GNUC__

21
#if !defined(_WIN32)
22
#include <dlfcn.h>   // dladdr
23
#include <unistd.h>  // sleep, usleep
24
#else                // _WIN32
25 26 27
#ifndef NOMINMAX
#define NOMINMAX  // msvc max/min macro conflict with std::min/max
#endif
28
#include <windows.h>  // GetModuleFileName, Sleep
29 30
#endif

31 32 33
#ifdef PADDLE_WITH_CUDA
#include <cublas_v2.h>
#include <cudnn.h>
34
#include <cufft.h>
35
#include <curand.h>
Z
zhangkaihuo 已提交
36
#include <cusparse.h>
37 38
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
39
#include "paddle/fluid/platform/external_error.pb.h"
40 41
#endif  // PADDLE_WITH_CUDA

42 43 44 45 46
#ifdef PADDLE_WITH_HIP
#include <hiprand.h>
#include <miopen/miopen.h>
#include <rocblas.h>
#include <thrust/system/hip/error.h>
47
#include <thrust/system_error.h>  // NOLINT
48 49
#endif

50
#include <fstream>
Y
Yu Yang 已提交
51
#include <iomanip>
L
liaogang 已提交
52
#include <memory>
53 54 55
#include <sstream>
#include <stdexcept>
#include <string>
S
sneaxiy 已提交
56 57
#include <type_traits>
#include <utility>
58

chen.zhiyu's avatar
chen.zhiyu 已提交
59 60 61 62
#if !defined(_WIN32) && !defined(PADDLE_WITH_MUSL)
#include <execinfo.h>
#endif

63
#define GLOG_NO_ABBREVIATED_SEVERITIES  // msvc conflict logging with windows.h
64
#include "gflags/gflags.h"
65
#include "glog/logging.h"
66
#include "paddle/fluid/platform/errors.h"
Y
Yi Wang 已提交
67
#include "paddle/fluid/platform/macros.h"
68
#include "paddle/fluid/platform/variant.h"
69 70
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/to_string.h"
71
#include "paddle/phi/backends/dynload/port.h"
72

73
#ifdef PADDLE_WITH_CUDA
74 75 76 77
#include "paddle/phi/backends/dynload/cublas.h"
#include "paddle/phi/backends/dynload/cudnn.h"
#include "paddle/phi/backends/dynload/curand.h"
#include "paddle/phi/backends/dynload/cusolver.h"
78
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
L
lilong12 已提交
79
#include <error.h>
80
#include "paddle/phi/backends/dynload/nccl.h"
Y
Yi Wang 已提交
81 82
#endif  // __APPLE__
#endif  // PADDLE_WITH_CUDA
83

84
#ifdef PADDLE_WITH_HIP
85 86 87 88
#include "paddle/phi/backends/dynload/hipfft.h"
#include "paddle/phi/backends/dynload/hiprand.h"
#include "paddle/phi/backends/dynload/miopen.h"
#include "paddle/phi/backends/dynload/rocblas.h"
89 90
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
#include <error.h>  // NOLINT
91
#include "paddle/phi/backends/dynload/rccl.h"
92 93 94
#endif  // __APPLE__
#endif  // PADDLE_WITH_HIP

95 96 97
// Note: these headers for simplify demangle type string
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/type_defs.h"
98
#include "paddle/phi/core/enforce.h"
99 100
// Note: this header for simplify HIP and CUDA type string
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
101
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
102
#endif
Z
Zeng Jinle 已提交
103
#include "paddle/fluid/platform/flags.h"
W
wanghuancoder 已提交
104

105
namespace phi {
W
wanghuancoder 已提交
106
class ErrorSummary;
107
}  // namespace phi
108

109 110
DECLARE_int32(call_stack_level);

111 112
namespace paddle {
namespace platform {
113
using namespace ::phi::enforce;  // NOLINT
114

115 116
/** HELPER MACROS AND FUNCTIONS **/

Z
Zeng Jinle 已提交
117 118 119 120
#ifndef PADDLE_MAY_THROW
#define PADDLE_MAY_THROW noexcept(false)
#endif

121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
/*
 * Summary: This BOOST_GET(_**) series macros are used to call boost::get
 *   safely. boost::get is not a completely safe api, although it will not
 *   go wrong in most cases, but in extreme cases, it may fail and directly
 *   throw a boost::bad_get exception, without any stack information.
 *   This kind of problems is difficult to debug, so add these macros to
 *   enrich boost::get error information. At the same time, we restrict
 *   the direct use of boost::get by CI rule.
 *
 * Parameters:
 *     __TYPE: the target variable type
 *     __VALUE: the target variable to get
 *
 * Examples:
 *     - unsafe writing: int x = boost::get<int>(y);
 *     - safe writing: int x = BOOST_GET(int, y);
 *
 * Note: GCC 4.8 cannot select right overloaded function here, so need
 *    to define different functions and macros here, after we upgreade
 *    CI gcc version, we can only define one BOOST_GET macro.
141
 */
142 143
namespace details {

144
using namespace phi::enforce::details;  // NOLINT
145

146 147 148 149 150 151 152 153 154 155 156
#define DEFINE_SAFE_BOOST_GET(__InputType, __OutputType, __OutputTypePtr,      \
                              __FuncName)                                      \
  template <typename OutputType, typename InputType>                           \
  auto __FuncName(__InputType input, const char* expression, const char* file, \
                  int line)                                                    \
      ->typename std::conditional<std::is_pointer<InputType>::value,           \
                                  __OutputTypePtr, __OutputType>::type {       \
    try {                                                                      \
      return boost::get<OutputType>(input);                                    \
    } catch (boost::bad_get&) {                                                \
      HANDLE_THE_ERROR                                                         \
157 158
      throw ::phi::enforce::EnforceNotMet(                                     \
          phi::errors::InvalidArgument(                                        \
159 160
              "boost::get failed, cannot get value "                           \
              "(%s) by type %s, its type is %s.",                              \
161 162
              expression, phi::enforce::demangle(typeid(OutputType).name()),   \
              phi::enforce::demangle(input.type().name())),                    \
163 164 165 166 167 168 169 170 171 172 173 174 175
          file, line);                                                         \
      END_HANDLE_THE_ERROR                                                     \
    }                                                                          \
  }

DEFINE_SAFE_BOOST_GET(InputType&, OutputType&, OutputType*, SafeBoostGet);
DEFINE_SAFE_BOOST_GET(const InputType&, const OutputType&, const OutputType*,
                      SafeBoostGetConst);
DEFINE_SAFE_BOOST_GET(InputType&&, OutputType, OutputType*,
                      SafeBoostGetMutable);

}  // namespace details

176 177 178 179 180 181 182 183
#define BOOST_GET(__TYPE, __VALUE)                                             \
  paddle::platform::details::SafeBoostGet<__TYPE>(__VALUE, #__VALUE, __FILE__, \
                                                  __LINE__)
#define BOOST_GET_CONST(__TYPE, __VALUE)                                  \
  paddle::platform::details::SafeBoostGetConst<__TYPE>(__VALUE, #__VALUE, \
                                                       __FILE__, __LINE__)
#define BOOST_GET_MUTABLE(__TYPE, __VALUE)                                  \
  paddle::platform::details::SafeBoostGetMutable<__TYPE>(__VALUE, #__VALUE, \
184 185
                                                         __FILE__, __LINE__)

186 187
/** OTHER EXCEPTION AND ENFORCE **/

188 189
struct EOFException : public std::exception {
  std::string err_str_;
190
  EOFException(const char* err_msg, const char* file, int line) {
191
    err_str_ = paddle::string::Sprintf("%s at [%s:%d]", err_msg, file, line);
192 193
  }

194
  const char* what() const noexcept override { return err_str_.c_str(); }
195 196
};

197
#define PADDLE_THROW_EOF()                                                   \
198 199
  do {                                                                       \
    HANDLE_THE_ERROR                                                         \
200 201
    throw paddle::platform::EOFException("There is no next data.", __FILE__, \
                                         __LINE__);                          \
202
    END_HANDLE_THE_ERROR                                                     \
203
  } while (0)
M
minqiyang 已提交
204

205 206 207 208 209 210
#define PADDLE_THROW_BAD_ALLOC(...)                                      \
  do {                                                                   \
    HANDLE_THE_ERROR                                                     \
    throw ::paddle::memory::allocation::BadAlloc(                        \
        phi::ErrorSummary(__VA_ARGS__).to_string(), __FILE__, __LINE__); \
    END_HANDLE_THE_ERROR                                                 \
211 212
  } while (0)

213 214
/**************************************************************************/
/**************************** NVIDIA ERROR ********************************/
215
#ifdef PADDLE_WITH_CUDA
216

217
namespace details {
M
minqiyang 已提交
218

219 220 221 222 223 224 225 226 227 228 229
template <typename T>
struct ExternalApiType {};

#define DEFINE_EXTERNAL_API_TYPE(type, success_value, proto_type) \
  template <>                                                     \
  struct ExternalApiType<type> {                                  \
    using Type = type;                                            \
    static constexpr Type kSuccess = success_value;               \
    static constexpr const char* kTypeString = #proto_type;       \
    static constexpr platform::proto::ApiType kProtoType =        \
        platform::proto::ApiType::proto_type;                     \
230 231
  }

232 233 234 235
DEFINE_EXTERNAL_API_TYPE(cudaError_t, cudaSuccess, CUDA);
DEFINE_EXTERNAL_API_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS, CURAND);
DEFINE_EXTERNAL_API_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS, CUDNN);
DEFINE_EXTERNAL_API_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS, CUBLAS);
Z
zhangkaihuo 已提交
236
DEFINE_EXTERNAL_API_TYPE(cusparseStatus_t, CUSPARSE_STATUS_SUCCESS, CUSPARSE);
237
DEFINE_EXTERNAL_API_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS, CUSOLVER);
238
DEFINE_EXTERNAL_API_TYPE(cufftResult_t, CUFFT_SUCCESS, CUFFT);
239
DEFINE_EXTERNAL_API_TYPE(CUresult, CUDA_SUCCESS, CU);
240 241 242

#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess, NCCL);
243
#endif
244 245 246 247 248 249 250 251 252 253

}  // namespace details

template <typename T>
inline const char* GetErrorMsgUrl(T status) {
  using __CUDA_STATUS_TYPE__ = decltype(status);
  platform::proto::ApiType proto_type =
      details::ExternalApiType<__CUDA_STATUS_TYPE__>::kProtoType;
  switch (proto_type) {
    case platform::proto::ApiType::CUDA:
254
    case platform::proto::ApiType::CU:
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
      return "https://docs.nvidia.com/cuda/cuda-runtime-api/"
             "group__CUDART__TYPES.html#group__CUDART__TYPES_"
             "1g3f51e3575c2178246db0a94a430e0038";
      break;
    case platform::proto::ApiType::CURAND:
      return "https://docs.nvidia.com/cuda/curand/"
             "group__HOST.html#group__HOST_1gb94a31d5c165858c96b6c18b70644437";
      break;
    case platform::proto::ApiType::CUDNN:
      return "https://docs.nvidia.com/deeplearning/cudnn/api/"
             "index.html#cudnnStatus_t";
      break;
    case platform::proto::ApiType::CUBLAS:
      return "https://docs.nvidia.com/cuda/cublas/index.html#cublasstatus_t";
      break;
    case platform::proto::ApiType::CUSOLVER:
      return "https://docs.nvidia.com/cuda/cusolver/"
             "index.html#cuSolverSPstatus";
      break;
    case platform::proto::ApiType::NCCL:
      return "https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/"
             "types.html#ncclresult-t";
      break;
278 279
    case platform::proto::ApiType::CUFFT:
      return "https://docs.nvidia.com/cuda/cufft/index.html#cufftresult";
Z
zhangkaihuo 已提交
280 281 282 283
    case platform::proto::ApiType::CUSPARSE:
      return "https://docs.nvidia.com/cuda/cusparse/"
             "index.html#cusparseStatus_t";
      break;
284 285 286 287 288 289 290 291
    default:
      return "Unknown type of External API, can't get error message URL!";
      break;
  }
}

template <typename T>
inline std::string GetExternalErrorMsg(T status) {
292
  std::ostringstream sout;
293 294 295
  bool _initSucceed = false;
  platform::proto::ExternalErrorDesc externalError;
  if (externalError.ByteSizeLong() == 0) {
296 297 298
    std::string filePath;
#if !defined(_WIN32)
    Dl_info info;
299
    if (dladdr(reinterpret_cast<void*>(GetCurrentTraceBackString), &info)) {
300 301 302 303 304 305 306 307
      std::string strModule(info.dli_fname);
      const size_t last_slash_idx = strModule.find_last_of("/");
      std::string compare_path = strModule.substr(strModule.length() - 6);
      if (std::string::npos != last_slash_idx) {
        strModule.erase(last_slash_idx, std::string::npos);
      }
      if (compare_path.compare("avx.so") == 0) {
        filePath =
308 309 310 311 312
            strModule +
            "/../include/third_party/externalError/data/externalErrorMsg.pb";
      } else {
        filePath = strModule +
                   "/../../third_party/externalError/data/externalErrorMsg.pb";
313 314 315
      }
    }
#else
316
    char buf[512];
317 318
    MEMORY_BASIC_INFORMATION mbi;
    HMODULE h_module =
319
        (::VirtualQuery(GetCurrentTraceBackString, &mbi, sizeof(mbi)) != 0)
320 321
            ? (HMODULE)mbi.AllocationBase
            : NULL;
322
    GetModuleFileName(h_module, buf, 512);
323 324 325 326 327 328 329
    std::string strModule(buf);
    const size_t last_slash_idx = strModule.find_last_of("\\");
    std::string compare_path = strModule.substr(strModule.length() - 7);
    if (std::string::npos != last_slash_idx) {
      strModule.erase(last_slash_idx, std::string::npos);
    }
    if (compare_path.compare("avx.pyd") == 0) {
330 331 332
      filePath = strModule +
                 "\\..\\include\\third_"
                 "party\\externalerror\\data\\externalErrorMsg.pb";
333 334
    } else {
      filePath =
335 336
          strModule +
          "\\..\\..\\third_party\\externalerror\\data\\externalErrorMsg.pb";
337 338 339
    }
#endif
    std::ifstream fin(filePath, std::ios::in | std::ios::binary);
340
    _initSucceed = externalError.ParseFromIstream(&fin);
341
  }
342 343 344
  using __CUDA_STATUS_TYPE__ = decltype(status);
  platform::proto::ApiType proto_type =
      details::ExternalApiType<__CUDA_STATUS_TYPE__>::kProtoType;
345
  if (_initSucceed) {
346 347 348 349 350 351
    for (int i = 0; i < externalError.errors_size(); ++i) {
      if (proto_type == externalError.errors(i).type()) {
        for (int j = 0; j < externalError.errors(i).messages_size(); ++j) {
          if (status == externalError.errors(i).messages(j).code()) {
            sout << "\n  [Hint: "
                 << externalError.errors(i).messages(j).message() << "]";
352 353 354 355 356 357
            return sout.str();
          }
        }
      }
    }
  }
358 359 360 361 362 363

  sout << "\n  [Hint: Please search for the error code(" << status
       << ") on website (" << GetErrorMsgUrl(status)
       << ") to get Nvidia's official solution and advice about "
       << details::ExternalApiType<__CUDA_STATUS_TYPE__>::kTypeString
       << " Error.]";
364
  return sout.str();
365 366
}

367 368 369 370
template std::string GetExternalErrorMsg<cudaError_t>(cudaError_t);
template std::string GetExternalErrorMsg<curandStatus_t>(curandStatus_t);
template std::string GetExternalErrorMsg<cudnnStatus_t>(cudnnStatus_t);
template std::string GetExternalErrorMsg<cublasStatus_t>(cublasStatus_t);
Z
zhangkaihuo 已提交
371
template std::string GetExternalErrorMsg<cusparseStatus_t>(cusparseStatus_t);
372
template std::string GetExternalErrorMsg<cusolverStatus_t>(cusolverStatus_t);
373
template std::string GetExternalErrorMsg<cufftResult_t>(cufftResult_t);
374
template std::string GetExternalErrorMsg<CUresult>(CUresult);
375 376 377 378 379 380 381 382 383 384 385 386
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
template std::string GetExternalErrorMsg<ncclResult_t>(ncclResult_t);
#endif

/*************** CUDA ERROR ***************/
inline bool is_error(cudaError_t e) { return e != cudaSuccess; }

inline std::string build_nvidia_error_msg(cudaError_t e) {
  std::ostringstream sout;
  sout << "CUDA error(" << e << "), " << cudaGetErrorString(e) << ". "
       << GetExternalErrorMsg(e);
  return sout.str();
387 388
}

389 390 391
/*************** CURAND ERROR ***************/
inline bool is_error(curandStatus_t stat) {
  return stat != CURAND_STATUS_SUCCESS;
392 393 394
}

inline std::string build_nvidia_error_msg(curandStatus_t stat) {
395 396 397
  std::ostringstream sout;
  sout << "CURAND error(" << stat << "). " << GetExternalErrorMsg(stat);
  return sout.str();
398 399
}

400
/*************** CUDNN ERROR ***************/
M
minqiyang 已提交
401 402
inline bool is_error(cudnnStatus_t stat) {
  return stat != CUDNN_STATUS_SUCCESS;
403 404
}

405
inline std::string build_nvidia_error_msg(cudnnStatus_t stat) {
406 407
  std::ostringstream sout;
  sout << "CUDNN error(" << stat << "), "
408
       << phi::dynload::cudnnGetErrorString(stat) << ". "
409 410
       << GetExternalErrorMsg(stat);
  return sout.str();
411 412
}

413
/*************** CUBLAS ERROR ***************/
M
minqiyang 已提交
414 415
inline bool is_error(cublasStatus_t stat) {
  return stat != CUBLAS_STATUS_SUCCESS;
416 417
}

418
inline std::string build_nvidia_error_msg(cublasStatus_t stat) {
419 420 421
  std::ostringstream sout;
  sout << "CUBLAS error(" << stat << "). " << GetExternalErrorMsg(stat);
  return sout.str();
422 423
}

Z
zhangkaihuo 已提交
424 425 426 427 428 429 430 431 432 433 434
/*************** CUSPARSE ERROR ***************/
inline bool is_error(cusparseStatus_t stat) {
  return stat != CUSPARSE_STATUS_SUCCESS;
}

inline std::string build_nvidia_error_msg(cusparseStatus_t stat) {
  std::ostringstream sout;
  sout << "CUSparse error(" << stat << "). " << GetExternalErrorMsg(stat);
  return sout.str();
}

435
/*************** CUSOLVER ERROR ***************/
G
Guo Sheng 已提交
436 437 438 439 440
inline bool is_error(cusolverStatus_t stat) {
  return stat != CUSOLVER_STATUS_SUCCESS;
}

inline std::string build_nvidia_error_msg(cusolverStatus_t stat) {
441 442 443
  std::ostringstream sout;
  sout << "CUSOLVER error(" << stat << "). " << GetExternalErrorMsg(stat);
  return sout.str();
G
Guo Sheng 已提交
444 445
}

446 447 448 449 450 451 452 453 454
/*************** CUFFT ERROR ***************/
inline bool is_error(cufftResult_t stat) { return stat != CUFFT_SUCCESS; }

inline std::string build_nvidia_error_msg(cufftResult_t stat) {
  std::ostringstream sout;
  sout << "CUFFT error(" << stat << "). " << GetExternalErrorMsg(stat);
  return sout.str();
}

455 456 457 458 459 460 461 462 463
/*************** CUresult ERROR ***************/
inline bool is_error(CUresult stat) { return stat != CUDA_SUCCESS; }

inline std::string build_nvidia_error_msg(CUresult stat) {
  std::ostringstream sout;
  sout << "CU error(" << stat << "). " << GetExternalErrorMsg(stat);
  return sout.str();
}

464
/**************** NCCL ERROR ****************/
465
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
S
sneaxiy 已提交
466 467 468 469
inline bool is_error(ncclResult_t nccl_result) {
  return nccl_result != ncclSuccess;
}

470
inline std::string build_nvidia_error_msg(ncclResult_t nccl_result) {
471 472
  std::ostringstream sout;
  sout << "NCCL error(" << nccl_result << "), "
473
       << phi::dynload::ncclGetErrorString(nccl_result) << ". ";
L
lilong12 已提交
474 475 476 477 478 479 480 481 482
  if (errno == ENOSPC || errno == EAGAIN) {
    std::string detail(strerror(errno));
    detail += "\nPlease try one of the following solutions:";
    detail += "\n1. export NCCL_SHM_DISABLE=1;";
    detail += "\n2. export NCCL_P2P_LEVEL=SYS;";
    detail +=
        "\n3. Increase shared memory by setting the -shm-size "
        "option when starting docker container, e.g., setting "
        " -shm-size=2g.\n";
483
    sout << " Detail: " + detail;
L
lilong12 已提交
484
  }
485 486
  sout << GetExternalErrorMsg(nccl_result);
  return sout.str();
487
}
488
#endif  // not(__APPLE__) and PADDLE_WITH_NCCL
489

490
#define PADDLE_ENFORCE_GPU_SUCCESS(COND)                         \
491 492 493 494
  do {                                                           \
    auto __cond__ = (COND);                                      \
    using __CUDA_STATUS_TYPE__ = decltype(__cond__);             \
    constexpr auto __success_type__ =                            \
495
        ::paddle::platform::details::ExternalApiType<            \
496 497
            __CUDA_STATUS_TYPE__>::kSuccess;                     \
    if (UNLIKELY(__cond__ != __success_type__)) {                \
498
      auto __summary__ = phi::errors::External(                  \
499 500 501
          ::paddle::platform::build_nvidia_error_msg(__cond__)); \
      __THROW_ERROR_INTERNAL__(__summary__);                     \
    }                                                            \
502 503
  } while (0)

504 505 506 507 508 509 510 511 512 513
#define PADDLE_ENFORCE_CUDA_LAUNCH_SUCCESS(OP)                                 \
  do {                                                                         \
    auto res = cudaGetLastError();                                             \
    if (UNLIKELY(res != cudaSuccess)) {                                        \
      auto msg = ::paddle::platform::build_nvidia_error_msg(res);              \
      PADDLE_THROW(platform::errors::Fatal("CUDA error after kernel (%s): %s", \
                                           OP, msg));                          \
    }                                                                          \
  } while (0)

514
inline void retry_sleep(unsigned milliseconds) {
515
#ifdef _WIN32
516
  Sleep(milliseconds);
517
#else
518 519 520 521 522 523 524 525 526
  if (milliseconds < 1000) {
    // usleep argument must be less than 1,000,000. Reference:
    // https://pubs.opengroup.org/onlinepubs/7908799/xsh/usleep.html
    usleep(milliseconds * 1000);
  } else {
    // clip to sleep in seconds because we can not and don't have to
    // sleep for exact milliseconds
    sleep(milliseconds / 1000);
  }
527 528 529
#endif
}

530 531 532 533 534 535
#define PADDLE_RETRY_CUDA_SUCCESS(COND)                                 \
  do {                                                                  \
    auto __cond__ = (COND);                                             \
    int retry_count = 1;                                                \
    using __CUDA_STATUS_TYPE__ = decltype(__cond__);                    \
    constexpr auto __success_type__ =                                   \
536
        ::paddle::platform::details::ExternalApiType<                   \
537 538
            __CUDA_STATUS_TYPE__>::kSuccess;                            \
    while (UNLIKELY(__cond__ != __success_type__) && retry_count < 5) { \
539
      paddle::platform::retry_sleep(10000);                             \
540 541 542 543
      __cond__ = (COND);                                                \
      ++retry_count;                                                    \
    }                                                                   \
    if (UNLIKELY(__cond__ != __success_type__)) {                       \
544
      auto __summary__ = phi::errors::External(                         \
545 546 547 548 549
          ::paddle::platform::build_nvidia_error_msg(__cond__));        \
      __THROW_ERROR_INTERNAL__(__summary__);                            \
    }                                                                   \
  } while (0)

550
#undef DEFINE_EXTERNAL_API_TYPE
551
#endif  // PADDLE_WITH_CUDA
S
add EQ  
Superjom 已提交
552

553 554
/**************************************************************************/
/***************************** HIP ERROR **********************************/
555 556 557 558 559 560 561 562 563 564 565
#ifdef PADDLE_WITH_HIP

/***** HIP ERROR *****/
inline bool is_error(hipError_t e) { return e != hipSuccess; }

inline std::string build_rocm_error_msg(hipError_t e) {
  std::ostringstream sout;
  sout << " Hip error(" << e << "), " << hipGetErrorString(e) << ".";
  return sout.str();
}

566
/***** HIPRAND ERROR *****/
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617
inline bool is_error(hiprandStatus_t stat) {
  return stat != HIPRAND_STATUS_SUCCESS;
}

inline const char* hiprandGetErrorString(hiprandStatus_t stat) {
  switch (stat) {
    case HIPRAND_STATUS_SUCCESS:
      return "HIPRAND_STATUS_SUCCESS";
    case HIPRAND_STATUS_VERSION_MISMATCH:
      return "HIPRAND_STATUS_VERSION_MISMATCH";
    case HIPRAND_STATUS_NOT_INITIALIZED:
      return "HIPRAND_STATUS_NOT_INITIALIZED";
    case HIPRAND_STATUS_ALLOCATION_FAILED:
      return "HIPRAND_STATUS_ALLOCATION_FAILED";
    case HIPRAND_STATUS_TYPE_ERROR:
      return "HIPRAND_STATUS_TYPE_ERROR";
    case HIPRAND_STATUS_OUT_OF_RANGE:
      return "HIPRAND_STATUS_OUT_OF_RANGE";
    case HIPRAND_STATUS_LENGTH_NOT_MULTIPLE:
      return "HIPRAND_STATUS_LENGTH_NOT_MULTIPLE";
    case HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED:
      return "HIPRAND_STATUS_DOUBLE_PRECISION_REQUIRED";
    case HIPRAND_STATUS_LAUNCH_FAILURE:
      return "HIPRAND_STATUS_LAUNCH_FAILURE";
    case HIPRAND_STATUS_PREEXISTING_FAILURE:
      return "HIPRAND_STATUS_PREEXISTING_FAILURE";
    case HIPRAND_STATUS_INITIALIZATION_FAILED:
      return "HIPRAND_STATUS_INITIALIZATION_FAILED";
    case HIPRAND_STATUS_ARCH_MISMATCH:
      return "HIPRAND_STATUS_ARCH_MISMATCH";
    case HIPRAND_STATUS_INTERNAL_ERROR:
      return "HIPRAND_STATUS_INTERNAL_ERROR";
    case HIPRAND_STATUS_NOT_IMPLEMENTED:
      return "HIPRAND_STATUS_NOT_IMPLEMENTED";
    default:
      return "Unknown hiprand status";
  }
}

inline std::string build_rocm_error_msg(hiprandStatus_t stat) {
  std::string msg(" Hiprand error, ");
  return msg + hiprandGetErrorString(stat) + " ";
}

/***** MIOPEN ERROR *****/
inline bool is_error(miopenStatus_t stat) {
  return stat != miopenStatusSuccess;
}

inline std::string build_rocm_error_msg(miopenStatus_t stat) {
  std::string msg(" Miopen error, ");
618
  return msg + phi::dynload::miopenGetErrorString(stat) + " ";
619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659
}

/***** ROCBLAS ERROR *****/
inline bool is_error(rocblas_status stat) {
  return stat != rocblas_status_success;
}

inline const char* rocblasGetErrorString(rocblas_status stat) {
  switch (stat) {
    case rocblas_status_invalid_handle:
      return "rocblas_status_invalid_handle";
    case rocblas_status_memory_error:
      return "rocblas_status_memory_error";
    case rocblas_status_invalid_value:
      return "rocblas_status_invalid_value";
    case rocblas_status_not_implemented:
      return "rocblas_status_not_implemented";
    case rocblas_status_invalid_pointer:
      return "rocblas_status_invalid_pointer";
    case rocblas_status_invalid_size:
      return "rocblas_status_invalid_size";
    case rocblas_status_internal_error:
      return "rocblas_status_internal_error";
    default:
      return "Unknown cublas status";
  }
}

inline std::string build_rocm_error_msg(rocblas_status stat) {
  std::string msg(" Rocblas error, ");
  return msg + rocblasGetErrorString(stat) + " ";
}

/****** RCCL ERROR ******/
#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
inline bool is_error(ncclResult_t nccl_result) {
  return nccl_result != ncclSuccess;
}

inline std::string build_rocm_error_msg(ncclResult_t nccl_result) {
  std::string msg(" Rccl error, ");
660
  return msg + phi::dynload::ncclGetErrorString(nccl_result) + " ";
661 662 663
}
#endif  // not(__APPLE__) and PADDLE_WITH_NCCL

664 665 666 667 668
/***** HIPFFT ERROR *****/
inline bool is_error(hipfftResult_t stat) { return stat != HIPFFT_SUCCESS; }

inline std::string build_rocm_error_msg(hipfftResult_t stat) {
  std::string msg(" HIPFFT error, ");
669
  return msg + phi::dynload::hipfftGetErrorString(stat) + " ";
670 671
}

672 673 674
namespace details {

template <typename T>
675
struct ExternalApiType {};
676

677 678 679 680 681
#define DEFINE_EXTERNAL_API_TYPE(type, success_value) \
  template <>                                         \
  struct ExternalApiType<type> {                      \
    using Type = type;                                \
    static constexpr Type kSuccess = success_value;   \
682 683
  }

684 685 686 687
DEFINE_EXTERNAL_API_TYPE(hipError_t, hipSuccess);
DEFINE_EXTERNAL_API_TYPE(hiprandStatus_t, HIPRAND_STATUS_SUCCESS);
DEFINE_EXTERNAL_API_TYPE(miopenStatus_t, miopenStatusSuccess);
DEFINE_EXTERNAL_API_TYPE(rocblas_status, rocblas_status_success);
688
DEFINE_EXTERNAL_API_TYPE(hipfftResult_t, HIPFFT_SUCCESS);
689 690

#if !defined(__APPLE__) && defined(PADDLE_WITH_RCCL)
691
DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess);
692 693 694 695
#endif

}  // namespace details

696
#define PADDLE_ENFORCE_GPU_SUCCESS(COND)                       \
697 698 699 700
  do {                                                         \
    auto __cond__ = (COND);                                    \
    using __CUDA_STATUS_TYPE__ = decltype(__cond__);           \
    constexpr auto __success_type__ =                          \
701
        ::paddle::platform::details::ExternalApiType<          \
702 703
            __CUDA_STATUS_TYPE__>::kSuccess;                   \
    if (UNLIKELY(__cond__ != __success_type__)) {              \
704
      auto __summary__ = phi::errors::External(                \
705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723
          ::paddle::platform::build_rocm_error_msg(__cond__)); \
      __THROW_ERROR_INTERNAL__(__summary__);                   \
    }                                                          \
  } while (0)

inline void retry_sleep(unsigned millisecond) {
#ifdef _WIN32
  Sleep(millisecond);
#else
  sleep(millisecond);
#endif
}

#define PADDLE_RETRY_CUDA_SUCCESS(COND)                                 \
  do {                                                                  \
    auto __cond__ = (COND);                                             \
    int retry_count = 1;                                                \
    using __CUDA_STATUS_TYPE__ = decltype(__cond__);                    \
    constexpr auto __success_type__ =                                   \
724
        ::paddle::platform::details::ExternalApiType<                   \
725 726
            __CUDA_STATUS_TYPE__>::kSuccess;                            \
    while (UNLIKELY(__cond__ != __success_type__) && retry_count < 5) { \
727
      ::paddle::platform::retry_sleep(10000);                           \
728 729 730 731
      __cond__ = (COND);                                                \
      ++retry_count;                                                    \
    }                                                                   \
    if (UNLIKELY(__cond__ != __success_type__)) {                       \
732
      auto __summary__ = phi::errors::External(                         \
733 734 735 736 737
          ::paddle::platform::build_rocm_error_msg(__cond__));          \
      __THROW_ERROR_INTERNAL__(__summary__);                            \
    }                                                                   \
  } while (0)

738
#undef DEFINE_EXTERNAL_API_TYPE
739 740
#endif  // PADDLE_WITH_HIP

741 742
}  // namespace platform
}  // namespace paddle