From db5e74ab95773497d4c8a24be21f96270df79f38 Mon Sep 17 00:00:00 2001 From: chengduo Date: Wed, 8 May 2019 19:25:28 +0800 Subject: [PATCH] update assert (#17282) test=develop --- paddle/fluid/operators/cross_entropy_op.h | 2 + paddle/fluid/operators/lookup_table_op.cu | 8 ++-- paddle/fluid/platform/assert.h | 54 ++++++++++------------- 3 files changed, 29 insertions(+), 35 deletions(-) diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index 1d625579052..89bacfc33ed 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -154,6 +154,8 @@ struct HardLabelCrossEntropyForwardFunctor { HOSTDEVICE void operator()(int64_t idx) const { auto label = label_[idx]; + PADDLE_ASSERT_MSG(label >= 0 && label < feature_size_, + "The label is out of the range.", label); if (label != ignore_index_) { auto match_x = x_[idx * feature_size_ + label]; y_[idx] = -math::TolerableValue()(real_log(match_x)); diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index a863af4af91..8716662f158 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -32,8 +32,8 @@ __global__ void LookupTable(T *output, const T *table, const int64_t *ids, while (idy < K) { int64_t id = ids[idy]; - PADDLE_ASSERT_MSG_CODE(id >= 0, "received id:", id); - PADDLE_ASSERT_MSG_CODE(id < N, "received id:", id); + PADDLE_ASSERT_MSG(id >= 0, "received id:", id); + PADDLE_ASSERT_MSG(id < N, "received id:", id); T *out = output + idy * D; const T *tab = table + id * D; for (int i = idx; i < D; i += BlockDimX) { @@ -59,8 +59,8 @@ __global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids, while (idy < K) { int64_t id = ids[idy]; - PADDLE_ASSERT_MSG_CODE(id >= 0, "received id:", id); - PADDLE_ASSERT_MSG_CODE(id < N, "received id:", id); + PADDLE_ASSERT_MSG(id >= 0, "received id:", id); + PADDLE_ASSERT_MSG(id < N, "received id:", id); const T *out = output + idy * D; T *tab = table + id * D; for (int i = idx; i < D; i += BlockDimX) { diff --git a/paddle/fluid/platform/assert.h b/paddle/fluid/platform/assert.h index 497c7b3c87f..e3884a985e0 100644 --- a/paddle/fluid/platform/assert.h +++ b/paddle/fluid/platform/assert.h @@ -17,40 +17,32 @@ limitations under the License. */ #define STRINGIFY(x) #x #define TOSTRING(x) STRINGIFY(x) +// For cuda, the assertions can affect performance and it is therefore +// recommended to disable them in production code +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion #if defined(__CUDA_ARCH__) #include -#define PADDLE_ASSERT(e) \ - do { \ - if (!(e)) { \ - printf("%s:%d Assertion `%s` failed.\n", __FILE__, __LINE__, \ - TOSTRING(e)); \ - asm("trap;"); \ - } \ - } while (0) +#define EXIT() asm("trap;") +#else +#include +#define EXIT() throw std::runtime_error("Exception encounter.") +#endif -#define PADDLE_ASSERT_MSG(e, m) \ - do { \ - if (!(e)) { \ - printf("%s:%d Assertion `%s` failed (%s).\n", __FILE__, __LINE__, \ - TOSTRING(e), m); \ - asm("trap;"); \ - } \ +#define PADDLE_ASSERT(_IS_NOT_ERROR) \ + do { \ + if (!(_IS_NOT_ERROR)) { \ + printf("Exception: %s:%d Assertion `%s` failed.\n", __FILE__, __LINE__, \ + TOSTRING(_IS_NOT_ERROR)); \ + EXIT(); \ + } \ } while (0) -#define PADDLE_ASSERT_MSG_CODE(e, m, c) \ - do { \ - if (!(e)) { \ - printf("%s:%d Assertion `%s` failed (%s %ld).\n", __FILE__, __LINE__, \ - TOSTRING(e), m, c); \ - asm("trap;"); \ - } \ +// NOTE: PADDLE_ASSERT is mainly used in CUDA Kernel or HOSTDEVICE function. +#define PADDLE_ASSERT_MSG(_IS_NOT_ERROR, __MSG, __VAL) \ + do { \ + if (!(_IS_NOT_ERROR)) { \ + printf("Exception: %s:%d Assertion `%s` failed (%s %ld).\n", __FILE__, \ + __LINE__, TOSTRING(_IS_NOT_ERROR), __MSG, __VAL); \ + EXIT(); \ + } \ } while (0) -#else -#include -// For cuda, the assertions can affect performance and it is therefore -// recommended to disable them in production code -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion -#define PADDLE_ASSERT(e) assert((e)) -#define PADDLE_ASSERT_MSG(e, m) assert((e) && (m)) -#define PADDLE_ASSERT_MSG_CODE(e, m, c) assert((e) && (m) && (c || 1)) -#endif -- GitLab