未验证 提交 db5e74ab 编写于 作者: C chengduo 提交者: GitHub

update assert (#17282)

test=develop
上级 c3195de5
......@@ -154,6 +154,8 @@ struct HardLabelCrossEntropyForwardFunctor {
HOSTDEVICE void operator()(int64_t idx) const {
auto label = label_[idx];
PADDLE_ASSERT_MSG(label >= 0 && label < feature_size_,
"The label is out of the range.", label);
if (label != ignore_index_) {
auto match_x = x_[idx * feature_size_ + label];
y_[idx] = -math::TolerableValue<T>()(real_log(match_x));
......
......@@ -32,8 +32,8 @@ __global__ void LookupTable(T *output, const T *table, const int64_t *ids,
while (idy < K) {
int64_t id = ids[idy];
PADDLE_ASSERT_MSG_CODE(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG_CODE(id < N, "received id:", id);
PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG(id < N, "received id:", id);
T *out = output + idy * D;
const T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
......@@ -59,8 +59,8 @@ __global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
while (idy < K) {
int64_t id = ids[idy];
PADDLE_ASSERT_MSG_CODE(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG_CODE(id < N, "received id:", id);
PADDLE_ASSERT_MSG(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG(id < N, "received id:", id);
const T *out = output + idy * D;
T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
......
......@@ -17,40 +17,32 @@ limitations under the License. */
#define STRINGIFY(x) #x
#define TOSTRING(x) STRINGIFY(x)
// For cuda, the assertions can affect performance and it is therefore
// recommended to disable them in production code
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
#if defined(__CUDA_ARCH__)
#include <stdio.h>
#define PADDLE_ASSERT(e) \
do { \
if (!(e)) { \
printf("%s:%d Assertion `%s` failed.\n", __FILE__, __LINE__, \
TOSTRING(e)); \
asm("trap;"); \
} \
} while (0)
#define EXIT() asm("trap;")
#else
#include <assert.h>
#define EXIT() throw std::runtime_error("Exception encounter.")
#endif
#define PADDLE_ASSERT_MSG(e, m) \
#define PADDLE_ASSERT(_IS_NOT_ERROR) \
do { \
if (!(e)) { \
printf("%s:%d Assertion `%s` failed (%s).\n", __FILE__, __LINE__, \
TOSTRING(e), m); \
asm("trap;"); \
if (!(_IS_NOT_ERROR)) { \
printf("Exception: %s:%d Assertion `%s` failed.\n", __FILE__, __LINE__, \
TOSTRING(_IS_NOT_ERROR)); \
EXIT(); \
} \
} while (0)
#define PADDLE_ASSERT_MSG_CODE(e, m, c) \
// NOTE: PADDLE_ASSERT is mainly used in CUDA Kernel or HOSTDEVICE function.
#define PADDLE_ASSERT_MSG(_IS_NOT_ERROR, __MSG, __VAL) \
do { \
if (!(e)) { \
printf("%s:%d Assertion `%s` failed (%s %ld).\n", __FILE__, __LINE__, \
TOSTRING(e), m, c); \
asm("trap;"); \
if (!(_IS_NOT_ERROR)) { \
printf("Exception: %s:%d Assertion `%s` failed (%s %ld).\n", __FILE__, \
__LINE__, TOSTRING(_IS_NOT_ERROR), __MSG, __VAL); \
EXIT(); \
} \
} while (0)
#else
#include <assert.h>
// For cuda, the assertions can affect performance and it is therefore
// recommended to disable them in production code
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
#define PADDLE_ASSERT(e) assert((e))
#define PADDLE_ASSERT_MSG(e, m) assert((e) && (m))
#define PADDLE_ASSERT_MSG_CODE(e, m, c) assert((e) && (m) && (c || 1))
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册