未验证 提交 c7382df8 编写于 作者: Y Yibing Liu 提交者: GitHub

Print assert failure id in lookup_table_op (#14698)

上级 566a3259
......@@ -31,8 +31,8 @@ __global__ void LookupTable(T *output, const T *table, const int64_t *ids,
while (idy < K) {
int64_t id = ids[idy];
PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N);
PADDLE_ASSERT_MSG_CODE(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG_CODE(id < N, "received id:", id);
T *out = output + idy * D;
const T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
......@@ -57,9 +57,9 @@ __global__ void LookupTableGrad(T *table, const T *output, const int64_t *ids,
int idy = blockIdx.x + threadIdx.y * GridDimX;
while (idy < K) {
int id = ids[idy];
PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N);
int64_t id = ids[idy];
PADDLE_ASSERT_MSG_CODE(id >= 0, "received id:", id);
PADDLE_ASSERT_MSG_CODE(id < N, "received id:", id);
const T *out = output + idy * D;
T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
......
......@@ -36,6 +36,15 @@ limitations under the License. */
asm("trap;"); \
} \
} while (0)
#define PADDLE_ASSERT_MSG_CODE(e, m, c) \
do { \
if (!(e)) { \
printf("%s:%d Assertion `%s` failed (%s %d).\n", __FILE__, __LINE__, \
TOSTRING(e), m, c); \
asm("trap;"); \
} \
} while (0)
#else
#include <assert.h>
// For cuda, the assertions can affect performance and it is therefore
......@@ -43,4 +52,5 @@ limitations under the License. */
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#assertion
#define PADDLE_ASSERT(e) assert((e))
#define PADDLE_ASSERT_MSG(e, m) assert((e) && (m))
#define PADDLE_ASSERT_MSG_CODE(e, m, c) assert((e) && (m) && (c || 1))
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册