error.h 2.3 KB
Newer Older
L
liaogang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
#pragma once

#include <sstream>
#include <stdexcept>
#include <string>

#ifndef PADDLE_ONLY_CPU

#include <cublas_v2.h>
#include <cudnn.h>
#include <curand.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>

#endif  // PADDLE_ONLY_CPU

namespace paddle {
namespace platform {

#ifndef PADDLE_ONLY_CPU

inline void throw_on_error(cudaError_t e, const char* message) {
  if (e) {
    throw thrust::system_error(e, thrust::cuda_category(), message);
  }
}

inline void throw_on_error(curandStatus_t stat, const char* message) {
  if (stat != CURAND_STATUS_SUCCESS) {
    throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
                               message);
  }
}

inline void throw_on_error(cudnnStatus_t stat, const char* message) {
  std::stringstream ss;
  if (stat == CUDNN_STATUS_SUCCESS) {
    return;
  } else {
    ss << cudnnGetErrorString(stat);
    ss << ", " << message;
    throw std::runtime_error(ss.str());
  }
}

inline void throw_on_error(cublasStatus_t stat, const char* message) {
  std::stringstream ss;
  if (stat == CUBLAS_STATUS_SUCCESS) {
    return;
  } else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
    ss << "CUBLAS: not initialized";
  } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) {
    ss << "CUBLAS: alloc failed";
  } else if (stat == CUBLAS_STATUS_INVALID_VALUE) {
    ss << "CUBLAS: invalid value";
  } else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) {
    ss << "CUBLAS: arch mismatch";
  } else if (stat == CUBLAS_STATUS_MAPPING_ERROR) {
    ss << "CUBLAS: mapping error";
  } else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) {
    ss << "CUBLAS: execution failed";
  } else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) {
    ss << "CUBLAS: internal error";
  } else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) {
    ss << "CUBLAS: not supported";
  } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) {
    ss << "CUBLAS: license error";
  }
  ss << ", " << message;
  throw std::runtime_error(ss.str());
}

inline void throw_on_error(cublasStatus_t stat) {
  const char* message = "";
  throw_on_error(stat, message);
}

#endif  // PADDLE_ONLY_CPU

inline void throw_on_error(int stat, const char* message) {
  if (stat) {
    throw std::runtime_error(message + (", stat = " + std::to_string(stat)));
  }
}

}  // namespace platform
}  // namespace paddle