提交 aa41ee75 编写于 作者: M minqiyang

Accelerate PADDLE_ENFORCE

上级 728e7e88
...@@ -49,6 +49,8 @@ constexpr char kTempVarName[] = "@TEMP@"; ...@@ -49,6 +49,8 @@ constexpr char kTempVarName[] = "@TEMP@";
/// e.g. Variable "x@GRAD" is the gradient of varibale "x". /// e.g. Variable "x@GRAD" is the gradient of varibale "x".
constexpr char kGradVarSuffix[] = "@GRAD"; constexpr char kGradVarSuffix[] = "@GRAD";
constexpr size_t kGradVarSuffixSize = 5U;
/// Variables with this suffix are supposed to be filled up with zeros. /// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO"; constexpr char kZeroVarSuffix[] = "@ZERO";
...@@ -60,7 +62,11 @@ constexpr char kNewGradSuffix[] = "@NEWGRAD@"; ...@@ -60,7 +62,11 @@ constexpr char kNewGradSuffix[] = "@NEWGRAD@";
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority; extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
inline std::string GradVarName(const std::string& var_name) { inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix; std::string result;
result.reserve(var_name.size() + kGradVarSuffixSize);
result += var_name;
result += kGradVarSuffix;
return result;
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
...@@ -101,8 +107,8 @@ class OperatorBase { ...@@ -101,8 +107,8 @@ class OperatorBase {
bool HasAttr(const std::string& name) const { return attrs_.count(name); } bool HasAttr(const std::string& name) const { return attrs_.count(name); }
template <typename T> template <typename T>
inline const T& Attr(const std::string& name) const { inline const T& Attr(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", PADDLE_ENFORCE(attrs_.find(name) != attrs_.end(),
name); "%s should be in AttributeMap", name);
return boost::get<T>(attrs_.at(name)); return boost::get<T>(attrs_.at(name));
} }
const AttributeMap& Attrs() const { return attrs_; } const AttributeMap& Attrs() const { return attrs_; }
......
...@@ -140,68 +140,72 @@ struct EOFException : public std::exception { ...@@ -140,68 +140,72 @@ struct EOFException : public std::exception {
#define LIKELY(condition) (condition) #define LIKELY(condition) (condition)
#endif #endif
inline bool is_error(bool stat) { return !stat; }
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
bool stat, const Args&... args) { bool stat, const Args&... args) {
if (UNLIKELY(!(stat))) {
#ifndef REPLACE_ENFORCE_GLOG #ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(string::Sprintf(args...)); throw std::runtime_error(string::Sprintf(args...));
#else #else
LOG(FATAL) << string::Sprintf(args...); LOG(FATAL) << string::Sprintf(args...);
#endif #endif
}
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
inline bool is_error(cudaError_t e) { return UNLIKELY(e); }
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudaError_t e, const Args&... args) { cudaError_t e, const Args&... args) {
if (UNLIKELY(e)) {
#ifndef REPLACE_ENFORCE_GLOG #ifndef REPLACE_ENFORCE_GLOG
throw thrust::system_error(e, thrust::cuda_category(), throw thrust::system_error(e, thrust::cuda_category(),
string::Sprintf(args...)); string::Sprintf(args...));
#else #else
LOG(FATAL) << string::Sprintf(args...); LOG(FATAL) << string::Sprintf(args...);
#endif #endif
} }
inline bool is_error(curandStatus_t stat) {
return stat != CURAND_STATUS_SUCCESS;
} }
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
curandStatus_t stat, const Args&... args) { curandStatus_t stat, const Args&... args) {
if (stat != CURAND_STATUS_SUCCESS) {
#ifndef REPLACE_ENFORCE_GLOG #ifndef REPLACE_ENFORCE_GLOG
throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(), throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
string::Sprintf(args...)); string::Sprintf(args...));
#else #else
LOG(FATAL) << string::Sprintf(args...); LOG(FATAL) << string::Sprintf(args...);
#endif #endif
} }
inline bool is_error(cudnnStatus_t stat) {
return stat != CUDNN_STATUS_SUCCESS;
} }
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudnnStatus_t stat, const Args&... args) { cudnnStatus_t stat, const Args&... args) {
if (stat == CUDNN_STATUS_SUCCESS) {
return;
} else {
#ifndef REPLACE_ENFORCE_GLOG #ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) + throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) +
string::Sprintf(args...)); string::Sprintf(args...));
#else #else
LOG(FATAL) << string::Sprintf(args...); LOG(FATAL) << string::Sprintf(args...);
#endif #endif
} }
inline bool is_error(cublasStatus_t stat) {
return stat != CUBLAS_STATUS_SUCCESS;
} }
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cublasStatus_t stat, const Args&... args) { cublasStatus_t stat, const Args&... args) {
std::string err; std::string err;
if (stat == CUBLAS_STATUS_SUCCESS) { if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
return;
} else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
err = "CUBLAS: not initialized, "; err = "CUBLAS: not initialized, ";
} else if (stat == CUBLAS_STATUS_ALLOC_FAILED) { } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) {
err = "CUBLAS: alloc failed, "; err = "CUBLAS: alloc failed, ";
...@@ -254,11 +258,21 @@ inline void throw_on_error(T e) { ...@@ -254,11 +258,21 @@ inline void throw_on_error(T e) {
#define PADDLE_THROW(...) \ #define PADDLE_THROW(...) \
throw ::paddle::platform::EnforceNotMet(__FILE__, __LINE__, __VA_ARGS__) throw ::paddle::platform::EnforceNotMet(__FILE__, __LINE__, __VA_ARGS__)
#define PADDLE_JUDGE
#define __PADDLE_UNARY_COMPARE(COND, ...) \
do { \
auto cond = COND; \
if (UNLIKELY(::paddle::platform::is_error(cond))) { \
::paddle::platform::throw_on_error(cond, ##__VA_ARGS__); \
} \
} while (0)
#ifndef REPLACE_ENFORCE_GLOG #ifndef REPLACE_ENFORCE_GLOG
#define PADDLE_ENFORCE(...) \ #define PADDLE_ENFORCE(COND, ...) \
do { \ do { \
try { \ try { \
::paddle::platform::throw_on_error(__VA_ARGS__); \ __PADDLE_UNARY_COMPARE(COND, ##__VA_ARGS__); \
} catch (...) { \ } catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \ __FILE__, __LINE__); \
...@@ -266,7 +280,7 @@ inline void throw_on_error(T e) { ...@@ -266,7 +280,7 @@ inline void throw_on_error(T e) {
} while (false) } while (false)
#else #else
#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__); #define PADDLE_ENFORCE(COND, ...) __PADDLE_UNARY_COMPARE(COND, ##__VA_ARGS__);
#endif // REPLACE_ENFORCE_GLOG #endif // REPLACE_ENFORCE_GLOG
#define PADDLE_THROW_EOF() \ #define PADDLE_THROW_EOF() \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册