提交 f403f69a 编写于 作者: M minqiyang

Accelerate PADDLE_ENFORCE

test=release/1.2
上级 847cbdce
...@@ -120,6 +120,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run( ...@@ -120,6 +120,7 @@ FeedFetchList FastThreadedSSAGraphExecutor::Run(
ClearFetchOp(graph_.get(), &fetch_ops); ClearFetchOp(graph_.get(), &fetch_ops);
return fetches; return fetches;
} }
void FastThreadedSSAGraphExecutor::RunOpAsync( void FastThreadedSSAGraphExecutor::RunOpAsync(
std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps, std::unordered_map<OpHandleBase *, std::atomic<int>> *op_deps,
OpHandleBase *op, OpHandleBase *op,
......
...@@ -163,11 +163,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { ...@@ -163,11 +163,7 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
} }
bool OperatorBase::HasInputs(const std::string& name) const { bool OperatorBase::HasInputs(const std::string& name) const {
if (inputs_.find(name) != inputs_.end()) { return inputs_.find(name) != inputs_.end();
return true;
} else {
return false;
}
} }
std::string OperatorBase::Input(const std::string& name) const { std::string OperatorBase::Input(const std::string& name) const {
......
...@@ -49,6 +49,8 @@ constexpr char kTempVarName[] = "@TEMP@"; ...@@ -49,6 +49,8 @@ constexpr char kTempVarName[] = "@TEMP@";
/// e.g. Variable "x@GRAD" is the gradient of varibale "x". /// e.g. Variable "x@GRAD" is the gradient of varibale "x".
constexpr char kGradVarSuffix[] = "@GRAD"; constexpr char kGradVarSuffix[] = "@GRAD";
constexpr size_t kGradVarSuffixSize = 5U;
/// Variables with this suffix are supposed to be filled up with zeros. /// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO"; constexpr char kZeroVarSuffix[] = "@ZERO";
...@@ -60,7 +62,11 @@ constexpr char kNewGradSuffix[] = "@NEWGRAD@"; ...@@ -60,7 +62,11 @@ constexpr char kNewGradSuffix[] = "@NEWGRAD@";
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority; extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
inline std::string GradVarName(const std::string& var_name) { inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix; std::string result;
result.reserve(var_name.size() + kGradVarSuffixSize);
result += var_name;
result += kGradVarSuffix;
return result;
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
...@@ -101,8 +107,8 @@ class OperatorBase { ...@@ -101,8 +107,8 @@ class OperatorBase {
bool HasAttr(const std::string& name) const { return attrs_.count(name); } bool HasAttr(const std::string& name) const { return attrs_.count(name); }
template <typename T> template <typename T>
inline const T& Attr(const std::string& name) const { inline const T& Attr(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", PADDLE_ENFORCE(attrs_.find(name) != attrs_.end(),
name); "%s should be in AttributeMap", name);
return boost::get<T>(attrs_.at(name)); return boost::get<T>(attrs_.at(name));
} }
const AttributeMap& Attrs() const { return attrs_; } const AttributeMap& Attrs() const { return attrs_; }
......
...@@ -69,17 +69,17 @@ void TestWord2vecPrediction(const std::string& model_path) { ...@@ -69,17 +69,17 @@ void TestWord2vecPrediction(const std::string& model_path) {
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
CHECK(predictor->Run(slots, &outputs)); CHECK(predictor->Run(slots, &outputs));
PADDLE_ENFORCE(outputs.size(), 1UL); PADDLE_ENFORCE_EQ(outputs.size(), 1UL);
// Check the output buffer size and result of each tid. // Check the output buffer size and result of each tid.
PADDLE_ENFORCE(outputs.front().data.length(), 33168UL); PADDLE_ENFORCE_EQ(outputs.front().data.length(), 33168UL);
float result[5] = {0.00129761, 0.00151112, 0.000423564, 0.00108815, float result[5] = {0.00129761, 0.00151112, 0.000423564, 0.00108815,
0.000932706}; 0.000932706};
const size_t num_elements = outputs.front().data.length() / sizeof(float); const size_t num_elements = outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory. // The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(static_cast<size_t>(5UL), num_elements); for (size_t i = 0; i < std::min(static_cast<size_t>(5UL), num_elements);
i++) { i++) {
LOG(INFO) << "data: " LOG(INFO) << "data: " << static_cast<float*>(outputs.front().data.data())[i]
<< static_cast<float*>(outputs.front().data.data())[i]; << " result: " << result[i];
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i], PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
result[i]); result[i]);
} }
......
...@@ -25,7 +25,7 @@ namespace detail { ...@@ -25,7 +25,7 @@ namespace detail {
*/ */
template <typename T, typename... ARGS> template <typename T, typename... ARGS>
inline T& Ref(T* ptr, ARGS&&... args) { inline T& Ref(T* ptr, ARGS&&... args) {
PADDLE_ENFORCE(ptr != nullptr, args...); PADDLE_ENFORCE(ptr != nullptr, ::paddle::string::Sprintf(args...));
return *ptr; return *ptr;
} }
......
...@@ -84,7 +84,9 @@ class ProtoEncodeHelper { ...@@ -84,7 +84,9 @@ class ProtoEncodeHelper {
~ProtoEncodeHelper() { ~ProtoEncodeHelper() {
#define REPLACE_ENFORCE_GLOG 1 #define REPLACE_ENFORCE_GLOG 1
// Make sure callers didn't do operations that went over max_size promised // Make sure callers didn't do operations that went over max_size promised
if (paddle::platform::is_error(p_ <= limit_)) {
paddle::platform::throw_on_error(p_ <= limit_); paddle::platform::throw_on_error(p_ <= limit_);
}
#undef REPLACE_ENFORCE_GLOG #undef REPLACE_ENFORCE_GLOG
} }
......
...@@ -50,8 +50,8 @@ template <typename T> ...@@ -50,8 +50,8 @@ template <typename T>
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(std::is_same<T, float>::value, const bool is_float_type = std::is_same<T, float>::value;
"MKLDNN LRN must use float data."); PADDLE_ENFORCE(is_float_type, "MKLDNN LRN must use float data.");
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"MKLDNN LRN must use CPUPlace."); "MKLDNN LRN must use CPUPlace.");
...@@ -132,8 +132,8 @@ template <typename T> ...@@ -132,8 +132,8 @@ template <typename T>
class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(std::is_same<T, float>::value, const bool is_float_type = std::is_same<T, float>::value;
"MKLDNN LRN must use float data."); PADDLE_ENFORCE(is_float_type, "MKLDNN LRN must use float data.");
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"MKLDNN LRN must use CPUPlace."); "MKLDNN LRN must use CPUPlace.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
......
...@@ -131,68 +131,72 @@ struct EOFException : public std::exception { ...@@ -131,68 +131,72 @@ struct EOFException : public std::exception {
#define LIKELY(condition) (condition) #define LIKELY(condition) (condition)
#endif #endif
inline bool is_error(bool stat) { return !stat; }
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
bool stat, const Args&... args) { bool stat, const Args&... args) {
if (UNLIKELY(!(stat))) {
#ifndef REPLACE_ENFORCE_GLOG #ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(string::Sprintf(args...)); throw std::runtime_error(string::Sprintf(args...));
#else #else
LOG(FATAL) << string::Sprintf(args...); LOG(FATAL) << string::Sprintf(args...);
#endif #endif
}
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
inline bool is_error(cudaError_t e) { return UNLIKELY(e); }
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudaError_t e, const Args&... args) { cudaError_t e, const Args&... args) {
if (UNLIKELY(e)) {
#ifndef REPLACE_ENFORCE_GLOG #ifndef REPLACE_ENFORCE_GLOG
throw thrust::system_error(e, thrust::cuda_category(), throw thrust::system_error(e, thrust::cuda_category(),
string::Sprintf(args...)); string::Sprintf(args...));
#else #else
LOG(FATAL) << string::Sprintf(args...); LOG(FATAL) << string::Sprintf(args...);
#endif #endif
} }
inline bool is_error(curandStatus_t stat) {
return stat != CURAND_STATUS_SUCCESS;
} }
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
curandStatus_t stat, const Args&... args) { curandStatus_t stat, const Args&... args) {
if (stat != CURAND_STATUS_SUCCESS) {
#ifndef REPLACE_ENFORCE_GLOG #ifndef REPLACE_ENFORCE_GLOG
throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(), throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(),
string::Sprintf(args...)); string::Sprintf(args...));
#else #else
LOG(FATAL) << string::Sprintf(args...); LOG(FATAL) << string::Sprintf(args...);
#endif #endif
} }
inline bool is_error(cudnnStatus_t stat) {
return stat != CUDNN_STATUS_SUCCESS;
} }
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudnnStatus_t stat, const Args&... args) { cudnnStatus_t stat, const Args&... args) {
if (stat == CUDNN_STATUS_SUCCESS) {
return;
} else {
#ifndef REPLACE_ENFORCE_GLOG #ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) + throw std::runtime_error(platform::dynload::cudnnGetErrorString(stat) +
string::Sprintf(args...)); string::Sprintf(args...));
#else #else
LOG(FATAL) << string::Sprintf(args...); LOG(FATAL) << string::Sprintf(args...);
#endif #endif
} }
inline bool is_error(cublasStatus_t stat) {
return stat != CUBLAS_STATUS_SUCCESS;
} }
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cublasStatus_t stat, const Args&... args) { cublasStatus_t stat, const Args&... args) {
std::string err; std::string err;
if (stat == CUBLAS_STATUS_SUCCESS) { if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
return;
} else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) {
err = "CUBLAS: not initialized, "; err = "CUBLAS: not initialized, ";
} else if (stat == CUBLAS_STATUS_ALLOC_FAILED) { } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) {
err = "CUBLAS: alloc failed, "; err = "CUBLAS: alloc failed, ";
...@@ -219,12 +223,11 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -219,12 +223,11 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
} }
#if !defined(__APPLE__) && !defined(_WIN32) #if !defined(__APPLE__) && !defined(_WIN32)
inline bool is_error(ncclResult_t stat) { return stat != ncclSuccess; }
template <typename... Args> template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
ncclResult_t stat, const Args&... args) { ncclResult_t stat, const Args&... args) {
if (stat == ncclSuccess) {
return;
} else {
#ifndef REPLACE_ENFORCE_GLOG #ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(platform::dynload::ncclGetErrorString(stat) + throw std::runtime_error(platform::dynload::ncclGetErrorString(stat) +
string::Sprintf(args...)); string::Sprintf(args...));
...@@ -232,7 +235,6 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error( ...@@ -232,7 +235,6 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
LOG(FATAL) << platform::dynload::ncclGetErrorString(stat) LOG(FATAL) << platform::dynload::ncclGetErrorString(stat)
<< string::Sprintf(args...); << string::Sprintf(args...);
#endif #endif
}
} }
#endif // __APPLE__ and windows #endif // __APPLE__ and windows
#endif // PADDLE_WITH_CUDA #endif // PADDLE_WITH_CUDA
...@@ -250,21 +252,49 @@ inline void throw_on_error(T e) { ...@@ -250,21 +252,49 @@ inline void throw_on_error(T e) {
__FILE__, __LINE__); \ __FILE__, __LINE__); \
} while (false) } while (false)
#define __PADDLE_THROW_ERROR_I(_, _9, _8, _7, _6, _5, _4, _3, _2, X_, ...) X_;
#define __THROW_ON_ERROR_ONE_ARG(COND, ARG) \
::paddle::platform::throw_on_error(COND, ::paddle::string::Sprintf(ARG));
#define __PADDLE_THROW_ON_ERROR(COND, ...) \
__PADDLE_THROW_ERROR_I( \
__VA_ARGS__, ::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
::paddle::platform::throw_on_error(COND, __VA_ARGS__), \
__THROW_ON_ERROR_ONE_ARG(COND, __VA_ARGS__))
#define __PADDLE_UNARY_COMPARE(COND, ...) \
do { \
auto __cond = COND; \
if (UNLIKELY(::paddle::platform::is_error(__cond))) { \
__PADDLE_THROW_ON_ERROR(__cond, __VA_ARGS__); \
} \
} while (0)
#ifndef REPLACE_ENFORCE_GLOG #ifndef REPLACE_ENFORCE_GLOG
#define PADDLE_ENFORCE(...) \ #define __PADDLE_ENFORCE_I(COND, ...) \
do { \ do { \
try { \ try { \
::paddle::platform::throw_on_error(__VA_ARGS__); \ __PADDLE_UNARY_COMPARE(COND, __VA_ARGS__); \
} catch (...) { \ } catch (...) { \
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \ throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \ __FILE__, __LINE__); \
} \ } \
} while (false) } while (0)
#else #else
#define PADDLE_ENFORCE(...) ::paddle::platform::throw_on_error(__VA_ARGS__); #define __PADDLE_ENFORCE_I(COND, ...) __PADDLE_UNARY_COMPARE(COND, __VA_ARGS__);
#endif // REPLACE_ENFORCE_GLOG #endif // REPLACE_ENFORCE_GLOG
#define __PADDLE_ENFORCE(__args) __PADDLE_ENFORCE_I __args
#define PADDLE_ENFORCE(...) __PADDLE_ENFORCE((__VA_ARGS__))
#define PADDLE_THROW_EOF() \ #define PADDLE_THROW_EOF() \
do { \ do { \
throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \ throw ::paddle::platform::EOFException("There is no next data.", __FILE__, \
......
...@@ -37,6 +37,25 @@ TEST(ENFORCE, FAILED) { ...@@ -37,6 +37,25 @@ TEST(ENFORCE, FAILED) {
HasPrefix(StringPiece(error.what()), "Enforce is not ok 123 at all")); HasPrefix(StringPiece(error.what()), "Enforce is not ok 123 at all"));
} }
EXPECT_TRUE(caught_exception); EXPECT_TRUE(caught_exception);
caught_exception = false;
try {
PADDLE_ENFORCE(false, "Enforce is not ok at all");
} catch (paddle::platform::EnforceNotMet error) {
caught_exception = true;
EXPECT_TRUE(
HasPrefix(StringPiece(error.what()), "Enforce is not ok at all"));
}
EXPECT_TRUE(caught_exception);
caught_exception = false;
try {
PADDLE_ENFORCE(false);
} catch (paddle::platform::EnforceNotMet error) {
caught_exception = true;
EXPECT_NE(std::string(error.what()).find(" at "), 0);
}
EXPECT_TRUE(caught_exception);
} }
TEST(ENFORCE, NO_ARG_OK) { TEST(ENFORCE, NO_ARG_OK) {
......
...@@ -87,7 +87,7 @@ void Fprintf(std::ostream& out, const char* fmt, const Args&... args) { ...@@ -87,7 +87,7 @@ void Fprintf(std::ostream& out, const char* fmt, const Args&... args) {
template <typename... Args> template <typename... Args>
std::string Sprintf(const Args&... args) { std::string Sprintf(const Args&... args) {
std::ostringstream oss; std::ostringstream oss;
Fprintf(oss, ""); Fprintf(oss, "%s", args...);
return oss.str(); return oss.str();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册