未验证 提交 db5e74ab 编写于 作者: C chengduo 提交者: GitHub

update assert (#17282)

test=develop
上级 c3195de5
...@@ -154,6 +154,8 @@ struct HardLabelCrossEntropyForwardFunctor { ...@@ -154,6 +154,8 @@ struct HardLabelCrossEntropyForwardFunctor {
HOSTDEVICE void operator()(int64_t idx) const { HOSTDEVICE void operator()(int64_t idx) const {
auto label = label_[idx]; auto label = label_[idx];
PADDLE_ASSERT_MSG(label >= 0 && label < feature_size_,
"The label is out of the range.", label);
if (label != ignore_index_) { if (label != ignore_index_) {
auto match_x = x_[idx * feature_size_ + label]; auto match_x = x_[idx * feature_size_ + label];
y_[idx] = -math::TolerableValue<T>()(real_log(match_x)); y_[idx] = -math::TolerableValue<T>()(real_log(match_x));
......
...@@ -32,8 +32,8 @@ __global__ void LookupTable(T *output, const T *table, const int64_t *ids, ...@@ -32,8 +32,8 @@ __global__ void LookupTable(T *output, const T *table, const int64_t *ids,
while (idy < K) { while (idy < K) {
int64_t id = ids[idy]; int64_t id = ids[idy];
PADDLE_ASSERT_MSG_CODE(id >= 0, "received id:", id); PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG_CODE(id < N, "received id:", id); PADDLE_ASSERT_MSG(id < N, "received id:", id);
T *out = output + idy * D; T *out = output + idy * D;
const T *tab = table + id * D; const T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) { for (int i = idx; i < D; i += BlockDimX) {
...@@ -59,8 +59,8 @@ __global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids, ...@@ -59,8 +59,8 @@ __global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
while (idy < K) { while (idy < K) {
int64_t id = ids[idy]; int64_t id = ids[idy];
PADDLE_ASSERT_MSG_CODE(id >= 0, "received id:", id); PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG_CODE(id < N, "received id:", id); PADDLE_ASSERT_MSG(id < N, "received id:", id);
const T *out = output + idy * D; const T *out = output + idy * D;
T *tab = table + id * D; T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) { for (int i = idx; i < D; i += BlockDimX) {
......
...@@ -17,40 +17,32 @@ limitations under the License. */ ...@@ -17,40 +17,32 @@ limitations under the License. */
#define STRINGIFY(x) #x #define STRINGIFY(x) #x
#define TOSTRING(x) STRINGIFY(x) #define TOSTRING(x) STRINGIFY(x)
// For cuda, the assertions can affect performance and it is therefore
// recommended to disable them in production code
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
#include <stdio.h> #include <stdio.h>
#define PADDLE_ASSERT(e) \ #define EXIT() asm("trap;")
do { \ #else
if (!(e)) { \ #include <assert.h>
printf("%s:%d Assertion `%s` failed.\n", __FILE__, __LINE__, \ #define EXIT() throw std::runtime_error("Exception encounter.")
TOSTRING(e)); \ #endif
asm("trap;"); \
} \
} while (0)
#define PADDLE_ASSERT_MSG(e, m) \ #define PADDLE_ASSERT(_IS_NOT_ERROR) \
do { \ do { \
if (!(e)) { \ if (!(_IS_NOT_ERROR)) { \
printf("%s:%d Assertion `%s` failed (%s).\n", __FILE__, __LINE__, \ printf("Exception: %s:%d Assertion `%s` failed.\n", __FILE__, __LINE__, \
TOSTRING(e), m); \ TOSTRING(_IS_NOT_ERROR)); \
asm("trap;"); \ EXIT(); \
} \ } \
} while (0) } while (0)
#define PADDLE_ASSERT_MSG_CODE(e, m, c) \ // NOTE: PADDLE_ASSERT is mainly used in CUDA Kernel or HOSTDEVICE function.
do { \ #define PADDLE_ASSERT_MSG(_IS_NOT_ERROR, __MSG, __VAL) \
if (!(e)) { \ do { \
printf("%s:%d Assertion `%s` failed (%s %ld).\n", __FILE__, __LINE__, \ if (!(_IS_NOT_ERROR)) { \
TOSTRING(e), m, c); \ printf("Exception: %s:%d Assertion `%s` failed (%s %ld).\n", __FILE__, \
asm("trap;"); \ __LINE__, TOSTRING(_IS_NOT_ERROR), __MSG, __VAL); \
} \ EXIT(); \
} \
} while (0) } while (0)
#else
#include <assert.h>
// For cuda, the assertions can affect performance and it is therefore
// recommended to disable them in production code
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
#define PADDLE_ASSERT(e) assert((e))
#define PADDLE_ASSERT_MSG(e, m) assert((e) && (m))
#define PADDLE_ASSERT_MSG_CODE(e, m, c) assert((e) && (m) && (c || 1))
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册