提交 4cb84f33 编写于 作者: P Peng LI 提交者: GitHub

Merge pull request #1432 from pengli09/fix-warpctc-with-double-precision-bug

Fix macro bug in hl_warpctc_warp.cc to support double precision
......@@ -54,22 +54,26 @@ DYNAMIC_LOAD_WARPCTC_WRAP(get_workspace_size)
#define WARPCTC_GET_VERSION dynload::get_warpctc_version
#define WARPCTC_GET_STATUS_STRING dynload::ctcGetStatusString
static int g_warpctcVersion = -1;
#ifndef PADDLE_TYPE_DOUBLE
#define WARPCTC_COMPUTE_LOSS dynload::compute_ctc_loss
#define WARPCTC_GET_WORKSPACE_SIZE dynload::get_workspace_size
#else
#define WARPCTC_LOG_FATAL \
LOG(FATAL) << "warp-ctc [version " << g_warpctcVersion \
<< "] Error: not support double precision."
#define WARPCTC_COMPUTE_LOSS(...) WARPCTC_LOG_FATAL(__VA_ARGS__)
#define WARPCTC_GET_WORKSPACE_SIZE(...) WARPCTC_LOG_FATAL(__VA_ARGS__)
hl_warpctc_status_t fatal(...) {
LOG(FATAL) << "warp-ctc [version " << g_warpctcVersion
<< "] Error: not support double precision.";
// both of get_warpctc_version() and get_workspace_size() return an ctcStatus
// type value
return CTC_STATUS_EXECUTION_FAILED;
}
#define WARPCTC_COMPUTE_LOSS fatal
#define WARPCTC_GET_WORKSPACE_SIZE fatal
#endif
/**
* Check build-in warp-ctc function using glog and it also
* support << operator for more details error info.
*/
static int g_warpctcVersion = -1;
#define CHECK_WARPCTC(warpctcStat) \
CHECK_EQ(CTC_STATUS_SUCCESS, warpctcStat) \
<< "warp-ctc [version " << g_warpctcVersion \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册