提交 aa41ee75 编写于 作者: M minqiyang

Accelerate PADDLE_ENFORCE

上级 728e7e88
......@@ -49,6 +49,8 @@ constexpr char kTempVarName[] = "@TEMP@";
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
constexpr char kGradVarSuffix[] = "@GRAD";
constexpr size_t kGradVarSuffixSize = 5U;
/// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO";
......@@ -60,7 +62,11 @@ constexpr char kNewGradSuffix[] = "@NEWGRAD@";
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix;
std::string result;
result.reserve(var_name.size() + kGradVarSuffixSize);
result += var_name;
result += kGradVarSuffix;
return result;
}
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
......@@ -101,8 +107,8 @@ class OperatorBase {
bool HasAttr(const std::string& name) const { return attrs_.count(name); }
template <typename T>
inline const T& Attr(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
PADDLE_ENFORCE(attrs_.find(name) != attrs_.end(),
"%s should be in AttributeMap", name);
return boost::get<T>(attrs_.at(name));
}
const AttributeMap& Attrs() const { return attrs_; }
......
......@@ -140,68 +140,72 @@ struct EOFException : public std::exception {
#define LIKELY(condition) (condition)
#endif
inline bool is_error(bool stat) { return !stat; }
template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
bool stat, const Args&... args) {
if (UNLIKELY(!(stat))) {
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(string::Sprintf(args...));
#else
LOG(FATAL) << string::Sprintf(args...);
#endif
}
}
#ifdef PADDLE_WITH_CUDA
inline bool is_error(cudaError_t e) { return UNLIKELY(e); }
template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudaError_t e, const Args&... args) {
if (UNLIKELY(e)) {
#ifndef REPLACE_ENFORCE_GLOG
throw thrust::system_error(e, thrust::cuda_category(),
string::Sprintf(args...));
#else
LOG(FATAL) << string::Sprintf(args...);
#endif
}
}
inline bool is_error(curandStatus_t stat) {
return stat != CURAND_STATUS_SUCCESS;
}
template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
curandStatus_t stat, const Args&... args) {
if (stat != CURAND_STATUS_SUCCESS) {
#ifndef REPLACE_ENFORCE_GLOG
throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
string::Sprintf(args...));
#else
LOG(FATAL) << string::Sprintf(args...);
#endif
}
}
inline bool is_error(cudnnStatus_t stat) {
return stat != CUDNN_STATUS_SUCCESS;
}
template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudnnStatus_t stat, const Args&... args) {
if (stat == CUDNN_STATUS_SUCCESS) {
return;
} else {
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) +
string::Sprintf(args...));
#else
LOG(FATAL) << string::Sprintf(args...);
#endif
}
}
inline bool is_error(cublasStatus_t stat) {
return stat != CUBLAS_STATUS_SUCCESS;
}
template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cublasStatus_t stat, const Args&... args) {
std::string err;
if (stat == CUBLAS_STATUS_SUCCESS) {
return;
} else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
err = "CUBLAS: not initialized, ";
} else if (stat == CUBLAS_STATUS_ALLOC_FAILED) {
err = "CUBLAS: alloc failed, ";
......@@ -254,11 +258,21 @@ inline void throw_on_error(T e) {
#define PADDLE_THROW(...) \
throw ::paddle::platform::EnforceNotMet(__FILE__, __LINE__, __VA_ARGS__)
#define PADDLE_JUDGE
#define __PADDLE_UNARY_COMPARE(COND, ...) \
do { \
auto cond = COND; \
if (UNLIKELY(::paddle::platform::is_error(cond))) { \
::paddle::platform::throw_on_error(cond, ##__VA_ARGS__); \
} \
} while (0)
#ifndef REPLACE_ENFORCE_GLOG
#define PADDLE_ENFORCE(...) \
#define PADDLE_ENFORCE(COND, ...) \
do { \
try { \
::paddle::platform::throw_on_error(__VA_ARGS__); \
__PADDLE_UNARY_COMPARE(COND, ##__VA_ARGS__); \
} catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
......@@ -266,7 +280,7 @@ inline void throw_on_error(T e) {
} while (false)
#else
#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__);
#define PADDLE_ENFORCE(COND, ...) __PADDLE_UNARY_COMPARE(COND, ##__VA_ARGS__);
#endif // REPLACE_ENFORCE_GLOG
#define PADDLE_THROW_EOF() \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册